add simulation and backpropagation, some cleanup
This commit is contained in:
parent
91e6af8998
commit
def372fcc6
2 changed files with 37 additions and 17 deletions
|
@ -103,8 +103,8 @@ struct bt_t {
|
||||||
// MCTS
|
// MCTS
|
||||||
bt_node_t *mcts_selection(bt_node_t *);
|
bt_node_t *mcts_selection(bt_node_t *);
|
||||||
void mcts_expansion(bt_node_t *);
|
void mcts_expansion(bt_node_t *);
|
||||||
void mcts_simulation();
|
bool mcts_simulation(bt_node_t *);
|
||||||
void mcts_back_propagation();
|
void mcts_back_propagation(bt_node_t *, bool);
|
||||||
|
|
||||||
// déclarées mais non définies
|
// déclarées mais non définies
|
||||||
double eval();
|
double eval();
|
||||||
|
|
|
@ -195,7 +195,7 @@ bt_move_t bt_t::get_rand_move() {
|
||||||
|
|
||||||
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
||||||
if (node->nb_simulation) {
|
if (node->nb_simulation) {
|
||||||
bt_node_t *the_chosen = NULL;
|
bt_node_t *the_chosen = nullptr;
|
||||||
float score = 0.;
|
float score = 0.;
|
||||||
for (auto i : node->children) {
|
for (auto i : node->children) {
|
||||||
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||||
|
@ -231,22 +231,43 @@ void bt_t::mcts_expansion(bt_node_t *root) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void bt_t::mcts_simulation(void) {}
|
bool bt_t::mcts_simulation(bt_node_t *node) {
|
||||||
|
bt_t b_copy = *this;
|
||||||
|
|
||||||
void bt_t::mcts_back_propagation(void) {}
|
int me = (b_copy.turn % 2 == 0) ? WHITE : BLACK;
|
||||||
|
|
||||||
|
// try to play my move
|
||||||
|
b_copy.play(node->move);
|
||||||
|
|
||||||
|
// then play randomly 'til the game is over
|
||||||
|
while (b_copy.endgame() != EMPTY) {
|
||||||
|
b_copy.play(b_copy.get_rand_move());
|
||||||
|
}
|
||||||
|
|
||||||
|
// if i won
|
||||||
|
return ((b_copy.turn % 2 == 0) ? WHITE : BLACK == me) ? true : false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
|
||||||
|
// propagate values to the top of the tree
|
||||||
|
while (simulated->parent != nullptr) {
|
||||||
|
if (won) {
|
||||||
|
simulated->wins++;
|
||||||
|
}
|
||||||
|
simulated->nb_simulation++;
|
||||||
|
simulated = simulated->parent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bt_move_t bt_t::get_mcts_move(double max_time) {
|
bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||||
// Init tree
|
// Init tree
|
||||||
bt_node_t *tree = new bt_node_t();
|
bt_node_t *tree = new bt_node_t();
|
||||||
tree->parent = NULL;
|
tree->parent = nullptr;
|
||||||
tree->children = {};
|
tree->children = {};
|
||||||
tree->wins = 0;
|
tree->wins = 0;
|
||||||
tree->nb_simulation = 0;
|
tree->nb_simulation = 0;
|
||||||
mcts_expansion(tree);
|
mcts_expansion(tree);
|
||||||
|
|
||||||
// Copy board
|
|
||||||
bt_t b_copy = *this;
|
|
||||||
|
|
||||||
// Time constraint
|
// Time constraint
|
||||||
auto now = std::chrono::steady_clock::now();
|
auto now = std::chrono::steady_clock::now();
|
||||||
std::chrono::duration<double> elapsed{};
|
std::chrono::duration<double> elapsed{};
|
||||||
|
@ -254,31 +275,30 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||||
// MCTS
|
// MCTS
|
||||||
while (elapsed.count() < max_time) {
|
while (elapsed.count() < max_time) {
|
||||||
// Selection
|
// Selection
|
||||||
auto selected = b_copy.mcts_selection(tree);
|
auto selected = mcts_selection(tree);
|
||||||
|
|
||||||
// Expansion
|
// Expansion
|
||||||
b_copy.mcts_expansion(selected);
|
mcts_expansion(selected);
|
||||||
|
|
||||||
// Simulation
|
// Simulation
|
||||||
// bool is_win = mcts_simulation(selected);
|
bool is_win = mcts_simulation(selected);
|
||||||
|
|
||||||
// Update
|
// Update
|
||||||
// mcts_back_propagation(selected, is_win);
|
mcts_back_propagation(selected, is_win);
|
||||||
|
|
||||||
// Time constraint
|
// Update elapsed time
|
||||||
elapsed = std::chrono::steady_clock::now() - now;
|
elapsed = std::chrono::steady_clock::now() - now;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select best move
|
// Select best move
|
||||||
/* bt_node_t *best_node = NULL;
|
bt_node_t *best_node = nullptr;
|
||||||
int best_score = -1;
|
int best_score = -1;
|
||||||
for (auto i : tree->children) {
|
for (auto i : tree->children) {
|
||||||
if (i->nb_simulation > best_score) {
|
if (i->nb_simulation > best_score) {
|
||||||
best_node = i;
|
best_node = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return best_node->move; */
|
return best_node->move;
|
||||||
return moves[(static_cast<int>(rand())) % nb_moves]; // TMP
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bt_t::can_play(bt_move_t _m) {
|
bool bt_t::can_play(bt_move_t _m) {
|
||||||
|
|
Reference in a new issue