some debug stuff
This commit is contained in:
parent
5bf573b4f3
commit
b5c53228fd
1 changed files with 29 additions and 14 deletions
|
@ -193,37 +193,31 @@ bt_move_t bt_t::get_rand_move() {
|
|||
}
|
||||
|
||||
bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
|
||||
if (node->nb_simulation) {
|
||||
bt_node_t *the_chosen = nullptr;
|
||||
float score = 0.;
|
||||
if (node->nb_simulation > 0) {
|
||||
bt_node_t *best = nullptr;
|
||||
float score = std::numeric_limits<float>::lowest();
|
||||
for (auto i : node->children) {
|
||||
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||
float curr_score =
|
||||
static_cast<float>(i->wins) / i->nb_simulation +
|
||||
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
|
||||
if (curr_score > score) {
|
||||
the_chosen = i;
|
||||
best = i;
|
||||
score = curr_score;
|
||||
}
|
||||
}
|
||||
// Return the element with the higher score
|
||||
return the_chosen;
|
||||
return best;
|
||||
} else {
|
||||
// Return the first element if no simulation already done
|
||||
return node->children[0];
|
||||
// Return random element if no simulation already done
|
||||
return node
|
||||
->children[(static_cast<int>(rand())) % (node->children.size() - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
void bt_t::mcts_expansion(bt_node_t *root) {
|
||||
update_moves();
|
||||
|
||||
// Play the move
|
||||
if (can_play(root->move)) {
|
||||
play(root->move);
|
||||
} else {
|
||||
fprintf(stderr, "?????");
|
||||
}
|
||||
|
||||
for (int i = 0; i < nb_moves; i++) {
|
||||
// Child
|
||||
bt_node_t *tmp = new bt_node_t;
|
||||
|
@ -263,6 +257,18 @@ void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
|
|||
}
|
||||
}
|
||||
|
||||
void print_vec(bt_t board, std::vector<bt_node_t *> v) {
|
||||
if (v.empty()) {
|
||||
fprintf(stderr, "[]\n");
|
||||
return;
|
||||
}
|
||||
for (auto i : v) {
|
||||
fprintf(stderr, "move(%s)", i->move.tostr(board.nbl).c_str());
|
||||
fprintf(stderr, " label(%d/%d)", i->wins, i->nb_simulation);
|
||||
fprintf(stderr, " parent(%s)\n", i->parent->move.tostr(board.nbl).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
bt_move_t bt_t::get_mcts_move(double max_time) {
|
||||
// Init tree
|
||||
bt_node_t *tree = new bt_node_t();
|
||||
|
@ -270,7 +276,9 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
|||
tree->children = {};
|
||||
tree->wins = 0;
|
||||
tree->nb_simulation = 0;
|
||||
print_vec(*this, tree->children);
|
||||
mcts_expansion(tree);
|
||||
print_vec(*this, tree->children);
|
||||
|
||||
// Time constraint
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
@ -284,6 +292,13 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
|
|||
// Selection
|
||||
bt_node_t *selected = copy_b.mcts_selection(tree);
|
||||
|
||||
// Play the move
|
||||
if (can_play(selected->move)) {
|
||||
play(selected->move);
|
||||
} else {
|
||||
fprintf(stderr, "?????\n");
|
||||
}
|
||||
|
||||
// Expansion
|
||||
copy_b.mcts_expansion(selected);
|
||||
|
||||
|
|
Reference in a new issue