Skip to content

Commit

Permalink
changed float into uint32_t for visits counter
Browse files Browse the repository at this point in the history
fixes #39
reformulated connecting new node to search tree
  • Loading branch information
QueensGambit committed May 27, 2020
1 parent 724b79b commit a6e3628
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 100 deletions.
2 changes: 1 addition & 1 deletion engine/src/evalinfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,6 @@ void update_eval_info(EvalInfo& evalInfo, Node* rootNode, size_t tbHits, size_t
evalInfo.bestMoveQ = rootNode->get_q_value(bestMoveIdx);
evalInfo.centipawns = value_to_centipawn(evalInfo.bestMoveQ);
}
evalInfo.nodes = size_t(get_node_count(rootNode));
evalInfo.nodes = get_node_count(rootNode);
evalInfo.tbHits = tbHits;
}
45 changes: 19 additions & 26 deletions engine/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ bool Node::at_least_one_drawn_child() const
bool Node::only_won_child_nodes() const
{
for (Node* childNode : d->childNodes) {
if (!childNode->is_playout_node() || childNode->d->nodeType != SOLVED_WIN) {
if (childNode->d->nodeType != SOLVED_WIN) {
return false;
}
}
Expand Down Expand Up @@ -247,16 +247,6 @@ void Node::mcts_policy_based_on_q_n(DynamicVector<float>& mctsPolicy, float qVal

void Node::solve_for_terminal(const Node* childNode)
{
if (childNode == nullptr || !childNode->is_playout_node()) {
info_string("nullptr as child node backup!");
return;
}
if (d == nullptr) {
init_node_data();
}
if (parentNode->d == nullptr) {
parentNode->init_node_data();
}
if (d->nodeType != UNSOLVED) {
// already solved
return;
Expand Down Expand Up @@ -360,7 +350,7 @@ void Node::apply_virtual_loss_to_child(size_t childIdx, float virtualLoss)
// temporarily reduce the attraction of this node by applying a virtual loss /
// the effect of virtual loss will be undone if the playout is over
// virtual increase the number of visits
d->childNumberVisits[childIdx] += virtualLoss;
d->childNumberVisits[childIdx] += size_t(virtualLoss);
// make it look like if one has lost X games from this node forward where X is the virtual loss value
d->actionValues[childIdx] -= virtualLoss;
d->qValues[childIdx] = d->actionValues[childIdx] / d->childNumberVisits[childIdx];
Expand All @@ -371,7 +361,7 @@ Node *Node::get_parent_node() const
return parentNode;
}

void Node::increment_visits(float numberVisits)
void Node::increment_visits(size_t numberVisits)
{
parentNode->lock();
parentNode->d->childNumberVisits[childIdxForParent] += numberVisits;
Expand Down Expand Up @@ -401,15 +391,13 @@ void Node::reserve_full_memory()

void Node::increment_no_visit_idx()
{
lock();
if (d->noVisitIdx < get_number_child_nodes()) {
++d->noVisitIdx;
if (d->noVisitIdx == PRESERVED_ITEMS) {
reserve_full_memory();
}
d->add_empty_node();
}
unlock();
}

void Node::fully_expand_node()
Expand Down Expand Up @@ -446,7 +434,7 @@ void Node::prepare_node_for_visits()
init_node_data();
}

float Node::get_visits() const
uint32_t Node::get_visits() const
{
return parentNode->d->childNumberVisits[childIdxForParent];
}
Expand All @@ -469,8 +457,13 @@ void Node::backup_value(size_t childIdx, float value, float virtualLoss)
void Node::revert_virtual_loss_and_update(size_t childIdx, float value, float virtualLoss)
{
lock();
d->childNumberVisits[childIdx] -= virtualLoss - 1;
d->actionValues[childIdx] += virtualLoss + value;
if (virtualLoss != 1) {
d->childNumberVisits[childIdx] -= size_t(virtualLoss) - 1;
d->actionValues[childIdx] += virtualLoss + value;
}
else {
d->actionValues[childIdx] += 1 + value;
}
d->qValues[childIdx] = d->actionValues[childIdx] / d->childNumberVisits[childIdx];
if (is_terminal_value(value)) {
++d->terminalVisits;
Expand Down Expand Up @@ -607,7 +600,7 @@ int Node::get_checkmate_idx() const
return d->checkmateIdx;
}

DynamicVector<float> Node::get_child_number_visits() const
DynamicVector<uint32_t> Node::get_child_number_visits() const
{
return d->childNumberVisits;
}
Expand Down Expand Up @@ -637,7 +630,7 @@ uint16_t Node::get_end_in_ply() const
return d->endInPly;
}

float Node::get_terminal_visits() const
uint32_t Node::get_terminal_visits() const
{
return d->terminalVisits;
}
Expand Down Expand Up @@ -842,7 +835,7 @@ bool is_capture(const Board* pos, Move move)

DynamicVector<float> Node::get_current_u_values(const SearchSettings* searchSettings)
{
return get_current_cput(get_visits(), searchSettings) * blaze::subvector(policyProbSmall, 0, d->noVisitIdx) * (sqrt(get_visits()) / (d->childNumberVisits + 1.f));
return get_current_cput(get_visits(), searchSettings) * blaze::subvector(policyProbSmall, 0, d->noVisitIdx) * (sqrt(get_visits()) / (d->childNumberVisits + 1.0));
}

Node *Node::get_child_node(size_t childIdx)
Expand Down Expand Up @@ -966,10 +959,10 @@ ostream& operator<<(ostream &os, const Node *node)
<< "-----+-------+-------------+-----------+------------+-------------" << endl;

for (size_t childIdx = 0; childIdx < node->get_number_child_nodes(); ++childIdx) {
int n = 0;
size_t n = 0;
float q = Q_INIT;
if (childIdx < node->d->noVisitIdx) {
n = int(node->d->childNumberVisits[childIdx]);
n = node->d->childNumberVisits[childIdx];
q = max(node->d->qValues[childIdx], -1.0f);
}

Expand All @@ -993,8 +986,8 @@ ostream& operator<<(ostream &os, const Node *node)
<< "isTerminal:\t" << node->is_terminal() << endl
<< "isTablebase:\t" << node->is_tablebase() << endl
<< "unsolvedNodes:\t" << node->d->numberUnsolvedChildNodes << endl
<< "Visits:\t\t" << int(node->get_visits()) << endl
<< "terminalVisits:\t" << int(node->get_terminal_visits()) << endl;
<< "Visits:\t\t" << node->get_visits() << endl
<< "terminalVisits:\t" << node->get_terminal_visits() << endl;
return os;
}

Expand Down Expand Up @@ -1072,7 +1065,7 @@ bool is_terminal_value(float value)
return (value == WIN || value == DRAW || value == LOSS);
}

float get_node_count(const Node *node)
size_t get_node_count(const Node *node)
{
return node->get_visits() - node->get_terminal_visits();
}
10 changes: 5 additions & 5 deletions engine/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Node

void revert_virtual_loss_and_update(float value);
Node* get_parent_node() const;
void increment_visits(float numberVisits);
void increment_visits(size_t numberVisits);
void subtract_visits(size_t numberVisits);
void increment_no_visit_idx();
void fully_expand_node();
Expand All @@ -192,7 +192,7 @@ class Node
*/
void make_to_root();

float get_visits() const;
uint32_t get_visits() const;

void lock();
void unlock();
Expand Down Expand Up @@ -312,13 +312,13 @@ class Node
* @return ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Node* node);
DynamicVector<float> get_child_number_visits() const;
DynamicVector<uint32_t> get_child_number_visits() const;
void enable_has_nn_results();
int plies_from_null() const;
bool is_tablebase() const;
uint8_t get_node_type() const;
uint16_t get_end_in_ply() const;
float get_terminal_visits() const;
uint32_t get_terminal_visits() const;

void init_node_data(size_t numberNodes);
void init_node_data();
Expand Down Expand Up @@ -562,6 +562,6 @@ bool is_terminal_value(float value);
* @param node Given node
* @return Number of subnodes for thhe given node
*/
float get_node_count(const Node* node);
size_t get_node_count(const Node* node);

#endif // NODE_H
2 changes: 1 addition & 1 deletion engine/src/nodedata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

void NodeData::add_empty_node()
{
append(childNumberVisits, 0.0f);
append(childNumberVisits, 0U);
append(actionValues, 0.0f);
append(qValues, Q_INIT);
childNodes.emplace_back(nullptr);
Expand Down
4 changes: 2 additions & 2 deletions engine/src/nodedata.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ class Node;
*/
struct NodeData
{
DynamicVector<float> childNumberVisits;
DynamicVector<uint32_t> childNumberVisits;
DynamicVector<float> actionValues;
DynamicVector<float> qValues;
vector<Node*> childNodes;

float terminalVisits;
uint32_t terminalVisits;

uint16_t checkmateIdx;
uint16_t endInPly;
Expand Down
79 changes: 38 additions & 41 deletions engine/src/searchthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,23 @@ void SearchThread::set_is_running(bool value)
isRunning = value;
}

void SearchThread::add_new_node_to_tree(Board* newPos, Node* parentNode, size_t childIdx, bool inCheck)
NodeBackup SearchThread::add_new_node_to_tree(Board* newPos, Node* parentNode, size_t childIdx, bool inCheck)
{
mapWithMutex->mtx.lock();
unordered_map<Key, Node*>::const_iterator it = mapWithMutex->hashTable.find(newPos->hash_key());
if(searchSettings->useTranspositionTable && it != mapWithMutex->hashTable.end() &&
is_transposition_verified(it, newPos->get_state_info())) {
mapWithMutex->mtx.unlock();
parentNode->increment_no_visit_idx();
Node *newNode = new Node(*it->second);
parentNode->add_transposition_child_node(newNode, childIdx);
transpositionNodes->add_element(newNode);
}
else {
mapWithMutex->mtx.unlock();
parentNode->increment_no_visit_idx();
assert(parentNode != nullptr);
Node *newNode = new Node(newPos, inCheck, parentNode, childIdx, searchSettings);
// fill a new board in the input_planes vector
// we shift the index by NB_VALUES_TOTAL each time
board_to_planes(newPos, newPos->number_repetitions(), true, inputPlanes+newNodes->size()*NB_VALUES_TOTAL);

// connect the Node to the parent
parentNode->add_new_child_node(newNode, childIdx);

// save a reference newly created list in the temporary list for node creation
// it will later be updated with the evaluation of the NN
newNodes->add_element(newNode);
newNodeSideToMove->add_element(newPos->side_to_move());
return NODE_TRANSPOSITION;
}
mapWithMutex->mtx.unlock();
assert(parentNode != nullptr);
Node *newNode = new Node(newPos, inCheck, parentNode, childIdx, searchSettings);
// connect the Node to the parent
parentNode->add_new_child_node(newNode, childIdx);
return NODE_NEW_NODE;
}

void SearchThread::stop()
Expand Down Expand Up @@ -157,12 +145,16 @@ void random_root_playout(NodeDescription& description, Node* currentNode, size_t
return;
}
}
childIdx = min(currentNode->get_no_visit_idx(), currentNode->get_number_child_nodes()-1);
currentNode->increment_no_visit_idx();
else {
childIdx = min(currentNode->get_no_visit_idx(), currentNode->get_number_child_nodes()-1);
currentNode->lock();
currentNode->increment_no_visit_idx();
currentNode->unlock();
}
}
}

Node* get_new_child_to_evaluate(Board* pos, Node* rootNode, size_t& childIdx, NodeDescription& description, bool& inCheck, StateListPtr& states, const SearchSettings* searchSettings)
Node* SearchThread::get_new_child_to_evaluate(Board* pos, size_t& childIdx, NodeDescription& description) //, bool& inCheck)
{
rootNode->increment_visits(searchSettings->virtualLoss);
description.depth = 0;
Expand All @@ -183,24 +175,22 @@ Node* get_new_child_to_evaluate(Board* pos, Node* rootNode, size_t& childIdx, No
Node* nextNode = currentNode->get_child_node(childIdx);
description.depth++;
if (nextNode == nullptr) {
description.isCollision = false;
description.isTerminal = false;
currentNode->unlock();
inCheck = pos->gives_check(currentNode->get_move(childIdx));
const bool inCheck = pos->gives_check(currentNode->get_move(childIdx));
// this new StateInfo will be freed from memory when 'pos' is freed
pos->do_move(currentNode->get_move(childIdx), *(new StateInfo));
description.type = add_new_node_to_tree(pos, currentNode, childIdx, inCheck);
currentNode->increment_no_visit_idx();
currentNode->unlock();
return currentNode;
}
if (nextNode->is_terminal()) {
description.isCollision = false;
description.isTerminal = true;
description.type = NODE_TERMINAL;
currentNode->unlock();
pos->do_move(currentNode->get_move(childIdx), *(new StateInfo));
return currentNode;
}
if (!nextNode->has_nn_results()) {
description.isCollision = true;
description.isTerminal = false;
description.type = NODE_COLLISION;
currentNode->unlock();
pos->do_move(currentNode->get_move(childIdx), *(new StateInfo));
return currentNode;
Expand Down Expand Up @@ -295,23 +285,30 @@ void SearchThread::create_mini_batch()
numTerminalNodes < TERMINAL_NODE_CACHE) {

Board newPos = Board(*rootPos);
bool inCheck;
parentNode = get_new_child_to_evaluate(&newPos, rootNode, childIdx, description, inCheck, states, searchSettings);
parentNode = get_new_child_to_evaluate(&newPos, childIdx, description);
Node* newNode = parentNode->get_child_node(childIdx);
depthSum += description.depth;
depthMax = max(depthMax, description.depth);

get_avg_depth();

if(description.isTerminal) {
if(description.type == NODE_TERMINAL) {
++numTerminalNodes;
parentNode->backup_value(childIdx, -parentNode->get_child_node(childIdx)->get_value(), searchSettings->virtualLoss);
parentNode->backup_value(childIdx, -newNode->get_value(), searchSettings->virtualLoss);
}
else if (description.isCollision) {
else if (description.type == NODE_COLLISION) {
// store a pointer to the collision node in order to revert the virtual loss of the forward propagation
collisionNodes->add_element(parentNode->get_child_node(childIdx));
collisionNodes->add_element(newNode);
}
else {
add_new_node_to_tree(&newPos, parentNode, childIdx, inCheck);
else if (description.type == NODE_TRANSPOSITION) {
transpositionNodes->add_element(newNode);
}
else { // NODE_NEW_NODE
// fill a new board in the input_planes vector
// we shift the index by NB_VALUES_TOTAL each time
board_to_planes(&newPos, newPos.number_repetitions(), true, inputPlanes+newNodes->size()*NB_VALUES_TOTAL);
// save a reference newly created list in the temporary list for node creation
// it will later be updated with the evaluation of the NN
newNodes->add_element(newNode);
newNodeSideToMove->add_element(newPos.side_to_move());
}
}
}
Expand Down
Loading

0 comments on commit a6e3628

Please sign in to comment.