add selection
This commit is contained in:
parent
ffed7158f3
commit
881c609f74
2 changed files with 26 additions and 2 deletions
|
@ -101,7 +101,7 @@ struct bt_t {
|
||||||
long long int mkH3();
|
long long int mkH3();
|
||||||
|
|
||||||
// MCTS
|
// MCTS
|
||||||
void mcts_selection(bt_node_t *);
|
bt_node_t *mcts_selection(bt_node_t *);
|
||||||
void mcts_expansion(bt_node_t *);
|
void mcts_expansion(bt_node_t *);
|
||||||
void mcts_simulation();
|
void mcts_simulation();
|
||||||
void mcts_back_propagation();
|
void mcts_back_propagation();
|
||||||
|
|
|
@ -193,7 +193,27 @@ bt_move_t bt_t::get_rand_move() {
|
||||||
return moves[r];
|
return moves[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
void bt_t::mcts_selection(bt_node_t *node) { (void)node; }
|
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
||||||
|
if (node->nb_simulation) {
|
||||||
|
bt_node_t *the_chosen = NULL;
|
||||||
|
float score = 0.;
|
||||||
|
for (auto i : node->children) {
|
||||||
|
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||||
|
float curr_score =
|
||||||
|
static_cast<float>(i->wins) / i->nb_simulation +
|
||||||
|
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
|
||||||
|
if (curr_score > score) {
|
||||||
|
the_chosen = i;
|
||||||
|
score = curr_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Return the element with the higher score
|
||||||
|
return the_chosen;
|
||||||
|
} else {
|
||||||
|
// Return the first element if no simulation already done
|
||||||
|
return node->children[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void bt_t::mcts_expansion(bt_node_t *root) {
|
void bt_t::mcts_expansion(bt_node_t *root) {
|
||||||
update_moves();
|
update_moves();
|
||||||
|
@ -234,15 +254,19 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||||
// MCTS
|
// MCTS
|
||||||
while (elapsed.count() < max_time) {
|
while (elapsed.count() < max_time) {
|
||||||
// Selection
|
// Selection
|
||||||
|
auto xd = b_copy.mcts_selection(tree);
|
||||||
std::cerr << "selection\n";
|
std::cerr << "selection\n";
|
||||||
|
|
||||||
// Expansion
|
// Expansion
|
||||||
|
// mcts_expansion(selected);
|
||||||
std::cerr << "expansion\n";
|
std::cerr << "expansion\n";
|
||||||
|
|
||||||
// Simulation
|
// Simulation
|
||||||
|
// bool is_win = mcts_simulation(selected);
|
||||||
std::cerr << "simulation\n";
|
std::cerr << "simulation\n";
|
||||||
|
|
||||||
// Update
|
// Update
|
||||||
|
// mcts_back_propagation(selected, is_win);
|
||||||
std::cerr << "update\n";
|
std::cerr << "update\n";
|
||||||
|
|
||||||
// Time constraint
|
// Time constraint
|
||||||
|
|
Reference in a new issue