add simulation and backpropagation, some cleanup

This commit is contained in:
Mylloon 2023-03-29 18:25:39 +02:00
parent 91e6af8998
commit def372fcc6
Signed by: Anri
GPG key ID: A82D63DFF8D1317F
2 changed files with 37 additions and 17 deletions

View file

@ -103,8 +103,8 @@ struct bt_t {
// MCTS // MCTS
bt_node_t *mcts_selection(bt_node_t *); bt_node_t *mcts_selection(bt_node_t *);
void mcts_expansion(bt_node_t *); void mcts_expansion(bt_node_t *);
void mcts_simulation(); bool mcts_simulation(bt_node_t *);
void mcts_back_propagation(); void mcts_back_propagation(bt_node_t *, bool);
// déclarées mais non définies // déclarées mais non définies
double eval(); double eval();

View file

@ -195,7 +195,7 @@ bt_move_t bt_t::get_rand_move() {
bt_node_t *bt_t::mcts_selection(bt_node_t *node) { bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
if (node->nb_simulation) { if (node->nb_simulation) {
bt_node_t *the_chosen = NULL; bt_node_t *the_chosen = nullptr;
float score = 0.; float score = 0.;
for (auto i : node->children) { for (auto i : node->children) {
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
@ -231,22 +231,43 @@ void bt_t::mcts_expansion(bt_node_t *root) {
} }
} }
void bt_t::mcts_simulation(void) {} bool bt_t::mcts_simulation(bt_node_t *node) {
bt_t b_copy = *this;
void bt_t::mcts_back_propagation(void) {} int me = (b_copy.turn % 2 == 0) ? WHITE : BLACK;
// try to play my move
b_copy.play(node->move);
// then play randomly 'til the game is over
while (b_copy.endgame() != EMPTY) {
b_copy.play(b_copy.get_rand_move());
}
// if i won
return ((b_copy.turn % 2 == 0) ? WHITE : BLACK == me) ? true : false;
}
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
// propagate values to the top of the tree
while (simulated->parent != nullptr) {
if (won) {
simulated->wins++;
}
simulated->nb_simulation++;
simulated = simulated->parent;
}
}
bt_move_t bt_t::get_mcts_move(double max_time) { bt_move_t bt_t::get_mcts_move(double max_time) {
// Init tree // Init tree
bt_node_t *tree = new bt_node_t(); bt_node_t *tree = new bt_node_t();
tree->parent = NULL; tree->parent = nullptr;
tree->children = {}; tree->children = {};
tree->wins = 0; tree->wins = 0;
tree->nb_simulation = 0; tree->nb_simulation = 0;
mcts_expansion(tree); mcts_expansion(tree);
// Copy board
bt_t b_copy = *this;
// Time constraint // Time constraint
auto now = std::chrono::steady_clock::now(); auto now = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed{}; std::chrono::duration<double> elapsed{};
@ -254,31 +275,30 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
// MCTS // MCTS
while (elapsed.count() < max_time) { while (elapsed.count() < max_time) {
// Selection // Selection
auto selected = b_copy.mcts_selection(tree); auto selected = mcts_selection(tree);
// Expansion // Expansion
b_copy.mcts_expansion(selected); mcts_expansion(selected);
// Simulation // Simulation
// bool is_win = mcts_simulation(selected); bool is_win = mcts_simulation(selected);
// Update // Update
// mcts_back_propagation(selected, is_win); mcts_back_propagation(selected, is_win);
// Time constraint // Update elapsed time
elapsed = std::chrono::steady_clock::now() - now; elapsed = std::chrono::steady_clock::now() - now;
} }
// Select best move // Select best move
/* bt_node_t *best_node = NULL; bt_node_t *best_node = nullptr;
int best_score = -1; int best_score = -1;
for (auto i : tree->children) { for (auto i : tree->children) {
if (i->nb_simulation > best_score) { if (i->nb_simulation > best_score) {
best_node = i; best_node = i;
} }
} }
return best_node->move; */ return best_node->move;
return moves[(static_cast<int>(rand())) % nb_moves]; // TMP
} }
bool bt_t::can_play(bt_move_t _m) { bool bt_t::can_play(bt_move_t _m) {