* rework of selection
* don't override nbmove to 0 * fix typo in simulation not playing * fix win verification * fix backpropagation * use iterator instead of time for algo limitation (MEMORY LEAK?) * remove debug prints
This commit is contained in:
parent
b5c53228fd
commit
343a6e8cf0
2 changed files with 46 additions and 54 deletions
|
@ -1,7 +1,6 @@
|
|||
#ifndef MYBT_H
|
||||
#define MYBT_H
|
||||
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
|
||||
#define WHITE 0
|
||||
|
@ -103,8 +102,8 @@ struct bt_t {
|
|||
// MCTS
|
||||
bt_node_t *mcts_selection(bt_node_t *);
|
||||
void mcts_expansion(bt_node_t *);
|
||||
bool mcts_simulation(bt_node_t *);
|
||||
void mcts_back_propagation(bt_node_t *, bool);
|
||||
bool mcts_simulation();
|
||||
static void mcts_back_propagation(bt_node_t *, bool);
|
||||
|
||||
void add_move(int _li, int _ci, int _lf, int _cf) {
|
||||
moves[nb_moves].line_i = _li;
|
||||
|
|
|
@ -193,29 +193,39 @@ bt_move_t bt_t::get_rand_move() {
|
|||
}
|
||||
|
||||
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
||||
if (node->nb_simulation > 0) {
|
||||
bt_node_t *best = nullptr;
|
||||
float score = std::numeric_limits<float>::lowest();
|
||||
for (auto i : node->children) {
|
||||
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||
float curr_score =
|
||||
static_cast<float>(i->wins) / i->nb_simulation +
|
||||
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
|
||||
if (curr_score > score) {
|
||||
best = i;
|
||||
score = curr_score;
|
||||
}
|
||||
}
|
||||
// Return the element with the higher score
|
||||
return best;
|
||||
} else {
|
||||
// Return random element if no simulation already done
|
||||
return node
|
||||
->children[(static_cast<int>(rand())) % (node->children.size() - 1)];
|
||||
// Play the move is game over or first call
|
||||
if ((endgame() != EMPTY) ||
|
||||
((node->nb_simulation == 0) && (node->parent != nullptr))) {
|
||||
play(node->move);
|
||||
return node;
|
||||
}
|
||||
|
||||
bt_node_t *best = nullptr;
|
||||
float score = std::numeric_limits<float>::lowest();
|
||||
for (auto i : node->children) {
|
||||
// Play the move if never played
|
||||
if (i->nb_simulation == 0) {
|
||||
play(i->move);
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||
float curr_score =
|
||||
(static_cast<float>(i->wins) / i->nb_simulation) +
|
||||
((sqrt(2) * sqrt(log(node->nb_simulation)) / i->nb_simulation));
|
||||
|
||||
if (curr_score > score) {
|
||||
best = i;
|
||||
score = curr_score;
|
||||
}
|
||||
}
|
||||
|
||||
play(best->move);
|
||||
return mcts_selection(best);
|
||||
}
|
||||
|
||||
void bt_t::mcts_expansion(bt_node_t *root) {
|
||||
void bt_t::mcts_expansion(bt_node_t *node) {
|
||||
update_moves();
|
||||
|
||||
for (int i = 0; i < nb_moves; i++) {
|
||||
|
@ -223,32 +233,29 @@ void bt_t::mcts_expansion(bt_node_t *root) {
|
|||
bt_node_t *tmp = new bt_node_t;
|
||||
tmp->move = moves[i];
|
||||
tmp->wins = 0;
|
||||
tmp->parent = root;
|
||||
tmp->parent = node;
|
||||
tmp->nb_simulation = 0;
|
||||
|
||||
// Add child to root
|
||||
root->children.push_back(tmp);
|
||||
// Add child to selected node
|
||||
node->children.push_back(tmp);
|
||||
}
|
||||
|
||||
// Avoid re-add same moves
|
||||
nb_moves = 0;
|
||||
}
|
||||
|
||||
bool bt_t::mcts_simulation(bt_node_t *node) {
|
||||
bool bt_t::mcts_simulation(void) {
|
||||
int me = (turn % 2 == 0) ? WHITE : BLACK;
|
||||
|
||||
// then play randomly 'til the game is over
|
||||
while (endgame() != EMPTY) {
|
||||
while (endgame() == EMPTY) {
|
||||
play(get_rand_move());
|
||||
}
|
||||
|
||||
// if i won
|
||||
return ((turn % 2 == 0) ? WHITE : BLACK == me) ? true : false;
|
||||
return ((turn % 2 == 0) ? WHITE : BLACK) == me ? true : false;
|
||||
}
|
||||
|
||||
void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
|
||||
// propagate values to the top of the tree
|
||||
while (simulated->parent != nullptr) {
|
||||
while (simulated != nullptr) {
|
||||
if (won) {
|
||||
simulated->wins++;
|
||||
}
|
||||
|
@ -276,40 +283,26 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
|||
tree->children = {};
|
||||
tree->wins = 0;
|
||||
tree->nb_simulation = 0;
|
||||
print_vec(*this, tree->children);
|
||||
mcts_expansion(tree);
|
||||
print_vec(*this, tree->children);
|
||||
|
||||
// Time constraint
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double> elapsed{};
|
||||
// On populise l'arbre avec les entrées de départ
|
||||
mcts_expansion(tree);
|
||||
|
||||
// MCTS
|
||||
while (elapsed.count() < max_time) {
|
||||
for (int it = 0; it < 200; ++it) {
|
||||
// Copy board
|
||||
bt_t copy_b = *this;
|
||||
|
||||
// Selection
|
||||
bt_node_t *selected = copy_b.mcts_selection(tree);
|
||||
|
||||
// Play the move
|
||||
if (can_play(selected->move)) {
|
||||
play(selected->move);
|
||||
} else {
|
||||
fprintf(stderr, "?????\n");
|
||||
}
|
||||
bt_node_t *result = copy_b.mcts_selection(tree);
|
||||
|
||||
// Expansion
|
||||
copy_b.mcts_expansion(selected);
|
||||
copy_b.mcts_expansion(result);
|
||||
|
||||
// Simulation
|
||||
bool is_win = copy_b.mcts_simulation(selected);
|
||||
bool is_win = copy_b.mcts_simulation();
|
||||
|
||||
// Update
|
||||
mcts_back_propagation(selected, is_win);
|
||||
|
||||
// Update elapsed time
|
||||
elapsed = std::chrono::steady_clock::now() - now;
|
||||
mcts_back_propagation(result, is_win);
|
||||
}
|
||||
|
||||
// Select best move
|
||||
|
|
Reference in a new issue