00001
00002
00003
00004
00005
00006 #include "SgSystem.h"
00007 #include "SgUctSearch.h"
00008
00009 #include <algorithm>
00010 #include <cmath>
00011 #include <iomanip>
00012 #include <boost/format.hpp>
00013 #include <boost/io/ios_state.hpp>
00014 #include <boost/version.hpp>
00015 #include "SgDebug.h"
00016 #include "SgHashTable.h"
00017 #include "SgMath.h"
00018 #include "SgWrite.h"
00019
00020 using namespace std;
00021 using boost::barrier;
00022 using boost::condition;
00023 using boost::format;
00024 using boost::mutex;
00025 using boost::shared_ptr;
00026 using boost::io::ios_all_saver;
00027
00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000)
00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000)
00030
00031
00032
00033 namespace {
00034
00035 const bool DEBUG_THREADS = false;
00036
00037 void Notify(mutex& aMutex, condition& aCondition)
00038 {
00039 mutex::scoped_lock lock(aMutex);
00040 aCondition.notify_all();
00041 }
00042
00043 }
00044
00045
00046
00047 void SgUctGameInfo::Clear(std::size_t numberPlayouts)
00048 {
00049 m_nodes.clear();
00050 m_inTreeSequence.clear();
00051 if (numberPlayouts != m_sequence.size())
00052 {
00053 m_sequence.resize(numberPlayouts);
00054 m_skipRaveUpdate.resize(numberPlayouts);
00055 m_eval.resize(numberPlayouts);
00056 m_aborted.resize(numberPlayouts);
00057 }
00058 for (size_t i = 0; i < numberPlayouts; ++i)
00059 {
00060 m_sequence[i].clear();
00061 m_skipRaveUpdate[i].clear();
00062 }
00063 }
00064
00065
00066
00067 SgUctThreadState::SgUctThreadState(size_t threadId, int moveRange)
00068 : m_threadId(threadId),
00069 m_isSearchInitialized(false),
00070 m_isTreeOutOfMem(false),
00071 m_randomizeCounter(0)
00072 {
00073 if (moveRange > 0)
00074 {
00075 m_firstPlay.reset(new size_t[moveRange]);
00076 m_firstPlayOpp.reset(new size_t[moveRange]);
00077 }
00078 }
00079
00080 SgUctThreadState::~SgUctThreadState()
00081 {
00082 }
00083
00084 void SgUctThreadState::EndPlayout()
00085 {
00086
00087 }
00088
00089 void SgUctThreadState::GameStart()
00090 {
00091
00092 }
00093
00094 void SgUctThreadState::StartPlayout()
00095 {
00096
00097 }
00098
00099 void SgUctThreadState::StartPlayouts()
00100 {
00101
00102 }
00103
00104
00105
00106 SgUctThreadStateFactory::~SgUctThreadStateFactory()
00107 {
00108 }
00109
00110
00111
00112 SgUctSearch::Thread::Function::Function(Thread& thread)
00113 : m_thread(thread)
00114 {
00115 }
00116
00117 void SgUctSearch::Thread::Function::operator()()
00118 {
00119 m_thread();
00120 }
00121
00122 SgUctSearch::Thread::Thread(SgUctSearch& search,
00123 auto_ptr<SgUctThreadState> state)
00124 : m_state(state),
00125 m_search(search),
00126 m_quit(false),
00127 m_threadReady(2),
00128 m_playFinishedLock(m_playFinishedMutex),
00129 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34
00130 m_globalLock(search.m_globalMutex, false),
00131 #else
00132 m_globalLock(search.m_globalMutex, boost::defer_lock),
00133 #endif
00134 m_thread(Function(*this))
00135 {
00136 m_threadReady.wait();
00137 }
00138
00139 SgUctSearch::Thread::~Thread()
00140 {
00141 m_quit = true;
00142 StartPlay();
00143 m_thread.join();
00144 }
00145
00146 void SgUctSearch::Thread::operator()()
00147 {
00148 if (DEBUG_THREADS)
00149 SgDebug() << "SgUctSearch::Thread: starting thread "
00150 << m_state->m_threadId << '\n';
00151 mutex::scoped_lock lock(m_startPlayMutex);
00152 m_threadReady.wait();
00153 while (true)
00154 {
00155 m_startPlay.wait(lock);
00156 if (m_quit)
00157 break;
00158 m_search.SearchLoop(*m_state, &m_globalLock);
00159 Notify(m_playFinishedMutex, m_playFinished);
00160 }
00161 if (DEBUG_THREADS)
00162 SgDebug() << "SgUctSearch::Thread: finishing thread "
00163 << m_state->m_threadId << '\n';
00164 }
00165
00166 void SgUctSearch::Thread::StartPlay()
00167 {
00168 Notify(m_startPlayMutex, m_startPlay);
00169 }
00170
00171 void SgUctSearch::Thread::WaitPlayFinished()
00172 {
00173 m_playFinished.wait(m_playFinishedLock);
00174 }
00175
00176
00177
00178 void SgUctSearchStat::Clear()
00179 {
00180 m_time = 0;
00181 m_knowledge = 0;
00182 m_gamesPerSecond = 0;
00183 m_gameLength.Clear();
00184 m_movesInTree.Clear();
00185 m_aborted.Clear();
00186 }
00187
00188 void SgUctSearchStat::Write(std::ostream& out) const
00189 {
00190 ios_all_saver saver(out);
00191 out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n'
00192 << SgWriteLabel("GameLen") << fixed << setprecision(1);
00193 m_gameLength.Write(out);
00194 out << '\n'
00195 << SgWriteLabel("InTree");
00196 m_movesInTree.Write(out);
00197 out << '\n'
00198 << SgWriteLabel("Aborted")
00199 << static_cast<int>(100 * m_aborted.Mean()) << "%\n"
00200 << SgWriteLabel("Games/s") << fixed << setprecision(1)
00201 << m_gamesPerSecond << '\n';
00202 }
00203
00204
00205
00206 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory,
00207 int moveRange)
00208 : m_threadStateFactory(threadStateFactory),
00209 m_logGames(false),
00210 m_rave(false),
00211 m_knowledgeThreshold(),
00212 m_moveSelect(SG_UCTMOVESELECT_COUNT),
00213 m_raveCheckSame(false),
00214 m_randomizeRaveFrequency(20),
00215 m_lockFree(false),
00216 m_weightRaveUpdates(true),
00217 m_pruneFullTree(true),
00218 m_numberThreads(1),
00219 m_numberPlayouts(1),
00220 m_maxNodes(4000000),
00221 m_pruneMinCount(16),
00222 m_moveRange(moveRange),
00223 m_maxGameLength(numeric_limits<size_t>::max()),
00224 m_expandThreshold(1),
00225 m_biasTermConstant(0.7f),
00226 m_firstPlayUrgency(10000),
00227 m_raveWeightInitial(0.9f),
00228 m_raveWeightFinal(20000),
00229 m_virtualLoss(false),
00230 m_logFileName("uctsearch.log"),
00231 m_fastLog(10),
00232 m_mpiSynchronizer(SgMpiNullSynchronizer::Create())
00233 {
00234
00235
00236
00237 }
00238
00239 SgUctSearch::~SgUctSearch()
00240 {
00241 DeleteThreads();
00242 }
00243
00244 void SgUctSearch::ApplyRootFilter(vector<SgMoveInfo>& moves)
00245 {
00246
00247 vector<SgMoveInfo> filteredMoves;
00248 for (vector<SgMoveInfo>::const_iterator it = moves.begin();
00249 it != moves.end(); ++it)
00250 if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move)
00251 == m_rootFilter.end())
00252 filteredMoves.push_back(*it);
00253 moves = filteredMoves;
00254 }
00255
00256 size_t SgUctSearch::GamesPlayed() const
00257 {
00258 return m_tree.Root().MoveCount() - m_startRootMoveCount;
00259 }
00260
00261 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state)
00262 {
00263 size_t gamesPlayed = GamesPlayed();
00264 if (SgUserAbort())
00265 {
00266 Debug(state, "SgUctSearch: abort flag");
00267 return true;
00268 }
00269 const bool isEarlyAbort = CheckEarlyAbort();
00270 if (gamesPlayed >= m_maxGames)
00271 {
00272 Debug(state, "SgUctSearch: max games reached");
00273 return true;
00274 }
00275 if ( isEarlyAbort
00276 && m_earlyAbort->m_reductionFactor * gamesPlayed >= m_maxGames
00277 )
00278 {
00279 Debug(state, "SgUctSearch: max games reached (early abort)");
00280 m_wasEarlyAbort = true;
00281 return true;
00282 }
00283 if (m_numberGames % m_checkTimeInterval == 0)
00284 {
00285 double time = m_timer.GetTime();
00286 if (time > m_maxTime)
00287 {
00288 Debug(state, "SgUctSearch: max time reached");
00289 return true;
00290 }
00291 if (isEarlyAbort
00292 && m_earlyAbort->m_reductionFactor * time > m_maxTime)
00293 {
00294 Debug(state, "SgUctSearch: max time reached (early abort)");
00295 m_wasEarlyAbort = true;
00296 return true;
00297 }
00298 UpdateCheckTimeInterval(time);
00299 if (m_moveSelect == SG_UCTMOVESELECT_COUNT)
00300 {
00301 double remainingGamesDouble = m_maxGames - gamesPlayed - 1;
00302
00303
00304 if (time > 1.)
00305 {
00306 double remainingTime = m_maxTime - time;
00307 remainingGamesDouble =
00308 min(remainingGamesDouble,
00309 remainingTime * m_statistics.m_gamesPerSecond);
00310 }
00311 size_t sizeTypeMax = numeric_limits<size_t>::max();
00312 size_t remainingGames;
00313 if (remainingGamesDouble > static_cast<double>(sizeTypeMax - 1))
00314 remainingGames = sizeTypeMax;
00315 else
00316 remainingGames = static_cast<size_t>(remainingGamesDouble);
00317 if (CheckCountAbort(state, remainingGames))
00318 {
00319 Debug(state, "SgUctSearch: move cannot change anymore");
00320 return true;
00321 }
00322 }
00323 }
00324 return false;
00325 }
00326
00327 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state,
00328 std::size_t remainingGames) const
00329 {
00330 const SgUctNode& root = m_tree.Root();
00331 const SgUctNode* bestChild = FindBestChild(root);
00332 if (bestChild == 0)
00333 return false;
00334 size_t bestCount = bestChild->MoveCount();
00335 vector<SgMove>& excludeMoves = state.m_excludeMoves;
00336 excludeMoves.clear();
00337 excludeMoves.push_back(bestChild->Move());
00338 const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves);
00339 if (secondBestChild == 0)
00340 return false;
00341 std::size_t secondBestCount = secondBestChild->MoveCount();
00342 SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1);
00343 return (secondBestCount + remainingGames <= bestCount);
00344 }
00345
00346 bool SgUctSearch::CheckEarlyAbort() const
00347 {
00348 const SgUctNode& root = m_tree.Root();
00349 return m_earlyAbort.get() != 0
00350 && root.HasMean()
00351 && root.MoveCount() > m_earlyAbort->m_minGames
00352 && root.Mean() > m_earlyAbort->m_threshold;
00353 }
00354
00355 void SgUctSearch::CreateThreads()
00356 {
00357 DeleteThreads();
00358 for (size_t i = 0; i < m_numberThreads; ++i)
00359 {
00360 auto_ptr<SgUctThreadState> state(
00361 m_threadStateFactory->Create(i, *this));
00362 shared_ptr<Thread> thread(new Thread(*this, state));
00363 m_threads.push_back(thread);
00364 }
00365 m_tree.CreateAllocators(m_numberThreads);
00366 m_tree.SetMaxNodes(m_maxNodes);
00367
00368 m_searchLoopFinished.reset(new barrier(m_numberThreads));
00369 }
00370
00371
00372
00373
00374
00375
00376
00377
00378 void SgUctSearch::Debug(const SgUctThreadState& state,
00379 const std::string& textLine)
00380 {
00381 if (m_numberThreads > 1)
00382 {
00383
00384 GlobalLock lock(m_globalMutex);
00385 SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine);
00386 }
00387 else
00388 SgDebug() << (format("%1%\n") % textLine);
00389 }
00390
00391 void SgUctSearch::DeleteThreads()
00392 {
00393 m_threads.clear();
00394 }
00395
00396
00397
00398
00399
00400 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node)
00401 {
00402 size_t threadId = state.m_threadId;
00403 if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00404 {
00405 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00406 % m_tree.MaxNodes()));
00407 state.m_isTreeOutOfMem = true;
00408 m_isTreeOutOfMemory = true;
00409 SgSynchronizeThreadMemory();
00410 return;
00411 }
00412 m_tree.CreateChildren(threadId, node, state.m_moves);
00413 }
00414
00415 const SgUctNode*
00416 SgUctSearch::FindBestChild(const SgUctNode& node,
00417 const vector<SgMove>* excludeMoves) const
00418 {
00419 if (! node.HasChildren())
00420 return 0;
00421 const SgUctNode* bestChild = 0;
00422 float bestValue = 0;
00423 for (SgUctChildIterator it(m_tree, node); it; ++it)
00424 {
00425 const SgUctNode& child = *it;
00426 if (excludeMoves != 0)
00427 {
00428 vector<SgMove>::const_iterator begin = excludeMoves->begin();
00429 vector<SgMove>::const_iterator end = excludeMoves->end();
00430 if (find(begin, end, child.Move()) != end)
00431 continue;
00432 }
00433 if ( ! child.HasMean()
00434 && ! ( ( m_moveSelect == SG_UCTMOVESELECT_BOUND
00435 || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE
00436 )
00437 && m_rave
00438 && child.HasRaveValue()
00439 )
00440 )
00441 continue;
00442 float moveValue = InverseEval(child.Mean());
00443 size_t moveCount = child.MoveCount();
00444 float value;
00445 switch (m_moveSelect)
00446 {
00447 case SG_UCTMOVESELECT_VALUE:
00448 value = moveValue;
00449 break;
00450 case SG_UCTMOVESELECT_COUNT:
00451 value = moveCount;
00452 break;
00453 case SG_UCTMOVESELECT_BOUND:
00454 value = GetBound(m_rave, node, child);
00455 break;
00456 case SG_UCTMOVESELECT_ESTIMATE:
00457 value = GetValueEstimate(m_rave, child);
00458 break;
00459 default:
00460 SG_ASSERT(false);
00461 value = SG_UCTMOVESELECT_VALUE;
00462 }
00463 if (bestChild == 0 || value > bestValue)
00464 {
00465 bestChild = &child;
00466 bestValue = value;
00467 }
00468 }
00469 return bestChild;
00470 }
00471
00472 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const
00473 {
00474 sequence.clear();
00475 const SgUctNode* current = &m_tree.Root();
00476 while (true)
00477 {
00478 current = FindBestChild(*current);
00479 if (current == 0)
00480 break;
00481 sequence.push_back(current->Move());
00482 if (! current->HasChildren())
00483 break;
00484 }
00485 }
00486
00487 void SgUctSearch::GenerateAllMoves(std::vector<SgMoveInfo>& moves)
00488 {
00489 if (m_threads.size() == 0)
00490 CreateThreads();
00491 moves.clear();
00492 OnStartSearch();
00493 SgUctThreadState& state = ThreadState(0);
00494 state.StartSearch();
00495 SgProvenNodeType type;
00496 state.GenerateAllMoves(0, moves, type);
00497 }
00498
00499 float SgUctSearch::GetBound(bool useRave, const SgUctNode& node,
00500 const SgUctNode& child) const
00501 {
00502 size_t posCount = node.PosCount();
00503 return GetBound(useRave, Log(posCount), child);
00504 }
00505
00506 float SgUctSearch::GetBound(bool useRave, float logPosCount,
00507 const SgUctNode& child) const
00508 {
00509 float value;
00510 if (useRave)
00511 value = GetValueEstimateRave(child);
00512 else
00513 value = GetValueEstimate(false, child);
00514 if (m_biasTermConstant == 0.0)
00515 return value;
00516 else
00517 {
00518 float moveCount = static_cast<float>(child.MoveCount());
00519 float bound =
00520 value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
00521 return bound;
00522 }
00523 }
00524
00525 SgUctTree& SgUctSearch::GetTempTree()
00526 {
00527 m_tempTree.Clear();
00528
00529
00530
00531 if (m_tempTree.NuAllocators() != NumberThreads())
00532 {
00533 m_tempTree.CreateAllocators(NumberThreads());
00534 m_tempTree.SetMaxNodes(MaxNodes());
00535 }
00536 else if (m_tempTree.MaxNodes() != MaxNodes())
00537 {
00538 m_tempTree.SetMaxNodes(MaxNodes());
00539 }
00540 return m_tempTree;
00541 }
00542
00543 float SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
00544 {
00545 float value = 0.f;
00546 float weightSum = 0.f;
00547 bool hasValue = false;
00548 if (child.HasMean())
00549 {
00550 float weight = static_cast<float>(child.MoveCount());
00551 value += weight * InverseEval(child.Mean());
00552 weightSum += weight;
00553 hasValue = true;
00554 }
00555 if (useRave && child.HasRaveValue())
00556 {
00557 float raveCount = child.RaveCount();
00558 float weight =
00559 raveCount
00560 / ( m_raveWeightParam1
00561 + m_raveWeightParam2 * raveCount
00562 );
00563 value += weight * child.RaveValue();
00564 weightSum += weight;
00565 hasValue = true;
00566 }
00567 if (hasValue)
00568 return value / weightSum;
00569 else
00570 return m_firstPlayUrgency;
00571 }
00572
00573
00574
00575
00576
00577
00578
00579 float SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
00580 {
00581 SG_ASSERT(m_rave);
00582 bool hasRave = child.HasRaveValue();
00583 float value;
00584 if (child.HasMean())
00585 {
00586 float moveValue = InverseEval(child.Mean());
00587 if (hasRave)
00588 {
00589 float moveCount = child.MoveCount();
00590 float raveCount = child.RaveCount();
00591 float weight =
00592 raveCount
00593 / (moveCount
00594 * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
00595 + raveCount);
00596 value = weight * child.RaveValue() + (1.f - weight) * moveValue;
00597 }
00598 else
00599 {
00600
00601
00602
00603
00604 SG_ASSERT(m_numberThreads > 1 && m_lockFree);
00605 value = moveValue;
00606 }
00607 }
00608 else if (hasRave)
00609 value = child.RaveValue();
00610 else
00611 value = m_firstPlayUrgency;
00612 SG_ASSERT(m_numberThreads > 1
00613 || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3);
00614 return value;
00615 }
00616
00617 string SgUctSearch::LastGameSummaryLine() const
00618 {
00619 return SummaryLine(LastGameInfo());
00620 }
00621
00622 float SgUctSearch::Log(float x) const
00623 {
00624
00625
00626
00627
00628 #if SG_UCTFASTLOG
00629 return m_fastLog.Log(x);
00630 #else
00631 return log(x);
00632 #endif
00633 }
00634
00635
00636
00637 void SgUctSearch::CreateChildren(SgUctThreadState& state,
00638 const SgUctNode& node,
00639 bool deleteChildTrees)
00640 {
00641 size_t threadId = state.m_threadId;
00642 if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00643 {
00644 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00645 % m_tree.MaxNodes()));
00646 state.m_isTreeOutOfMem = true;
00647 m_isTreeOutOfMemory = true;
00648 SgSynchronizeThreadMemory();
00649 return;
00650 }
00651 m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees);
00652 }
00653
00654 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current)
00655 {
00656 if (m_knowledgeThreshold.empty())
00657 return false;
00658 for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i)
00659 {
00660 const std::size_t threshold = m_knowledgeThreshold[i];
00661 if (current->KnowledgeCount() < threshold)
00662 {
00663 if (current->MoveCount() >= threshold)
00664 {
00665
00666
00667
00668 SgUctNode* node = const_cast<SgUctNode*>(current);
00669 node->SetKnowledgeCount(threshold);
00670 SG_ASSERT(current->MoveCount());
00671 return true;
00672 }
00673 return false;
00674 }
00675 }
00676 return false;
00677 }
00678
00679 void SgUctSearch::OnStartSearch()
00680 {
00681 m_mpiSynchronizer->OnStartSearch(*this);
00682 }
00683
00684 void SgUctSearch::OnEndSearch()
00685 {
00686 m_mpiSynchronizer->OnEndSearch(*this);
00687 }
00688
00689
00690 void SgUctSearch::PrintSearchProgress(double currTime) const
00691 {
00692 const int MAX_SEQ_PRINT_LENGTH = 15;
00693 const size_t MIN_MOVE_COUNT = 10;
00694 size_t rootMoveCount = m_tree.Root().MoveCount();
00695 float rootMean = m_tree.Root().Mean();
00696 ostringstream out;
00697 const SgUctNode* current = &m_tree.Root();
00698 out << fixed << setprecision(0)
00699 << SgTime::Format(currTime, true) << " | "
00700 << fixed << setprecision(3) << rootMean << " "
00701 << "| " << rootMoveCount << " ";
00702 for(int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i)
00703 {
00704 current = FindBestChild(*current);
00705 if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT)
00706 break;
00707 if (i == 0)
00708 out << "|";
00709 if (i < MAX_SEQ_PRINT_LENGTH)
00710 out << " " << SgWritePoint(current->Move());
00711 else
00712 out << " *";
00713 }
00714 SgDebug() << out.str() << endl;
00715 }
00716
00717 void SgUctSearch::OnSearchIteration(std::size_t gameNumber, int threadId,
00718 const SgUctGameInfo& info)
00719 {
00720 const int DISPLAY_INTERVAL = 5;
00721
00722 m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info);
00723 double currTime = m_timer.GetTime();
00724
00725 if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL)
00726 {
00727 PrintSearchProgress(currTime);
00728 m_lastScoreDisplayTime = currTime;
00729 }
00730 }
00731
00732 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock)
00733 {
00734 state.m_isTreeOutOfMem = false;
00735 state.GameStart();
00736 SgUctGameInfo& info = state.m_gameInfo;
00737 info.Clear(m_numberPlayouts);
00738 bool isTerminal;
00739 bool abortInTree = ! PlayInTree(state, isTerminal);
00740
00741
00742 if (m_virtualLoss)
00743 m_tree.AddVirtualLoss(info.m_nodes);
00744
00745
00746 if (lock != 0)
00747 lock->unlock();
00748
00749 size_t nuMovesInTree = info.m_inTreeSequence.size();
00750
00751
00752 if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven())
00753 {
00754 for (size_t i = 0; i < m_numberPlayouts; ++i)
00755 {
00756 info.m_sequence[i] = info.m_inTreeSequence;
00757 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00758 float eval = info.m_nodes.back()->IsProvenWin() ? 1.0 : 0.0;
00759 size_t nuMoves = info.m_sequence[i].size();
00760 if (nuMoves % 2 != 0)
00761 eval = InverseEval(eval);
00762 info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem;
00763 info.m_eval[i] = eval;
00764 }
00765 }
00766 else
00767 {
00768 state.StartPlayouts();
00769 for (size_t i = 0; i < m_numberPlayouts; ++i)
00770 {
00771 state.StartPlayout();
00772 info.m_sequence[i] = info.m_inTreeSequence;
00773
00774 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00775 bool abort = abortInTree || state.m_isTreeOutOfMem;
00776 if (! abort && ! isTerminal)
00777 abort = ! PlayoutGame(state, i);
00778 float eval;
00779 if (abort)
00780 eval = UnknownEval();
00781 else
00782 eval = state.Evaluate();
00783 size_t nuMoves = info.m_sequence[i].size();
00784 if (nuMoves % 2 != 0)
00785 eval = InverseEval(eval);
00786 info.m_aborted[i] = abort;
00787 info.m_eval[i] = eval;
00788 state.EndPlayout();
00789 state.TakeBackPlayout(nuMoves - nuMovesInTree);
00790 }
00791 }
00792 state.TakeBackInTree(nuMovesInTree);
00793
00794
00795 if (lock != 0)
00796 lock->lock();
00797
00798
00799 if (m_virtualLoss)
00800 m_tree.RemoveVirtualLoss(info.m_nodes);
00801
00802 UpdateTree(info);
00803 if (m_rave)
00804 UpdateRaveValues(state);
00805 UpdateStatistics(info);
00806 }
00807
00808
00809
00810
00811
00812
00813
00814 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal)
00815 {
00816 vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence;
00817 vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes;
00818 const SgUctNode* root = &m_tree.Root();
00819 const SgUctNode* current = root;
00820 nodes.push_back(current);
00821 bool breakAfterSelect = false;
00822 isTerminal = false;
00823 bool deepenTree = false;
00824 while (true)
00825 {
00826 if (sequence.size() == m_maxGameLength)
00827 return false;
00828 if (current->IsProven())
00829 break;
00830 if (! current->HasChildren())
00831 {
00832 state.m_moves.clear();
00833 SgProvenNodeType provenType = SG_NOT_PROVEN;
00834 state.GenerateAllMoves(0, state.m_moves, provenType);
00835 if (current == root)
00836 ApplyRootFilter(state.m_moves);
00837 if (provenType != SG_NOT_PROVEN)
00838 {
00839 SgUctNode* node = const_cast<SgUctNode*>(current);
00840 node->SetProvenNodeType(provenType);
00841 break;
00842 }
00843 if (state.m_moves.empty())
00844 {
00845 isTerminal = true;
00846 break;
00847 }
00848 if ( deepenTree
00849 || current->MoveCount() >= m_expandThreshold
00850 )
00851 {
00852 deepenTree = false;
00853 ExpandNode(state, *current);
00854 if (state.m_isTreeOutOfMem)
00855 return true;
00856 if (! deepenTree)
00857 breakAfterSelect = true;
00858 }
00859 else
00860 break;
00861 }
00862 else if (NeedToComputeKnowledge(current))
00863 {
00864 m_statistics.m_knowledge++;
00865 deepenTree = false;
00866 SgProvenNodeType provenType = SG_NOT_PROVEN;
00867 bool truncate = state.GenerateAllMoves(current->KnowledgeCount(),
00868 state.m_moves,
00869 provenType);
00870 if (current == root)
00871 ApplyRootFilter(state.m_moves);
00872 CreateChildren(state, *current, truncate);
00873 if (provenType != SG_NOT_PROVEN)
00874 {
00875 SgUctNode* node = const_cast<SgUctNode*>(current);
00876 node->SetProvenNodeType(provenType);
00877 break;
00878 }
00879 if (state.m_moves.empty())
00880 {
00881 isTerminal = true;
00882 break;
00883 }
00884 if (state.m_isTreeOutOfMem)
00885 return true;
00886 if (! deepenTree)
00887 breakAfterSelect = true;
00888 }
00889 current = &SelectChild(state.m_randomizeCounter, *current);
00890 nodes.push_back(current);
00891 SgMove move = current->Move();
00892 state.Execute(move);
00893 sequence.push_back(move);
00894 if (breakAfterSelect)
00895 break;
00896 }
00897 return true;
00898 }
00899
00900
00901
00902
00903
00904
00905 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout)
00906 {
00907 SgUctGameInfo& info = state.m_gameInfo;
00908 vector<SgMove>& sequence = info.m_sequence[playout];
00909 vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
00910 while (true)
00911 {
00912 if (sequence.size() == m_maxGameLength)
00913 return false;
00914 bool skipRave = false;
00915 SgMove move = state.GeneratePlayoutMove(skipRave);
00916 if (move == SG_NULLMOVE)
00917 break;
00918 state.ExecutePlayout(move);
00919 sequence.push_back(move);
00920 skipRaveUpdate.push_back(skipRave);
00921 }
00922 return true;
00923 }
00924
00925 float SgUctSearch::Search(std::size_t maxGames, double maxTime,
00926 vector<SgMove>& sequence,
00927 const vector<SgMove>& rootFilter,
00928 SgUctTree* initTree,
00929 SgUctEarlyAbortParam* earlyAbort)
00930 {
00931 m_timer.Start();
00932 m_rootFilter = rootFilter;
00933 if (m_logGames)
00934 {
00935 m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str());
00936 m_log << "StartSearch maxGames=" << maxGames << '\n';
00937 }
00938 m_maxGames = maxGames;
00939 m_maxTime = maxTime;
00940 m_earlyAbort.reset(0);
00941 if (earlyAbort != 0)
00942 m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort));
00943
00944 for (size_t i = 0; i < m_threads.size(); ++i)
00945 {
00946 m_threads[i]->m_state->m_isSearchInitialized = false;
00947 }
00948 StartSearch(rootFilter, initTree);
00949 size_t pruneMinCount = m_pruneMinCount;
00950 while (true)
00951 {
00952 m_isTreeOutOfMemory = false;
00953 SgSynchronizeThreadMemory();
00954 for (size_t i = 0; i < m_threads.size(); ++i)
00955 m_threads[i]->StartPlay();
00956 for (size_t i = 0; i < m_threads.size(); ++i)
00957 m_threads[i]->WaitPlayFinished();
00958 if (m_aborted || ! m_pruneFullTree)
00959 break;
00960 else
00961 {
00962 double startPruneTime = m_timer.GetTime();
00963 SgDebug() << "SgUctSearch: pruning nodes with count < "
00964 << pruneMinCount << " (at time " << fixed << setprecision(1)
00965 << startPruneTime << ")\n";
00966 SgUctTree& tempTree = GetTempTree();
00967 m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true);
00968 int prunedSizePercentage =
00969 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes());
00970 SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes()
00971 << " (" << prunedSizePercentage << "%) time: "
00972 << (m_timer.GetTime() - startPruneTime) << "\n";
00973 if (prunedSizePercentage > 50)
00974 pruneMinCount *= 2;
00975 else
00976 pruneMinCount = m_pruneMinCount;
00977 m_tree.Swap(tempTree);
00978 }
00979 }
00980 EndSearch();
00981 m_statistics.m_time = m_timer.GetTime();
00982 if (m_statistics.m_time > numeric_limits<double>::epsilon())
00983 m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time;
00984 if (m_logGames)
00985 m_log.close();
00986 FindBestSequence(sequence);
00987 return (m_tree.Root().MoveCount() > 0) ? m_tree.Root().Mean() : 0.5;
00988 }
00989
00990
00991 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock)
00992 {
00993 if (! state.m_isSearchInitialized)
00994 {
00995 OnThreadStartSearch(state);
00996 state.m_isSearchInitialized = true;
00997 }
00998
00999 if (NumberThreads() == 1 || m_lockFree)
01000 lock = 0;
01001 if (lock != 0)
01002 lock->lock();
01003 state.m_isTreeOutOfMem = false;
01004 while (! state.m_isTreeOutOfMem)
01005 {
01006 PlayGame(state, lock);
01007 OnSearchIteration(m_numberGames + 1, state.m_threadId,
01008 state.m_gameInfo);
01009 if (m_logGames)
01010 m_log << SummaryLine(state.m_gameInfo) << '\n';
01011 ++m_numberGames;
01012 if (m_isTreeOutOfMemory)
01013 break;
01014 if (m_aborted || CheckAbortSearch(state))
01015 {
01016 m_aborted = true;
01017 SgSynchronizeThreadMemory();
01018 break;
01019 }
01020 }
01021 if (lock != 0)
01022 lock->unlock();
01023
01024 m_searchLoopFinished->wait();
01025 if (m_aborted || ! m_pruneFullTree)
01026 OnThreadEndSearch(state);
01027 }
01028
01029 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state)
01030 {
01031 m_mpiSynchronizer->OnThreadStartSearch(*this, state);
01032 }
01033
01034 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state)
01035 {
01036 m_mpiSynchronizer->OnThreadEndSearch(*this, state);
01037 }
01038
01039 SgPoint SgUctSearch::SearchOnePly(size_t maxGames, double maxTime,
01040 float& value)
01041 {
01042 if (m_threads.size() == 0)
01043 CreateThreads();
01044 OnStartSearch();
01045
01046
01047 SgUctThreadState& state = ThreadState(0);
01048 state.StartSearch();
01049 vector<SgMoveInfo> moves;
01050 SgProvenNodeType provenType;
01051 state.GameStart();
01052 state.GenerateAllMoves(0, moves, provenType);
01053 vector<SgUctStatisticsBase> statistics(moves.size());
01054 size_t games = 0;
01055 m_timer.Start();
01056 SgUctGameInfo& info = state.m_gameInfo;
01057 while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort())
01058 {
01059 for (size_t i = 0; i < moves.size(); ++i)
01060 {
01061 state.GameStart();
01062 info.Clear(1);
01063 SgMove move = moves[i].m_move;
01064 state.Execute(move);
01065 info.m_inTreeSequence.push_back(move);
01066 info.m_sequence[0].push_back(move);
01067 info.m_skipRaveUpdate[0].push_back(false);
01068 state.StartPlayouts();
01069 state.StartPlayout();
01070 bool abortGame = ! PlayoutGame(state, 0);
01071 float eval;
01072 if (abortGame)
01073 eval = UnknownEval();
01074 else
01075 eval = state.Evaluate();
01076 state.EndPlayout();
01077 state.TakeBackPlayout(info.m_sequence[0].size() - 1);
01078 state.TakeBackInTree(1);
01079 statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ?
01080 eval : InverseEval(eval));
01081 OnSearchIteration(games + 1, 0, info);
01082 ++games;
01083 }
01084 }
01085 SgMove bestMove = SG_NULLMOVE;
01086 for (size_t i = 0; i < moves.size(); ++i)
01087 {
01088 SgDebug() << SgWritePoint(moves[i].m_move)
01089 << ' ' << statistics[i].Mean() << '\n';
01090 if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value)
01091 {
01092 bestMove = moves[i].m_move;
01093 value = statistics[i].Mean();
01094 }
01095 }
01096 return bestMove;
01097 }
01098
01099 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter,
01100 const SgUctNode& node)
01101 {
01102 bool useRave = m_rave;
01103 if (m_randomizeRaveFrequency > 0)
01104 {
01105 ++randomizeCounter;
01106 if (randomizeCounter % m_randomizeRaveFrequency == 0)
01107 useRave = false;
01108 }
01109 SG_ASSERT(node.HasChildren());
01110 size_t posCount = node.PosCount();
01111 if (posCount == 0)
01112
01113 return *SgUctChildIterator(m_tree, node);
01114 float logPosCount = Log(posCount);
01115 const SgUctNode* bestChild = 0;
01116 float bestUpperBound = 0;
01117 for (SgUctChildIterator it(m_tree, node); it; ++it)
01118 {
01119 const SgUctNode& child = *it;
01120 float bound = GetBound(useRave, logPosCount, child);
01121 if (bestChild == 0 || bound > bestUpperBound)
01122 {
01123 bestChild = &child;
01124 bestUpperBound = bound;
01125 }
01126 }
01127 SG_ASSERT(bestChild != 0);
01128 return *bestChild;
01129 }
01130
01131 void SgUctSearch::SetNumberThreads(std::size_t n)
01132 {
01133 SG_ASSERT(n >= 1);
01134 if (m_numberThreads == n)
01135 return;
01136 m_numberThreads = n;
01137 CreateThreads();
01138 }
01139
01140 void SgUctSearch::SetRave(bool enable)
01141 {
01142 if (enable && m_moveRange <= 0)
01143 throw SgException("RAVE not supported for this game");
01144 m_rave = enable;
01145 }
01146
01147 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory)
01148 {
01149 SG_ASSERT(m_threadStateFactory.get() == 0);
01150 m_threadStateFactory.reset(factory);
01151 DeleteThreads();
01152
01153
01154
01155 }
01156
01157 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter,
01158 SgUctTree* initTree)
01159 {
01160 if (m_threads.size() == 0)
01161 CreateThreads();
01162 if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU)
01163
01164
01165
01166
01167
01168 SgWarning() << "SgUctSearch: using cpu time with multiple threads\n";
01169 m_raveWeightParam1 = 1.f / m_raveWeightInitial;
01170 m_raveWeightParam2 = 1.f / m_raveWeightFinal;
01171 if (initTree == 0)
01172 m_tree.Clear();
01173 else
01174 {
01175 m_tree.Swap(*initTree);
01176 if (m_tree.HasCapacity(0, m_tree.Root().NuChildren()))
01177 m_tree.ApplyFilter(0, m_tree.Root(), rootFilter);
01178 else
01179 SgWarning() <<
01180 "SgUctSearch: "
01181 "root filter not applied (tree reached maximum size)\n";
01182 }
01183 m_statistics.Clear();
01184 m_aborted = false;
01185 m_wasEarlyAbort = false;
01186 m_checkTimeInterval = 1;
01187 m_numberGames = 0;
01188 m_lastScoreDisplayTime = m_timer.GetTime();
01189 OnStartSearch();
01190 m_startRootMoveCount = m_tree.Root().MoveCount();
01191 for (size_t i = 0; i < m_threads.size(); ++i)
01192 ThreadState(i).StartSearch();
01193 }
01194
01195 void SgUctSearch::EndSearch()
01196 {
01197 OnEndSearch();
01198 }
01199
01200 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const
01201 {
01202 ostringstream buffer;
01203 const vector<const SgUctNode*>& nodes = info.m_nodes;
01204 for (size_t i = 1; i < nodes.size(); ++i)
01205 {
01206 const SgUctNode* node = nodes[i];
01207 SgMove move = node->Move();
01208 buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2)
01209 << node->Mean() << ',' << node->MoveCount() << ')';
01210 }
01211 for (size_t i = 0; i < info.m_eval.size(); ++i)
01212 buffer << ' ' << fixed << setprecision(2) << info.m_eval[i];
01213 return buffer.str();
01214 }
01215
01216 void SgUctSearch::UpdateCheckTimeInterval(double time)
01217 {
01218 if (time < numeric_limits<double>::epsilon())
01219 return;
01220
01221
01222 float wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime);
01223 if (time < wantedTimeDiff / 10)
01224 {
01225
01226 m_checkTimeInterval *= 2;
01227 return;
01228 }
01229 m_statistics.m_gamesPerSecond = GamesPlayed() / time;
01230 double gamesPerSecondPerThread =
01231 m_statistics.m_gamesPerSecond / m_numberThreads;
01232 m_checkTimeInterval =
01233 static_cast<size_t>(wantedTimeDiff * gamesPerSecondPerThread);
01234 if (m_checkTimeInterval == 0)
01235 m_checkTimeInterval = 1;
01236 }
01237
01238
01239
01240
01241
01242 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state)
01243 {
01244 for (size_t i = 0; i < m_numberPlayouts; ++i)
01245 UpdateRaveValues(state, i);
01246 }
01247
01248 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01249 std::size_t playout)
01250 {
01251 SgUctGameInfo& info = state.m_gameInfo;
01252 const vector<SgMove>& sequence = info.m_sequence[playout];
01253 if (sequence.size() == 0)
01254 return;
01255 SG_ASSERT(m_moveRange > 0);
01256 size_t* firstPlay = state.m_firstPlay.get();
01257 size_t* firstPlayOpp = state.m_firstPlayOpp.get();
01258 fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max());
01259 fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max());
01260 const vector<const SgUctNode*>& nodes = info.m_nodes;
01261 const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01262 float eval = info.m_eval[playout];
01263 float invEval = InverseEval(eval);
01264 size_t nuNodes = nodes.size();
01265 size_t i = sequence.size() - 1;
01266 bool opp = (i % 2 != 0);
01267
01268
01269 for ( ; i >= nuNodes; --i)
01270 {
01271 SG_ASSERT(i < skipRaveUpdate.size());
01272 SG_ASSERT(i < sequence.size());
01273 if (! skipRaveUpdate[i])
01274 {
01275 SgMove mv = sequence[i];
01276 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01277 if (i < first)
01278 first = i;
01279 }
01280 opp = ! opp;
01281 }
01282
01283 while (true)
01284 {
01285 SG_ASSERT(i < skipRaveUpdate.size());
01286 SG_ASSERT(i < sequence.size());
01287
01288 SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]);
01289 if (! skipRaveUpdate[i])
01290 {
01291 SgMove mv = sequence[i];
01292 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01293 if (i < first)
01294 first = i;
01295 if (opp)
01296 UpdateRaveValues(state, playout, invEval, i,
01297 firstPlayOpp, firstPlay);
01298 else
01299 UpdateRaveValues(state, playout, eval, i,
01300 firstPlay, firstPlayOpp);
01301 }
01302 if (i == 0)
01303 break;
01304 --i;
01305 opp = ! opp;
01306 }
01307 }
01308
01309 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01310 std::size_t playout, float eval,
01311 std::size_t i,
01312 const std::size_t firstPlay[],
01313 const std::size_t firstPlayOpp[])
01314 {
01315 SG_ASSERT(i < state.m_gameInfo.m_nodes.size());
01316 const SgUctNode* node = state.m_gameInfo.m_nodes[i];
01317 if (! node->HasChildren())
01318 return;
01319 std::size_t len = state.m_gameInfo.m_sequence[playout].size();
01320 for (SgUctChildIterator it(m_tree, *node); it; ++it)
01321 {
01322 const SgUctNode& child = *it;
01323 SgMove mv = child.Move();
01324 size_t first = firstPlay[mv];
01325 SG_ASSERT(first >= i);
01326 if (first == numeric_limits<size_t>::max())
01327 continue;
01328 if (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first))
01329 continue;
01330 float weight;
01331 if (m_weightRaveUpdates)
01332 weight = 2.f - static_cast<float>(first - i) / (len - i);
01333 else
01334 weight = 1.f;
01335 m_tree.AddRaveValue(child, eval, weight);
01336 }
01337 }
01338
01339 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info)
01340 {
01341 m_statistics.m_movesInTree.Add(
01342 static_cast<float>(info.m_inTreeSequence.size()));
01343 for (size_t i = 0; i < m_numberPlayouts; ++i)
01344 {
01345 m_statistics.m_gameLength.Add(
01346 static_cast<float>(info.m_sequence[i].size()));
01347 m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f);
01348 }
01349 }
01350
01351 void SgUctSearch::UpdateTree(const SgUctGameInfo& info)
01352 {
01353 float eval = 0;
01354 for (size_t i = 0; i < m_numberPlayouts; ++i)
01355 eval += info.m_eval[i];
01356 eval /= m_numberPlayouts;
01357 float inverseEval = InverseEval(eval);
01358 const vector<const SgUctNode*>& nodes = info.m_nodes;
01359
01360 for (size_t i = 0; i < nodes.size(); ++i)
01361 {
01362 const SgUctNode& node = *nodes[i];
01363 const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0);
01364 m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval, m_numberPlayouts);
01365 }
01366 }
01367
01368 void SgUctSearch::WriteStatistics(ostream& out) const
01369 {
01370 out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n'
01371 << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n';
01372 if (!m_knowledgeThreshold.empty())
01373 out << SgWriteLabel("Knowledge")
01374 << m_statistics.m_knowledge << " (" << fixed << setprecision(1)
01375 << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount()
01376 << "%)\n";
01377 m_statistics.Write(out);
01378 m_mpiSynchronizer->WriteStatistics(out);
01379 }
01380
01381