some debug stuff

This commit is contained in:
Mylloon 2023-03-31 13:04:02 +02:00
parent 5bf573b4f3
commit b5c53228fd
Signed by: Anri
GPG key ID: A82D63DFF8D1317F

View file

@ -193,37 +193,31 @@ bt_move_t bt_t::get_rand_move() {
} }
bt_node_t *bt_t::mcts_selection(bt_node_t *node) { bt_node_t *bt_t::mcts_selection(bt_node_t *node) {
if (node->nb_simulation) { if (node->nb_simulation > 0) {
bt_node_t *the_chosen = nullptr; bt_node_t *best = nullptr;
float score = 0.; float score = std::numeric_limits<float>::lowest();
for (auto i : node->children) { for (auto i : node->children) {
// https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
float curr_score = float curr_score =
static_cast<float>(i->wins) / i->nb_simulation + static_cast<float>(i->wins) / i->nb_simulation +
sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation); sqrt(2) * sqrt(log(node->nb_simulation) / i->nb_simulation);
if (curr_score > score) { if (curr_score > score) {
the_chosen = i; best = i;
score = curr_score; score = curr_score;
} }
} }
// Return the element with the higher score // Return the element with the higher score
return the_chosen; return best;
} else { } else {
// Return the first element if no simulation already done // Return random element if no simulation already done
return node->children[0]; return node
->children[(static_cast<int>(rand())) % (node->children.size() - 1)];
} }
} }
void bt_t::mcts_expansion(bt_node_t *root) { void bt_t::mcts_expansion(bt_node_t *root) {
update_moves(); update_moves();
// Play the move
if (can_play(root->move)) {
play(root->move);
} else {
fprintf(stderr, "?????");
}
for (int i = 0; i < nb_moves; i++) { for (int i = 0; i < nb_moves; i++) {
// Child // Child
bt_node_t *tmp = new bt_node_t; bt_node_t *tmp = new bt_node_t;
@ -263,6 +257,18 @@ void bt_t::mcts_back_propagation(bt_node_t *simulated, bool won) {
} }
} }
void print_vec(bt_t board, std::vector<bt_node_t *> v) {
if (v.empty()) {
fprintf(stderr, "[]\n");
return;
}
for (auto i : v) {
fprintf(stderr, "move(%s)", i->move.tostr(board.nbl).c_str());
fprintf(stderr, " label(%d/%d)", i->wins, i->nb_simulation);
fprintf(stderr, " parent(%s)\n", i->parent->move.tostr(board.nbl).c_str());
}
}
bt_move_t bt_t::get_mcts_move(double max_time) { bt_move_t bt_t::get_mcts_move(double max_time) {
// Init tree // Init tree
bt_node_t *tree = new bt_node_t(); bt_node_t *tree = new bt_node_t();
@ -270,7 +276,9 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
tree->children = {}; tree->children = {};
tree->wins = 0; tree->wins = 0;
tree->nb_simulation = 0; tree->nb_simulation = 0;
print_vec(*this, tree->children);
mcts_expansion(tree); mcts_expansion(tree);
print_vec(*this, tree->children);
// Time constraint // Time constraint
auto now = std::chrono::steady_clock::now(); auto now = std::chrono::steady_clock::now();
@ -284,6 +292,13 @@ bt_move_t bt_t::get_mcts_move(double max_time) {
// Selection // Selection
bt_node_t *selected = copy_b.mcts_selection(tree); bt_node_t *selected = copy_b.mcts_selection(tree);
// Play the move
if (can_play(selected->move)) {
play(selected->move);
} else {
fprintf(stderr, "?????\n");
}
// Expansion // Expansion
copy_b.mcts_expansion(selected); copy_b.mcts_expansion(selected);