From 394bb035e2f785fa7524b44ee3996bb9e159e2ab Mon Sep 17 00:00:00 2001 From: Zyxxx_xxxyZ Date: Wed, 1 Jul 2026 10:16:10 +0800 Subject: [PATCH] Add persistent MCTS export/import support --- cpp/command/runtests.cpp | 23 +- cpp/main.cpp | 6 + cpp/main.h | 2 + cpp/neuralnet/nneval.cpp | 23 +- cpp/neuralnet/nneval.h | 1 + cpp/search/search.cpp | 758 +++++++++++++++++++++++++++- cpp/search/search.h | 44 ++ cpp/search/searchexplorehelpers.cpp | 51 +- cpp/search/searchnode.cpp | 121 ++++- cpp/search/searchnode.h | 14 + cpp/search/searchupdatehelpers.cpp | 289 +++++++++++ cpp/tests/tests.h | 2 + cpp/tests/testsearchnonn.cpp | 571 +++++++++++++++++++++ docs/persistent-mcts.md | 85 ++++ 14 files changed, 1961 insertions(+), 29 deletions(-) create mode 100644 docs/persistent-mcts.md diff --git a/cpp/command/runtests.cpp b/cpp/command/runtests.cpp index 663add137c..241b0c0f8e 100644 --- a/cpp/command/runtests.cpp +++ b/cpp/command/runtests.cpp @@ -92,6 +92,28 @@ int MainCmds::runoutputtests(const vector& args) { return 0; } +int MainCmds::runpersistentmctstests(const vector& args) { + (void)args; + Board::initHash(); + ScoreValue::initTables(); + + Tests::runPersistentMCTSTests(); + + ScoreValue::freeTables(); + return 0; +} + +int MainCmds::runpersistentmctsstricttests(const vector& args) { + (void)args; + Board::initHash(); + ScoreValue::initTables(); + + Tests::runPersistentMCTSStrictTests(); + + ScoreValue::freeTables(); + return 0; +} + int MainCmds::runsearchtests(const vector& args) { Board::initHash(); ScoreValue::initTables(); @@ -773,4 +795,3 @@ int MainCmds::runconfigtests(const vector& args) { Tests::runParseAllConfigsTest(); return 0; } - diff --git a/cpp/main.cpp b/cpp/main.cpp index 8e4cba8f31..f2dce1ceeb 100644 --- a/cpp/main.cpp +++ b/cpp/main.cpp @@ -57,6 +57,8 @@ runnnsymmetriestest : Run neural net on a hardcoded rectangle board and dump sym runownershiptests : Run neural net search on some hardcoded positions and print avg ownership runoutputtests : Run a bunch of things and dump details to stdout +runpersistentmctstests : Run persistent MCTS correctness tests +runpersistentmctsstricttests : Run strict persistent MCTS SGF alignment tests runsearchtests : Run a bunch of things using a neural net and dump details to stdout runsearchtestsv3 : Run a bunch more things using a neural net and dump details to stdout runsearchtestsv8 : Run a bunch more things using a neural net and dump details to stdout @@ -103,6 +105,10 @@ static int handleSubcommand(const string& subcommand, const vector& args return MainCmds::runownershiptests(subArgs); else if(subcommand == "runoutputtests") return MainCmds::runoutputtests(subArgs); + else if(subcommand == "runpersistentmctstests") + return MainCmds::runpersistentmctstests(subArgs); + else if(subcommand == "runpersistentmctsstricttests") + return MainCmds::runpersistentmctsstricttests(subArgs); else if(subcommand == "runsearchtests") return MainCmds::runsearchtests(subArgs); else if(subcommand == "runsearchtestsv3") diff --git a/cpp/main.h b/cpp/main.h index 15581c65e3..383d65f88c 100644 --- a/cpp/main.h +++ b/cpp/main.h @@ -19,6 +19,8 @@ namespace MainCmds { int runnnontinyboardtest(const std::vector& args); int runnnsymmetriestest(const std::vector& args); int runoutputtests(const std::vector& args); + int runpersistentmctstests(const std::vector& args); + int runpersistentmctsstricttests(const std::vector& args); int runsearchtests(const std::vector& args); int runsearchtestsv3(const std::vector& args); int runsearchtestsv8(const std::vector& args); diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index ca90e0598a..7b3a8ccceb 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -13,6 +13,7 @@ NNResultBuf::NNResultBuf() includeOwnerMap(false), boardXSizeForServer(0), boardYSizeForServer(0), + nnHashForServer(), rowSpatialBuf(), rowGlobalBuf(), rowMetaBuf(), @@ -613,6 +614,11 @@ void NNEvaluator::serve( testAssert(resultBuf->hasResult == false); resultBuf->result = std::make_shared(); + string deterministicSeedBase = randSeed + ":DebugSkip:" + resultBuf->nnHashForServer.toString(); + Rand policyRand(deterministicSeedBase + ":Policy"); + Rand ownerRand(deterministicSeedBase + ":Owner"); + Rand valueRand(deterministicSeedBase + ":Value"); + float* policyProbs = resultBuf->result->policyProbs; for(int i = 0; iresult->nnXLen = nnXLen; resultBuf->result->nnYLen = nnYLen; @@ -637,7 +643,7 @@ void NNEvaluator::serve( for(int y = 0; yresult->whiteOwnerMap = whiteOwnerMap; @@ -647,11 +653,11 @@ void NNEvaluator::serve( } // These aren't really probabilities. Win/Loss/NoResult will get softmaxed later - double whiteWinProb = 0.0 + rand.nextGaussian() * 0.20; - double whiteLossProb = 0.0 + rand.nextGaussian() * 0.20; - double whiteScoreMean = 0.0 + rand.nextGaussian() * 0.20; - double whiteScoreMeanSq = 0.0 + rand.nextGaussian() * 0.20; - double whiteNoResultProb = 0.0 + rand.nextGaussian() * 0.20; + double whiteWinProb = 0.0 + valueRand.nextGaussian() * 0.20; + double whiteLossProb = 0.0 + valueRand.nextGaussian() * 0.20; + double whiteScoreMean = 0.0 + valueRand.nextGaussian() * 0.20; + double whiteScoreMeanSq = 0.0 + valueRand.nextGaussian() * 0.20; + double whiteNoResultProb = 0.0 + valueRand.nextGaussian() * 0.20; double varTimeLeft = 0.5 * boardXSize * boardYSize; resultBuf->result->whiteWinProb = (float)whiteWinProb; resultBuf->result->whiteLossProb = (float)whiteLossProb; @@ -901,6 +907,7 @@ void NNEvaluator::evaluate( buf.boardXSizeForServer = board.x_size; buf.boardYSizeForServer = board.y_size; + buf.nnHashForServer = nnHash; if(!debugSkipNeuralNet) { fillRowBufs(board, history, nextPlayer, sgfMeta, nnInputParams, buf); diff --git a/cpp/neuralnet/nneval.h b/cpp/neuralnet/nneval.h index 0bbf167c95..19f32e64a8 100644 --- a/cpp/neuralnet/nneval.h +++ b/cpp/neuralnet/nneval.h @@ -51,6 +51,7 @@ struct NNResultBuf { bool includeOwnerMap; int boardXSizeForServer; int boardYSizeForServer; + Hash128 nnHashForServer; std::vector rowSpatialBuf; std::vector rowGlobalBuf; std::vector rowMetaBuf; diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index c7a6e4fa75..52fd8b47e3 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -6,7 +6,9 @@ #include "../search/search.h" #include +#include #include +#include #include "../core/fancymath.h" #include "../core/test.h" @@ -89,6 +91,14 @@ Search::Search(const SearchParams& params, NNEvaluator* nnEval, NNEvaluator* hum lastSearchNumPlayouts(0), effectiveSearchTimeCarriedOver(0.0), randSeed(rSeed), + persistentMCTSEnabled(false), + persistentCurrentRootKey(), + persistentCurrentRootAncestorKeys(), + persistentRootNodes(), + persistentPendingVisitsByRoot(), + persistentPendingVisitsMutex(), + persistentPropagateDescendantCredits(true), + persistentConsumingPendingVisits(false), rootKoHashTable(NULL), valueWeightDistribution(NULL), patternBonusTable(NULL), @@ -295,18 +305,454 @@ void Search::setNNEval(NNEvaluator* nnEval) { void Search::clearSearch() { effectiveSearchTimeCarriedOver = 0.0; + if(rootNode == NULL && persistentRootNodes.size() > 0) { + for(auto& entry: persistentRootNodes) + delete entry.second; + persistentRootNodes.clear(); + } if(rootNode != NULL) { deleteAllTableNodesMulithreaded(); - //Root is not stored in node table - if(rootNode != NULL) { - delete rootNode; + if(persistentRootNodes.size() > 0) { + for(auto& entry: persistentRootNodes) + delete entry.second; + persistentRootNodes.clear(); rootNode = NULL; } + else { + //Root is not stored in node table + if(rootNode != NULL) { + delete rootNode; + rootNode = NULL; + } + } } + persistentCurrentRootKey = Hash128(); + persistentCurrentRootAncestorKeys.clear(); + persistentPendingVisitsByRoot.clear(); + persistentPropagateDescendantCredits = true; + persistentConsumingPendingVisits = false; clearOldNNOutputs(); searchNodeAge = 0; } +void Search::setPersistentMCTSEnabled(bool enabled) { + if(persistentMCTSEnabled == enabled) + return; + clearSearch(); + persistentMCTSEnabled = enabled; +} + +bool Search::getPersistentMCTSEnabled() const { + return persistentMCTSEnabled; +} + +void Search::setPositionForMCTSPersistence(Player pla, const Board& board, const BoardHistory& history) { + if(!persistentMCTSEnabled) { + setPosition(pla,board,history); + return; + } + + rootPla = pla; + plaThatSearchIsFor = C_EMPTY; + rootBoard = board; + rootHistory = history; + rootKoHashTable->recompute(rootHistory); + avoidMoveUntilByLocBlack.clear(); + avoidMoveUntilByLocWhite.clear(); + + persistentCurrentRootKey = getPersistentRootKey(rootHistory,rootPla); + persistentCurrentRootAncestorKeys = getPersistentRootAncestorKeys(rootHistory,rootPla); + + const bool forceNonTerminal = rootHistory.isGameFinished; + rootNode = getOrCreatePersistentRootNode(persistentCurrentRootKey, forceNonTerminal); + rootGraphHash = persistentCurrentRootKey; + materializePersistentMCTS(false); + consumePersistentPendingVisits(); +} + +static nlohmann::json nodeStatsToJson(const NodeStats& stats) { + nlohmann::json data; + data["visits"] = stats.visits; + data["winLossValueAvg"] = stats.winLossValueAvg; + data["noResultValueAvg"] = stats.noResultValueAvg; + data["scoreMeanAvg"] = stats.scoreMeanAvg; + data["scoreMeanSqAvg"] = stats.scoreMeanSqAvg; + data["leadAvg"] = stats.leadAvg; + data["utilityAvg"] = stats.utilityAvg; + data["utilitySqAvg"] = stats.utilitySqAvg; + data["weightSum"] = stats.weightSum; + data["weightSqSum"] = stats.weightSqSum; + return data; +} + +static NodeStats nodeStatsOfJson(const nlohmann::json& data) { + NodeStats stats; + stats.visits = data["visits"].get(); + stats.winLossValueAvg = data["winLossValueAvg"].get(); + stats.noResultValueAvg = data["noResultValueAvg"].get(); + stats.scoreMeanAvg = data["scoreMeanAvg"].get(); + stats.scoreMeanSqAvg = data["scoreMeanSqAvg"].get(); + stats.leadAvg = data["leadAvg"].get(); + stats.utilityAvg = data["utilityAvg"].get(); + stats.utilitySqAvg = data["utilitySqAvg"].get(); + stats.weightSum = data["weightSum"].get(); + stats.weightSqSum = data["weightSqSum"].get(); + return stats; +} + +static nlohmann::json nnOutputToJson(const NNOutput* nnOutput) { + nlohmann::json data; + if(nnOutput == NULL) + return data; + data["nnHash"] = nnOutput->nnHash.toString(); + data["whiteWinProb"] = nnOutput->whiteWinProb; + data["whiteLossProb"] = nnOutput->whiteLossProb; + data["whiteNoResultProb"] = nnOutput->whiteNoResultProb; + data["whiteScoreMean"] = nnOutput->whiteScoreMean; + data["whiteScoreMeanSq"] = nnOutput->whiteScoreMeanSq; + data["whiteLead"] = nnOutput->whiteLead; + data["varTimeLeft"] = nnOutput->varTimeLeft; + data["shorttermWinlossError"] = nnOutput->shorttermWinlossError; + data["shorttermScoreError"] = nnOutput->shorttermScoreError; + data["policyOptimismUsed"] = nnOutput->policyOptimismUsed; + data["nnXLen"] = nnOutput->nnXLen; + data["nnYLen"] = nnOutput->nnYLen; + data["policyProbs"] = vector(nnOutput->policyProbs, nnOutput->policyProbs + NNPos::MAX_NN_POLICY_SIZE); + if(nnOutput->whiteOwnerMap != NULL) + data["whiteOwnerMap"] = vector(nnOutput->whiteOwnerMap, nnOutput->whiteOwnerMap + nnOutput->nnXLen * nnOutput->nnYLen); + if(nnOutput->noisedPolicyProbs != NULL) + data["noisedPolicyProbs"] = vector(nnOutput->noisedPolicyProbs, nnOutput->noisedPolicyProbs + NNPos::MAX_NN_POLICY_SIZE); + return data; +} + +static shared_ptr nnOutputOfJson(const nlohmann::json& data) { + if(data.is_null() || data.empty()) + return nullptr; + shared_ptr nnOutput = make_shared(); + nnOutput->nnHash = Hash128::ofString(data["nnHash"].get()); + nnOutput->whiteWinProb = data["whiteWinProb"].get(); + nnOutput->whiteLossProb = data["whiteLossProb"].get(); + nnOutput->whiteNoResultProb = data["whiteNoResultProb"].get(); + nnOutput->whiteScoreMean = data["whiteScoreMean"].get(); + nnOutput->whiteScoreMeanSq = data["whiteScoreMeanSq"].get(); + nnOutput->whiteLead = data["whiteLead"].get(); + nnOutput->varTimeLeft = data["varTimeLeft"].get(); + nnOutput->shorttermWinlossError = data["shorttermWinlossError"].get(); + nnOutput->shorttermScoreError = data["shorttermScoreError"].get(); + nnOutput->policyOptimismUsed = data["policyOptimismUsed"].get(); + nnOutput->nnXLen = data["nnXLen"].get(); + nnOutput->nnYLen = data["nnYLen"].get(); + vector policyProbs = data["policyProbs"].get>(); + testAssert(policyProbs.size() == NNPos::MAX_NN_POLICY_SIZE); + std::copy(policyProbs.begin(), policyProbs.end(), nnOutput->policyProbs); + if(data.contains("whiteOwnerMap")) { + vector whiteOwnerMap = data["whiteOwnerMap"].get>(); + testAssert(whiteOwnerMap.size() == nnOutput->nnXLen * nnOutput->nnYLen); + nnOutput->whiteOwnerMap = new float[whiteOwnerMap.size()]; + std::copy(whiteOwnerMap.begin(), whiteOwnerMap.end(), nnOutput->whiteOwnerMap); + } + if(data.contains("noisedPolicyProbs")) { + vector noisedPolicyProbs = data["noisedPolicyProbs"].get>(); + testAssert(noisedPolicyProbs.size() == NNPos::MAX_NN_POLICY_SIZE); + nnOutput->noisedPolicyProbs = new float[NNPos::MAX_NN_POLICY_SIZE]; + std::copy(noisedPolicyProbs.begin(), noisedPolicyProbs.end(), nnOutput->noisedPolicyProbs); + } + return nnOutput; +} + +static nlohmann::json positionToJson(const Board& board, Player pla, const BoardHistory& history) { + nlohmann::json data; + data["rootBoard"] = Board::toJson(board); + data["rootPla"] = PlayerIO::playerToString(pla); + data["initialBoard"] = Board::toJson(history.initialBoard); + data["initialPla"] = PlayerIO::playerToString(history.initialPla); + data["initialEncorePhase"] = history.initialEncorePhase; + data["initialTurnNumber"] = history.initialTurnNumber; + data["assumeMultipleStartingBlackMovesAreHandicap"] = history.assumeMultipleStartingBlackMovesAreHandicap; + data["overrideNumHandicapStones"] = history.overrideNumHandicapStones; + data["rules"] = history.rules.toJson(); + data["moves"] = nlohmann::json::array(); + for(size_t i = 0; i< history.moveHistory.size(); i++) { + nlohmann::json moveData; + moveData["loc"] = history.moveHistory[i].loc; + moveData["pla"] = PlayerIO::playerToString(history.moveHistory[i].pla); + moveData["preventEncore"] = (bool)history.preventEncoreHistory[i]; + data["moves"].push_back(moveData); + } + return data; +} + +static void positionOfJson( + const nlohmann::json& data, + Board& board, + Player& pla, + BoardHistory& history +) { + Board initialBoard = Board::ofJson(data["initialBoard"]); + Player initialPla = PlayerIO::parsePlayer(data["initialPla"].get()); + Rules rules = Rules::parseRules(data["rules"].dump()); + int initialEncorePhase = data["initialEncorePhase"].get(); + BoardHistory rebuiltHistory(initialBoard, initialPla, rules, initialEncorePhase); + rebuiltHistory.setInitialTurnNumber(data["initialTurnNumber"].get()); + rebuiltHistory.setAssumeMultipleStartingBlackMovesAreHandicap( + data["assumeMultipleStartingBlackMovesAreHandicap"].get() + ); + rebuiltHistory.setOverrideNumHandicapStones(data["overrideNumHandicapStones"].get()); + + Board rebuiltBoard = initialBoard; + for(const nlohmann::json& moveData: data["moves"]) { + Loc loc = (Loc)moveData["loc"].get(); + Player movePla = PlayerIO::parsePlayer(moveData["pla"].get()); + bool preventEncore = moveData.value("preventEncore", false); + bool suc = rebuiltHistory.makeBoardMoveTolerant(rebuiltBoard, loc, movePla, preventEncore); + if(!suc) + throw IOError("Could not replay move in persistent MCTS position import"); + } + + Board savedRootBoard = Board::ofJson(data["rootBoard"]); + if(!rebuiltBoard.isEqualForTesting(savedRootBoard)) + throw IOError("Persistent MCTS import position history does not reconstruct saved root board"); + + board = rebuiltBoard; + pla = PlayerIO::parsePlayer(data["rootPla"].get()); + if(rebuiltHistory.presumedNextMovePla != pla) + throw IOError("Persistent MCTS import position root player does not match replayed history"); + history = rebuiltHistory; +} + +void Search::exportPersistentMCTS(const string& path) const { + nlohmann::json data; + data["version"] = 1; + data["persistentMCTSEnabled"] = persistentMCTSEnabled; + data["rootPla"] = PlayerIO::playerToString(rootPla); + data["position"] = positionToJson(rootBoard,rootPla,rootHistory); + data["rootKey"] = persistentCurrentRootKey.toString(); + data["rootAncestorKeys"] = nlohmann::json::array(); + for(const Hash128& key: persistentCurrentRootAncestorKeys) + data["rootAncestorKeys"].push_back(key.toString()); + data["pendingVisits"] = nlohmann::json::array(); + { + std::lock_guard lock(persistentPendingVisitsMutex); + for(const auto& entry: persistentPendingVisitsByRoot) { + if(entry.second <= 0) + continue; + nlohmann::json pending; + pending["root"] = entry.first.toString(); + pending["visits"] = entry.second; + data["pendingVisits"].push_back(pending); + } + } + data["rootCopies"] = nlohmann::json::array(); + + unordered_map idByNode; + vector nodes; + vector inNodeTable; + auto addNode = [&](const SearchNode* node, bool isInNodeTable) { + auto iter = idByNode.find(node); + if(iter != idByNode.end()) { + if(isInNodeTable) + inNodeTable[iter->second] = true; + return iter->second; + } + int id = (int)nodes.size(); + idByNode[node] = id; + nodes.push_back(node); + inNodeTable.push_back(isInNodeTable); + return id; + }; + for(const auto& entry: persistentRootNodes) + addNode(entry.second,false); + for(const std::map& nodeMap: nodeTable->entries) { + for(const auto& entry: nodeMap) + addNode(entry.second,true); + } + for(size_t idx = 0; idxgetChildren(); + int childrenCapacity = children.getCapacity(); + for(int i = 0; inextPla; + nodeData["forceNonTerminal"] = node->forceNonTerminal; + nodeData["mutexIdx"] = node->mutexIdx; + nodeData["graphHash"] = node->graphHash.toString(); + nodeData["direct"] = nlohmann::json::array(); + const vector>* directStats = node->getPersistentDirectStatsByRoot(); + if(directStats != NULL) { + for(const pair& entry: *directStats) { + nlohmann::json direct; + direct["root"] = entry.first.toString(); + direct["stats"] = nodeStatsToJson(entry.second); + nodeData["direct"].push_back(direct); + } + } + nodeData["nnOutput"] = nnOutputToJson(node->getNNOutput()); + nodeData["humanOutput"] = nnOutputToJson(node->getHumanOutput()); + nodeData["children"] = nlohmann::json::array(); + ConstSearchNodeChildrenReference children = node->getChildren(); + int childrenCapacity = children.getCapacity(); + for(int i = 0; i>* edgeVisits = childPointer.getPersistentEdgeVisitsByRoot(); + if(edgeVisits != NULL) { + for(const pair& entry: *edgeVisits) { + nlohmann::json edge; + edge["root"] = entry.first.toString(); + edge["visits"] = entry.second; + childData["edgeVisits"].push_back(edge); + } + } + nodeData["children"].push_back(childData); + } + data["nodes"].push_back(nodeData); + } + + ofstream out; + out.open(path); + if(!out) + throw IOError("Could not open persistent MCTS export file for writing: " + path); + out << data.dump(2) << "\n"; +} + +void Search::importPersistentMCTS(const string& path) { + ifstream in; + in.open(path); + if(!in) + throw IOError("Could not open persistent MCTS import file for reading: " + path); + nlohmann::json data; + in >> data; + clearSearch(); + persistentMCTSEnabled = data.value("persistentMCTSEnabled", true); + bool hasStoredRootKey = data.contains("rootKey"); + Hash128 storedRootKey; + if(hasStoredRootKey) + storedRootKey = Hash128::ofString(data["rootKey"].get()); + + if(data.contains("position")) { + positionOfJson(data["position"],rootBoard,rootPla,rootHistory); + rootKoHashTable->recompute(rootHistory); + Hash128 recomputedRootKey = getPersistentRootKey(rootHistory,rootPla); + if(hasStoredRootKey && storedRootKey != recomputedRootKey) + throw IOError("Persistent MCTS import root key does not match saved position"); + persistentCurrentRootKey = recomputedRootKey; + persistentCurrentRootAncestorKeys = getPersistentRootAncestorKeys(rootHistory,rootPla); + rootGraphHash = recomputedRootKey; + } + else { + if(data.contains("rootPla")) + rootPla = PlayerIO::parsePlayer(data["rootPla"].get()); + if(hasStoredRootKey) + persistentCurrentRootKey = storedRootKey; + persistentCurrentRootAncestorKeys.clear(); + if(data.contains("rootAncestorKeys")) { + for(const nlohmann::json& key: data["rootAncestorKeys"]) + persistentCurrentRootAncestorKeys.push_back(Hash128::ofString(key.get())); + } + } + persistentPendingVisitsByRoot.clear(); + if(data.contains("pendingVisits")) { + for(const nlohmann::json& pending: data["pendingVisits"]) { + Hash128 key = Hash128::ofString(pending["root"].get()); + int64_t visits = pending["visits"].get(); + if(visits > 0) + persistentPendingVisitsByRoot[key] += visits; + } + } + vector nodes; + if(data.contains("nodes")) { + nodes.resize(data["nodes"].size(),NULL); + for(const nlohmann::json& nodeData: data["nodes"]) { + int id = nodeData["id"].get(); + Player nextPla = (Player)nodeData["nextPla"].get(); + bool forceNonTerminal = nodeData["forceNonTerminal"].get(); + uint32_t mutexIdx = nodeData["mutexIdx"].get(); + Hash128 graphHash = Hash128::ofString(nodeData["graphHash"].get()); + SearchNode* node = new SearchNode(nextPla,forceNonTerminal,mutexIdx,graphHash); + vector> directEntries; + for(const nlohmann::json& direct: nodeData["direct"]) + directEntries.push_back(make_pair(Hash128::ofString(direct["root"].get()), nodeStatsOfJson(direct["stats"]))); + node->setPersistentDirectStatsByRootForLoad(directEntries); + shared_ptr nnOutput = nnOutputOfJson(nodeData["nnOutput"]); + if(nnOutput != nullptr) + node->storeNNOutputIfNull(new shared_ptr(nnOutput)); + shared_ptr humanOutput = nnOutputOfJson(nodeData["humanOutput"]); + if(humanOutput != nullptr) + node->storeHumanOutputIfNull(new shared_ptr(humanOutput)); + nodes[id] = node; + } + for(const nlohmann::json& nodeData: data["nodes"]) { + int id = nodeData["id"].get(); + SearchNode* node = nodes[id]; + const nlohmann::json& childrenData = nodeData["children"]; + if(childrenData.size() > 0) { + node->initializeChildren(); + SearchNodeState stateValue = SearchNode::STATE_EXPANDED0; + node->state.store(stateValue,std::memory_order_release); + for(int i = 0; imaybeExpandChildrenCapacityForNewChild(stateValue,i+1); + testAssert(suc); + SearchNodeChildrenReference children = node->getChildren(stateValue); + SearchChildPointer& childPointer = children[i]; + childPointer.setMoveLoc((Loc)childrenData[i]["moveLoc"].get()); + childPointer.store(nodes[childrenData[i]["node"].get()]); + vector> edgeEntries; + for(const nlohmann::json& edge: childrenData[i]["edgeVisits"]) + edgeEntries.push_back(make_pair(Hash128::ofString(edge["root"].get()), edge["visits"].get())); + childPointer.setPersistentEdgeVisitsByRootForLoad(edgeEntries); + } + } + else if(node->getNNOutput() != NULL) { + node->initializeChildren(); + node->state.store(SearchNode::STATE_EXPANDED0,std::memory_order_release); + } + if(nodeData["inNodeTable"].get()) { + Hash128 childHash = node->graphHash; + if(node->forceNonTerminal) + childHash ^= Hash128(0xd4c31800cb8809e2ULL,0xf75f9d2083f2ffcaULL); + uint32_t nodeTableIdx = nodeTable->getIndex(childHash.hash0); + nodeTable->entries[nodeTableIdx][childHash] = node; + } + } + for(const nlohmann::json& rootCopy: data["rootCopies"]) { + Hash128 key = Hash128::ofString(rootCopy["key"].get()); + SearchNode* node = nodes[rootCopy["node"].get()]; + persistentRootNodes[key] = node; + } + auto rootIter = persistentRootNodes.find(persistentCurrentRootKey); + if(rootIter != persistentRootNodes.end()) + rootNode = rootIter->second; + else if(persistentRootNodes.size() > 0) + rootNode = persistentRootNodes.begin()->second; + materializePersistentMCTS(false); + consumePersistentPendingVisits(); + } +} + bool Search::isLegalTolerant(Loc moveLoc, Player movePla) const { return rootHistory.isLegalTolerant(rootBoard,moveLoc,movePla); } @@ -323,6 +769,32 @@ bool Search::makeMove(Loc moveLoc, Player movePla, bool preventEncore) { if(!isLegalTolerant(moveLoc,movePla)) return false; + if(persistentMCTSEnabled) { + if(movePla != rootPla) + setPlayerAndClearHistory(movePla); + + float oldWhiteHandicapBonusScore = rootHistory.whiteHandicapBonusScore; + Board newBoard = rootBoard; + BoardHistory newHistory = rootHistory; + KoHashTable newKoHashTable; + newKoHashTable.recompute(newHistory); + newHistory.makeBoardMoveAssumeLegal(newBoard,moveLoc,movePla,&newKoHashTable,preventEncore); + Player newPla = getOpp(movePla); + + setPositionForMCTSPersistence(newPla,newBoard,newHistory); + + avoidMoveUntilByLocBlack.clear(); + avoidMoveUntilByLocWhite.clear(); + + if(rootHistory.whiteHandicapBonusScore != oldWhiteHandicapBonusScore) + clearSearch(); + if(searchParams.conservativePass && rootHistory.passWouldEndGame(rootBoard,rootPla)) + clearSearch(); + if(preventEncore && rootHistory.passWouldEndPhase(rootBoard,rootPla)) + clearSearch(); + return true; + } + if(movePla != rootPla) setPlayerAndClearHistory(movePla); @@ -609,6 +1081,9 @@ void Search::runWholeSearch( } } + if(persistentMCTSEnabled && rootNode != NULL) + materializePersistentMCTS(true); + if(searchParams.useEvalCache && searchParams.useGraphSearch && evalCache != nullptr && rootNode != NULL && mirroringPla == C_EMPTY) { recursivelyRecordEvalCache(*rootNode); } @@ -663,6 +1138,13 @@ void Search::beginSearch(bool pondering) { //cout << "BEGINSEARCH " << PlayerIO::playerToString(rootPla) << " " << PlayerIO::playerToString(plaThatSearchIsFor) << endl; clearOldNNOutputs(); + if(persistentMCTSEnabled) { + persistentCurrentRootKey = getPersistentRootKey(rootHistory,rootPla); + persistentCurrentRootAncestorKeys = getPersistentRootAncestorKeys(rootHistory,rootPla); + const bool forceNonTerminal = rootHistory.isGameFinished; + rootNode = getOrCreatePersistentRootNode(persistentCurrentRootKey, forceNonTerminal); + materializePersistentMCTS(false); + } computeRootValues(); //Prepare value bias table if we need it @@ -707,10 +1189,13 @@ void Search::beginSearch(bool pondering) { rootSymmetries.push_back(0); } + if(persistentMCTSEnabled) + materializePersistentMCTS(true); + SearchThread dummyThread(-1, *this); //If we're using graph search, we recompute the graph hash from scratch at the start of search. - if(searchParams.useGraphSearch) + if(searchParams.useGraphSearch || persistentMCTSEnabled) rootGraphHash = GraphHash::getGraphHashFromScratch(rootHistory, rootPla, searchParams.graphSearchRepBound, searchParams.drawEquivalentWinsForWhite); else rootGraphHash = Hash128(); @@ -736,7 +1221,7 @@ void Search::beginSearch(bool pondering) { SearchNodeChildrenReference children = node.getChildren(); int childrenCapacity = children.getCapacity(); bool anyFiltered = false; - if(childrenCapacity > 0) { + if(childrenCapacity > 0 && !persistentMCTSEnabled) { //This filtering, by deleting children, doesn't conform to the normal invariants that hold during search. //However nothing else should be running at this time and the search hasn't actually started yet, so this is okay. @@ -823,6 +1308,9 @@ void Search::beginSearch(bool pondering) { } } + if(persistentMCTSEnabled) + consumePersistentPendingVisits(); + //Clear unused stuff in value bias table since we may have pruned rootNode stuff if(searchParams.subtreeValueBiasFactor != 0 && subtreeValueBiasTable != NULL) subtreeValueBiasTable->clearUnusedSynchronous(); @@ -838,12 +1326,220 @@ uint32_t Search::createMutexIdxForNode(SearchThread& thread) const { //Based on sha256 of "search.cpp FORCE_NON_TERMINAL_HASH" static const Hash128 FORCE_NON_TERMINAL_HASH = Hash128(0xd4c31800cb8809e2ULL,0xf75f9d2083f2ffcaULL); +Hash128 Search::getPersistentRootKey(const BoardHistory& history, Player pla) const { + return GraphHash::getGraphHashFromScratch( + history, pla, searchParams.graphSearchRepBound, searchParams.drawEquivalentWinsForWhite + ); +} + +vector Search::getPersistentRootAncestorKeys(const BoardHistory& historyOrig, Player pla) const { + vector ret; + BoardHistory history = historyOrig.copyToInitial(); + Board board = history.getRecentBoard(0); + Hash128 graphHash; + + for(size_t i = 0; i <= historyOrig.moveHistory.size(); i++) { + Player nextPla = (i < historyOrig.moveHistory.size()) ? historyOrig.moveHistory[i].pla : pla; + graphHash = GraphHash::getGraphHash( + graphHash, history, nextPla, searchParams.graphSearchRepBound, searchParams.drawEquivalentWinsForWhite + ); + ret.push_back(graphHash); + if(i < historyOrig.moveHistory.size()) { + bool suc = history.makeBoardMoveTolerant( + board, historyOrig.moveHistory[i].loc, historyOrig.moveHistory[i].pla, historyOrig.preventEncoreHistory[i] + ); + testAssert(suc); + } + } + return ret; +} + +SearchNode* Search::findPersistentNodeInTable(Hash128 graphHash, bool forceNonTerminal) const { + Hash128 childHash = graphHash; + if(forceNonTerminal) + childHash ^= FORCE_NON_TERMINAL_HASH; + uint32_t nodeTableIdx = nodeTable->getIndex(childHash.hash0); + const std::map& nodeMap = nodeTable->entries[nodeTableIdx]; + auto iter = nodeMap.find(childHash); + if(iter == nodeMap.end()) + return NULL; + return iter->second; +} + +SearchNode* Search::getOrCreatePersistentRootNode(Hash128 rootKey, bool forceNonTerminal) { + auto iter = persistentRootNodes.find(rootKey); + if(iter != persistentRootNodes.end()) + return iter->second; + + SearchNode* node = NULL; + SearchNode* tableNode = findPersistentNodeInTable(rootKey, forceNonTerminal); + if(tableNode != NULL) { + const bool copySubtreeValueBias = false; + node = new SearchNode(*tableNode, forceNonTerminal, copySubtreeValueBias); + } + else { + uint32_t mutexIdx = nonSearchRand.nextUInt() & (mutexPool->getNumMutexes()-1); + node = new SearchNode(rootPla, forceNonTerminal, mutexIdx, rootKey); + } + persistentRootNodes[rootKey] = node; + return node; +} + +void Search::setNodeStats(SearchNode& node, const NodeStats& stats) { + while(node.statsLock.test_and_set(std::memory_order_acquire)); + node.stats.winLossValueAvg.store(stats.winLossValueAvg,std::memory_order_release); + node.stats.noResultValueAvg.store(stats.noResultValueAvg,std::memory_order_release); + node.stats.scoreMeanAvg.store(stats.scoreMeanAvg,std::memory_order_release); + node.stats.scoreMeanSqAvg.store(stats.scoreMeanSqAvg,std::memory_order_release); + node.stats.leadAvg.store(stats.leadAvg,std::memory_order_release); + node.stats.utilityAvg.store(stats.utilityAvg,std::memory_order_release); + node.stats.utilitySqAvg.store(stats.utilitySqAvg,std::memory_order_release); + node.stats.weightSqSum.store(stats.weightSqSum,std::memory_order_release); + node.stats.weightSum.store(stats.weightSum,std::memory_order_release); + node.stats.visits.store(stats.visits,std::memory_order_release); + node.statsLock.clear(std::memory_order_release); +} + +void Search::materializePersistentNode(SearchNode& node, SearchThread& thread, bool filterRootChildren) { + SearchNodeChildrenReference children = node.getChildren(); + int childrenCapacity = children.getCapacity(); + bool hasVisibleChildren = false; + bool isRoot = (&node == rootNode); + vector visibleRootKeys = {persistentCurrentRootKey}; + for(int i = 0; i 0) + hasVisibleChildren = true; + } + + NodeStats directStats = node.getPersistentDirectStats(visibleRootKeys); + if(hasVisibleChildren) + recomputeNodeStatsFromPersistentStats(node,thread,directStats,isRoot); + else + setNodeStats(node,directStats); +} + +void Search::materializePersistentMCTS(bool filterRootChildren) { + if(!persistentMCTSEnabled || rootNode == NULL) + return; + + int numAdditionalThreads = numAdditionalThreadsToUseForTasks(); + vector dummyThreads(numAdditionalThreads+1, NULL); + for(int threadIdx = 0; threadIdx f = [&](SearchNode* node, int threadIdx) { + materializePersistentNode(*node, *dummyThreads[threadIdx], filterRootChildren); + }; + applyRecursivelyPostOrderMulithreaded({rootNode},&f); + + for(int threadIdx = 0; threadIdx& visited +) { + if(!persistentMCTSEnabled) + return; + auto result = visited.insert(&node); + if(!result.second) + return; + + SearchNodeChildrenReference children = node.getChildren(); + int childrenCapacity = children.getCapacity(); + for(int i = 0; i lock(persistentPendingVisitsMutex); + auto iter = persistentPendingVisitsByRoot.find(persistentCurrentRootKey); + if(iter != persistentPendingVisitsByRoot.end()) { + pendingVisits = iter->second; + persistentPendingVisitsByRoot.erase(iter); + } + } + if(pendingVisits <= 0) + return; + + rootGraphHash = getPersistentRootKey(rootHistory,rootPla); + if(!searchParams.rootSymmetryPruning) { + std::fill(rootSymDupLoc,rootSymDupLoc+Board::MAX_ARR_SIZE,false); + rootSymmetries.clear(); + rootSymmetries.push_back(0); + } + else { + const std::vector& avoidMoveUntilByLoc = rootPla == P_BLACK ? avoidMoveUntilByLocBlack : avoidMoveUntilByLocWhite; + if(rootPruneOnlySymmetries.size() > 0) + SymmetryHelpers::markDuplicateMoveLocs(rootBoard,rootHistory,&rootPruneOnlySymmetries,avoidMoveUntilByLoc,rootSymDupLoc,rootSymmetries); + else + SymmetryHelpers::markDuplicateMoveLocs(rootBoard,rootHistory,NULL,avoidMoveUntilByLoc,rootSymDupLoc,rootSymmetries); + } + computeRootValues(); + if(rootNode == NULL) { + const bool forceNonTerminal = rootHistory.isGameFinished; + rootNode = getOrCreatePersistentRootNode(persistentCurrentRootKey, forceNonTerminal); + } + materializePersistentMCTS(true); + + bool oldPropagateDescendantCredits = persistentPropagateDescendantCredits; + bool oldConsumingPendingVisits = persistentConsumingPendingVisits; + Player oldPlaThatSearchIsFor = plaThatSearchIsFor; + Player oldPlaThatSearchIsForLastSearch = plaThatSearchIsForLastSearch; + persistentPropagateDescendantCredits = false; + persistentConsumingPendingVisits = true; + if(plaThatSearchIsFor == C_EMPTY) + plaThatSearchIsFor = rootPla; + + SearchThread thread(0,*this); + int64_t consumedVisits = 0; + int64_t attempts = 0; + while(consumedVisits < pendingVisits) { + bool counted = runSinglePlayout(thread, (double)(pendingVisits - consumedVisits)); + if(counted) + consumedVisits++; + attempts++; + if(attempts > pendingVisits * 100 + 1000) + throw StringError("Persistent MCTS pending visit catch-up failed to make progress"); + } + + persistentPropagateDescendantCredits = oldPropagateDescendantCredits; + persistentConsumingPendingVisits = oldConsumingPendingVisits; + plaThatSearchIsFor = oldPlaThatSearchIsFor; + plaThatSearchIsForLastSearch = oldPlaThatSearchIsForLastSearch; + materializePersistentMCTS(true); +} + //Must be called AFTER making the bestChildMoveLoc in the thread board and hist. SearchNode* Search::allocateOrFindNode(SearchThread& thread, Player nextPla, Loc bestChildMoveLoc, bool forceNonTerminal, Hash128 graphHash) { //Hash to use as a unique id for this node in the table, for transposition detection. //If this collides, we will be sad, but it should be astronomically rare since our hash is 128 bits. Hash128 childHash; - if(searchParams.useGraphSearch) { + if(searchParams.useGraphSearch || persistentMCTSEnabled) { childHash = graphHash; if(forceNonTerminal) childHash ^= FORCE_NON_TERMINAL_HASH; @@ -1191,6 +1887,23 @@ bool Search::playoutDescend( } SearchNodeState nodeState = node.state.load(std::memory_order_acquire); + if(persistentMCTSEnabled && nodeState >= SearchNode::STATE_EXPANDED0) { + int64_t visibleVisits = node.stats.visits.load(std::memory_order_acquire); + double visibleWeight = node.stats.weightSum.load(std::memory_order_acquire); + if(visibleVisits <= 0 || visibleWeight <= 0.0) { + if(node.getNNOutput() == NULL) { + bool suc = initNodeNNOutput(thread,node,isRoot,false,false); + if(!suc) { + thread.shouldCountPlayout = false; + return false; + } + return true; + } + addCurrentNNOutputAsLeafValue(node,true); + return true; + } + ensurePersistentDirectStatsForCurrentRoot(node,thread); + } if(nodeState == SearchNode::STATE_UNEVALUATED) { //Always attempt to set a new nnOutput. That way, if some GPU is slow and malfunctioning, we don't get blocked by it. { @@ -1231,10 +1944,11 @@ bool Search::playoutDescend( int bestChildIdx; Loc bestChildMoveLoc; bool countEdgeVisit; + bool bestChildIsNew; SearchNode* child = NULL; while(true) { - selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,countEdgeVisit,isRoot); + selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,countEdgeVisit,bestChildIsNew,isRoot); //The absurdly rare case that the move chosen is not legal //(this should only happen either on a bug or where the nnHash doesn't have full legality information or when there's an actual hash collision). @@ -1293,8 +2007,8 @@ bool Search::playoutDescend( } //Do we think we are searching a new child for the first time? - if(bestChildIdx >= numChildrenFound) { - assert(bestChildIdx == numChildrenFound); + if(bestChildIsNew) { + assert(bestChildIdx >= numChildrenFound); assert(bestChildIdx < NNPos::MAX_NN_POLICY_SIZE); bool suc = node.maybeExpandChildrenCapacityForNewChild(nodeState, numChildrenFound+1); //Someone else is expanding. Loop again trying to select the best child to explore. @@ -1317,7 +2031,7 @@ bool Search::playoutDescend( //Make the move! We need to make the move before we create the node so we can see the new state and get the right graphHash. thread.history.makeBoardMoveAssumeLegal(thread.board,bestChildMoveLoc,thread.pla,rootKoHashTable); thread.pla = getOpp(thread.pla); - if(searchParams.useGraphSearch) + if(searchParams.useGraphSearch || persistentMCTSEnabled) thread.graphHash = GraphHash::getGraphHash( thread.graphHash, thread.history, thread.pla, searchParams.graphSearchRepBound, searchParams.drawEquivalentWinsForWhite ); @@ -1329,6 +2043,11 @@ bool Search::playoutDescend( canForceNonTerminalDueToFriendlyPass ); child = allocateOrFindNode(thread, thread.pla, bestChildMoveLoc, forceNonTerminal, thread.graphHash); + if(persistentMCTSEnabled) { + std::unordered_set visited; + materializePersistentSubtree(*child,thread,false,visited); + ensurePersistentDirectStatsForCurrentRoot(*child,thread); + } child->virtualLosses.fetch_add(1,std::memory_order_release); { @@ -1367,6 +2086,11 @@ bool Search::playoutDescend( child = children[bestChildIdx].getIfAllocated(); assert(child != NULL); + if(persistentMCTSEnabled) { + std::unordered_set visited; + materializePersistentSubtree(*child,thread,false,visited); + ensurePersistentDirectStatsForCurrentRoot(*child,thread); + } child->virtualLosses.fetch_add(1,std::memory_order_release); //If edge visits is too much smaller than the child's visits, we can avoid descending. @@ -1381,7 +2105,7 @@ bool Search::playoutDescend( //Make the move! thread.history.makeBoardMoveAssumeLegal(thread.board,bestChildMoveLoc,thread.pla,rootKoHashTable); thread.pla = getOpp(thread.pla); - if(searchParams.useGraphSearch) + if(searchParams.useGraphSearch || persistentMCTSEnabled) thread.graphHash = GraphHash::getGraphHash( thread.graphHash, thread.history, thread.pla, searchParams.graphSearchRepBound, searchParams.drawEquivalentWinsForWhite ); @@ -1400,7 +2124,7 @@ bool Search::playoutDescend( if(!result.second) { if(countEdgeVisit) { SearchNodeChildrenReference children = node.getChildren(nodeState); - children[bestChildIdx].addEdgeVisits(1); + addEdgeVisits(node,children[bestChildIdx],1); updateStatsAfterPlayout(node,thread,isRoot); } // Regardless of whether we count an edge visit or not here, we @@ -1421,7 +2145,7 @@ bool Search::playoutDescend( if(shouldUpdateChildAncestors) { nodeState = node.state.load(std::memory_order_acquire); SearchNodeChildrenReference children = node.getChildren(nodeState); - children[bestChildIdx].addEdgeVisits(1); + addEdgeVisits(node,children[bestChildIdx],1); updateStatsAfterPlayout(node,thread,isRoot); } child->virtualLosses.fetch_add(-1,std::memory_order_release); @@ -1468,5 +2192,13 @@ bool Search::maybeCatchUpEdgeVisits( // numToAdd = std::min((childVisits - edgeVisits + 3) / 4, maxNumToAdd); } while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd)); + if(persistentMCTSEnabled) { + { + std::lock_guard lock(mutexPool->getMutex(node.mutexIdx)); + childPointer.addPersistentEdgeVisits(persistentCurrentRootKey,numToAdd); + } + mirrorPersistentEdgeVisitsToRootCopy(node,childPointer,numToAdd); + } + return true; } diff --git a/cpp/search/search.h b/cpp/search/search.h index 30b3c7a0e9..49308660ec 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -29,6 +29,7 @@ typedef int SearchNodeState; // See SearchNode::STATE_* struct SearchNode; struct SearchThread; struct Search; +struct NodeStats; struct DistributionTable; struct PatternBonusTable; struct PolicySortEntry; @@ -127,6 +128,15 @@ struct Search { std::string randSeed; + bool persistentMCTSEnabled; + Hash128 persistentCurrentRootKey; + std::vector persistentCurrentRootAncestorKeys; + std::map persistentRootNodes; + std::map persistentPendingVisitsByRoot; + mutable std::mutex persistentPendingVisitsMutex; + bool persistentPropagateDescendantCredits; + bool persistentConsumingPendingVisits; + //Contains all koHashes of positions/situations up to and including the root KoHashTable* rootKoHashTable; @@ -247,6 +257,14 @@ struct Search { //Just directly clear search without changing anything void clearSearch(); + //Persistent-MCTS mode keeps root copies across arbitrary root changes and materializes + //the view for the current root from root-tagged primitive statistics. Off by default. + void setPersistentMCTSEnabled(bool enabled); + bool getPersistentMCTSEnabled() const; + void setPositionForMCTSPersistence(Player pla, const Board& board, const BoardHistory& history); + void exportPersistentMCTS(const std::string& path) const; + void importPersistentMCTS(const std::string& path); + //Updates position and preserves the relevant subtree of search //If the move is not legal for the specified player, returns false and does nothing, else returns true //In the case where the player was not the expected one moving next, also clears history. @@ -599,6 +617,7 @@ struct Search { void selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, bool& countEdgeVisit, + bool& bestChildIsNew, bool isRoot ) const; @@ -618,12 +637,37 @@ struct Search { bool isTerminal, bool assumeNoExistingWeight ); + void addPersistentDirectLeafValue( + SearchNode& node, + double winLossValue, + double noResultValue, + double scoreMean, + double scoreMeanSq, + double lead, + double utility, + double weight + ); void addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExistingWeight); + bool ensurePersistentDirectStatsForCurrentRoot(SearchNode& node, SearchThread& thread); double computeWeightFromNNOutput(const NNOutput* nnOutput) const; void updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, bool isRoot); void recomputeNodeStats(SearchNode& node, SearchThread& thread, int32_t numVisitsToAdd, bool isRoot); + void addEdgeVisits(SearchNode& parent, SearchChildPointer& childPointer, int64_t delta); + void mirrorPersistentDirectStatsToRootCopy(const SearchNode& sourceNode, const NodeStats& stats); + void mirrorPersistentEdgeVisitsToRootCopy(const SearchNode& sourceParent, const SearchChildPointer& sourceChildPointer, int64_t delta); + void addPersistentDescendantCredit(const SearchNode& descendant); + void consumePersistentPendingVisits(); + void recomputeNodeStatsFromPersistentStats(SearchNode& node, SearchThread& thread, const NodeStats& directStats, bool isRoot); + void materializePersistentMCTS(bool filterRootChildren); + void materializePersistentNode(SearchNode& node, SearchThread& thread, bool filterRootChildren); + void materializePersistentSubtree(SearchNode& node, SearchThread& thread, bool filterRootChildren, std::unordered_set& visited); + Hash128 getPersistentRootKey(const BoardHistory& history, Player pla) const; + std::vector getPersistentRootAncestorKeys(const BoardHistory& history, Player pla) const; + SearchNode* findPersistentNodeInTable(Hash128 graphHash, bool forceNonTerminal) const; + SearchNode* getOrCreatePersistentRootNode(Hash128 rootKey, bool forceNonTerminal); + void setNodeStats(SearchNode& node, const NodeStats& stats); void adjustEvalsFromCacheHelper( const std::shared_ptr& evalCacheEntry, diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 5b4acac010..7f7581b1b7 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -324,6 +324,7 @@ double Search::getFpuValueForChildrenAssumeVisited( void Search::selectBestChildToDescend( SearchThread& thread, const SearchNode& node, SearchNodeState nodeState, int& numChildrenFound, int& bestChildIdx, Loc& bestChildMoveLoc, bool& countEdgeVisit, + bool& bestChildIsNew, bool isRoot) const { assert(thread.pla == node.nextPla); @@ -332,6 +333,7 @@ void Search::selectBestChildToDescend( bestChildIdx = -1; bestChildMoveLoc = Board::NULL_LOC; countEdgeVisit = true; + bestChildIsNew = false; ConstSearchNodeChildrenReference children = node.getChildren(nodeState); int childrenCapacity = children.getCapacity(); @@ -348,6 +350,9 @@ void Search::selectBestChildToDescend( const SearchNode* child = childPointer.getIfAllocated(); if(child == NULL) break; + int64_t edgeVisits = childPointer.getEdgeVisits(); + if(persistentMCTSEnabled && edgeVisits <= 0) + continue; Loc moveLoc = childPointer.getMoveLocRelaxed(); int movePos = getPos(moveLoc); float nnPolicyProb = policyProbs[movePos]; @@ -355,7 +360,6 @@ void Search::selectBestChildToDescend( continue; policyProbMassVisited += nnPolicyProb; - int64_t edgeVisits = childPointer.getEdgeVisits(); double childWeight = child->stats.getChildWeight(edgeVisits); totalChildWeight += childWeight; @@ -409,6 +413,9 @@ void Search::selectBestChildToDescend( const SearchNode* child = childPointer.getIfAllocated(); if(child == NULL) break; + int64_t edgeVisits = childPointer.getEdgeVisits(); + if(persistentMCTSEnabled && edgeVisits <= 0) + continue; Loc moveLoc = childPointer.getMoveLocRelaxed(); int movePos = getPos(moveLoc); float nnPolicyProb = policyProbs[movePos]; @@ -432,6 +439,8 @@ void Search::selectBestChildToDescend( const SearchNode* child = childPointer.getIfAllocated(); if(child == NULL) break; + if(persistentMCTSEnabled && childPointer.getEdgeVisits() <= 0) + continue; double childWeight = child->stats.weightSum.load(std::memory_order_acquire); totalChildWeight += childWeight; if(childWeight > maxChildWeight) @@ -467,8 +476,38 @@ void Search::selectBestChildToDescend( break; numChildrenFound++; int64_t childEdgeVisits = childPointer.getEdgeVisits(); - Loc moveLoc = childPointer.getMoveLocRelaxed(); + posesWithChildBuf[getPos(moveLoc)] = true; + if(persistentMCTSEnabled && childEdgeVisits <= 0) { + int movePos = getPos(moveLoc); + float nnPolicyProb = policyProbs[movePos]; + if(nnPolicyProb < 0) + continue; + if(isRoot) { + assert(thread.board.pos_hash == rootBoard.pos_hash); + assert(thread.pla == rootPla); + if(!isAllowedRootMove(moveLoc)) + continue; + } + if(antiMirror) + maybeApplyAntiMirrorPolicy(nnPolicyProb, moveLoc, policyProbs, node.nextPla, &thread); + double selectionValue = getNewExploreSelectionValue( + node, + exploreScaling, + nnPolicyProb,fpuValue, + parentWeightPerVisit, + maxChildWeight, + countEdgeVisit, + &thread + ); + if(selectionValue > maxSelectionValue) { + maxSelectionValue = selectionValue; + bestChildIdx = i; + bestChildMoveLoc = moveLoc; + } + continue; + } + bool isDuringSearch = true; double selectionValue = getExploreSelectionValueOfChild( node,policyProbs,child, @@ -491,8 +530,6 @@ void Search::selectBestChildToDescend( bestChildIdx = i; bestChildMoveLoc = moveLoc; } - - posesWithChildBuf[getPos(moveLoc)] = true; } const std::vector& avoidMoveUntilByLoc = thread.pla == P_BLACK ? avoidMoveUntilByLocBlack : avoidMoveUntilByLocWhite; @@ -545,6 +582,7 @@ void Search::selectBestChildToDescend( maxSelectionValue = selectionValue; bestChildIdx = numChildrenFound; bestChildMoveLoc = moveLoc; + bestChildIsNew = true; } } } @@ -603,6 +641,7 @@ void Search::selectBestChildToDescend( maxSelectionValue = selectionValue; bestChildIdx = numChildrenFound; bestChildMoveLoc = bestNewMoveLoc; + bestChildIsNew = true; } } @@ -623,11 +662,14 @@ void Search::selectBestChildToDescend( hasPassMove = true; else hasNonPassMove = true; + if(persistentMCTSEnabled && childPointer.getEdgeVisits() <= 0) + continue; } if(!hasPassMove && bestChildMoveLoc != Board::PASS_LOC && bestChildMoveLoc != Board::NULL_LOC) { bestChildIdx = numChildrenFound; bestChildMoveLoc = Board::PASS_LOC; countEdgeVisit = false; + bestChildIsNew = true; // Specifically for these special extra-pass search playouts, we don't count them for the purpose of visit/playout limits. thread.shouldCountPlayout = false; } @@ -635,6 +677,7 @@ void Search::selectBestChildToDescend( bestChildIdx = numChildrenFound; bestChildMoveLoc = bestNewMoveLoc; countEdgeVisit = false; + bestChildIsNew = true; // Specifically for these special extra-pass search playouts, we don't count them for the purpose of visit/playout limits. thread.shouldCountPlayout = false; } diff --git a/cpp/search/searchnode.cpp b/cpp/search/searchnode.cpp index 54cb9977d5..21919eca6a 100644 --- a/cpp/search/searchnode.cpp +++ b/cpp/search/searchnode.cpp @@ -3,6 +3,8 @@ #include "../search/search.h" #include "../core/test.h" +using namespace std; + NodeStatsAtomic::NodeStatsAtomic() :visits(0), winLossValueAvg(0.0), @@ -77,9 +79,14 @@ MoreNodeStats::~MoreNodeStats() SearchChildPointer::SearchChildPointer(): data(NULL), edgeVisits(0), - moveLoc(Board::NULL_LOC) + moveLoc(Board::NULL_LOC), + persistentEdgeVisitsByRoot(NULL) {} +SearchChildPointer::~SearchChildPointer() { + delete persistentEdgeVisitsByRoot; +} + void SearchChildPointer::storeAll(const SearchChildPointer& other) { SearchNode* d = other.data.load(std::memory_order_acquire); int64_t e = other.edgeVisits.load(std::memory_order_acquire); @@ -87,6 +94,9 @@ void SearchChildPointer::storeAll(const SearchChildPointer& other) { moveLoc.store(m,std::memory_order_release); edgeVisits.store(e,std::memory_order_release); data.store(d,std::memory_order_release); + delete persistentEdgeVisitsByRoot; + persistentEdgeVisitsByRoot = + other.persistentEdgeVisitsByRoot == NULL ? NULL : new vector>(*other.persistentEdgeVisitsByRoot); } bool SearchChildPointer::storeIfNull(SearchNode* node) { @@ -97,6 +107,44 @@ bool SearchChildPointer::compexweakEdgeVisits(int64_t& expected, int64_t desired return edgeVisits.compare_exchange_weak(expected, desired, std::memory_order_acq_rel); } +void SearchChildPointer::addPersistentEdgeVisits(Hash128 rootKey, int64_t delta) { + if(delta == 0) + return; + if(persistentEdgeVisitsByRoot == NULL) + persistentEdgeVisitsByRoot = new vector>(); + for(pair& entry: *persistentEdgeVisitsByRoot) { + if(entry.first == rootKey) { + entry.second += delta; + return; + } + } + persistentEdgeVisitsByRoot->push_back(make_pair(rootKey,delta)); +} + +int64_t SearchChildPointer::getPersistentEdgeVisits(const vector& rootKeys) const { + if(persistentEdgeVisitsByRoot == NULL) + return 0; + int64_t ret = 0; + for(const pair& entry: *persistentEdgeVisitsByRoot) { + for(const Hash128& rootKey: rootKeys) { + if(entry.first == rootKey) { + ret += entry.second; + break; + } + } + } + return ret; +} + +const vector>* SearchChildPointer::getPersistentEdgeVisitsByRoot() const { + return persistentEdgeVisitsByRoot; +} + +void SearchChildPointer::setPersistentEdgeVisitsByRootForLoad(const vector>& entries) { + delete persistentEdgeVisitsByRoot; + persistentEdgeVisitsByRoot = entries.empty() ? NULL : new vector>(entries); +} + //----------------------------------------------------------------------------------------- @@ -121,7 +169,8 @@ SearchNode::SearchNode(Player pla, bool fnt, uint32_t mIdx, Hash128 gh) subtreeValueBiasTableEntry(), graphHash(gh), evalCacheEntry(nullptr), - dirtyCounter(0) + dirtyCounter(0), + persistentDirectStatsByRoot(NULL) { } @@ -144,7 +193,10 @@ SearchNode::SearchNode(const SearchNode& other, bool fnt, bool copySubtreeValueB subtreeValueBiasTableEntry(), graphHash(other.graphHash), evalCacheEntry(other.evalCacheEntry), - dirtyCounter(other.dirtyCounter.load(std::memory_order_acquire)) + dirtyCounter(other.dirtyCounter.load(std::memory_order_acquire)), + persistentDirectStatsByRoot( + other.persistentDirectStatsByRoot == NULL ? NULL : new vector>(*other.persistentDirectStatsByRoot) + ) { { std::shared_ptr* otherVal = other.nnOutput.load(std::memory_order_acquire); @@ -334,6 +386,68 @@ void SearchNode::collapseChildrenCapacity(int numGoodChildren) { } } +static void addNodeStatsInPlace(NodeStats& dst, const NodeStats& src) { + if(src.visits == 0 && src.weightSum == 0.0 && src.weightSqSum == 0.0) + return; + if(dst.visits == 0 && dst.weightSum == 0.0 && dst.weightSqSum == 0.0) { + dst = src; + return; + } + + double oldWeightSum = dst.weightSum; + double newWeightSum = oldWeightSum + src.weightSum; + if(newWeightSum > 0.0) { + dst.winLossValueAvg = (dst.winLossValueAvg * oldWeightSum + src.winLossValueAvg * src.weightSum) / newWeightSum; + dst.noResultValueAvg = (dst.noResultValueAvg * oldWeightSum + src.noResultValueAvg * src.weightSum) / newWeightSum; + dst.scoreMeanAvg = (dst.scoreMeanAvg * oldWeightSum + src.scoreMeanAvg * src.weightSum) / newWeightSum; + dst.scoreMeanSqAvg = (dst.scoreMeanSqAvg * oldWeightSum + src.scoreMeanSqAvg * src.weightSum) / newWeightSum; + dst.leadAvg = (dst.leadAvg * oldWeightSum + src.leadAvg * src.weightSum) / newWeightSum; + dst.utilityAvg = (dst.utilityAvg * oldWeightSum + src.utilityAvg * src.weightSum) / newWeightSum; + dst.utilitySqAvg = (dst.utilitySqAvg * oldWeightSum + src.utilitySqAvg * src.weightSum) / newWeightSum; + } + dst.visits += src.visits; + dst.weightSum = newWeightSum; + dst.weightSqSum += src.weightSqSum; +} + +void SearchNode::addPersistentDirectStats(Hash128 rootKey, const NodeStats& statsToAdd) { + if(statsToAdd.visits == 0 && statsToAdd.weightSum == 0.0 && statsToAdd.weightSqSum == 0.0) + return; + if(persistentDirectStatsByRoot == NULL) + persistentDirectStatsByRoot = new vector>(); + for(pair& entry: *persistentDirectStatsByRoot) { + if(entry.first == rootKey) { + addNodeStatsInPlace(entry.second, statsToAdd); + return; + } + } + persistentDirectStatsByRoot->push_back(make_pair(rootKey,statsToAdd)); +} + +NodeStats SearchNode::getPersistentDirectStats(const vector& rootKeys) const { + NodeStats ret; + if(persistentDirectStatsByRoot == NULL) + return ret; + for(const pair& entry: *persistentDirectStatsByRoot) { + for(const Hash128& rootKey: rootKeys) { + if(entry.first == rootKey) { + addNodeStatsInPlace(ret, entry.second); + break; + } + } + } + return ret; +} + +const vector>* SearchNode::getPersistentDirectStatsByRoot() const { + return persistentDirectStatsByRoot; +} + +void SearchNode::setPersistentDirectStatsByRootForLoad(const vector>& entries) { + delete persistentDirectStatsByRoot; + persistentDirectStatsByRoot = entries.empty() ? NULL : new vector>(entries); +} + NNOutput* SearchNode::getNNOutput() { std::shared_ptr* nn = nnOutput.load(std::memory_order_acquire); if(nn == NULL) @@ -406,4 +520,5 @@ SearchNode::~SearchNode() { delete nnOutput; if(humanOutput != NULL) delete humanOutput; + delete persistentDirectStatsByRoot; } diff --git a/cpp/search/searchnode.h b/cpp/search/searchnode.h index 0a0382bcfd..d2f25a453a 100644 --- a/cpp/search/searchnode.h +++ b/cpp/search/searchnode.h @@ -107,8 +107,10 @@ struct SearchChildPointer { std::atomic data; std::atomic edgeVisits; std::atomic moveLoc; // Generally this will be always guarded under release semantics of data or of the array itself. + std::vector>* persistentEdgeVisitsByRoot; public: SearchChildPointer(); + ~SearchChildPointer(); SearchChildPointer(const SearchChildPointer&) = delete; SearchChildPointer& operator=(const SearchChildPointer&) = delete; @@ -135,6 +137,11 @@ struct SearchChildPointer { inline Loc getMoveLocRelaxed() const { return moveLoc.load(std::memory_order_relaxed); } inline void setMoveLoc(Loc loc) { moveLoc.store(loc, std::memory_order_release); } inline void setMoveLocRelaxed(Loc loc) { moveLoc.store(loc, std::memory_order_relaxed); } + + void addPersistentEdgeVisits(Hash128 rootKey, int64_t delta); + int64_t getPersistentEdgeVisits(const std::vector& rootKeys) const; + const std::vector>* getPersistentEdgeVisitsByRoot() const; + void setPersistentEdgeVisitsByRootForLoad(const std::vector>& entries); }; namespace SearchChildrenSizes { @@ -234,6 +241,8 @@ struct SearchNode { std::atomic dirtyCounter; + std::vector>* persistentDirectStatsByRoot; + //-------------------------------------------------------------------------------- SearchNode(Player prevPla, bool forceNonTerminal, uint32_t mutexIdx, Hash128 graphHash); SearchNode(const SearchNode&, bool forceNonTerminal, bool copySubtreeValueBias); @@ -270,6 +279,11 @@ struct SearchNode { bool maybeExpandChildrenCapacityForNewChild(SearchNodeState& stateValue, int numChildrenFullPlusOne); void collapseChildrenCapacity(int numGoodChildren); + void addPersistentDirectStats(Hash128 rootKey, const NodeStats& stats); + NodeStats getPersistentDirectStats(const std::vector& rootKeys) const; + const std::vector>* getPersistentDirectStatsByRoot() const; + void setPersistentDirectStatsByRootForLoad(const std::vector>& entries); + private: bool tryExpandingChildrenCapacityAssumeFull(SearchNodeState& stateValue); }; diff --git a/cpp/search/searchupdatehelpers.cpp b/cpp/search/searchupdatehelpers.cpp index 22db5ce16b..23759d2f01 100644 --- a/cpp/search/searchupdatehelpers.cpp +++ b/cpp/search/searchupdatehelpers.cpp @@ -7,6 +7,30 @@ #include "../core/using.h" //------------------------ +static void copyNNOutputsAndExpandedStateIfPresent(const SearchNode& sourceNode, SearchNode& targetNode) { + std::shared_ptr* sourceNNOutput = sourceNode.nnOutput.load(std::memory_order_acquire); + if(sourceNNOutput != NULL && targetNode.getNNOutput() == NULL) { + std::shared_ptr* copy = new std::shared_ptr(*sourceNNOutput); + bool suc = targetNode.storeNNOutputIfNull(copy); + if(!suc) + delete copy; + } + std::shared_ptr* sourceHumanOutput = sourceNode.humanOutput.load(std::memory_order_acquire); + if(sourceHumanOutput != NULL && targetNode.getHumanOutput() == NULL) { + std::shared_ptr* copy = new std::shared_ptr(*sourceHumanOutput); + bool suc = targetNode.storeHumanOutputIfNull(copy); + if(!suc) + delete copy; + } + if(sourceNNOutput != NULL) { + SearchNodeState stateValue = targetNode.state.load(std::memory_order_acquire); + if(stateValue < SearchNode::STATE_EXPANDED0) { + targetNode.initializeChildren(); + targetNode.state.store(SearchNode::STATE_EXPANDED0,std::memory_order_release); + } + } +} + void Search::addLeafValue( SearchNode& node, @@ -37,6 +61,9 @@ void Search::addLeafValue( utility += getPatternBonus(node.patternBonusHash,getOpp(node.nextPla)); + if(persistentMCTSEnabled) + addPersistentDirectLeafValue(node,winLossValue,noResultValue,scoreMean,scoreMeanSq,lead,utility,weight); + double utilitySq = utility * utility; double weightSq = weight * weight; @@ -80,6 +107,147 @@ void Search::addLeafValue( } } +void Search::addPersistentDirectLeafValue( + SearchNode& node, + double winLossValue, + double noResultValue, + double scoreMean, + double scoreMeanSq, + double lead, + double utility, + double weight +) { + NodeStats stats; + stats.visits = 1; + stats.winLossValueAvg = winLossValue; + stats.noResultValueAvg = noResultValue; + stats.scoreMeanAvg = scoreMean; + stats.scoreMeanSqAvg = scoreMeanSq; + stats.leadAvg = lead; + stats.utilityAvg = utility; + stats.utilitySqAvg = utility * utility; + stats.weightSum = weight; + stats.weightSqSum = weight * weight; + while(node.statsLock.test_and_set(std::memory_order_acquire)); + node.addPersistentDirectStats(persistentCurrentRootKey,stats); + node.statsLock.clear(std::memory_order_release); + mirrorPersistentDirectStatsToRootCopy(node,stats); +} + +bool Search::ensurePersistentDirectStatsForCurrentRoot(SearchNode& node, SearchThread& thread) { + if(!persistentMCTSEnabled || node.getNNOutput() == NULL) + return false; + + vector visibleRootKeys = {persistentCurrentRootKey}; + NodeStats directStats = node.getPersistentDirectStats(visibleRootKeys); + if(directStats.visits > 0 && directStats.weightSum > 0.0) + return false; + + int64_t visibleVisits = node.stats.visits.load(std::memory_order_acquire); + double visibleWeight = node.stats.weightSum.load(std::memory_order_acquire); + bool assumeNoExistingWeight = visibleVisits <= 0 || visibleWeight <= 0.0; + addCurrentNNOutputAsLeafValue(node,assumeNoExistingWeight); + if(!assumeNoExistingWeight) + materializePersistentNode(node,thread,false); + return true; +} + +void Search::addEdgeVisits(SearchNode& parent, SearchChildPointer& childPointer, int64_t delta) { + childPointer.addEdgeVisits(delta); + if(persistentMCTSEnabled) { + { + std::lock_guard lock(mutexPool->getMutex(parent.mutexIdx)); + childPointer.addPersistentEdgeVisits(persistentCurrentRootKey,delta); + } + mirrorPersistentEdgeVisitsToRootCopy(parent,childPointer,delta); + SearchNode* child = childPointer.getIfAllocated(); + if(child != NULL) { + for(int64_t i = 0; isecond; + if(rootCopy == &sourceNode) + return; + { + std::lock_guard lock(mutexPool->getMutex(rootCopy->mutexIdx)); + copyNNOutputsAndExpandedStateIfPresent(sourceNode,*rootCopy); + } + while(rootCopy->statsLock.test_and_set(std::memory_order_acquire)); + rootCopy->addPersistentDirectStats(persistentCurrentRootKey,stats); + rootCopy->statsLock.clear(std::memory_order_release); +} + +void Search::mirrorPersistentEdgeVisitsToRootCopy( + const SearchNode& sourceParent, + const SearchChildPointer& sourceChildPointer, + int64_t delta +) { + if(!persistentMCTSEnabled || delta == 0) + return; + auto iter = persistentRootNodes.find(sourceParent.graphHash); + if(iter == persistentRootNodes.end()) + return; + SearchNode* rootCopy = iter->second; + if(rootCopy == &sourceParent) + return; + + Loc moveLoc = sourceChildPointer.getMoveLocRelaxed(); + SearchNode* child = const_cast(sourceChildPointer.getIfAllocated()); + if(child == NULL || moveLoc == Board::NULL_LOC) + return; + + std::lock_guard lock(mutexPool->getMutex(rootCopy->mutexIdx)); + copyNNOutputsAndExpandedStateIfPresent(sourceParent,*rootCopy); + SearchNodeState stateValue = rootCopy->state.load(std::memory_order_acquire); + if(stateValue < SearchNode::STATE_EXPANDED0) { + rootCopy->initializeChildren(); + rootCopy->state.store(SearchNode::STATE_EXPANDED0,std::memory_order_release); + stateValue = SearchNode::STATE_EXPANDED0; + } + + SearchNodeChildrenReference children = rootCopy->getChildren(stateValue); + int childrenCapacity = children.getCapacity(); + int numChildrenFound = 0; + for(; numChildrenFoundmaybeExpandChildrenCapacityForNewChild(stateValue,numChildrenFound+1); + assert(suc); + children = rootCopy->getChildren(stateValue); + SearchChildPointer& targetChildPointer = children[numChildrenFound]; + targetChildPointer.setMoveLocRelaxed(moveLoc); + targetChildPointer.store(child); + targetChildPointer.addPersistentEdgeVisits(persistentCurrentRootKey,delta); +} + +void Search::addPersistentDescendantCredit(const SearchNode& descendant) { + if( + !persistentMCTSEnabled || + !persistentPropagateDescendantCredits || + persistentConsumingPendingVisits || + descendant.graphHash == persistentCurrentRootKey + ) + return; + std::lock_guard lock(persistentPendingVisitsMutex); + persistentPendingVisitsByRoot[descendant.graphHash] += 1; +} + void Search::addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExistingWeight) { const NNOutput* nnOutput = node.getNNOutput(); assert(nnOutput != NULL); @@ -161,6 +329,127 @@ void Search::updateStatsAfterPlayout(SearchNode& node, SearchThread& thread, boo } } +void Search::recomputeNodeStatsFromPersistentStats(SearchNode& node, SearchThread& thread, const NodeStats& directStats, bool isRoot) { + vector& statsBuf = thread.statsBuf; + int numGoodChildren = 0; + + ConstSearchNodeChildrenReference children = node.getChildren(); + int childrenCapacity = children.getCapacity(); + double origTotalChildWeight = 0.0; + int64_t thisNodeVisits = directStats.visits; + for(int i = 0; istats); + + if(stats.stats.visits <= 0 || stats.stats.weightSum <= 0.0 || edgeVisits <= 0) + continue; + + double childUtility = stats.stats.utilityAvg; + stats.selfUtility = node.nextPla == P_WHITE ? childUtility : -childUtility; + stats.weightAdjusted = stats.stats.getChildWeight(edgeVisits); + stats.prevMoveLoc = moveLoc; + + origTotalChildWeight += stats.weightAdjusted; + thisNodeVisits += edgeVisits; + numGoodChildren++; + } + + double currentTotalChildWeight = origTotalChildWeight; + + if(searchParams.useNoisePruning && numGoodChildren > 0 && !(searchParams.antiMirror && mirroringPla != C_EMPTY)) { + double policyProbsBuf[NNPos::MAX_NN_POLICY_SIZE]; + { + const NNOutput* nnOutput = node.getNNOutput(); + assert(nnOutput != NULL); + const float* policyProbs = nnOutput->getPolicyProbsMaybeNoised(); + for(int i = 0; i maxChildWeight) + maxChildWeight = statsBuf[i].weightAdjusted; + } + amountToSubtract = std::min(searchParams.chosenMoveSubtract, maxChildWeight/64.0); + amountToPrune = std::min(searchParams.chosenMovePrune, maxChildWeight/64.0); + } + + downweightBadChildrenAndNormalizeWeight( + numGoodChildren, currentTotalChildWeight, currentTotalChildWeight, + amountToSubtract, amountToPrune, statsBuf + ); + } + + double winLossValueSum = 0.0; + double noResultValueSum = 0.0; + double scoreMeanSum = 0.0; + double scoreMeanSqSum = 0.0; + double leadSum = 0.0; + double utilitySum = 0.0; + double utilitySqSum = 0.0; + double weightSqSum = 0.0; + double weightSum = currentTotalChildWeight; + for(int i = 0; i 0.0) { + winLossValueSum += directStats.winLossValueAvg * directStats.weightSum; + noResultValueSum += directStats.noResultValueAvg * directStats.weightSum; + scoreMeanSum += directStats.scoreMeanAvg * directStats.weightSum; + scoreMeanSqSum += directStats.scoreMeanSqAvg * directStats.weightSum; + leadSum += directStats.leadAvg * directStats.weightSum; + utilitySum += directStats.utilityAvg * directStats.weightSum; + utilitySqSum += directStats.utilitySqAvg * directStats.weightSum; + weightSum += directStats.weightSum; + weightSqSum += directStats.weightSqSum; + } + + if(weightSum <= 0.0) { + setNodeStats(node,directStats); + return; + } + + NodeStats newStats; + newStats.visits = thisNodeVisits; + newStats.winLossValueAvg = winLossValueSum / weightSum; + newStats.noResultValueAvg = noResultValueSum / weightSum; + newStats.scoreMeanAvg = scoreMeanSum / weightSum; + newStats.scoreMeanSqAvg = scoreMeanSqSum / weightSum; + newStats.leadAvg = leadSum / weightSum; + newStats.utilityAvg = utilitySum / weightSum; + newStats.utilitySqAvg = utilitySqSum / weightSum; + newStats.weightSum = weightSum; + newStats.weightSqSum = weightSqSum; + setNodeStats(node,newStats); +} + //Recompute all the stats of this node based on its children, except its visits and virtual losses, which are not child-dependent and //are updated in the manner specified. //Assumes this node has an nnOutput diff --git a/cpp/tests/tests.h b/cpp/tests/tests.h index 30bf5d9e96..ddd1bfddab 100644 --- a/cpp/tests/tests.h +++ b/cpp/tests/tests.h @@ -46,6 +46,8 @@ namespace Tests { //testsearchnonn.cpp void runNNLessSearchTests(); + void runPersistentMCTSTests(); + void runPersistentMCTSStrictTests(); //testsearch.cpp void runSearchTests(const std::string& modelFile, bool inputsNHWC, bool cudaNHWC, int symmetry, bool useFP16); //testsearchv3.cpp diff --git a/cpp/tests/testsearchnonn.cpp b/cpp/tests/testsearchnonn.cpp index 8c4fd5f12e..fbb2859bfa 100644 --- a/cpp/tests/testsearchnonn.cpp +++ b/cpp/tests/testsearchnonn.cpp @@ -1,10 +1,12 @@ #include "../tests/tests.h" #include +#include #include #include #include "../core/fileutils.h" +#include "../dataio/files.h" #include "../dataio/sgf.h" #include "../neuralnet/nninputs.h" #include "../search/asyncbot.h" @@ -2656,4 +2658,573 @@ x.x.x cout << "Done" << endl; } +namespace { + +struct PersistentMCTSTestPosition { + string sgfFile; + int64_t turnIdx; + Board board; + Player nextPla; + BoardHistory hist; +}; + +struct PersistentMCTSTestCase { + string sgfFile; + vector positions; + vector rootSequence; + vector exportSteps; +}; + +struct PersistentMCTSRootSnapshot { + int64_t visits; + map edgeVisitsByLoc; + ReportedSearchValues values; + vector ownership; +}; + +static string resolvePersistentMCTSTestPath(const string& path) { + if(FileUtils::exists(path)) + return path; + if(FileUtils::exists("cpp/" + path)) + return "cpp/" + path; + if(FileUtils::exists("../" + path)) + return "../" + path; + return path; +} + +static SearchParams makePersistentMCTSTestParams(int64_t maxPlayouts) { + SearchParams params; + params.dynamicScoreUtilityFactor = 0.0; + params.useGraphSearch = true; + params.useEvalCache = false; + params.useNoisePruning = false; + params.useUncertainty = false; + params.rootNoiseEnabled = false; + params.rootSymmetryPruning = false; + params.rootNumSymmetriesToSample = 1; + params.rootDesiredPerChildVisitsCoeff = 0.0; + params.rootPolicyOptimism = 0.0; + params.policyOptimism = 0.0; + params.wideRootNoise = 0.0; + params.enablePassingHacks = false; + params.enableMorePassingHacks = false; + params.playoutDoublingAdvantage = 0.0; + params.playoutDoublingAdvantagePla = C_EMPTY; + params.subtreeValueBiasFactor = 0.0; + params.numThreads = 1; + params.minPlayoutsPerThread = 0.0; + params.maxVisits = ((int64_t)1L << 50); + params.maxPlayouts = maxPlayouts; + params.maxTime = 1.0e20; + params.searchFactorAfterOnePass = 1.0; + params.searchFactorAfterTwoPass = 1.0; + return params; +} + +static bool persistentMCTSApproxEqual(double a, double b) { + const double atol = 1e-4; + const double rtol = 1e-3; + return std::fabs(a-b) <= atol + rtol * std::max(std::fabs(a), std::fabs(b)); +} + +static void assertPersistentMCTSApprox( + double actual, + double expected, + const string& label, + const string& context +) { + if(!persistentMCTSApproxEqual(actual,expected)) { + cout << "Persistent MCTS mismatch at " << context << " for " << label + << ": actual=" << actual << " expected=" << expected << endl; + testAssert(false); + } +} + +static PersistentMCTSRootSnapshot capturePersistentMCTSRootSnapshot(Search* search) { + PersistentMCTSRootSnapshot snapshot; + snapshot.visits = search->getRootVisits(); + snapshot.values = search->getRootValuesRequireSuccess(); + snapshot.ownership = search->getAverageTreeOwnership(); + + ConstSearchNodeChildrenReference children = search->rootNode->getChildren(); + int childrenCapacity = children.getCapacity(); + for(int i = 0; i 0) + snapshot.edgeVisitsByLoc[moveLoc] = edgeVisits; + } + return snapshot; +} + +static void comparePersistentMCTSRootSnapshots( + const PersistentMCTSRootSnapshot& actual, + const PersistentMCTSRootSnapshot& expected, + const Board& board, + const string& context +) { + if(actual.visits != expected.visits) { + cout << "Persistent MCTS visit mismatch at " << context + << ": actual=" << actual.visits << " expected=" << expected.visits << endl; + testAssert(false); + } + + set locs; + for(const auto& entry: actual.edgeVisitsByLoc) + locs.insert(entry.first); + for(const auto& entry: expected.edgeVisitsByLoc) + locs.insert(entry.first); + double actualVisitDenom = std::max((int64_t)1, actual.visits); + double expectedVisitDenom = std::max((int64_t)1, expected.visits); + for(Loc loc: locs) { + int64_t actualVisits = actual.edgeVisitsByLoc.count(loc) ? actual.edgeVisitsByLoc.at(loc) : 0; + int64_t expectedVisits = expected.edgeVisitsByLoc.count(loc) ? expected.edgeVisitsByLoc.at(loc) : 0; + double actualProp = (double)actualVisits / actualVisitDenom; + double expectedProp = (double)expectedVisits / expectedVisitDenom; + assertPersistentMCTSApprox( + actualProp, + expectedProp, + string("edge distribution ") + Location::toString(loc,board), + context + ); + } + + assertPersistentMCTSApprox(actual.values.winValue, expected.values.winValue, "root winValue", context); + assertPersistentMCTSApprox(actual.values.winLossValue, expected.values.winLossValue, "root winLossValue", context); + assertPersistentMCTSApprox(actual.values.expectedScore, expected.values.expectedScore, "root expectedScore", context); + assertPersistentMCTSApprox(actual.values.lead, expected.values.lead, "root lead", context); + + if(actual.ownership.size() != expected.ownership.size()) { + cout << "Persistent MCTS ownership size mismatch at " << context + << ": actual=" << actual.ownership.size() << " expected=" << expected.ownership.size() << endl; + testAssert(false); + } + for(size_t i = 0; i 0) { + int64_t chunk = std::min(64, visitsRemaining); + standard.searchParams.maxPlayouts = chunk; + standard.runWholeSearch(pos.nextPla); + visitsRemaining -= chunk; + } + PersistentMCTSRootSnapshot snapshot = capturePersistentMCTSRootSnapshot(&standard); + if(snapshot.visits != visits) { + cout << "Standard MCTS reference did not hit requested visits for " + << pos.sgfFile << " turn " << pos.turnIdx + << ": requested=" << visits << " actual=" << snapshot.visits << endl; + testAssert(false); + } + return snapshot; +} + +static vector collectPersistentMCTSSgfFiles() { + vector files; + vector fixedFiles = { + "tests/data/sampletest/sampletest7x7.sgf", + "tests/data/sampletest/sampletest9x9.sgf", + "tests/data/humanslbigdiff.sgf", + "tests/data/sampletest2/messy.sgf" + }; + for(const string& path: fixedFiles) { + string resolved = resolvePersistentMCTSTestPath(path); + if(FileUtils::exists(resolved)) + files.push_back(resolved); + } + + const char* includeExternalSgfs = std::getenv("KATAGO_PERSISTENT_MCTS_EXTERNAL_SGFS"); + if(includeExternalSgfs != NULL && string(includeExternalSgfs) == "1") { + vector dirs = { + "../sgfs", + "../../sgfs" + }; + for(const string& dir: dirs) { + if(FileUtils::exists(dir)) { + vector collected; + FileHelpers::collectSgfsFromDirOrFile(dir,collected); + files.insert(files.end(),collected.begin(),collected.end()); + } + } + } + + sort(files.begin(),files.end()); + files.erase(unique(files.begin(),files.end()),files.end()); + Rand rand("persistentMCTSSgfFileOrder"); + rand.shuffle(files); + return files; +} + +static bool makePersistentMCTSTestCase( + const string& sgfFile, + Rand& rand, + PersistentMCTSTestCase& testCase +) { + std::unique_ptr sgf; + try { + sgf = CompactSgf::loadFile(sgfFile); + } + catch(const StringError& e) { + cout << "Skipping SGF for persistent MCTS test due to parse error: " << sgfFile << " " << e.what() << endl; + return false; + } + if(sgf == nullptr || sgf->moves.size() <= 20) + return false; + if(sgf->xSize > NNPos::MAX_BOARD_LEN || sgf->ySize > NNPos::MAX_BOARD_LEN) + return false; + + Rules rules = Rules::getTrompTaylorish(); + try { + rules = sgf->getRulesOrFailAllowUnspecified(rules); + } + catch(const StringError& e) { + cout << "Skipping SGF for persistent MCTS test due to rules error: " << sgfFile << " " << e.what() << endl; + return false; + } + + int64_t latestStart = (int64_t)sgf->moves.size() - 20; + int64_t start = (int64_t)rand.nextUInt64((uint64_t)latestStart + 1); + vector offsets(20); + for(uint32_t i = 0; isetupBoardAndHistTolerant(rules,board,nextPla,hist,turnIdx,false); + } + catch(const StringError& e) { + continue; + } + if(hist.isGameFinished) + continue; + PersistentMCTSTestPosition pos; + pos.sgfFile = sgfFile; + pos.turnIdx = turnIdx; + pos.board = board; + pos.nextPla = nextPla; + pos.hist = hist; + testCase.positions.push_back(pos); + } + if(testCase.positions.size() < 15) + return false; + + testCase.rootSequence.clear(); + vector initialOrder(15); + for(int i = 0; i<15; i++) + initialOrder[i] = i; + rand.shuffle(initialOrder); + for(int i = 0; i<15; i++) + testCase.rootSequence.push_back(initialOrder[i]); + while(testCase.rootSequence.size() < 105) + testCase.rootSequence.push_back((int)rand.nextUInt(15)); + + vector exportCandidates(105); + for(int i = 0; i<105; i++) + exportCandidates[i] = i; + rand.shuffle(exportCandidates); + testCase.exportSteps.assign(exportCandidates.begin(),exportCandidates.begin()+7); + sort(testCase.exportSteps.begin(),testCase.exportSteps.end()); + return true; +} + +static vector makePersistentMCTSTestCases(int numCases) { + vector cases; + vector sgfFiles = collectPersistentMCTSSgfFiles(); + Rand rand("persistentMCTSSgfCaseSeed"); + for(const string& sgfFile: sgfFiles) { + PersistentMCTSTestCase testCase; + if(makePersistentMCTSTestCase(sgfFile,rand,testCase)) { + cases.push_back(testCase); + if(cases.size() >= (size_t)numCases) + break; + } + } + if(cases.size() < (size_t)numCases) { + cout << "Only found " << cases.size() << " persistent MCTS SGF cases" << endl; + } + testAssert(cases.size() > 0); + return cases; +} + +static void runPersistentMCTSSgfCorrectnessTest( + const SearchParams& baseParams, + NNEvaluator* nnEval, + Logger* logger +) { + vector cases = makePersistentMCTSTestCases(2); + for(size_t caseIdx = 0; caseIdx> data; + if(!data.contains("nodes")) + return 0; + return (int64_t)data["nodes"].size(); +} + +static void runPersistentMCTSSgfStressTest( + const SearchParams& baseParams, + NNEvaluator* nnEval, + Logger* logger +) { + vector cases = makePersistentMCTSTestCases(1); + const PersistentMCTSTestCase& testCase = cases[0]; + cout << "Persistent MCTS stress SGF case: " << testCase.sgfFile << endl; + + SearchParams params = baseParams; + params.maxPlayouts = 1024; + Search persistent(params, nnEval, logger, "persistentMCTSSgfStressSeed"); + persistent.setAlwaysIncludeOwnerMap(true); + persistent.setPersistentMCTSEnabled(true); + + int64_t totalRequestedPlayouts = 0; + size_t nextExportIdx = 0; + for(int step = 0; step 0); + + if(nextExportIdx < testCase.exportSteps.size() && step == testCase.exportSteps[nextExportIdx]) { + string exportPath = + "/tmp/katago_persistent_mcts_stress_" + Global::intToString(step) + ".json"; + PersistentMCTSRootSnapshot before = capturePersistentMCTSRootSnapshot(&persistent); + persistent.exportPersistentMCTS(exportPath); + persistent.clearSearch(); + persistent.importPersistentMCTS(exportPath); + PersistentMCTSRootSnapshot after = capturePersistentMCTSRootSnapshot(&persistent); + comparePersistentMCTSRootSnapshots( + after, + before, + pos.board, + testCase.sgfFile + " stress step " + Global::intToString(step) + ); + nextExportIdx++; + } + } + + string finalExportPath = "/tmp/katago_persistent_mcts_stress_final.json"; + persistent.exportPersistentMCTS(finalExportPath); + int64_t exportedNodes = countPersistentMCTSExportNodes(finalExportPath); + int64_t reasonableNodeBound = totalRequestedPlayouts + (int64_t)testCase.positions.size() + 4096; + cout << "Persistent MCTS stress nodes=" << exportedNodes + << " requestedPlayouts=" << totalRequestedPlayouts + << " bound=" << reasonableNodeBound << endl; + testAssert(exportedNodes > 0); + testAssert(exportedNodes <= reasonableNodeBound); +} + +} + +void Tests::runPersistentMCTSTests() { + cout << "Running persistent MCTS tests" << endl; + NeuralNet::globalInitialize(); + + string modelFile = "/dev/null"; + const bool logToStdout = false; + const bool logToStderr = false; + const bool logTime = false; + Logger logger(nullptr, logToStdout, logToStderr, logTime); + + NNEvaluator* nnEval = startNNEval(modelFile,logger,"persistent",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + + SearchParams params = makePersistentMCTSTestParams(80); + params.maxVisits = 1000000; + params.maxPlayouts = 80; + Search* search = new Search(params, nnEval, &logger, "persistentMCTSTestSeed"); + search->setPersistentMCTSEnabled(true); + + Rules rules = Rules::getTrompTaylorish(); + Board board = Board::parseBoard(7,7,R"%%( +....... +..x.o.. +....... +...x... +....... +..o.x.. +....... +)%%"); + Player nextPla = P_BLACK; + BoardHistory hist(board,nextPla,rules,0); + + search->setPositionForMCTSPersistence(nextPla,board,hist); + search->runWholeSearch(nextPla); + + auto collectRootEdges = [&]() { + vector> rootEdges; + ConstSearchNodeChildrenReference children = search->rootNode->getChildren(); + int childrenCapacity = children.getCapacity(); + for(int i = 0; i 0) + rootEdges.push_back(make_pair(moveLoc,edgeVisits)); + } + return rootEdges; + }; + auto getRootEdgeVisits = [&] (Loc loc) { + ConstSearchNodeChildrenReference children = search->rootNode->getChildren(); + int childrenCapacity = children.getCapacity(); + for(int i = 0; i>& expectedEdges) { + for(const pair& expected: expectedEdges) + testAssert(getRootEdgeVisits(expected.first) == expected.second); + }; + + vector> rootEdgesBefore = collectRootEdges(); + Loc locToDescend = Board::NULL_LOC; + int64_t locToDescendEdgeVisits = -1; + for(const pair& edge: rootEdgesBefore) { + if(edge.first != Board::PASS_LOC && edge.second > locToDescendEdgeVisits) { + locToDescend = edge.first; + locToDescendEdgeVisits = edge.second; + } + } + testAssert(locToDescend != Board::NULL_LOC); + testAssert(locToDescendEdgeVisits > 0); + + Board childBoard = board; + BoardHistory childHist = hist; + childHist.makeBoardMoveAssumeLegal(childBoard,locToDescend,nextPla,NULL); + Player childPla = getOpp(nextPla); + + search->setPositionForMCTSPersistence(childPla,childBoard,childHist); + int64_t inheritedChildRootVisits = search->getRootVisits(); + testAssert(inheritedChildRootVisits >= locToDescendEdgeVisits); + + search->searchParams.maxPlayouts = 200; + search->runWholeSearch(childPla); + testAssert(search->getRootVisits() > inheritedChildRootVisits); + int64_t childRootVisitsAfterChildSearch = search->getRootVisits(); + + search->setPositionForMCTSPersistence(nextPla,board,hist); + assertRootEdgesMatch(rootEdgesBefore); + + search->searchParams.maxPlayouts = search->getRootVisits() + 240; + search->runWholeSearch(nextPla); + int64_t locToDescendEdgeVisitsAfterAncestorSearch = getRootEdgeVisits(locToDescend); + testAssert(locToDescendEdgeVisitsAfterAncestorSearch > locToDescendEdgeVisits); + + search->setPositionForMCTSPersistence(childPla,childBoard,childHist); + testAssert(search->getRootVisits() > childRootVisitsAfterChildSearch); + + search->setPositionForMCTSPersistence(nextPla,board,hist); + vector> rootEdgesBeforeExport = collectRootEdges(); + + string exportPath = "/tmp/katago_persistent_mcts_test.json"; + search->exportPersistentMCTS(exportPath); + search->clearSearch(); + search->importPersistentMCTS(exportPath); + testAssert(search->getPersistentMCTSEnabled()); + testAssert(search->rootNode != NULL); + assertRootEdgesMatch(rootEdgesBeforeExport); + + delete search; + delete nnEval; + + NeuralNet::globalCleanup(); + cout << "Persistent MCTS tests passed" << endl; +} + +void Tests::runPersistentMCTSStrictTests() { + cout << "Running strict persistent MCTS tests" << endl; + NeuralNet::globalInitialize(); + + string modelFile = "/dev/null"; + const bool logToStdout = false; + const bool logToStderr = false; + const bool logTime = false; + Logger logger(nullptr, logToStdout, logToStderr, logTime); + + NNEvaluator* nnEval = startNNEval(modelFile,logger,"persistent-strict",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + + SearchParams sgfTestParams = makePersistentMCTSTestParams(64); + runPersistentMCTSSgfCorrectnessTest(sgfTestParams, nnEval, &logger); + runPersistentMCTSSgfStressTest(sgfTestParams, nnEval, &logger); + + delete nnEval; + + NeuralNet::globalCleanup(); + cout << "Strict persistent MCTS tests passed" << endl; +} diff --git a/docs/persistent-mcts.md b/docs/persistent-mcts.md new file mode 100644 index 0000000000..b327e151b9 --- /dev/null +++ b/docs/persistent-mcts.md @@ -0,0 +1,85 @@ +# Persistent MCTS model + +This note records the model implemented for persistent root switching, export, and import. + +## State + +Let game states be keyed by collision-free graph hashes. A node can be used as a root at different +times, so every persistent primitive contribution is tagged by the root hash that generated it. +The implementation stores only primitive facts: + +- root-tagged direct NN/terminal leaf statistics for a node; +- root-tagged counted edge visits for an edge; +- pending inherited visit counts keyed by possible future roots. + +Aggregated `NodeStats` are a cache. They are rebuilt by post-order materialization from the +currently visible primitive facts. + +For a current root `R`, the visible primitive facts are exactly the facts tagged by `R`. Ancestor +information is not read directly. Instead, when a real playout rooted at `A` traverses a descendant +state `D`, the implementation increments `pending[D]`. When `D` later becomes the current root, +those pending visits are consumed by running `D`'s own ordinary MCTS transition and writing +`D`-tagged primitive facts. Pending-consumption playouts do not create more pending visits; the +original ancestor playout already credited every descendant it actually traversed. + +Therefore each root `R` has a canonical root-local state + +```text +S_R = ordinary_mcts(root = R, samples = own(R) + inherited(R)) +``` + +where `own(R)` is the number of real searches rooted at `R`, and `inherited(R)` is the number of +real ancestor-root playouts whose path traversed `R`. + +## Isolation theorem + +For any root-switch sequence and any two states `X` and `Y`, searches rooted inside the strict +subtree of `Y` cannot affect the materialized MCTS result at `Y`. + +Proof: a real search rooted at `X` writes primitive facts only with tag `X`, and it increments +pending only for states traversed by that real playout. The materialized view at `Y` reads only +`Y`-tagged facts. If `X` is a strict descendant of `Y`, then `X != Y`, so `X`-tagged direct stats +and edge visits are invisible at `Y`. The playout rooted at `X` also never traverses its strict +ancestor `Y`, so it cannot increment `pending[Y]`. Thus neither the visible primitives nor the +pending count of `Y` can change. + +## Inheritance theorem + +For any state `C`, every real playout rooted at an ancestor of `C` that traverses `C` contributes +exactly one useful sample to `C`'s canonical MCTS state. + +Proof: when the ancestor playout traverses `C`, `addPersistentDescendantCredit` increments +`pending[C]` once. No descendant-root contribution is read while the ancestor is selecting moves, +because only the ancestor's tag is visible during that search. When `C` is later materialized, +`consumePersistentPendingVisits` removes those pending visits and runs exactly that many playouts +from root `C`, with descendant credit propagation disabled. Those playouts write only `C`-tagged +primitive facts, so they are visible when `C` is current and invisible to ancestors. Since they use +the same transition function as ordinary MCTS from `C`, `S_C` is exactly ordinary MCTS with +`own(C) + inherited(C)` samples. + +The statement is independent of switch order. A sequence such as `A, B, A, E, F, B, G, A, B` +only changes the counters `own(R)` and `inherited(R)` for each root `R`; materializing `R` depends +on those counters and `R`-tagged primitives, not on the previous root. + +## Why primitive facts are serialized + +KataGo parent stats are not linear deltas. They are recomputed from child edge visits, child +weights, value weighting, pruning, and direct NN values. Storing only aggregate `NodeStats` would +not be proof-preserving after root switches. The persistent layer serializes tagged direct stats, +tagged edge visits, NN outputs, root copies, graph-table nodes, the current root position/history, +and pending inherited visits. Import reconstructs the graph and materializes the current root from +those primitives. + +## Implementation invariants + +- Root copies are separate from graph-table nodes because KataGo normally keeps the root outside + the graph table. Direct stats and edge visits are mirrored into an existing root copy when an + ancestor search reaches the matching graph-table node. +- When a persistent search touches an already-expanded transposition node, it first ensures the + current root has a direct NN contribution for that node. This prevents import from rebuilding a + different state than the in-memory search. +- At the end of a persistent `runWholeSearch`, the current root subtree is materialized post-order. + This flushes graph-search parent caches so exported/imported state and reported root values agree. +- Root move filtering is non-destructive in persistent mode. Disallowed root children are + materialized with zero visible edge visits for the current root, while their tagged facts remain + available for future root switches.