* rework of selection

* don't override nbmove to 0
* fix typo in simulation not playing
* fix win verification
* fix backpropagation
* use iterator instead of time for algo limitation (MEMORY LEAK?)
* remove debug prints
This commit is contained in:
Mylloon 2023-03-31 16:39:27 +02:00
parent b5c53228fd
commit 343a6e8cf0
Signed by: Anri
GPG key ID: A82D63DFF8D1317F
2 changed files with 46 additions and 54 deletions

View file

@ -1,7 +1,6 @@
#ifndef MYBT_H #ifndef MYBT_H
#define MYBT_H #define MYBT_H
#include <chrono>
#include <random> #include <random>
#define WHITE 0 #define WHITE 0
@ -103,8 +102,8 @@ struct bt_t {
// MCTS // MCTS
bt_node_t *mcts_selection(bt_node_t *); bt_node_t *mcts_selection(bt_node_t *);
void mcts_expansion(bt_node_t *); void mcts_expansion(bt_node_t *);
bool mcts_simulation(bt_node_t *); bool mcts_simulation();
void mcts_back_propagation(bt_node_t *, bool); static void mcts_back_propagation(bt_node_t *, bool);
void add_move(int _li, int _ci, int _lf, int _cf) { void add_move(int _li, int _ci, int _lf, int _cf) {
moves[nb_moves].line_i = _li; moves[nb_moves].line_i = _li;

View file

@ -193,29 +193,39 @@ bt_move_t bt_t::get_rand_move() {
} }
bt_node_t *bt_t::mcts_selection(bt_node_t *node) { bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
if (node->nb_simulation > 0) { // Play the move is game over or first call
if ((endgame() != EMPTY) ||
((node->nb_simulation == 0) && (node->parent != nullptr))) {
play(node->move);
return node;
}
bt_node_t *best = nullptr; bt_node_t *best = nullptr;
float score = std::numeric_limits<float>::lowest(); float score = std::numeric_limits<float>::lowest();
for (auto i : node->children) { for (auto i : node->children) {
// Play the move if never played
if (i->nb_simulation == 0) {
play(i->move);
return i;
}
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
float curr_score = float curr_score =
static_cast<float>(i->wins) / i->nb_simulation + (static_cast<float>(i->wins) / i->nb_simulation) +
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation); ((sqrt(2) * sqrt(log(node->nb_simulation)) / i->nb_simulation));
if (curr_score > score) { if (curr_score > score) {
best = i; best = i;
score = curr_score; score = curr_score;
} }
} }
// Return the element with the higher score
return best; play(best->move);
} else { return mcts_selection(best);
// Return random element if no simulation already done
return node
->children[(static_cast<int>(rand())) % (node->children.size() - 1)];
}
} }
void bt_t::mcts_expansion(bt_node_t *root) { void bt_t::mcts_expansion(bt_node_t *node) {
update_moves(); update_moves();
for (int i = 0; i < nb_moves; i++) { for (int i = 0; i < nb_moves; i++) {
@ -223,32 +233,29 @@ void bt_t::mcts_expansion(bt_node_t *root) {
bt_node_t *tmp = new bt_node_t; bt_node_t *tmp = new bt_node_t;
tmp->move = moves[i]; tmp->move = moves[i];
tmp->wins = 0; tmp->wins = 0;
tmp->parent = root; tmp->parent = node;
tmp->nb_simulation = 0; tmp->nb_simulation = 0;
// Add child to root // Add child to selected node
root->children.push_back(tmp); node->children.push_back(tmp);
}
} }
// Avoid re-add same moves bool bt_t::mcts_simulation(void) {
nb_moves = 0;
}
bool bt_t::mcts_simulation(bt_node_t *node) {
int me = (turn % 2 == 0) ? WHITE : BLACK; int me = (turn % 2 == 0) ? WHITE : BLACK;
// then play randomly 'til the game is over // then play randomly 'til the game is over
while (endgame() != EMPTY) { while (endgame() == EMPTY) {
play(get_rand_move()); play(get_rand_move());
} }
// if i won // if i won
return ((turn % 2 == 0) ? WHITE : BLACK == me) ? true : false; return ((turn % 2 == 0) ? WHITE : BLACK) == me ? true : false;
} }
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) { void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
// propagate values to the top of the tree // propagate values to the top of the tree
while (simulated->parent != nullptr) { while (simulated != nullptr) {
if (won) { if (won) {
simulated->wins++; simulated->wins++;
} }
@ -276,40 +283,26 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
tree->children = {}; tree->children = {};
tree->wins = 0; tree->wins = 0;
tree->nb_simulation = 0; tree->nb_simulation = 0;
print_vec(*this, tree->children);
mcts_expansion(tree);
print_vec(*this, tree->children);
// Time constraint // On populise l'arbre avec les entrées de départ
auto now = std::chrono::steady_clock::now(); mcts_expansion(tree);
std::chrono::duration<double> elapsed{};
// MCTS // MCTS
while (elapsed.count() < max_time) { for (int it = 0; it < 200; ++it) {
// Copy board // Copy board
bt_t copy_b = *this; bt_t copy_b = *this;
// Selection // Selection
bt_node_t *selected = copy_b.mcts_selection(tree); bt_node_t *result = copy_b.mcts_selection(tree);
// Play the move
if (can_play(selected->move)) {
play(selected->move);
} else {
fprintf(stderr, "?????\n");
}
// Expansion // Expansion
copy_b.mcts_expansion(selected); copy_b.mcts_expansion(result);
// Simulation // Simulation
bool is_win = copy_b.mcts_simulation(selected); bool is_win = copy_b.mcts_simulation();
// Update // Update
mcts_back_propagation(selected, is_win); mcts_back_propagation(result, is_win);
// Update elapsed time
elapsed = std::chrono::steady_clock::now() - now;
} }
// Select best move // Select best move