From 343a6e8cf000b57d1e571da6d58689a13f6abc8f Mon Sep 17 00:00:00 2001 From: Mylloon Date: Fri, 31 Mar 2023 16:39:27 +0200 Subject: [PATCH] * rework of selection * don't override nbmove to 0 * fix typo in simulation not playing * fix win verification * fix backpropagation * use iterator instead of time for algo limitation (MEMORY LEAK?) * remove debug prints --- TP2/includes/mybt.h | 5 +-- TP2/src/mybt.cpp | 95 +++++++++++++++++++++------------------------ 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/TP2/includes/mybt.h b/TP2/includes/mybt.h index aec3449..7202ff0 100644 --- a/TP2/includes/mybt.h +++ b/TP2/includes/mybt.h @@ -1,7 +1,6 @@ #ifndef MYBT_H #define MYBT_H -#include #include #define WHITE 0 @@ -103,8 +102,8 @@ struct bt_t { // MCTS bt_node_t *mcts_selection(bt_node_t *); void mcts_expansion(bt_node_t *); - bool mcts_simulation(bt_node_t *); - void mcts_back_propagation(bt_node_t *, bool); + bool mcts_simulation(); + static void mcts_back_propagation(bt_node_t *, bool); void add_move(int _li, int _ci, int _lf, int _cf) { moves[nb_moves].line_i = _li; diff --git a/TP2/src/mybt.cpp b/TP2/src/mybt.cpp index 032c653..848a1e2 100644 --- a/TP2/src/mybt.cpp +++ b/TP2/src/mybt.cpp @@ -193,29 +193,39 @@ bt_move_t bt_t::get_rand_move() { } bt_node_t *bt_t::mcts_selection(bt_node_t *node) { - if (node->nb_simulation > 0) { - bt_node_t *best = nullptr; - float score = std::numeric_limits::lowest(); - for (auto i : node->children) { - // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation - float curr_score = - static_cast(i->wins) / i->nb_simulation + - sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation); - if (curr_score > score) { - best = i; - score = curr_score; - } - } - // Return the element with the higher score - return best; - } else { - // Return random element if no simulation already done - return node - ->children[(static_cast(rand())) % (node->children.size() - 1)]; + // Play the move is game over or first call + if ((endgame() != EMPTY) || + ((node->nb_simulation == 0) && (node->parent != nullptr))) { + play(node->move); + return node; } + + bt_node_t *best = nullptr; + float score = std::numeric_limits::lowest(); + for (auto i : node->children) { + // Play the move if never played + if (i->nb_simulation == 0) { + play(i->move); + + return i; + } + + // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation + float curr_score = + (static_cast(i->wins) / i->nb_simulation) + + ((sqrt(2) * sqrt(log(node->nb_simulation)) / i->nb_simulation)); + + if (curr_score > score) { + best = i; + score = curr_score; + } + } + + play(best->move); + return mcts_selection(best); } -void bt_t::mcts_expansion(bt_node_t *root) { +void bt_t::mcts_expansion(bt_node_t *node) { update_moves(); for (int i = 0; i < nb_moves; i++) { @@ -223,32 +233,29 @@ void bt_t::mcts_expansion(bt_node_t *root) { bt_node_t *tmp = new bt_node_t; tmp->move = moves[i]; tmp->wins = 0; - tmp->parent = root; + tmp->parent = node; tmp->nb_simulation = 0; - // Add child to root - root->children.push_back(tmp); + // Add child to selected node + node->children.push_back(tmp); } - - // Avoid re-add same moves - nb_moves = 0; } -bool bt_t::mcts_simulation(bt_node_t *node) { +bool bt_t::mcts_simulation(void) { int me = (turn % 2 == 0) ? WHITE : BLACK; // then play randomly 'til the game is over - while (endgame() != EMPTY) { + while (endgame() == EMPTY) { play(get_rand_move()); } // if i won - return ((turn % 2 == 0) ? WHITE : BLACK == me) ? true : false; + return ((turn % 2 == 0) ? WHITE : BLACK) == me ? true : false; } void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) { // propagate values to the top of the tree - while (simulated->parent != nullptr) { + while (simulated != nullptr) { if (won) { simulated->wins++; } @@ -276,40 +283,26 @@ bt_move_t bt_t::get_mcts_move(double max_time) { tree->children = {}; tree->wins = 0; tree->nb_simulation = 0; - print_vec(*this, tree->children); - mcts_expansion(tree); - print_vec(*this, tree->children); - // Time constraint - auto now = std::chrono::steady_clock::now(); - std::chrono::duration elapsed{}; + // On populise l'arbre avec les entrées de départ + mcts_expansion(tree); // MCTS - while (elapsed.count() < max_time) { + for (int it = 0; it < 200; ++it) { // Copy board bt_t copy_b = *this; // Selection - bt_node_t *selected = copy_b.mcts_selection(tree); - - // Play the move - if (can_play(selected->move)) { - play(selected->move); - } else { - fprintf(stderr, "?????\n"); - } + bt_node_t *result = copy_b.mcts_selection(tree); // Expansion - copy_b.mcts_expansion(selected); + copy_b.mcts_expansion(result); // Simulation - bool is_win = copy_b.mcts_simulation(selected); + bool is_win = copy_b.mcts_simulation(); // Update - mcts_back_propagation(selected, is_win); - - // Update elapsed time - elapsed = std::chrono::steady_clock::now() - now; + mcts_back_propagation(result, is_win); } // Select best move