* rework of selection
* don't override nbmove to 0 * fix typo in simulation not playing * fix win verification * fix backpropagation * use iterator instead of time for algo limitation (MEMORY LEAK?) * remove debug prints
This commit is contained in:
parent
b5c53228fd
commit
343a6e8cf0
2 changed files with 46 additions and 54 deletions
|
@ -1,7 +1,6 @@
|
||||||
#ifndef MYBT_H
|
#ifndef MYBT_H
|
||||||
#define MYBT_H
|
#define MYBT_H
|
||||||
|
|
||||||
#include <chrono>
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#define WHITE 0
|
#define WHITE 0
|
||||||
|
@ -103,8 +102,8 @@ struct bt_t {
|
||||||
// MCTS
|
// MCTS
|
||||||
bt_node_t *mcts_selection(bt_node_t *);
|
bt_node_t *mcts_selection(bt_node_t *);
|
||||||
void mcts_expansion(bt_node_t *);
|
void mcts_expansion(bt_node_t *);
|
||||||
bool mcts_simulation(bt_node_t *);
|
bool mcts_simulation();
|
||||||
void mcts_back_propagation(bt_node_t *, bool);
|
static void mcts_back_propagation(bt_node_t *, bool);
|
||||||
|
|
||||||
void add_move(int _li, int _ci, int _lf, int _cf) {
|
void add_move(int _li, int _ci, int _lf, int _cf) {
|
||||||
moves[nb_moves].line_i = _li;
|
moves[nb_moves].line_i = _li;
|
||||||
|
|
|
@ -193,29 +193,39 @@ bt_move_t bt_t::get_rand_move() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
||||||
if (node->nb_simulation > 0) {
|
// Play the move is game over or first call
|
||||||
|
if ((endgame() != EMPTY) ||
|
||||||
|
((node->nb_simulation == 0) && (node->parent != nullptr))) {
|
||||||
|
play(node->move);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
bt_node_t *best = nullptr;
|
bt_node_t *best = nullptr;
|
||||||
float score = std::numeric_limits<float>::lowest();
|
float score = std::numeric_limits<float>::lowest();
|
||||||
for (auto i : node->children) {
|
for (auto i : node->children) {
|
||||||
|
// Play the move if never played
|
||||||
|
if (i->nb_simulation == 0) {
|
||||||
|
play(i->move);
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||||
float curr_score =
|
float curr_score =
|
||||||
static_cast<float>(i->wins) / i->nb_simulation +
|
(static_cast<float>(i->wins) / i->nb_simulation) +
|
||||||
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
|
((sqrt(2) * sqrt(log(node->nb_simulation)) / i->nb_simulation));
|
||||||
|
|
||||||
if (curr_score > score) {
|
if (curr_score > score) {
|
||||||
best = i;
|
best = i;
|
||||||
score = curr_score;
|
score = curr_score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Return the element with the higher score
|
|
||||||
return best;
|
play(best->move);
|
||||||
} else {
|
return mcts_selection(best);
|
||||||
// Return random element if no simulation already done
|
|
||||||
return node
|
|
||||||
->children[(static_cast<int>(rand())) % (node->children.size() - 1)];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void bt_t::mcts_expansion(bt_node_t *root) {
|
void bt_t::mcts_expansion(bt_node_t *node) {
|
||||||
update_moves();
|
update_moves();
|
||||||
|
|
||||||
for (int i = 0; i < nb_moves; i++) {
|
for (int i = 0; i < nb_moves; i++) {
|
||||||
|
@ -223,32 +233,29 @@ void bt_t::mcts_expansion(bt_node_t *root) {
|
||||||
bt_node_t *tmp = new bt_node_t;
|
bt_node_t *tmp = new bt_node_t;
|
||||||
tmp->move = moves[i];
|
tmp->move = moves[i];
|
||||||
tmp->wins = 0;
|
tmp->wins = 0;
|
||||||
tmp->parent = root;
|
tmp->parent = node;
|
||||||
tmp->nb_simulation = 0;
|
tmp->nb_simulation = 0;
|
||||||
|
|
||||||
// Add child to root
|
// Add child to selected node
|
||||||
root->children.push_back(tmp);
|
node->children.push_back(tmp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avoid re-add same moves
|
bool bt_t::mcts_simulation(void) {
|
||||||
nb_moves = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bt_t::mcts_simulation(bt_node_t *node) {
|
|
||||||
int me = (turn % 2 == 0) ? WHITE : BLACK;
|
int me = (turn % 2 == 0) ? WHITE : BLACK;
|
||||||
|
|
||||||
// then play randomly 'til the game is over
|
// then play randomly 'til the game is over
|
||||||
while (endgame() != EMPTY) {
|
while (endgame() == EMPTY) {
|
||||||
play(get_rand_move());
|
play(get_rand_move());
|
||||||
}
|
}
|
||||||
|
|
||||||
// if i won
|
// if i won
|
||||||
return ((turn % 2 == 0) ? WHITE : BLACK == me) ? true : false;
|
return ((turn % 2 == 0) ? WHITE : BLACK) == me ? true : false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
|
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
|
||||||
// propagate values to the top of the tree
|
// propagate values to the top of the tree
|
||||||
while (simulated->parent != nullptr) {
|
while (simulated != nullptr) {
|
||||||
if (won) {
|
if (won) {
|
||||||
simulated->wins++;
|
simulated->wins++;
|
||||||
}
|
}
|
||||||
|
@ -276,40 +283,26 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||||
tree->children = {};
|
tree->children = {};
|
||||||
tree->wins = 0;
|
tree->wins = 0;
|
||||||
tree->nb_simulation = 0;
|
tree->nb_simulation = 0;
|
||||||
print_vec(*this, tree->children);
|
|
||||||
mcts_expansion(tree);
|
|
||||||
print_vec(*this, tree->children);
|
|
||||||
|
|
||||||
// Time constraint
|
// On populise l'arbre avec les entrées de départ
|
||||||
auto now = std::chrono::steady_clock::now();
|
mcts_expansion(tree);
|
||||||
std::chrono::duration<double> elapsed{};
|
|
||||||
|
|
||||||
// MCTS
|
// MCTS
|
||||||
while (elapsed.count() < max_time) {
|
for (int it = 0; it < 200; ++it) {
|
||||||
// Copy board
|
// Copy board
|
||||||
bt_t copy_b = *this;
|
bt_t copy_b = *this;
|
||||||
|
|
||||||
// Selection
|
// Selection
|
||||||
bt_node_t *selected = copy_b.mcts_selection(tree);
|
bt_node_t *result = copy_b.mcts_selection(tree);
|
||||||
|
|
||||||
// Play the move
|
|
||||||
if (can_play(selected->move)) {
|
|
||||||
play(selected->move);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "?????\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expansion
|
// Expansion
|
||||||
copy_b.mcts_expansion(selected);
|
copy_b.mcts_expansion(result);
|
||||||
|
|
||||||
// Simulation
|
// Simulation
|
||||||
bool is_win = copy_b.mcts_simulation(selected);
|
bool is_win = copy_b.mcts_simulation();
|
||||||
|
|
||||||
// Update
|
// Update
|
||||||
mcts_back_propagation(selected, is_win);
|
mcts_back_propagation(result, is_win);
|
||||||
|
|
||||||
// Update elapsed time
|
|
||||||
elapsed = std::chrono::steady_clock::now() - now;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select best move
|
// Select best move
|
||||||
|
|
Reference in a new issue