WIP: add mcts basics
* expansion: ok
This commit is contained in:
parent
d3bfc7d913
commit
ffed7158f3
2 changed files with 44 additions and 15 deletions
|
@ -44,6 +44,15 @@ struct bt_move_t {
|
||||||
#define MAX_LINES 10
|
#define MAX_LINES 10
|
||||||
#define MAX_COLS 10
|
#define MAX_COLS 10
|
||||||
|
|
||||||
|
// Node MCTS
|
||||||
|
struct bt_node_t {
|
||||||
|
bt_node_t *parent;
|
||||||
|
std::vector<bt_node_t *> children;
|
||||||
|
bt_move_t move;
|
||||||
|
int wins;
|
||||||
|
int nb_simulation;
|
||||||
|
};
|
||||||
|
|
||||||
// rules reminder :
|
// rules reminder :
|
||||||
// pieces moves from 1 square in diag and in front
|
// pieces moves from 1 square in diag and in front
|
||||||
// pieces captures only in diag
|
// pieces captures only in diag
|
||||||
|
@ -91,6 +100,12 @@ struct bt_t {
|
||||||
std::string mkH2();
|
std::string mkH2();
|
||||||
long long int mkH3();
|
long long int mkH3();
|
||||||
|
|
||||||
|
// MCTS
|
||||||
|
void mcts_selection(bt_node_t *);
|
||||||
|
void mcts_expansion(bt_node_t *);
|
||||||
|
void mcts_simulation();
|
||||||
|
void mcts_back_propagation();
|
||||||
|
|
||||||
// déclarées mais non définies
|
// déclarées mais non définies
|
||||||
double eval();
|
double eval();
|
||||||
bt_move_t minimax(double _sec);
|
bt_move_t minimax(double _sec);
|
||||||
|
@ -109,13 +124,4 @@ struct bt_t {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Node MCTS
|
|
||||||
struct bt_node_t {
|
|
||||||
bt_node_t *parent;
|
|
||||||
std::vector<bt_node_t *> children;
|
|
||||||
bt_move_t move;
|
|
||||||
int wins;
|
|
||||||
int nb_simulation;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif /* MYBT_H */
|
#endif /* MYBT_H */
|
||||||
|
|
|
@ -193,22 +193,45 @@ bt_move_t bt_t::get_rand_move() {
|
||||||
return moves[r];
|
return moves[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
bt_move_t bt_t::get_mcts_move(double max_time) {
|
void bt_t::mcts_selection(bt_node_t *node) { (void)node; }
|
||||||
|
|
||||||
|
void bt_t::mcts_expansion(bt_node_t *root) {
|
||||||
update_moves();
|
update_moves();
|
||||||
|
|
||||||
|
for (int i = 0; i < nb_moves; i++) {
|
||||||
|
// Child
|
||||||
|
bt_node_t *tmp = new bt_node_t;
|
||||||
|
tmp->move = moves[i];
|
||||||
|
tmp->wins = 0;
|
||||||
|
tmp->parent = root;
|
||||||
|
tmp->nb_simulation = 0;
|
||||||
|
|
||||||
|
// Add child to root
|
||||||
|
root->children.push_back(tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void bt_t::mcts_simulation(void) {}
|
||||||
|
|
||||||
|
void bt_t::mcts_back_propagation(void) {}
|
||||||
|
|
||||||
|
bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||||
// Init tree
|
// Init tree
|
||||||
bt_node_t *tree = new bt_node_t();
|
bt_node_t *tree = new bt_node_t();
|
||||||
tree->children = {};
|
|
||||||
tree->parent = NULL;
|
tree->parent = NULL;
|
||||||
tree->nb_simulation = 0;
|
tree->children = {};
|
||||||
tree->wins = 0;
|
tree->wins = 0;
|
||||||
|
tree->nb_simulation = 0;
|
||||||
|
mcts_expansion(tree);
|
||||||
|
|
||||||
// Copy board
|
// Copy board
|
||||||
bt_t b_copy = *this;
|
bt_t b_copy = *this;
|
||||||
|
|
||||||
// MCTS
|
// Time constraint
|
||||||
auto now = std::chrono::steady_clock::now();
|
auto now = std::chrono::steady_clock::now();
|
||||||
std::chrono::duration<double> elapsed{};
|
std::chrono::duration<double> elapsed{};
|
||||||
|
|
||||||
|
// MCTS
|
||||||
while (elapsed.count() < max_time) {
|
while (elapsed.count() < max_time) {
|
||||||
// Selection
|
// Selection
|
||||||
std::cerr << "selection\n";
|
std::cerr << "selection\n";
|
||||||
|
@ -227,14 +250,14 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select best move
|
// Select best move
|
||||||
bt_node_t *best_node = NULL;
|
/* bt_node_t *best_node = NULL;
|
||||||
int best_score = -1;
|
int best_score = -1;
|
||||||
for (auto i : tree->children) {
|
for (auto i : tree->children) {
|
||||||
if (i->nb_simulation > best_score) {
|
if (i->nb_simulation > best_score) {
|
||||||
best_node = i;
|
best_node = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// return best_node->move;
|
return best_node->move; */
|
||||||
return moves[(static_cast<int>(rand())) % nb_moves]; // TMP
|
return moves[(static_cast<int>(rand())) % nb_moves]; // TMP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Reference in a new issue