add selection

This commit is contained in:
Mylloon 2023-03-29 17:56:34 +02:00
parent ffed7158f3
commit 881c609f74
Signed by: Anri
GPG key ID: A82D63DFF8D1317F
2 changed files with 26 additions and 2 deletions

View file

@ -101,7 +101,7 @@ struct bt_t {
long long int mkH3(); long long int mkH3();
// MCTS // MCTS
void mcts_selection(bt_node_t *); bt_node_t *mcts_selection(bt_node_t *);
void mcts_expansion(bt_node_t *); void mcts_expansion(bt_node_t *);
void mcts_simulation(); void mcts_simulation();
void mcts_back_propagation(); void mcts_back_propagation();

View file

@ -193,7 +193,27 @@ bt_move_t bt_t::get_rand_move() {
return moves[r]; return moves[r];
} }
void bt_t::mcts_selection(bt_node_t *node) { (void)node; } bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
if (node->nb_simulation) {
bt_node_t *the_chosen = NULL;
float score = 0.;
for (auto i : node->children) {
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
float curr_score =
static_cast<float>(i->wins) / i->nb_simulation +
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
if (curr_score > score) {
the_chosen = i;
score = curr_score;
}
}
// Return the element with the higher score
return the_chosen;
} else {
// Return the first element if no simulation already done
return node->children[0];
}
}
void bt_t::mcts_expansion(bt_node_t *root) { void bt_t::mcts_expansion(bt_node_t *root) {
update_moves(); update_moves();
@ -234,15 +254,19 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
// MCTS // MCTS
while (elapsed.count() < max_time) { while (elapsed.count() < max_time) {
// Selection // Selection
auto xd = b_copy.mcts_selection(tree);
std::cerr << "selection\n"; std::cerr << "selection\n";
// Expansion // Expansion
// mcts_expansion(selected);
std::cerr << "expansion\n"; std::cerr << "expansion\n";
// Simulation // Simulation
// bool is_win = mcts_simulation(selected);
std::cerr << "simulation\n"; std::cerr << "simulation\n";
// Update // Update
// mcts_back_propagation(selected, is_win);
std::cerr << "update\n"; std::cerr << "update\n";
// Time constraint // Time constraint