WIP: add mcts basics

* expansion: ok
This commit is contained in:
Mylloon 2023-03-29 17:26:33 +02:00
parent d3bfc7d913
commit ffed7158f3
Signed by: Anri
GPG key ID: A82D63DFF8D1317F
2 changed files with 44 additions and 15 deletions

View file

@ -44,6 +44,15 @@ struct bt_move_t {
#define MAX_LINES 10 #define MAX_LINES 10
#define MAX_COLS 10 #define MAX_COLS 10
// Node MCTS
struct bt_node_t {
bt_node_t *parent;
std::vector<bt_node_t *> children;
bt_move_t move;
int wins;
int nb_simulation;
};
// rules reminder : // rules reminder :
// pieces moves from 1 square in diag and in front // pieces moves from 1 square in diag and in front
// pieces captures only in diag // pieces captures only in diag
@ -91,6 +100,12 @@ struct bt_t {
std::string mkH2(); std::string mkH2();
long long int mkH3(); long long int mkH3();
// MCTS
void mcts_selection(bt_node_t *);
void mcts_expansion(bt_node_t *);
void mcts_simulation();
void mcts_back_propagation();
// déclarées mais non définies // déclarées mais non définies
double eval(); double eval();
bt_move_t minimax(double _sec); bt_move_t minimax(double _sec);
@ -109,13 +124,4 @@ struct bt_t {
} }
}; };
// Node MCTS
struct bt_node_t {
bt_node_t *parent;
std::vector<bt_node_t *> children;
bt_move_t move;
int wins;
int nb_simulation;
};
#endif /* MYBT_H */ #endif /* MYBT_H */

View file

@ -193,22 +193,45 @@ bt_move_t bt_t::get_rand_move() {
return moves[r]; return moves[r];
} }
bt_move_t bt_t::get_mcts_move(double max_time) { void bt_t::mcts_selection(bt_node_t *node) { (void)node; }
void bt_t::mcts_expansion(bt_node_t *root) {
update_moves(); update_moves();
for (int i = 0; i < nb_moves; i++) {
// Child
bt_node_t *tmp = new bt_node_t;
tmp->move = moves[i];
tmp->wins = 0;
tmp->parent = root;
tmp->nb_simulation = 0;
// Add child to root
root->children.push_back(tmp);
}
}
void bt_t::mcts_simulation(void) {}
void bt_t::mcts_back_propagation(void) {}
bt_move_t bt_t::get_mcts_move(double max_time) {
// Init tree // Init tree
bt_node_t *tree = new bt_node_t(); bt_node_t *tree = new bt_node_t();
tree->children = {};
tree->parent = NULL; tree->parent = NULL;
tree->nb_simulation = 0; tree->children = {};
tree->wins = 0; tree->wins = 0;
tree->nb_simulation = 0;
mcts_expansion(tree);
// Copy board // Copy board
bt_t b_copy = *this; bt_t b_copy = *this;
// MCTS // Time constraint
auto now = std::chrono::steady_clock::now(); auto now = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed{}; std::chrono::duration<double> elapsed{};
// MCTS
while (elapsed.count() < max_time) { while (elapsed.count() < max_time) {
// Selection // Selection
std::cerr << "selection\n"; std::cerr << "selection\n";
@ -227,14 +250,14 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
} }
// Select best move // Select best move
bt_node_t *best_node = NULL; /* bt_node_t *best_node = NULL;
int best_score = -1; int best_score = -1;
for (auto i : tree->children) { for (auto i : tree->children) {
if (i->nb_simulation > best_score) { if (i->nb_simulation > best_score) {
best_node = i; best_node = i;
} }
} }
// return best_node->move; return best_node->move; */
return moves[(static_cast<int>(rand())) % nb_moves]; // TMP return moves[(static_cast<int>(rand())) % nb_moves]; // TMP
} }