diff --git a/TP2/includes/mybt.h b/TP2/includes/mybt.h index f0c7590..ad82b2c 100644 --- a/TP2/includes/mybt.h +++ b/TP2/includes/mybt.h @@ -101,7 +101,7 @@ struct bt_t { long long int mkH3(); // MCTS - void mcts_selection(bt_node_t *); + bt_node_t *mcts_selection(bt_node_t *); void mcts_expansion(bt_node_t *); void mcts_simulation(); void mcts_back_propagation(); diff --git a/TP2/src/mybt.cpp b/TP2/src/mybt.cpp index 5fc00ca..5290415 100644 --- a/TP2/src/mybt.cpp +++ b/TP2/src/mybt.cpp @@ -193,7 +193,27 @@ bt_move_t bt_t::get_rand_move() { return moves[r]; } -void bt_t::mcts_selection(bt_node_t *node) { (void)node; } +bt_node_t *bt_t::mcts_selection(bt_node_t *node) { + if (node->nb_simulation) { + bt_node_t *the_chosen = NULL; + float score = 0.; + for (auto i : node->children) { + // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation + float curr_score = + static_cast(i->wins) / i->nb_simulation + + sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation); + if (curr_score > score) { + the_chosen = i; + score = curr_score; + } + } + // Return the element with the higher score + return the_chosen; + } else { + // Return the first element if no simulation already done + return node->children[0]; + } +} void bt_t::mcts_expansion(bt_node_t *root) { update_moves(); @@ -234,15 +254,19 @@ bt_move_t bt_t::get_mcts_move(double max_time) { // MCTS while (elapsed.count() < max_time) { // Selection + auto xd = b_copy.mcts_selection(tree); std::cerr << "selection\n"; // Expansion + // mcts_expansion(selected); std::cerr << "expansion\n"; // Simulation + // bool is_win = mcts_simulation(selected); std::cerr << "simulation\n"; // Update + // mcts_back_propagation(selected, is_win); std::cerr << "update\n"; // Time constraint