00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctSearch.cpp 00003 */ 00004 //---------------------------------------------------------------------------- 00005 00006 #include "SgSystem.h" 00007 #include "SgUctSearch.h" 00008 00009 #include <algorithm> 00010 #include <cmath> 00011 #include <iomanip> 00012 #include <boost/format.hpp> 00013 #include <boost/io/ios_state.hpp> 00014 #include <boost/version.hpp> 00015 #include "SgDebug.h" 00016 #include "SgHashTable.h" 00017 #include "SgMath.h" 00018 #include "SgWrite.h" 00019 00020 using namespace std; 00021 using boost::barrier; 00022 using boost::condition; 00023 using boost::format; 00024 using boost::mutex; 00025 using boost::shared_ptr; 00026 using boost::io::ios_all_saver; 00027 00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000) 00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000) 00030 00031 //---------------------------------------------------------------------------- 00032 00033 namespace { 00034 00035 const bool DEBUG_THREADS = false; 00036 00037 void Notify(mutex& aMutex, condition& aCondition) 00038 { 00039 mutex::scoped_lock lock(aMutex); 00040 aCondition.notify_all(); 00041 } 00042 00043 } // namespace 00044 00045 //---------------------------------------------------------------------------- 00046 00047 void SgUctGameInfo::Clear(std::size_t numberPlayouts) 00048 { 00049 m_nodes.clear(); 00050 m_inTreeSequence.clear(); 00051 if (numberPlayouts != m_sequence.size()) 00052 { 00053 m_sequence.resize(numberPlayouts); 00054 m_skipRaveUpdate.resize(numberPlayouts); 00055 m_eval.resize(numberPlayouts); 00056 m_aborted.resize(numberPlayouts); 00057 } 00058 for (size_t i = 0; i < numberPlayouts; ++i) 00059 { 00060 m_sequence[i].clear(); 00061 m_skipRaveUpdate[i].clear(); 00062 } 00063 } 00064 00065 //---------------------------------------------------------------------------- 00066 00067 SgUctThreadState::SgUctThreadState(size_t threadId, int moveRange) 00068 : m_threadId(threadId), 00069 m_isSearchInitialized(false), 00070 m_isTreeOutOfMem(false), 00071 m_randomizeCounter(0) 00072 { 00073 if (moveRange > 0) 00074 { 00075 m_firstPlay.reset(new size_t[moveRange]); 00076 m_firstPlayOpp.reset(new size_t[moveRange]); 00077 } 00078 } 00079 00080 SgUctThreadState::~SgUctThreadState() 00081 { 00082 } 00083 00084 void SgUctThreadState::EndPlayout() 00085 { 00086 // Default implementation does nothing 00087 } 00088 00089 void SgUctThreadState::GameStart() 00090 { 00091 // Default implementation does nothing 00092 } 00093 00094 void SgUctThreadState::StartPlayout() 00095 { 00096 // Default implementation does nothing 00097 } 00098 00099 void SgUctThreadState::StartPlayouts() 00100 { 00101 // Default implementation does nothing 00102 } 00103 00104 //---------------------------------------------------------------------------- 00105 00106 SgUctThreadStateFactory::~SgUctThreadStateFactory() 00107 { 00108 } 00109 00110 //---------------------------------------------------------------------------- 00111 00112 SgUctSearch::Thread::Function::Function(Thread& thread) 00113 : m_thread(thread) 00114 { 00115 } 00116 00117 void SgUctSearch::Thread::Function::operator()() 00118 { 00119 m_thread(); 00120 } 00121 00122 SgUctSearch::Thread::Thread(SgUctSearch& search, 00123 auto_ptr<SgUctThreadState> state) 00124 : m_state(state), 00125 m_search(search), 00126 m_quit(false), 00127 m_threadReady(2), 00128 m_playFinishedLock(m_playFinishedMutex), 00129 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34 00130 m_globalLock(search.m_globalMutex, false), 00131 #else 00132 m_globalLock(search.m_globalMutex, boost::defer_lock), 00133 #endif 00134 m_thread(Function(*this)) 00135 { 00136 m_threadReady.wait(); 00137 } 00138 00139 SgUctSearch::Thread::~Thread() 00140 { 00141 m_quit = true; 00142 StartPlay(); 00143 m_thread.join(); 00144 } 00145 00146 void SgUctSearch::Thread::operator()() 00147 { 00148 if (DEBUG_THREADS) 00149 SgDebug() << "SgUctSearch::Thread: starting thread " 00150 << m_state->m_threadId << '\n'; 00151 mutex::scoped_lock lock(m_startPlayMutex); 00152 m_threadReady.wait(); 00153 while (true) 00154 { 00155 m_startPlay.wait(lock); 00156 if (m_quit) 00157 break; 00158 m_search.SearchLoop(*m_state, &m_globalLock); 00159 Notify(m_playFinishedMutex, m_playFinished); 00160 } 00161 if (DEBUG_THREADS) 00162 SgDebug() << "SgUctSearch::Thread: finishing thread " 00163 << m_state->m_threadId << '\n'; 00164 } 00165 00166 void SgUctSearch::Thread::StartPlay() 00167 { 00168 Notify(m_startPlayMutex, m_startPlay); 00169 } 00170 00171 void SgUctSearch::Thread::WaitPlayFinished() 00172 { 00173 m_playFinished.wait(m_playFinishedLock); 00174 } 00175 00176 //---------------------------------------------------------------------------- 00177 00178 void SgUctSearchStat::Clear() 00179 { 00180 m_time = 0; 00181 m_knowledge = 0; 00182 m_gamesPerSecond = 0; 00183 m_gameLength.Clear(); 00184 m_movesInTree.Clear(); 00185 m_aborted.Clear(); 00186 } 00187 00188 void SgUctSearchStat::Write(std::ostream& out) const 00189 { 00190 ios_all_saver saver(out); 00191 out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n' 00192 << SgWriteLabel("GameLen") << fixed << setprecision(1); 00193 m_gameLength.Write(out); 00194 out << '\n' 00195 << SgWriteLabel("InTree"); 00196 m_movesInTree.Write(out); 00197 out << '\n' 00198 << SgWriteLabel("Aborted") 00199 << static_cast<int>(100 * m_aborted.Mean()) << "%\n" 00200 << SgWriteLabel("Games/s") << fixed << setprecision(1) 00201 << m_gamesPerSecond << '\n'; 00202 } 00203 00204 //---------------------------------------------------------------------------- 00205 00206 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory, 00207 int moveRange) 00208 : m_threadStateFactory(threadStateFactory), 00209 m_logGames(false), 00210 m_rave(false), 00211 m_knowledgeThreshold(), 00212 m_moveSelect(SG_UCTMOVESELECT_COUNT), 00213 m_raveCheckSame(false), 00214 m_randomizeRaveFrequency(20), 00215 m_lockFree(false), 00216 m_weightRaveUpdates(true), 00217 m_pruneFullTree(true), 00218 m_numberThreads(1), 00219 m_numberPlayouts(1), 00220 m_maxNodes(4000000), 00221 m_pruneMinCount(16), 00222 m_moveRange(moveRange), 00223 m_maxGameLength(numeric_limits<size_t>::max()), 00224 m_expandThreshold(1), 00225 m_biasTermConstant(0.7f), 00226 m_firstPlayUrgency(10000), 00227 m_raveWeightInitial(0.9f), 00228 m_raveWeightFinal(20000), 00229 m_virtualLoss(false), 00230 m_logFileName("uctsearch.log"), 00231 m_fastLog(10), 00232 m_mpiSynchronizer(SgMpiNullSynchronizer::Create()) 00233 { 00234 // Don't create thread states here, because the factory passes the search 00235 // (which is not fully constructed here, because the subclass constructors 00236 // are not called yet) as an argument to the Create() function 00237 } 00238 00239 SgUctSearch::~SgUctSearch() 00240 { 00241 DeleteThreads(); 00242 } 00243 00244 void SgUctSearch::ApplyRootFilter(vector<SgMoveInfo>& moves) 00245 { 00246 // Filter without changing the order of the unfiltered moves 00247 vector<SgMoveInfo> filteredMoves; 00248 for (vector<SgMoveInfo>::const_iterator it = moves.begin(); 00249 it != moves.end(); ++it) 00250 if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move) 00251 == m_rootFilter.end()) 00252 filteredMoves.push_back(*it); 00253 moves = filteredMoves; 00254 } 00255 00256 size_t SgUctSearch::GamesPlayed() const 00257 { 00258 return m_tree.Root().MoveCount() - m_startRootMoveCount; 00259 } 00260 00261 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state) 00262 { 00263 size_t gamesPlayed = GamesPlayed(); 00264 if (SgUserAbort()) 00265 { 00266 Debug(state, "SgUctSearch: abort flag"); 00267 return true; 00268 } 00269 const bool isEarlyAbort = CheckEarlyAbort(); 00270 if (gamesPlayed >= m_maxGames) 00271 { 00272 Debug(state, "SgUctSearch: max games reached"); 00273 return true; 00274 } 00275 if ( isEarlyAbort 00276 && m_earlyAbort->m_reductionFactor * gamesPlayed >= m_maxGames 00277 ) 00278 { 00279 Debug(state, "SgUctSearch: max games reached (early abort)"); 00280 m_wasEarlyAbort = true; 00281 return true; 00282 } 00283 if (m_numberGames % m_checkTimeInterval == 0) 00284 { 00285 double time = m_timer.GetTime(); 00286 if (time > m_maxTime) 00287 { 00288 Debug(state, "SgUctSearch: max time reached"); 00289 return true; 00290 } 00291 if (isEarlyAbort 00292 && m_earlyAbort->m_reductionFactor * time > m_maxTime) 00293 { 00294 Debug(state, "SgUctSearch: max time reached (early abort)"); 00295 m_wasEarlyAbort = true; 00296 return true; 00297 } 00298 UpdateCheckTimeInterval(time); 00299 if (m_moveSelect == SG_UCTMOVESELECT_COUNT) 00300 { 00301 double remainingGamesDouble = m_maxGames - gamesPlayed - 1; 00302 // Use time based count abort, only if time > 1, otherwise 00303 // m_gamesPerSecond is unreliable 00304 if (time > 1.) 00305 { 00306 double remainingTime = m_maxTime - time; 00307 remainingGamesDouble = 00308 min(remainingGamesDouble, 00309 remainingTime * m_statistics.m_gamesPerSecond); 00310 } 00311 size_t sizeTypeMax = numeric_limits<size_t>::max(); 00312 size_t remainingGames; 00313 if (remainingGamesDouble > static_cast<double>(sizeTypeMax - 1)) 00314 remainingGames = sizeTypeMax; 00315 else 00316 remainingGames = static_cast<size_t>(remainingGamesDouble); 00317 if (CheckCountAbort(state, remainingGames)) 00318 { 00319 Debug(state, "SgUctSearch: move cannot change anymore"); 00320 return true; 00321 } 00322 } 00323 } 00324 return false; 00325 } 00326 00327 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state, 00328 std::size_t remainingGames) const 00329 { 00330 const SgUctNode& root = m_tree.Root(); 00331 const SgUctNode* bestChild = FindBestChild(root); 00332 if (bestChild == 0) 00333 return false; 00334 size_t bestCount = bestChild->MoveCount(); 00335 vector<SgMove>& excludeMoves = state.m_excludeMoves; 00336 excludeMoves.clear(); 00337 excludeMoves.push_back(bestChild->Move()); 00338 const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves); 00339 if (secondBestChild == 0) 00340 return false; 00341 std::size_t secondBestCount = secondBestChild->MoveCount(); 00342 SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1); 00343 return (secondBestCount + remainingGames <= bestCount); 00344 } 00345 00346 bool SgUctSearch::CheckEarlyAbort() const 00347 { 00348 const SgUctNode& root = m_tree.Root(); 00349 return m_earlyAbort.get() != 0 00350 && root.HasMean() 00351 && root.MoveCount() > m_earlyAbort->m_minGames 00352 && root.Mean() > m_earlyAbort->m_threshold; 00353 } 00354 00355 void SgUctSearch::CreateThreads() 00356 { 00357 DeleteThreads(); 00358 for (size_t i = 0; i < m_numberThreads; ++i) 00359 { 00360 auto_ptr<SgUctThreadState> state( 00361 m_threadStateFactory->Create(i, *this)); 00362 shared_ptr<Thread> thread(new Thread(*this, state)); 00363 m_threads.push_back(thread); 00364 } 00365 m_tree.CreateAllocators(m_numberThreads); 00366 m_tree.SetMaxNodes(m_maxNodes); 00367 00368 m_searchLoopFinished.reset(new barrier(m_numberThreads)); 00369 } 00370 00371 /** Write a debugging line of text from within a thread. 00372 Prepends the line with the thread number if number of threads is greater 00373 than one. Also ensures that the line is written as a single string to 00374 avoid intermingling of text lines from different threads. 00375 @param state The state of the thread (only used for state.m_threadId) 00376 @param textLine The line of text without trailing newline character. 00377 */ 00378 void SgUctSearch::Debug(const SgUctThreadState& state, 00379 const std::string& textLine) 00380 { 00381 if (m_numberThreads > 1) 00382 { 00383 // SgDebug() is not necessarily thread-safe 00384 GlobalLock lock(m_globalMutex); 00385 SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine); 00386 } 00387 else 00388 SgDebug() << (format("%1%\n") % textLine); 00389 } 00390 00391 void SgUctSearch::DeleteThreads() 00392 { 00393 m_threads.clear(); 00394 } 00395 00396 /** Expand a node. 00397 @param state The thread state with state.m_moves already computed. 00398 @param node The node to expand. 00399 */ 00400 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node) 00401 { 00402 size_t threadId = state.m_threadId; 00403 if (! m_tree.HasCapacity(threadId, state.m_moves.size())) 00404 { 00405 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached") 00406 % m_tree.MaxNodes())); 00407 state.m_isTreeOutOfMem = true; 00408 m_isTreeOutOfMemory = true; 00409 SgSynchronizeThreadMemory(); 00410 return; 00411 } 00412 m_tree.CreateChildren(threadId, node, state.m_moves); 00413 } 00414 00415 const SgUctNode* 00416 SgUctSearch::FindBestChild(const SgUctNode& node, 00417 const vector<SgMove>* excludeMoves) const 00418 { 00419 if (! node.HasChildren()) 00420 return 0; 00421 const SgUctNode* bestChild = 0; 00422 float bestValue = 0; 00423 for (SgUctChildIterator it(m_tree, node); it; ++it) 00424 { 00425 const SgUctNode& child = *it; 00426 if (excludeMoves != 0) 00427 { 00428 vector<SgMove>::const_iterator begin = excludeMoves->begin(); 00429 vector<SgMove>::const_iterator end = excludeMoves->end(); 00430 if (find(begin, end, child.Move()) != end) 00431 continue; 00432 } 00433 if ( ! child.HasMean() 00434 && ! ( ( m_moveSelect == SG_UCTMOVESELECT_BOUND 00435 || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE 00436 ) 00437 && m_rave 00438 && child.HasRaveValue() 00439 ) 00440 ) 00441 continue; 00442 float moveValue = InverseEval(child.Mean()); 00443 size_t moveCount = child.MoveCount(); 00444 float value; 00445 switch (m_moveSelect) 00446 { 00447 case SG_UCTMOVESELECT_VALUE: 00448 value = moveValue; 00449 break; 00450 case SG_UCTMOVESELECT_COUNT: 00451 value = moveCount; 00452 break; 00453 case SG_UCTMOVESELECT_BOUND: 00454 value = GetBound(m_rave, node, child); 00455 break; 00456 case SG_UCTMOVESELECT_ESTIMATE: 00457 value = GetValueEstimate(m_rave, child); 00458 break; 00459 default: 00460 SG_ASSERT(false); 00461 value = SG_UCTMOVESELECT_VALUE; 00462 } 00463 if (bestChild == 0 || value > bestValue) 00464 { 00465 bestChild = &child; 00466 bestValue = value; 00467 } 00468 } 00469 return bestChild; 00470 } 00471 00472 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const 00473 { 00474 sequence.clear(); 00475 const SgUctNode* current = &m_tree.Root(); 00476 while (true) 00477 { 00478 current = FindBestChild(*current); 00479 if (current == 0) 00480 break; 00481 sequence.push_back(current->Move()); 00482 if (! current->HasChildren()) 00483 break; 00484 } 00485 } 00486 00487 void SgUctSearch::GenerateAllMoves(std::vector<SgMoveInfo>& moves) 00488 { 00489 if (m_threads.size() == 0) 00490 CreateThreads(); 00491 moves.clear(); 00492 OnStartSearch(); 00493 SgUctThreadState& state = ThreadState(0); 00494 state.StartSearch(); 00495 SgProvenNodeType type; 00496 state.GenerateAllMoves(0, moves, type); 00497 } 00498 00499 float SgUctSearch::GetBound(bool useRave, const SgUctNode& node, 00500 const SgUctNode& child) const 00501 { 00502 size_t posCount = node.PosCount(); 00503 return GetBound(useRave, Log(posCount), child); 00504 } 00505 00506 float SgUctSearch::GetBound(bool useRave, float logPosCount, 00507 const SgUctNode& child) const 00508 { 00509 float value; 00510 if (useRave) 00511 value = GetValueEstimateRave(child); 00512 else 00513 value = GetValueEstimate(false, child); 00514 if (m_biasTermConstant == 0.0) 00515 return value; 00516 else 00517 { 00518 float moveCount = static_cast<float>(child.MoveCount()); 00519 float bound = 00520 value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1)); 00521 return bound; 00522 } 00523 } 00524 00525 SgUctTree& SgUctSearch::GetTempTree() 00526 { 00527 m_tempTree.Clear(); 00528 // Use NumberThreads() (not m_tree.NuAllocators()) and MaxNodes() (not 00529 // m_tree.MaxNodes()), because of the delayed thread (and thereby 00530 // allocator) creation in SgUctSearch 00531 if (m_tempTree.NuAllocators() != NumberThreads()) 00532 { 00533 m_tempTree.CreateAllocators(NumberThreads()); 00534 m_tempTree.SetMaxNodes(MaxNodes()); 00535 } 00536 else if (m_tempTree.MaxNodes() != MaxNodes()) 00537 { 00538 m_tempTree.SetMaxNodes(MaxNodes()); 00539 } 00540 return m_tempTree; 00541 } 00542 00543 float SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const 00544 { 00545 float value = 0.f; 00546 float weightSum = 0.f; 00547 bool hasValue = false; 00548 if (child.HasMean()) 00549 { 00550 float weight = static_cast<float>(child.MoveCount()); 00551 value += weight * InverseEval(child.Mean()); 00552 weightSum += weight; 00553 hasValue = true; 00554 } 00555 if (useRave && child.HasRaveValue()) 00556 { 00557 float raveCount = child.RaveCount(); 00558 float weight = 00559 raveCount 00560 / ( m_raveWeightParam1 00561 + m_raveWeightParam2 * raveCount 00562 ); 00563 value += weight * child.RaveValue(); 00564 weightSum += weight; 00565 hasValue = true; 00566 } 00567 if (hasValue) 00568 return value / weightSum; 00569 else 00570 return m_firstPlayUrgency; 00571 } 00572 00573 /** Optimized version of GetValueEstimate() if RAVE and not other 00574 estimators are used. 00575 Previously there were more estimators than move value and RAVE value, 00576 and in the future there may be again. GetValueEstimate() is easier to 00577 extend, this function is more optimized for the special case. 00578 */ 00579 float SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const 00580 { 00581 SG_ASSERT(m_rave); 00582 bool hasRave = child.HasRaveValue(); 00583 float value; 00584 if (child.HasMean()) 00585 { 00586 float moveValue = InverseEval(child.Mean()); 00587 if (hasRave) 00588 { 00589 float moveCount = child.MoveCount(); 00590 float raveCount = child.RaveCount(); 00591 float weight = 00592 raveCount 00593 / (moveCount 00594 * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount) 00595 + raveCount); 00596 value = weight * child.RaveValue() + (1.f - weight) * moveValue; 00597 } 00598 else 00599 { 00600 // This can happen only in lock-free multi-threading. Normally, 00601 // each move played in a position should also cause a RAVE value 00602 // to be added. But in lock-free multi-threading it can happen 00603 // that the move value was already updated but the RAVE value not 00604 SG_ASSERT(m_numberThreads > 1 && m_lockFree); 00605 value = moveValue; 00606 } 00607 } 00608 else if (hasRave) 00609 value = child.RaveValue(); 00610 else 00611 value = m_firstPlayUrgency; 00612 SG_ASSERT(m_numberThreads > 1 00613 || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/); 00614 return value; 00615 } 00616 00617 string SgUctSearch::LastGameSummaryLine() const 00618 { 00619 return SummaryLine(LastGameInfo()); 00620 } 00621 00622 float SgUctSearch::Log(float x) const 00623 { 00624 // TODO: can we speed up the computation of the logarithm by taking 00625 // advantage of the fact that the argument is an integer type? 00626 // Maybe round result to integer (then it is simple the position of the 00627 // highest bit 00628 #if SG_UCTFASTLOG 00629 return m_fastLog.Log(x); 00630 #else 00631 return log(x); 00632 #endif 00633 } 00634 00635 /** Creates the children with the given moves and merges with existing 00636 children in the tree. */ 00637 void SgUctSearch::CreateChildren(SgUctThreadState& state, 00638 const SgUctNode& node, 00639 bool deleteChildTrees) 00640 { 00641 size_t threadId = state.m_threadId; 00642 if (! m_tree.HasCapacity(threadId, state.m_moves.size())) 00643 { 00644 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached") 00645 % m_tree.MaxNodes())); 00646 state.m_isTreeOutOfMem = true; 00647 m_isTreeOutOfMemory = true; 00648 SgSynchronizeThreadMemory(); 00649 return; 00650 } 00651 m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees); 00652 } 00653 00654 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current) 00655 { 00656 if (m_knowledgeThreshold.empty()) 00657 return false; 00658 for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i) 00659 { 00660 const std::size_t threshold = m_knowledgeThreshold[i]; 00661 if (current->KnowledgeCount() < threshold) 00662 { 00663 if (current->MoveCount() >= threshold) 00664 { 00665 // Mark knowledge computed immediately so other 00666 // threads fall through and do not waste time 00667 // re-computing this knowledge. 00668 SgUctNode* node = const_cast<SgUctNode*>(current); 00669 node->SetKnowledgeCount(threshold); 00670 SG_ASSERT(current->MoveCount()); 00671 return true; 00672 } 00673 return false; 00674 } 00675 } 00676 return false; 00677 } 00678 00679 void SgUctSearch::OnStartSearch() 00680 { 00681 m_mpiSynchronizer->OnStartSearch(*this); 00682 } 00683 00684 void SgUctSearch::OnEndSearch() 00685 { 00686 m_mpiSynchronizer->OnEndSearch(*this); 00687 } 00688 00689 /** Print time, mean, nodes searched, and PV */ 00690 void SgUctSearch::PrintSearchProgress(double currTime) const 00691 { 00692 const int MAX_SEQ_PRINT_LENGTH = 15; 00693 const size_t MIN_MOVE_COUNT = 10; 00694 size_t rootMoveCount = m_tree.Root().MoveCount(); 00695 float rootMean = m_tree.Root().Mean(); 00696 ostringstream out; 00697 const SgUctNode* current = &m_tree.Root(); 00698 out << fixed << setprecision(0) 00699 << SgTime::Format(currTime, true) << " | " 00700 << fixed << setprecision(3) << rootMean << " " 00701 << "| " << rootMoveCount << " "; 00702 for(int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i) 00703 { 00704 current = FindBestChild(*current); 00705 if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT) 00706 break; 00707 if (i == 0) 00708 out << "|"; 00709 if (i < MAX_SEQ_PRINT_LENGTH) 00710 out << " " << SgWritePoint(current->Move()); 00711 else 00712 out << " *"; 00713 } 00714 SgDebug() << out.str() << endl; 00715 } 00716 00717 void SgUctSearch::OnSearchIteration(std::size_t gameNumber, int threadId, 00718 const SgUctGameInfo& info) 00719 { 00720 const int DISPLAY_INTERVAL = 5; 00721 00722 m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info); 00723 double currTime = m_timer.GetTime(); 00724 00725 if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL) 00726 { 00727 PrintSearchProgress(currTime); 00728 m_lastScoreDisplayTime = currTime; 00729 } 00730 } 00731 00732 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock) 00733 { 00734 state.m_isTreeOutOfMem = false; 00735 state.GameStart(); 00736 SgUctGameInfo& info = state.m_gameInfo; 00737 info.Clear(m_numberPlayouts); 00738 bool isTerminal; 00739 bool abortInTree = ! PlayInTree(state, isTerminal); 00740 00741 // add a virtual loss to all nodes in path 00742 if (m_virtualLoss) 00743 m_tree.AddVirtualLoss(info.m_nodes); 00744 00745 // The playout phase is always unlocked 00746 if (lock != 0) 00747 lock->unlock(); 00748 00749 size_t nuMovesInTree = info.m_inTreeSequence.size(); 00750 00751 // Play some "fake" playouts if node is a proven node 00752 if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven()) 00753 { 00754 for (size_t i = 0; i < m_numberPlayouts; ++i) 00755 { 00756 info.m_sequence[i] = info.m_inTreeSequence; 00757 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false); 00758 float eval = info.m_nodes.back()->IsProvenWin() ? 1.0 : 0.0; 00759 size_t nuMoves = info.m_sequence[i].size(); 00760 if (nuMoves % 2 != 0) 00761 eval = InverseEval(eval); 00762 info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem; 00763 info.m_eval[i] = eval; 00764 } 00765 } 00766 else 00767 { 00768 state.StartPlayouts(); 00769 for (size_t i = 0; i < m_numberPlayouts; ++i) 00770 { 00771 state.StartPlayout(); 00772 info.m_sequence[i] = info.m_inTreeSequence; 00773 // skipRaveUpdate only used in playout phase 00774 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false); 00775 bool abort = abortInTree || state.m_isTreeOutOfMem; 00776 if (! abort && ! isTerminal) 00777 abort = ! PlayoutGame(state, i); 00778 float eval; 00779 if (abort) 00780 eval = UnknownEval(); 00781 else 00782 eval = state.Evaluate(); 00783 size_t nuMoves = info.m_sequence[i].size(); 00784 if (nuMoves % 2 != 0) 00785 eval = InverseEval(eval); 00786 info.m_aborted[i] = abort; 00787 info.m_eval[i] = eval; 00788 state.EndPlayout(); 00789 state.TakeBackPlayout(nuMoves - nuMovesInTree); 00790 } 00791 } 00792 state.TakeBackInTree(nuMovesInTree); 00793 00794 // End of unlocked part if ! m_lockFree 00795 if (lock != 0) 00796 lock->lock(); 00797 00798 // Remove the virtual loss 00799 if (m_virtualLoss) 00800 m_tree.RemoveVirtualLoss(info.m_nodes); 00801 00802 UpdateTree(info); 00803 if (m_rave) 00804 UpdateRaveValues(state); 00805 UpdateStatistics(info); 00806 } 00807 00808 /** Play game until it leaves the tree. 00809 @param state 00810 @param[out] isTerminal Was the sequence terminated because of a real 00811 terminal position (GenerateAllMoves() returned an empty list)? 00812 @return @c false, if game was aborted due to maximum length 00813 */ 00814 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal) 00815 { 00816 vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence; 00817 vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes; 00818 const SgUctNode* root = &m_tree.Root(); 00819 const SgUctNode* current = root; 00820 nodes.push_back(current); 00821 bool breakAfterSelect = false; 00822 isTerminal = false; 00823 bool deepenTree = false; 00824 while (true) 00825 { 00826 if (sequence.size() == m_maxGameLength) 00827 return false; 00828 if (current->IsProven()) 00829 break; 00830 if (! current->HasChildren()) 00831 { 00832 state.m_moves.clear(); 00833 SgProvenNodeType provenType = SG_NOT_PROVEN; 00834 state.GenerateAllMoves(0, state.m_moves, provenType); 00835 if (current == root) 00836 ApplyRootFilter(state.m_moves); 00837 if (provenType != SG_NOT_PROVEN) 00838 { 00839 SgUctNode* node = const_cast<SgUctNode*>(current); 00840 node->SetProvenNodeType(provenType); 00841 break; 00842 } 00843 if (state.m_moves.empty()) 00844 { 00845 isTerminal = true; 00846 break; 00847 } 00848 if ( deepenTree 00849 || current->MoveCount() >= m_expandThreshold 00850 ) 00851 { 00852 deepenTree = false; 00853 ExpandNode(state, *current); 00854 if (state.m_isTreeOutOfMem) 00855 return true; 00856 if (! deepenTree) 00857 breakAfterSelect = true; 00858 } 00859 else 00860 break; 00861 } 00862 else if (NeedToComputeKnowledge(current)) 00863 { 00864 m_statistics.m_knowledge++; 00865 deepenTree = false; 00866 SgProvenNodeType provenType = SG_NOT_PROVEN; 00867 bool truncate = state.GenerateAllMoves(current->KnowledgeCount(), 00868 state.m_moves, 00869 provenType); 00870 if (current == root) 00871 ApplyRootFilter(state.m_moves); 00872 CreateChildren(state, *current, truncate); 00873 if (provenType != SG_NOT_PROVEN) 00874 { 00875 SgUctNode* node = const_cast<SgUctNode*>(current); 00876 node->SetProvenNodeType(provenType); 00877 break; 00878 } 00879 if (state.m_moves.empty()) 00880 { 00881 isTerminal = true; 00882 break; 00883 } 00884 if (state.m_isTreeOutOfMem) 00885 return true; 00886 if (! deepenTree) 00887 breakAfterSelect = true; 00888 } 00889 current = &SelectChild(state.m_randomizeCounter, *current); 00890 nodes.push_back(current); 00891 SgMove move = current->Move(); 00892 state.Execute(move); 00893 sequence.push_back(move); 00894 if (breakAfterSelect) 00895 break; 00896 } 00897 return true; 00898 } 00899 00900 /** Finish the game using GeneratePlayoutMove(). 00901 @param state The thread state. 00902 @param playout The number of the playout. 00903 @return @c false if game was aborted 00904 */ 00905 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout) 00906 { 00907 SgUctGameInfo& info = state.m_gameInfo; 00908 vector<SgMove>& sequence = info.m_sequence[playout]; 00909 vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout]; 00910 while (true) 00911 { 00912 if (sequence.size() == m_maxGameLength) 00913 return false; 00914 bool skipRave = false; 00915 SgMove move = state.GeneratePlayoutMove(skipRave); 00916 if (move == SG_NULLMOVE) 00917 break; 00918 state.ExecutePlayout(move); 00919 sequence.push_back(move); 00920 skipRaveUpdate.push_back(skipRave); 00921 } 00922 return true; 00923 } 00924 00925 float SgUctSearch::Search(std::size_t maxGames, double maxTime, 00926 vector<SgMove>& sequence, 00927 const vector<SgMove>& rootFilter, 00928 SgUctTree* initTree, 00929 SgUctEarlyAbortParam* earlyAbort) 00930 { 00931 m_timer.Start(); 00932 m_rootFilter = rootFilter; 00933 if (m_logGames) 00934 { 00935 m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str()); 00936 m_log << "StartSearch maxGames=" << maxGames << '\n'; 00937 } 00938 m_maxGames = maxGames; 00939 m_maxTime = maxTime; 00940 m_earlyAbort.reset(0); 00941 if (earlyAbort != 0) 00942 m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort)); 00943 00944 for (size_t i = 0; i < m_threads.size(); ++i) 00945 { 00946 m_threads[i]->m_state->m_isSearchInitialized = false; 00947 } 00948 StartSearch(rootFilter, initTree); 00949 size_t pruneMinCount = m_pruneMinCount; 00950 while (true) 00951 { 00952 m_isTreeOutOfMemory = false; 00953 SgSynchronizeThreadMemory(); 00954 for (size_t i = 0; i < m_threads.size(); ++i) 00955 m_threads[i]->StartPlay(); 00956 for (size_t i = 0; i < m_threads.size(); ++i) 00957 m_threads[i]->WaitPlayFinished(); 00958 if (m_aborted || ! m_pruneFullTree) 00959 break; 00960 else 00961 { 00962 double startPruneTime = m_timer.GetTime(); 00963 SgDebug() << "SgUctSearch: pruning nodes with count < " 00964 << pruneMinCount << " (at time " << fixed << setprecision(1) 00965 << startPruneTime << ")\n"; 00966 SgUctTree& tempTree = GetTempTree(); 00967 m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true); 00968 int prunedSizePercentage = 00969 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes()); 00970 SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes() 00971 << " (" << prunedSizePercentage << "%) time: " 00972 << (m_timer.GetTime() - startPruneTime) << "\n"; 00973 if (prunedSizePercentage > 50) 00974 pruneMinCount *= 2; 00975 else 00976 pruneMinCount = m_pruneMinCount; 00977 m_tree.Swap(tempTree); 00978 } 00979 } 00980 EndSearch(); 00981 m_statistics.m_time = m_timer.GetTime(); 00982 if (m_statistics.m_time > numeric_limits<double>::epsilon()) 00983 m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time; 00984 if (m_logGames) 00985 m_log.close(); 00986 FindBestSequence(sequence); 00987 return (m_tree.Root().MoveCount() > 0) ? m_tree.Root().Mean() : 0.5; 00988 } 00989 00990 /** Loop invoked by each thread for playing games. */ 00991 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock) 00992 { 00993 if (! state.m_isSearchInitialized) 00994 { 00995 OnThreadStartSearch(state); 00996 state.m_isSearchInitialized = true; 00997 } 00998 00999 if (NumberThreads() == 1 || m_lockFree) 01000 lock = 0; 01001 if (lock != 0) 01002 lock->lock(); 01003 state.m_isTreeOutOfMem = false; 01004 while (! state.m_isTreeOutOfMem) 01005 { 01006 PlayGame(state, lock); 01007 OnSearchIteration(m_numberGames + 1, state.m_threadId, 01008 state.m_gameInfo); 01009 if (m_logGames) 01010 m_log << SummaryLine(state.m_gameInfo) << '\n'; 01011 ++m_numberGames; 01012 if (m_isTreeOutOfMemory) 01013 break; 01014 if (m_aborted || CheckAbortSearch(state)) 01015 { 01016 m_aborted = true; 01017 SgSynchronizeThreadMemory(); 01018 break; 01019 } 01020 } 01021 if (lock != 0) 01022 lock->unlock(); 01023 01024 m_searchLoopFinished->wait(); 01025 if (m_aborted || ! m_pruneFullTree) 01026 OnThreadEndSearch(state); 01027 } 01028 01029 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state) 01030 { 01031 m_mpiSynchronizer->OnThreadStartSearch(*this, state); 01032 } 01033 01034 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state) 01035 { 01036 m_mpiSynchronizer->OnThreadEndSearch(*this, state); 01037 } 01038 01039 SgPoint SgUctSearch::SearchOnePly(size_t maxGames, double maxTime, 01040 float& value) 01041 { 01042 if (m_threads.size() == 0) 01043 CreateThreads(); 01044 OnStartSearch(); 01045 // SearchOnePly is not multi-threaded. 01046 // It uses the state of the first thread. 01047 SgUctThreadState& state = ThreadState(0); 01048 state.StartSearch(); 01049 vector<SgMoveInfo> moves; 01050 SgProvenNodeType provenType; 01051 state.GameStart(); 01052 state.GenerateAllMoves(0, moves, provenType); 01053 vector<SgUctStatisticsBase> statistics(moves.size()); 01054 size_t games = 0; 01055 m_timer.Start(); 01056 SgUctGameInfo& info = state.m_gameInfo; 01057 while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort()) 01058 { 01059 for (size_t i = 0; i < moves.size(); ++i) 01060 { 01061 state.GameStart(); 01062 info.Clear(1); 01063 SgMove move = moves[i].m_move; 01064 state.Execute(move); 01065 info.m_inTreeSequence.push_back(move); 01066 info.m_sequence[0].push_back(move); 01067 info.m_skipRaveUpdate[0].push_back(false); 01068 state.StartPlayouts(); 01069 state.StartPlayout(); 01070 bool abortGame = ! PlayoutGame(state, 0); 01071 float eval; 01072 if (abortGame) 01073 eval = UnknownEval(); 01074 else 01075 eval = state.Evaluate(); 01076 state.EndPlayout(); 01077 state.TakeBackPlayout(info.m_sequence[0].size() - 1); 01078 state.TakeBackInTree(1); 01079 statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ? 01080 eval : InverseEval(eval)); 01081 OnSearchIteration(games + 1, 0, info); 01082 ++games; 01083 } 01084 } 01085 SgMove bestMove = SG_NULLMOVE; 01086 for (size_t i = 0; i < moves.size(); ++i) 01087 { 01088 SgDebug() << SgWritePoint(moves[i].m_move) 01089 << ' ' << statistics[i].Mean() << '\n'; 01090 if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value) 01091 { 01092 bestMove = moves[i].m_move; 01093 value = statistics[i].Mean(); 01094 } 01095 } 01096 return bestMove; 01097 } 01098 01099 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter, 01100 const SgUctNode& node) 01101 { 01102 bool useRave = m_rave; 01103 if (m_randomizeRaveFrequency > 0) 01104 { 01105 ++randomizeCounter; 01106 if (randomizeCounter % m_randomizeRaveFrequency == 0) 01107 useRave = false; 01108 } 01109 SG_ASSERT(node.HasChildren()); 01110 size_t posCount = node.PosCount(); 01111 if (posCount == 0) 01112 // If position count is zero, return first child 01113 return *SgUctChildIterator(m_tree, node); 01114 float logPosCount = Log(posCount); 01115 const SgUctNode* bestChild = 0; 01116 float bestUpperBound = 0; 01117 for (SgUctChildIterator it(m_tree, node); it; ++it) 01118 { 01119 const SgUctNode& child = *it; 01120 float bound = GetBound(useRave, logPosCount, child); 01121 if (bestChild == 0 || bound > bestUpperBound) 01122 { 01123 bestChild = &child; 01124 bestUpperBound = bound; 01125 } 01126 } 01127 SG_ASSERT(bestChild != 0); 01128 return *bestChild; 01129 } 01130 01131 void SgUctSearch::SetNumberThreads(std::size_t n) 01132 { 01133 SG_ASSERT(n >= 1); 01134 if (m_numberThreads == n) 01135 return; 01136 m_numberThreads = n; 01137 CreateThreads(); 01138 } 01139 01140 void SgUctSearch::SetRave(bool enable) 01141 { 01142 if (enable && m_moveRange <= 0) 01143 throw SgException("RAVE not supported for this game"); 01144 m_rave = enable; 01145 } 01146 01147 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory) 01148 { 01149 SG_ASSERT(m_threadStateFactory.get() == 0); 01150 m_threadStateFactory.reset(factory); 01151 DeleteThreads(); 01152 // Don't create states here, because this function could be called in the 01153 // constructor of the subclass, and the factory passes the search (which 01154 // is not fully constructed) as an argument to the Create() function 01155 } 01156 01157 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter, 01158 SgUctTree* initTree) 01159 { 01160 if (m_threads.size() == 0) 01161 CreateThreads(); 01162 if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU) 01163 // Using CPU time with multiple threads makes the measured time 01164 // and games/sec not very meaningful; the total cputime is not equal 01165 // to the total real time, even if there is no other load on the 01166 // machine, because the time, while threads are waiting for a lock 01167 // does not contribute to the cputime. 01168 SgWarning() << "SgUctSearch: using cpu time with multiple threads\n"; 01169 m_raveWeightParam1 = 1.f / m_raveWeightInitial; 01170 m_raveWeightParam2 = 1.f / m_raveWeightFinal; 01171 if (initTree == 0) 01172 m_tree.Clear(); 01173 else 01174 { 01175 m_tree.Swap(*initTree); 01176 if (m_tree.HasCapacity(0, m_tree.Root().NuChildren())) 01177 m_tree.ApplyFilter(0, m_tree.Root(), rootFilter); 01178 else 01179 SgWarning() << 01180 "SgUctSearch: " 01181 "root filter not applied (tree reached maximum size)\n"; 01182 } 01183 m_statistics.Clear(); 01184 m_aborted = false; 01185 m_wasEarlyAbort = false; 01186 m_checkTimeInterval = 1; 01187 m_numberGames = 0; 01188 m_lastScoreDisplayTime = m_timer.GetTime(); 01189 OnStartSearch(); 01190 m_startRootMoveCount = m_tree.Root().MoveCount(); 01191 for (size_t i = 0; i < m_threads.size(); ++i) 01192 ThreadState(i).StartSearch(); 01193 } 01194 01195 void SgUctSearch::EndSearch() 01196 { 01197 OnEndSearch(); 01198 } 01199 01200 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const 01201 { 01202 ostringstream buffer; 01203 const vector<const SgUctNode*>& nodes = info.m_nodes; 01204 for (size_t i = 1; i < nodes.size(); ++i) 01205 { 01206 const SgUctNode* node = nodes[i]; 01207 SgMove move = node->Move(); 01208 buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2) 01209 << node->Mean() << ',' << node->MoveCount() << ')'; 01210 } 01211 for (size_t i = 0; i < info.m_eval.size(); ++i) 01212 buffer << ' ' << fixed << setprecision(2) << info.m_eval[i]; 01213 return buffer.str(); 01214 } 01215 01216 void SgUctSearch::UpdateCheckTimeInterval(double time) 01217 { 01218 if (time < numeric_limits<double>::epsilon()) 01219 return; 01220 // Dynamically update m_checkTimeInterval (see comment at definition of 01221 // m_checkTimeInterval) 01222 float wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime); 01223 if (time < wantedTimeDiff / 10) 01224 { 01225 // Computing games per second might be unreliable for small times 01226 m_checkTimeInterval *= 2; 01227 return; 01228 } 01229 m_statistics.m_gamesPerSecond = GamesPlayed() / time; 01230 double gamesPerSecondPerThread = 01231 m_statistics.m_gamesPerSecond / m_numberThreads; 01232 m_checkTimeInterval = 01233 static_cast<size_t>(wantedTimeDiff * gamesPerSecondPerThread); 01234 if (m_checkTimeInterval == 0) 01235 m_checkTimeInterval = 1; 01236 } 01237 01238 /** Update the RAVE values in the tree for both players after a game was 01239 played. 01240 @see SgUctSearch::Rave() 01241 */ 01242 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state) 01243 { 01244 for (size_t i = 0; i < m_numberPlayouts; ++i) 01245 UpdateRaveValues(state, i); 01246 } 01247 01248 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state, 01249 std::size_t playout) 01250 { 01251 SgUctGameInfo& info = state.m_gameInfo; 01252 const vector<SgMove>& sequence = info.m_sequence[playout]; 01253 if (sequence.size() == 0) 01254 return; 01255 SG_ASSERT(m_moveRange > 0); 01256 size_t* firstPlay = state.m_firstPlay.get(); 01257 size_t* firstPlayOpp = state.m_firstPlayOpp.get(); 01258 fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max()); 01259 fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max()); 01260 const vector<const SgUctNode*>& nodes = info.m_nodes; 01261 const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout]; 01262 float eval = info.m_eval[playout]; 01263 float invEval = InverseEval(eval); 01264 size_t nuNodes = nodes.size(); 01265 size_t i = sequence.size() - 1; 01266 bool opp = (i % 2 != 0); 01267 01268 // Update firstPlay, firstPlayOpp arrays using playout moves 01269 for ( ; i >= nuNodes; --i) 01270 { 01271 SG_ASSERT(i < skipRaveUpdate.size()); 01272 SG_ASSERT(i < sequence.size()); 01273 if (! skipRaveUpdate[i]) 01274 { 01275 SgMove mv = sequence[i]; 01276 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]); 01277 if (i < first) 01278 first = i; 01279 } 01280 opp = ! opp; 01281 } 01282 01283 while (true) 01284 { 01285 SG_ASSERT(i < skipRaveUpdate.size()); 01286 SG_ASSERT(i < sequence.size()); 01287 // skipRaveUpdate currently not used in in-tree phase 01288 SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]); 01289 if (! skipRaveUpdate[i]) 01290 { 01291 SgMove mv = sequence[i]; 01292 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]); 01293 if (i < first) 01294 first = i; 01295 if (opp) 01296 UpdateRaveValues(state, playout, invEval, i, 01297 firstPlayOpp, firstPlay); 01298 else 01299 UpdateRaveValues(state, playout, eval, i, 01300 firstPlay, firstPlayOpp); 01301 } 01302 if (i == 0) 01303 break; 01304 --i; 01305 opp = ! opp; 01306 } 01307 } 01308 01309 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state, 01310 std::size_t playout, float eval, 01311 std::size_t i, 01312 const std::size_t firstPlay[], 01313 const std::size_t firstPlayOpp[]) 01314 { 01315 SG_ASSERT(i < state.m_gameInfo.m_nodes.size()); 01316 const SgUctNode* node = state.m_gameInfo.m_nodes[i]; 01317 if (! node->HasChildren()) 01318 return; 01319 std::size_t len = state.m_gameInfo.m_sequence[playout].size(); 01320 for (SgUctChildIterator it(m_tree, *node); it; ++it) 01321 { 01322 const SgUctNode& child = *it; 01323 SgMove mv = child.Move(); 01324 size_t first = firstPlay[mv]; 01325 SG_ASSERT(first >= i); 01326 if (first == numeric_limits<size_t>::max()) 01327 continue; 01328 if (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first)) 01329 continue; 01330 float weight; 01331 if (m_weightRaveUpdates) 01332 weight = 2.f - static_cast<float>(first - i) / (len - i); 01333 else 01334 weight = 1.f; 01335 m_tree.AddRaveValue(child, eval, weight); 01336 } 01337 } 01338 01339 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info) 01340 { 01341 m_statistics.m_movesInTree.Add( 01342 static_cast<float>(info.m_inTreeSequence.size())); 01343 for (size_t i = 0; i < m_numberPlayouts; ++i) 01344 { 01345 m_statistics.m_gameLength.Add( 01346 static_cast<float>(info.m_sequence[i].size())); 01347 m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f); 01348 } 01349 } 01350 01351 void SgUctSearch::UpdateTree(const SgUctGameInfo& info) 01352 { 01353 float eval = 0; 01354 for (size_t i = 0; i < m_numberPlayouts; ++i) 01355 eval += info.m_eval[i]; 01356 eval /= m_numberPlayouts; 01357 float inverseEval = InverseEval(eval); 01358 const vector<const SgUctNode*>& nodes = info.m_nodes; 01359 01360 for (size_t i = 0; i < nodes.size(); ++i) 01361 { 01362 const SgUctNode& node = *nodes[i]; 01363 const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0); 01364 m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval, m_numberPlayouts); 01365 } 01366 } 01367 01368 void SgUctSearch::WriteStatistics(ostream& out) const 01369 { 01370 out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n' 01371 << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n'; 01372 if (!m_knowledgeThreshold.empty()) 01373 out << SgWriteLabel("Knowledge") 01374 << m_statistics.m_knowledge << " (" << fixed << setprecision(1) 01375 << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount() 01376 << "%)\n"; 01377 m_statistics.Write(out); 01378 m_mpiSynchronizer->WriteStatistics(out); 01379 } 01380 01381 //----------------------------------------------------------------------------