* rework of selection

* don't override nbmove to 0
* fix typo in simulation not playing
* fix win verification
* fix backpropagation
* use iterator instead of time for algo limitation (MEMORY LEAK?)
* remove debug prints
This commit is contained in:
Mylloon 2023-03-31 16:39:27 +02:00
parent b5c53228fd
commit 343a6e8cf0
Signed by: Anri
GPG key ID: A82D63DFF8D1317F
2 changed files with 46 additions and 54 deletions

View file

@ -1,7 +1,6 @@
#ifndef MYBT_H
#define MYBT_H
#include <chrono>
#include <random>
#define WHITE 0
@ -103,8 +102,8 @@ struct bt_t {
// MCTS
bt_node_t *mcts_selection(bt_node_t *);
void mcts_expansion(bt_node_t *);
bool mcts_simulation(bt_node_t *);
void mcts_back_propagation(bt_node_t *, bool);
bool mcts_simulation();
static void mcts_back_propagation(bt_node_t *, bool);
void add_move(int _li, int _ci, int _lf, int _cf) {
moves[nb_moves].line_i = _li;

View file

@ -193,29 +193,39 @@ bt_move_t bt_t::get_rand_move() {
}
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
if (node->nb_simulation > 0) {
bt_node_t *best = nullptr;
float score = std::numeric_limits<float>::lowest();
for (auto i : node->children) {
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
float curr_score =
static_cast<float>(i->wins) / i->nb_simulation +
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
if (curr_score > score) {
best = i;
score = curr_score;
}
}
// Return the element with the higher score
return best;
} else {
// Return random element if no simulation already done
return node
->children[(static_cast<int>(rand())) % (node->children.size() - 1)];
// Play the move is game over or first call
if ((endgame() != EMPTY) ||
((node->nb_simulation == 0) && (node->parent != nullptr))) {
play(node->move);
return node;
}
bt_node_t *best = nullptr;
float score = std::numeric_limits<float>::lowest();
for (auto i : node->children) {
// Play the move if never played
if (i->nb_simulation == 0) {
play(i->move);
return i;
}
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
float curr_score =
(static_cast<float>(i->wins) / i->nb_simulation) +
((sqrt(2) * sqrt(log(node->nb_simulation)) / i->nb_simulation));
if (curr_score > score) {
best = i;
score = curr_score;
}
}
play(best->move);
return mcts_selection(best);
}
void bt_t::mcts_expansion(bt_node_t *root) {
void bt_t::mcts_expansion(bt_node_t *node) {
update_moves();
for (int i = 0; i < nb_moves; i++) {
@ -223,32 +233,29 @@ void bt_t::mcts_expansion(bt_node_t *root) {
bt_node_t *tmp = new bt_node_t;
tmp->move = moves[i];
tmp->wins = 0;
tmp->parent = root;
tmp->parent = node;
tmp->nb_simulation = 0;
// Add child to root
root->children.push_back(tmp);
// Add child to selected node
node->children.push_back(tmp);
}
// Avoid re-add same moves
nb_moves = 0;
}
bool bt_t::mcts_simulation(bt_node_t *node) {
bool bt_t::mcts_simulation(void) {
int me = (turn % 2 == 0) ? WHITE : BLACK;
// then play randomly 'til the game is over
while (endgame() != EMPTY) {
while (endgame() == EMPTY) {
play(get_rand_move());
}
// if i won
return ((turn % 2 == 0) ? WHITE : BLACK == me) ? true : false;
return ((turn % 2 == 0) ? WHITE : BLACK) == me ? true : false;
}
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
// propagate values to the top of the tree
while (simulated->parent != nullptr) {
while (simulated != nullptr) {
if (won) {
simulated->wins++;
}
@ -276,40 +283,26 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
tree->children = {};
tree->wins = 0;
tree->nb_simulation = 0;
print_vec(*this, tree->children);
mcts_expansion(tree);
print_vec(*this, tree->children);
// Time constraint
auto now = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed{};
// On populise l'arbre avec les entrées de départ
mcts_expansion(tree);
// MCTS
while (elapsed.count() < max_time) {
for (int it = 0; it < 200; ++it) {
// Copy board
bt_t copy_b = *this;
// Selection
bt_node_t *selected = copy_b.mcts_selection(tree);
// Play the move
if (can_play(selected->move)) {
play(selected->move);
} else {
fprintf(stderr, "?????\n");
}
bt_node_t *result = copy_b.mcts_selection(tree);
// Expansion
copy_b.mcts_expansion(selected);
copy_b.mcts_expansion(result);
// Simulation
bool is_win = copy_b.mcts_simulation(selected);
bool is_win = copy_b.mcts_simulation();
// Update
mcts_back_propagation(selected, is_win);
// Update elapsed time
elapsed = std::chrono::steady_clock::now() - now;
mcts_back_propagation(result, is_win);
}
// Select best move