add selection
This commit is contained in:
parent
ffed7158f3
commit
881c609f74
2 changed files with 26 additions and 2 deletions
|
@ -101,7 +101,7 @@ struct bt_t {
|
|||
long long int mkH3();
|
||||
|
||||
// MCTS
|
||||
void mcts_selection(bt_node_t *);
|
||||
bt_node_t *mcts_selection(bt_node_t *);
|
||||
void mcts_expansion(bt_node_t *);
|
||||
void mcts_simulation();
|
||||
void mcts_back_propagation();
|
||||
|
|
|
@ -193,7 +193,27 @@ bt_move_t bt_t::get_rand_move() {
|
|||
return moves[r];
|
||||
}
|
||||
|
||||
void bt_t::mcts_selection(bt_node_t *node) { (void)node; }
|
||||
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
||||
if (node->nb_simulation) {
|
||||
bt_node_t *the_chosen = NULL;
|
||||
float score = 0.;
|
||||
for (auto i : node->children) {
|
||||
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||
float curr_score =
|
||||
static_cast<float>(i->wins) / i->nb_simulation +
|
||||
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
|
||||
if (curr_score > score) {
|
||||
the_chosen = i;
|
||||
score = curr_score;
|
||||
}
|
||||
}
|
||||
// Return the element with the higher score
|
||||
return the_chosen;
|
||||
} else {
|
||||
// Return the first element if no simulation already done
|
||||
return node->children[0];
|
||||
}
|
||||
}
|
||||
|
||||
void bt_t::mcts_expansion(bt_node_t *root) {
|
||||
update_moves();
|
||||
|
@ -234,15 +254,19 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
|||
// MCTS
|
||||
while (elapsed.count() < max_time) {
|
||||
// Selection
|
||||
auto xd = b_copy.mcts_selection(tree);
|
||||
std::cerr << "selection\n";
|
||||
|
||||
// Expansion
|
||||
// mcts_expansion(selected);
|
||||
std::cerr << "expansion\n";
|
||||
|
||||
// Simulation
|
||||
// bool is_win = mcts_simulation(selected);
|
||||
std::cerr << "simulation\n";
|
||||
|
||||
// Update
|
||||
// mcts_back_propagation(selected, is_win);
|
||||
std::cerr << "update\n";
|
||||
|
||||
// Time constraint
|
||||
|
|
Reference in a new issue