00001 
00002 
00003 
00004 
00005 
00006 #include "SgSystem.h"
00007 #include "SgUctSearch.h"
00008 
00009 #include <algorithm>
00010 #include <cmath>
00011 #include <iomanip>
00012 #include <boost/format.hpp>
00013 #include <boost/io/ios_state.hpp>
00014 #include <boost/version.hpp>
00015 #include "SgDebug.h"
00016 #include "SgHashTable.h"
00017 #include "SgMath.h"
00018 #include "SgWrite.h"
00019 
00020 using namespace std;
00021 using boost::barrier;
00022 using boost::condition;
00023 using boost::format;
00024 using boost::mutex;
00025 using boost::shared_ptr;
00026 using boost::io::ios_all_saver;
00027 
00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000)
00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000)
00030 
00031 
00032 
00033 namespace {
00034 
00035 const bool DEBUG_THREADS = false;
00036 
00037 void Notify(mutex& aMutex, condition& aCondition)
00038 {
00039     mutex::scoped_lock lock(aMutex);
00040     aCondition.notify_all();
00041 }
00042 
00043 } 
00044 
00045 
00046 
00047 void SgUctGameInfo::Clear(std::size_t numberPlayouts)
00048 {
00049     m_nodes.clear();
00050     m_inTreeSequence.clear();
00051     if (numberPlayouts != m_sequence.size())
00052     {
00053         m_sequence.resize(numberPlayouts);
00054         m_skipRaveUpdate.resize(numberPlayouts);
00055         m_eval.resize(numberPlayouts);
00056         m_aborted.resize(numberPlayouts);
00057     }
00058     for (size_t i = 0; i < numberPlayouts; ++i)
00059     {
00060         m_sequence[i].clear();
00061         m_skipRaveUpdate[i].clear();
00062     }
00063 }
00064 
00065 
00066 
00067 SgUctThreadState::SgUctThreadState(size_t threadId, int moveRange)
00068     : m_threadId(threadId),
00069       m_isSearchInitialized(false),
00070       m_isTreeOutOfMem(false),
00071       m_randomizeCounter(0)
00072 {
00073     if (moveRange > 0)
00074     {
00075         m_firstPlay.reset(new size_t[moveRange]);
00076         m_firstPlayOpp.reset(new size_t[moveRange]);
00077     }
00078 }
00079 
00080 SgUctThreadState::~SgUctThreadState()
00081 {
00082 }
00083 
00084 void SgUctThreadState::EndPlayout()
00085 {
00086     
00087 }
00088 
00089 void SgUctThreadState::GameStart()
00090 {
00091     
00092 }
00093 
00094 void SgUctThreadState::StartPlayout()
00095 {
00096     
00097 }
00098 
00099 void SgUctThreadState::StartPlayouts()
00100 {
00101     
00102 }
00103 
00104 
00105 
00106 SgUctThreadStateFactory::~SgUctThreadStateFactory()
00107 {
00108 }
00109 
00110 
00111 
00112 SgUctSearch::Thread::Function::Function(Thread& thread)
00113     : m_thread(thread)
00114 {
00115 }
00116 
00117 void SgUctSearch::Thread::Function::operator()()
00118 {
00119     m_thread();
00120 }
00121 
00122 SgUctSearch::Thread::Thread(SgUctSearch& search,
00123                             auto_ptr<SgUctThreadState> state)
00124     : m_state(state),
00125       m_search(search),
00126       m_quit(false),
00127       m_threadReady(2),
00128       m_playFinishedLock(m_playFinishedMutex),
00129 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34
00130       m_globalLock(search.m_globalMutex, false),
00131 #else
00132       m_globalLock(search.m_globalMutex, boost::defer_lock),
00133 #endif
00134       m_thread(Function(*this))
00135 {
00136     m_threadReady.wait();
00137 }
00138 
00139 SgUctSearch::Thread::~Thread()
00140 {
00141     m_quit = true;
00142     StartPlay();
00143     m_thread.join();
00144 }
00145 
00146 void SgUctSearch::Thread::operator()()
00147 {
00148     if (DEBUG_THREADS)
00149         SgDebug() << "SgUctSearch::Thread: starting thread "
00150                   << m_state->m_threadId << '\n';
00151     mutex::scoped_lock lock(m_startPlayMutex);
00152     m_threadReady.wait();
00153     while (true)
00154     {
00155         m_startPlay.wait(lock);
00156         if (m_quit)
00157             break;
00158         m_search.SearchLoop(*m_state, &m_globalLock);
00159         Notify(m_playFinishedMutex, m_playFinished);
00160     }
00161     if (DEBUG_THREADS)
00162         SgDebug() << "SgUctSearch::Thread: finishing thread "
00163                   << m_state->m_threadId << '\n';
00164 }
00165 
00166 void SgUctSearch::Thread::StartPlay()
00167 {
00168     Notify(m_startPlayMutex, m_startPlay);
00169 }
00170 
00171 void SgUctSearch::Thread::WaitPlayFinished()
00172 {
00173     m_playFinished.wait(m_playFinishedLock);
00174 }
00175 
00176 
00177 
00178 void SgUctSearchStat::Clear()
00179 {
00180     m_time = 0;
00181     m_knowledge = 0;
00182     m_gamesPerSecond = 0;
00183     m_gameLength.Clear();
00184     m_movesInTree.Clear();
00185     m_aborted.Clear();
00186 }
00187 
00188 void SgUctSearchStat::Write(std::ostream& out) const
00189 {
00190     ios_all_saver saver(out);
00191     out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n'
00192         << SgWriteLabel("GameLen") << fixed << setprecision(1);
00193     m_gameLength.Write(out);
00194     out << '\n'
00195         << SgWriteLabel("InTree");
00196     m_movesInTree.Write(out);
00197     out << '\n'
00198         << SgWriteLabel("Aborted")
00199         << static_cast<int>(100 * m_aborted.Mean()) << "%\n"
00200         << SgWriteLabel("Games/s") << fixed << setprecision(1)
00201         << m_gamesPerSecond << '\n';
00202 }
00203 
00204 
00205 
00206 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory,
00207                          int moveRange)
00208     : m_threadStateFactory(threadStateFactory),
00209       m_logGames(false),
00210       m_rave(false),
00211       m_knowledgeThreshold(),
00212       m_moveSelect(SG_UCTMOVESELECT_COUNT),
00213       m_raveCheckSame(false),
00214       m_randomizeRaveFrequency(20),
00215       m_lockFree(false),
00216       m_weightRaveUpdates(true),
00217       m_pruneFullTree(true),
00218       m_numberThreads(1),
00219       m_numberPlayouts(1),
00220       m_maxNodes(4000000),
00221       m_pruneMinCount(16),
00222       m_moveRange(moveRange),
00223       m_maxGameLength(numeric_limits<size_t>::max()),
00224       m_expandThreshold(1),
00225       m_biasTermConstant(0.7f),
00226       m_firstPlayUrgency(10000),
00227       m_raveWeightInitial(0.9f),
00228       m_raveWeightFinal(20000),
00229       m_virtualLoss(false),
00230       m_logFileName("uctsearch.log"),
00231       m_fastLog(10),
00232       m_mpiSynchronizer(SgMpiNullSynchronizer::Create())
00233 {
00234     
00235     
00236     
00237 }
00238 
00239 SgUctSearch::~SgUctSearch()
00240 {
00241     DeleteThreads();
00242 }
00243 
00244 void SgUctSearch::ApplyRootFilter(vector<SgMoveInfo>& moves)
00245 {
00246     
00247     vector<SgMoveInfo> filteredMoves;
00248     for (vector<SgMoveInfo>::const_iterator it = moves.begin();
00249          it != moves.end(); ++it)
00250         if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move)
00251             == m_rootFilter.end())
00252             filteredMoves.push_back(*it);
00253     moves = filteredMoves;
00254 }
00255 
00256 size_t SgUctSearch::GamesPlayed() const
00257 {
00258     return m_tree.Root().MoveCount() - m_startRootMoveCount;
00259 }
00260 
00261 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state)
00262 {
00263     size_t gamesPlayed = GamesPlayed();
00264     if (SgUserAbort())
00265     {
00266         Debug(state, "SgUctSearch: abort flag");
00267         return true;
00268     }
00269     const bool isEarlyAbort = CheckEarlyAbort();
00270     if (gamesPlayed >= m_maxGames)
00271     {
00272         Debug(state, "SgUctSearch: max games reached");
00273         return true;
00274     }
00275     if (   isEarlyAbort
00276         && m_earlyAbort->m_reductionFactor * gamesPlayed >= m_maxGames
00277        )
00278     {
00279         Debug(state, "SgUctSearch: max games reached (early abort)");
00280         m_wasEarlyAbort = true;
00281         return true;
00282     }
00283     if (m_numberGames % m_checkTimeInterval == 0)
00284     {
00285         double time = m_timer.GetTime();
00286         if (time > m_maxTime)
00287         {
00288             Debug(state, "SgUctSearch: max time reached");
00289             return true;
00290         }
00291         if (isEarlyAbort
00292             && m_earlyAbort->m_reductionFactor * time > m_maxTime)
00293         {
00294             Debug(state, "SgUctSearch: max time reached (early abort)");
00295             m_wasEarlyAbort = true;
00296             return true;
00297         }
00298         UpdateCheckTimeInterval(time);
00299         if (m_moveSelect == SG_UCTMOVESELECT_COUNT)
00300         {
00301             double remainingGamesDouble = m_maxGames - gamesPlayed - 1;
00302             
00303             
00304             if (time > 1.)
00305             {
00306                 double remainingTime = m_maxTime - time;
00307                 remainingGamesDouble =
00308                     min(remainingGamesDouble,
00309                         remainingTime * m_statistics.m_gamesPerSecond);
00310             }
00311             size_t sizeTypeMax = numeric_limits<size_t>::max();
00312             size_t remainingGames;
00313             if (remainingGamesDouble > static_cast<double>(sizeTypeMax - 1))
00314                 remainingGames = sizeTypeMax;
00315             else
00316                 remainingGames = static_cast<size_t>(remainingGamesDouble);
00317             if (CheckCountAbort(state, remainingGames))
00318             {
00319                 Debug(state, "SgUctSearch: move cannot change anymore");
00320                 return true;
00321             }
00322         }
00323     }
00324     return false;
00325 }
00326 
00327 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state,
00328                                   std::size_t remainingGames) const
00329 {
00330     const SgUctNode& root = m_tree.Root();
00331     const SgUctNode* bestChild = FindBestChild(root);
00332     if (bestChild == 0)
00333         return false;
00334     size_t bestCount = bestChild->MoveCount();
00335     vector<SgMove>& excludeMoves = state.m_excludeMoves;
00336     excludeMoves.clear();
00337     excludeMoves.push_back(bestChild->Move());
00338     const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves);
00339     if (secondBestChild == 0)
00340         return false;
00341     std::size_t secondBestCount = secondBestChild->MoveCount();
00342     SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1);
00343     return (secondBestCount + remainingGames <= bestCount);
00344 }
00345 
00346 bool SgUctSearch::CheckEarlyAbort() const
00347 {
00348     const SgUctNode& root = m_tree.Root();
00349     return   m_earlyAbort.get() != 0
00350           && root.HasMean()
00351           && root.MoveCount() > m_earlyAbort->m_minGames
00352           && root.Mean() > m_earlyAbort->m_threshold;
00353 }
00354 
00355 void SgUctSearch::CreateThreads()
00356 {
00357     DeleteThreads();
00358     for (size_t i = 0; i < m_numberThreads; ++i)
00359     {
00360         auto_ptr<SgUctThreadState> state(
00361                                       m_threadStateFactory->Create(i, *this));
00362         shared_ptr<Thread> thread(new Thread(*this, state));
00363         m_threads.push_back(thread);
00364     }
00365     m_tree.CreateAllocators(m_numberThreads);
00366     m_tree.SetMaxNodes(m_maxNodes);
00367 
00368     m_searchLoopFinished.reset(new barrier(m_numberThreads));
00369 }
00370 
00371 
00372 
00373 
00374 
00375 
00376 
00377 
00378 void SgUctSearch::Debug(const SgUctThreadState& state,
00379                         const std::string& textLine)
00380 {
00381     if (m_numberThreads > 1)
00382     {
00383         
00384         GlobalLock lock(m_globalMutex);
00385         SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine);
00386     }
00387     else
00388         SgDebug() << (format("%1%\n") % textLine);
00389 }
00390 
00391 void SgUctSearch::DeleteThreads()
00392 {
00393     m_threads.clear();
00394 }
00395 
00396 
00397 
00398 
00399 
00400 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node)
00401 {
00402     size_t threadId = state.m_threadId;
00403     if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00404     {
00405         Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00406                          % m_tree.MaxNodes()));
00407         state.m_isTreeOutOfMem = true;
00408         m_isTreeOutOfMemory = true;
00409         SgSynchronizeThreadMemory();
00410         return;
00411     }
00412     m_tree.CreateChildren(threadId, node, state.m_moves);
00413 }
00414 
00415 const SgUctNode*
00416 SgUctSearch::FindBestChild(const SgUctNode& node,
00417                            const vector<SgMove>* excludeMoves) const
00418 {
00419     if (! node.HasChildren())
00420         return 0;
00421     const SgUctNode* bestChild = 0;
00422     float bestValue = 0;
00423     for (SgUctChildIterator it(m_tree, node); it; ++it)
00424     {
00425         const SgUctNode& child = *it;
00426         if (excludeMoves != 0)
00427         {
00428             vector<SgMove>::const_iterator begin = excludeMoves->begin();
00429             vector<SgMove>::const_iterator end = excludeMoves->end();
00430             if (find(begin, end, child.Move()) != end)
00431                 continue;
00432         }
00433         if (  ! child.HasMean()
00434            && ! (  (  m_moveSelect == SG_UCTMOVESELECT_BOUND
00435                    || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE
00436                    )
00437                 && m_rave
00438                 && child.HasRaveValue()
00439                 )
00440             )
00441             continue;
00442         float moveValue = InverseEval(child.Mean());
00443         size_t moveCount = child.MoveCount();
00444         float value;
00445         switch (m_moveSelect)
00446         {
00447         case SG_UCTMOVESELECT_VALUE:
00448             value = moveValue;
00449             break;
00450         case SG_UCTMOVESELECT_COUNT:
00451             value = moveCount;
00452             break;
00453         case SG_UCTMOVESELECT_BOUND:
00454             value = GetBound(m_rave, node, child);
00455             break;
00456         case SG_UCTMOVESELECT_ESTIMATE:
00457             value = GetValueEstimate(m_rave, child);
00458             break;
00459         default:
00460             SG_ASSERT(false);
00461             value = SG_UCTMOVESELECT_VALUE;
00462         }
00463         if (bestChild == 0 || value > bestValue)
00464         {
00465             bestChild = &child;
00466             bestValue = value;
00467         }
00468     }
00469     return bestChild;
00470 }
00471 
00472 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const
00473 {
00474     sequence.clear();
00475     const SgUctNode* current = &m_tree.Root();
00476     while (true)
00477     {
00478         current = FindBestChild(*current);
00479         if (current == 0)
00480             break;
00481         sequence.push_back(current->Move());
00482         if (! current->HasChildren())
00483             break;
00484     }
00485 }
00486 
00487 void SgUctSearch::GenerateAllMoves(std::vector<SgMoveInfo>& moves)
00488 {
00489     if (m_threads.size() == 0)
00490         CreateThreads();
00491     moves.clear();
00492     OnStartSearch();
00493     SgUctThreadState& state = ThreadState(0);
00494     state.StartSearch();
00495     SgProvenNodeType type;
00496     state.GenerateAllMoves(0, moves, type);
00497 }
00498 
00499 float SgUctSearch::GetBound(bool useRave, const SgUctNode& node,
00500                             const SgUctNode& child) const
00501 {
00502     size_t posCount = node.PosCount();
00503     return GetBound(useRave, Log(posCount), child);
00504 }
00505 
00506 float SgUctSearch::GetBound(bool useRave, float logPosCount, 
00507                             const SgUctNode& child) const
00508 {
00509     float value;
00510     if (useRave)
00511         value = GetValueEstimateRave(child);
00512     else
00513         value = GetValueEstimate(false, child);
00514     if (m_biasTermConstant == 0.0)
00515         return value;
00516     else
00517     {
00518         float moveCount = static_cast<float>(child.MoveCount());
00519         float bound =
00520             value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
00521         return bound;
00522     }
00523 }
00524 
00525 SgUctTree& SgUctSearch::GetTempTree()
00526 {
00527     m_tempTree.Clear();
00528     
00529     
00530     
00531     if (m_tempTree.NuAllocators() != NumberThreads())
00532     {
00533         m_tempTree.CreateAllocators(NumberThreads());
00534         m_tempTree.SetMaxNodes(MaxNodes());
00535     }
00536     else if (m_tempTree.MaxNodes() != MaxNodes())
00537     {
00538         m_tempTree.SetMaxNodes(MaxNodes());
00539     }
00540     return m_tempTree;
00541 }
00542 
00543 float SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
00544 {
00545     float value = 0.f;
00546     float weightSum = 0.f;
00547     bool hasValue = false;
00548     if (child.HasMean())
00549     {
00550         float weight = static_cast<float>(child.MoveCount());
00551         value += weight * InverseEval(child.Mean());
00552         weightSum += weight;
00553         hasValue = true;
00554     }
00555     if (useRave && child.HasRaveValue())
00556     {
00557         float raveCount = child.RaveCount();
00558         float weight =
00559                 raveCount
00560               / (  m_raveWeightParam1
00561                  + m_raveWeightParam2 * raveCount
00562                 );
00563         value += weight * child.RaveValue();
00564         weightSum += weight;
00565         hasValue = true;
00566     }
00567     if (hasValue)
00568         return value / weightSum;
00569     else
00570         return m_firstPlayUrgency;
00571 }
00572 
00573 
00574 
00575 
00576 
00577 
00578 
00579 float SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
00580 {
00581     SG_ASSERT(m_rave);
00582     bool hasRave = child.HasRaveValue();
00583     float value;
00584     if (child.HasMean())
00585     {
00586         float moveValue = InverseEval(child.Mean());
00587         if (hasRave)
00588         {
00589             float moveCount = child.MoveCount();
00590             float raveCount = child.RaveCount();
00591             float weight =
00592                 raveCount
00593                 / (moveCount
00594                    * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
00595                    + raveCount);
00596             value = weight * child.RaveValue() + (1.f - weight) * moveValue;
00597         }
00598         else
00599         {
00600             
00601             
00602             
00603             
00604             SG_ASSERT(m_numberThreads > 1 && m_lockFree);
00605             value = moveValue;
00606         }
00607     }
00608     else if (hasRave)
00609         value = child.RaveValue();
00610     else
00611         value = m_firstPlayUrgency;
00612     SG_ASSERT(m_numberThreads > 1
00613               || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3);
00614     return value;
00615 }
00616 
00617 string SgUctSearch::LastGameSummaryLine() const
00618 {
00619     return SummaryLine(LastGameInfo());
00620 }
00621 
00622 float SgUctSearch::Log(float x) const
00623 {
00624     
00625     
00626     
00627     
00628 #if SG_UCTFASTLOG
00629     return m_fastLog.Log(x);
00630 #else
00631     return log(x);
00632 #endif
00633 }
00634 
00635 
00636 
00637 void SgUctSearch::CreateChildren(SgUctThreadState& state, 
00638                                  const SgUctNode& node,
00639                                  bool deleteChildTrees)
00640 {
00641     size_t threadId = state.m_threadId;
00642     if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00643     {
00644         Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00645                          % m_tree.MaxNodes()));
00646         state.m_isTreeOutOfMem = true;
00647         m_isTreeOutOfMemory = true;
00648         SgSynchronizeThreadMemory();
00649         return;
00650     }
00651     m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees);
00652 }
00653 
00654 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current)
00655 {
00656     if (m_knowledgeThreshold.empty())
00657         return false;
00658     for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i)
00659     {
00660         const std::size_t threshold = m_knowledgeThreshold[i];
00661         if (current->KnowledgeCount() < threshold)
00662         {
00663             if (current->MoveCount() >= threshold)
00664             {
00665                 
00666                 
00667                 
00668                 SgUctNode* node = const_cast<SgUctNode*>(current);
00669                 node->SetKnowledgeCount(threshold);
00670                 SG_ASSERT(current->MoveCount());
00671                 return true;
00672             }
00673             return false;
00674         }
00675     }
00676     return false;
00677 }
00678 
00679 void SgUctSearch::OnStartSearch()
00680 {
00681     m_mpiSynchronizer->OnStartSearch(*this);
00682 }
00683 
00684 void SgUctSearch::OnEndSearch()
00685 {
00686     m_mpiSynchronizer->OnEndSearch(*this);
00687 }
00688 
00689 
00690 void SgUctSearch::PrintSearchProgress(double currTime) const
00691 {
00692     const int MAX_SEQ_PRINT_LENGTH = 15;
00693     const size_t MIN_MOVE_COUNT = 10;
00694     size_t rootMoveCount = m_tree.Root().MoveCount();
00695     float rootMean = m_tree.Root().Mean();
00696     ostringstream out;
00697     const SgUctNode* current = &m_tree.Root();
00698     out << fixed << setprecision(0)
00699         << SgTime::Format(currTime, true) << " | "
00700         << fixed << setprecision(3) << rootMean << " "
00701         << "| " << rootMoveCount << " ";
00702     for(int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i)
00703     {
00704         current = FindBestChild(*current);
00705         if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT)
00706             break;
00707         if (i == 0)
00708             out << "|";
00709         if (i < MAX_SEQ_PRINT_LENGTH)
00710             out << " " << SgWritePoint(current->Move());
00711         else
00712             out << " *";
00713     }
00714     SgDebug() << out.str() << endl;
00715 }
00716 
00717 void SgUctSearch::OnSearchIteration(std::size_t gameNumber, int threadId,
00718                                     const SgUctGameInfo& info)
00719 {
00720     const int DISPLAY_INTERVAL = 5;
00721 
00722     m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info);
00723     double currTime = m_timer.GetTime();
00724 
00725     if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL)
00726     {
00727         PrintSearchProgress(currTime);
00728         m_lastScoreDisplayTime = currTime;
00729     }
00730 }
00731 
00732 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock)
00733 {
00734     state.m_isTreeOutOfMem = false;
00735     state.GameStart();
00736     SgUctGameInfo& info = state.m_gameInfo;
00737     info.Clear(m_numberPlayouts);
00738     bool isTerminal;
00739     bool abortInTree = ! PlayInTree(state, isTerminal);
00740 
00741     
00742     if (m_virtualLoss)
00743         m_tree.AddVirtualLoss(info.m_nodes);
00744 
00745     
00746     if (lock != 0)
00747         lock->unlock();
00748 
00749     size_t nuMovesInTree = info.m_inTreeSequence.size();
00750 
00751     
00752     if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven())
00753     {
00754         for (size_t i = 0; i < m_numberPlayouts; ++i)
00755         {
00756             info.m_sequence[i] = info.m_inTreeSequence;
00757             info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00758             float eval = info.m_nodes.back()->IsProvenWin() ? 1.0 : 0.0;
00759             size_t nuMoves = info.m_sequence[i].size();
00760             if (nuMoves % 2 != 0)
00761                 eval = InverseEval(eval);
00762             info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem;
00763             info.m_eval[i] = eval;
00764         }
00765     }
00766     else 
00767     {
00768         state.StartPlayouts();
00769         for (size_t i = 0; i < m_numberPlayouts; ++i)
00770         {
00771             state.StartPlayout();
00772             info.m_sequence[i] = info.m_inTreeSequence;
00773             
00774             info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00775             bool abort = abortInTree || state.m_isTreeOutOfMem;
00776             if (! abort && ! isTerminal)
00777                 abort = ! PlayoutGame(state, i);
00778             float eval;
00779             if (abort)
00780                 eval = UnknownEval();
00781             else
00782                 eval = state.Evaluate();
00783             size_t nuMoves = info.m_sequence[i].size();
00784             if (nuMoves % 2 != 0)
00785                 eval = InverseEval(eval);
00786             info.m_aborted[i] = abort;
00787             info.m_eval[i] = eval;
00788             state.EndPlayout();
00789             state.TakeBackPlayout(nuMoves - nuMovesInTree);
00790         }
00791     }
00792     state.TakeBackInTree(nuMovesInTree);
00793 
00794     
00795     if (lock != 0)
00796         lock->lock();
00797 
00798     
00799     if (m_virtualLoss)
00800         m_tree.RemoveVirtualLoss(info.m_nodes);
00801 
00802     UpdateTree(info);
00803     if (m_rave)
00804         UpdateRaveValues(state);
00805     UpdateStatistics(info);
00806 }
00807 
00808 
00809 
00810 
00811 
00812 
00813 
00814 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal)
00815 {
00816     vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence;
00817     vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes;
00818     const SgUctNode* root = &m_tree.Root();
00819     const SgUctNode* current = root;
00820     nodes.push_back(current);
00821     bool breakAfterSelect = false;
00822     isTerminal = false;
00823     bool deepenTree = false;
00824     while (true)
00825     {
00826         if (sequence.size() == m_maxGameLength)
00827             return false;
00828         if (current->IsProven())
00829             break;
00830         if (! current->HasChildren())
00831         {
00832             state.m_moves.clear();
00833             SgProvenNodeType provenType = SG_NOT_PROVEN;
00834             state.GenerateAllMoves(0, state.m_moves, provenType);
00835             if (current == root)
00836                 ApplyRootFilter(state.m_moves);
00837             if (provenType != SG_NOT_PROVEN)
00838             {
00839                 SgUctNode* node = const_cast<SgUctNode*>(current);
00840                 node->SetProvenNodeType(provenType);
00841                 break;
00842             }
00843             if (state.m_moves.empty())
00844             {
00845                 isTerminal = true;
00846                 break;
00847             }
00848             if (  deepenTree
00849                || current->MoveCount() >= m_expandThreshold
00850                )
00851             {
00852                 deepenTree = false;
00853                 ExpandNode(state, *current);
00854                 if (state.m_isTreeOutOfMem)
00855                     return true;
00856                 if (! deepenTree)
00857                     breakAfterSelect = true;
00858             }
00859             else
00860                 break;
00861         }
00862         else if (NeedToComputeKnowledge(current))
00863         {
00864             m_statistics.m_knowledge++;
00865             deepenTree = false;
00866             SgProvenNodeType provenType = SG_NOT_PROVEN;
00867             bool truncate = state.GenerateAllMoves(current->KnowledgeCount(), 
00868                                                    state.m_moves,
00869                                                    provenType);
00870             if (current == root)
00871                 ApplyRootFilter(state.m_moves);
00872             CreateChildren(state, *current, truncate);
00873             if (provenType != SG_NOT_PROVEN)
00874             {
00875                 SgUctNode* node = const_cast<SgUctNode*>(current);
00876                 node->SetProvenNodeType(provenType);
00877                 break;
00878             }
00879             if (state.m_moves.empty())
00880             {
00881                 isTerminal = true;
00882                 break;
00883             }
00884             if (state.m_isTreeOutOfMem)
00885                 return true;
00886             if (! deepenTree)
00887                 breakAfterSelect = true;
00888         }
00889         current = &SelectChild(state.m_randomizeCounter, *current);
00890         nodes.push_back(current);
00891         SgMove move = current->Move();
00892         state.Execute(move);
00893         sequence.push_back(move);
00894         if (breakAfterSelect)
00895             break;
00896     }
00897     return true;
00898 }
00899 
00900 
00901 
00902 
00903 
00904 
00905 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout)
00906 {
00907     SgUctGameInfo& info = state.m_gameInfo;
00908     vector<SgMove>& sequence = info.m_sequence[playout];
00909     vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
00910     while (true)
00911     {
00912         if (sequence.size() == m_maxGameLength)
00913             return false;
00914         bool skipRave = false;
00915         SgMove move = state.GeneratePlayoutMove(skipRave);
00916         if (move == SG_NULLMOVE)
00917             break;
00918         state.ExecutePlayout(move);
00919         sequence.push_back(move);
00920         skipRaveUpdate.push_back(skipRave);
00921     }
00922     return true;
00923 }
00924 
00925 float SgUctSearch::Search(std::size_t maxGames, double maxTime,
00926                           vector<SgMove>& sequence,
00927                           const vector<SgMove>& rootFilter,
00928                           SgUctTree* initTree,
00929                           SgUctEarlyAbortParam* earlyAbort)
00930 {
00931     m_timer.Start();
00932     m_rootFilter = rootFilter;
00933     if (m_logGames)
00934     {
00935         m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str());
00936         m_log << "StartSearch maxGames=" << maxGames << '\n';
00937     }
00938     m_maxGames = maxGames;
00939     m_maxTime = maxTime;
00940     m_earlyAbort.reset(0);
00941     if (earlyAbort != 0)
00942         m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort));
00943 
00944     for (size_t i = 0; i < m_threads.size(); ++i)
00945     {
00946         m_threads[i]->m_state->m_isSearchInitialized = false;
00947     }
00948     StartSearch(rootFilter, initTree);
00949     size_t pruneMinCount = m_pruneMinCount;
00950     while (true)
00951     {
00952         m_isTreeOutOfMemory = false;
00953         SgSynchronizeThreadMemory();
00954         for (size_t i = 0; i < m_threads.size(); ++i)
00955             m_threads[i]->StartPlay();
00956         for (size_t i = 0; i < m_threads.size(); ++i)
00957             m_threads[i]->WaitPlayFinished();
00958         if (m_aborted || ! m_pruneFullTree)
00959             break;
00960         else
00961         {
00962             double startPruneTime = m_timer.GetTime();
00963             SgDebug() << "SgUctSearch: pruning nodes with count < "
00964                   << pruneMinCount << " (at time " << fixed << setprecision(1)
00965                   << startPruneTime << ")\n";
00966             SgUctTree& tempTree = GetTempTree();
00967             m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true);
00968             int prunedSizePercentage =
00969                 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes());
00970             SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes()
00971                       << " (" << prunedSizePercentage << "%) time: "
00972                       << (m_timer.GetTime() - startPruneTime) << "\n";
00973             if (prunedSizePercentage > 50)
00974                 pruneMinCount *= 2;
00975             else
00976                 pruneMinCount = m_pruneMinCount; 
00977             m_tree.Swap(tempTree);
00978         }
00979     }
00980     EndSearch();
00981     m_statistics.m_time = m_timer.GetTime();
00982     if (m_statistics.m_time > numeric_limits<double>::epsilon())
00983         m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time;
00984     if (m_logGames)
00985         m_log.close();
00986     FindBestSequence(sequence);
00987     return (m_tree.Root().MoveCount() > 0) ? m_tree.Root().Mean() : 0.5;
00988 }
00989 
00990 
00991 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock)
00992 {
00993     if (! state.m_isSearchInitialized)
00994     {
00995         OnThreadStartSearch(state);
00996         state.m_isSearchInitialized = true;
00997     }
00998 
00999     if (NumberThreads() == 1 || m_lockFree)
01000         lock = 0;
01001     if (lock != 0)
01002         lock->lock();
01003     state.m_isTreeOutOfMem = false;
01004     while (! state.m_isTreeOutOfMem)
01005     {
01006         PlayGame(state, lock);
01007         OnSearchIteration(m_numberGames + 1, state.m_threadId,
01008                           state.m_gameInfo);
01009         if (m_logGames)
01010             m_log << SummaryLine(state.m_gameInfo) << '\n';
01011         ++m_numberGames;
01012         if (m_isTreeOutOfMemory)
01013             break;
01014         if (m_aborted || CheckAbortSearch(state))
01015         {
01016             m_aborted = true;
01017             SgSynchronizeThreadMemory();
01018             break;
01019         }
01020     }
01021     if (lock != 0)
01022         lock->unlock();
01023 
01024     m_searchLoopFinished->wait();
01025     if (m_aborted || ! m_pruneFullTree)
01026         OnThreadEndSearch(state);
01027 }
01028 
01029 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state)
01030 {
01031     m_mpiSynchronizer->OnThreadStartSearch(*this, state);
01032 }
01033 
01034 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state)
01035 {
01036     m_mpiSynchronizer->OnThreadEndSearch(*this, state);
01037 }
01038 
01039 SgPoint SgUctSearch::SearchOnePly(size_t maxGames, double maxTime,
01040                                   float& value)
01041 {
01042     if (m_threads.size() == 0)
01043         CreateThreads();
01044     OnStartSearch();
01045     
01046     
01047     SgUctThreadState& state = ThreadState(0);
01048     state.StartSearch();
01049     vector<SgMoveInfo> moves;
01050     SgProvenNodeType provenType;
01051     state.GameStart();
01052     state.GenerateAllMoves(0, moves, provenType);
01053     vector<SgUctStatisticsBase> statistics(moves.size());
01054     size_t games = 0;
01055     m_timer.Start();
01056     SgUctGameInfo& info = state.m_gameInfo;
01057     while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort())
01058     {
01059         for (size_t i = 0; i < moves.size(); ++i)
01060         {
01061             state.GameStart();
01062             info.Clear(1);
01063             SgMove move = moves[i].m_move;
01064             state.Execute(move);
01065             info.m_inTreeSequence.push_back(move);
01066             info.m_sequence[0].push_back(move);
01067             info.m_skipRaveUpdate[0].push_back(false);
01068             state.StartPlayouts();
01069             state.StartPlayout();
01070             bool abortGame = ! PlayoutGame(state, 0);
01071             float eval;
01072             if (abortGame)
01073                 eval = UnknownEval();
01074             else
01075                 eval = state.Evaluate();
01076             state.EndPlayout();
01077             state.TakeBackPlayout(info.m_sequence[0].size() - 1);
01078             state.TakeBackInTree(1);
01079             statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ?
01080                               eval : InverseEval(eval));
01081             OnSearchIteration(games + 1, 0, info);
01082             ++games;
01083         }
01084     }
01085     SgMove bestMove = SG_NULLMOVE;
01086     for (size_t i = 0; i < moves.size(); ++i)
01087     {
01088         SgDebug() << SgWritePoint(moves[i].m_move) 
01089                   << ' ' << statistics[i].Mean() << '\n';
01090         if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value)
01091         {
01092             bestMove = moves[i].m_move;
01093             value = statistics[i].Mean();
01094         }
01095     }
01096     return bestMove;
01097 }
01098 
01099 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter, 
01100                                           const SgUctNode& node)
01101 {
01102     bool useRave = m_rave;
01103     if (m_randomizeRaveFrequency > 0)
01104     {
01105         ++randomizeCounter;
01106         if (randomizeCounter % m_randomizeRaveFrequency == 0)
01107             useRave = false;
01108     }
01109     SG_ASSERT(node.HasChildren());
01110     size_t posCount = node.PosCount();
01111     if (posCount == 0)
01112         
01113         return *SgUctChildIterator(m_tree, node);
01114     float logPosCount = Log(posCount);
01115     const SgUctNode* bestChild = 0;
01116     float bestUpperBound = 0;
01117     for (SgUctChildIterator it(m_tree, node); it; ++it)
01118     {
01119         const SgUctNode& child = *it;
01120         float bound = GetBound(useRave, logPosCount, child);
01121         if (bestChild == 0 || bound > bestUpperBound)
01122         {
01123             bestChild = &child;
01124             bestUpperBound = bound;
01125         }
01126     }
01127     SG_ASSERT(bestChild != 0);
01128     return *bestChild;
01129 }
01130 
01131 void SgUctSearch::SetNumberThreads(std::size_t n)
01132 {
01133     SG_ASSERT(n >= 1);
01134     if (m_numberThreads == n)
01135         return;
01136     m_numberThreads = n;
01137     CreateThreads();
01138 }
01139 
01140 void SgUctSearch::SetRave(bool enable)
01141 {
01142     if (enable && m_moveRange <= 0)
01143         throw SgException("RAVE not supported for this game");
01144     m_rave = enable;
01145 }
01146 
01147 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory)
01148 {
01149     SG_ASSERT(m_threadStateFactory.get() == 0);
01150     m_threadStateFactory.reset(factory);
01151     DeleteThreads();
01152     
01153     
01154     
01155 }
01156 
01157 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter,
01158                               SgUctTree* initTree)
01159 {
01160     if (m_threads.size() == 0)
01161         CreateThreads();
01162     if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU)
01163         
01164         
01165         
01166         
01167         
01168         SgWarning() << "SgUctSearch: using cpu time with multiple threads\n";
01169     m_raveWeightParam1 = 1.f / m_raveWeightInitial;
01170     m_raveWeightParam2 = 1.f / m_raveWeightFinal;
01171     if (initTree == 0)
01172         m_tree.Clear();
01173     else
01174     {
01175         m_tree.Swap(*initTree);
01176         if (m_tree.HasCapacity(0, m_tree.Root().NuChildren()))
01177             m_tree.ApplyFilter(0, m_tree.Root(), rootFilter);
01178         else
01179             SgWarning() <<
01180                 "SgUctSearch: "
01181                 "root filter not applied (tree reached maximum size)\n";
01182     }
01183     m_statistics.Clear();
01184     m_aborted = false;
01185     m_wasEarlyAbort = false;
01186     m_checkTimeInterval = 1;
01187     m_numberGames = 0;
01188     m_lastScoreDisplayTime = m_timer.GetTime();
01189     OnStartSearch();
01190     m_startRootMoveCount = m_tree.Root().MoveCount();
01191     for (size_t i = 0; i < m_threads.size(); ++i)
01192         ThreadState(i).StartSearch();
01193 }
01194 
01195 void SgUctSearch::EndSearch()
01196 {
01197     OnEndSearch();
01198 }
01199 
01200 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const
01201 {
01202     ostringstream buffer;
01203     const vector<const SgUctNode*>& nodes = info.m_nodes;
01204     for (size_t i = 1; i < nodes.size(); ++i)
01205     {
01206         const SgUctNode* node = nodes[i];
01207         SgMove move = node->Move();
01208         buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2)
01209                << node->Mean() << ',' << node->MoveCount() << ')';
01210     }
01211     for (size_t i = 0; i < info.m_eval.size(); ++i)
01212         buffer << ' ' << fixed << setprecision(2) << info.m_eval[i];
01213     return buffer.str();
01214 }
01215 
01216 void SgUctSearch::UpdateCheckTimeInterval(double time)
01217 {
01218     if (time < numeric_limits<double>::epsilon())
01219         return;
01220     
01221     
01222     float wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime);
01223     if (time < wantedTimeDiff / 10)
01224     {
01225         
01226         m_checkTimeInterval *= 2;
01227         return;
01228     }
01229     m_statistics.m_gamesPerSecond = GamesPlayed() / time;
01230     double gamesPerSecondPerThread =
01231         m_statistics.m_gamesPerSecond / m_numberThreads;
01232     m_checkTimeInterval =
01233         static_cast<size_t>(wantedTimeDiff * gamesPerSecondPerThread);
01234     if (m_checkTimeInterval == 0)
01235         m_checkTimeInterval = 1;
01236 }
01237 
01238 
01239 
01240 
01241 
01242 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state)
01243 {
01244     for (size_t i = 0; i < m_numberPlayouts; ++i)
01245         UpdateRaveValues(state, i);
01246 }
01247 
01248 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01249                                    std::size_t playout)
01250 {
01251     SgUctGameInfo& info = state.m_gameInfo;
01252     const vector<SgMove>& sequence = info.m_sequence[playout];
01253     if (sequence.size() == 0)
01254         return;
01255     SG_ASSERT(m_moveRange > 0);
01256     size_t* firstPlay = state.m_firstPlay.get();
01257     size_t* firstPlayOpp = state.m_firstPlayOpp.get();
01258     fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max());
01259     fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max());
01260     const vector<const SgUctNode*>& nodes = info.m_nodes;
01261     const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01262     float eval = info.m_eval[playout];
01263     float invEval = InverseEval(eval);
01264     size_t nuNodes = nodes.size();
01265     size_t i = sequence.size() - 1;
01266     bool opp = (i % 2 != 0);
01267 
01268     
01269     for ( ; i >= nuNodes; --i)
01270     {
01271         SG_ASSERT(i < skipRaveUpdate.size());
01272         SG_ASSERT(i < sequence.size());
01273         if (! skipRaveUpdate[i])
01274         {
01275             SgMove mv = sequence[i];
01276             size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01277             if (i < first)
01278                 first = i;
01279         }
01280         opp = ! opp;
01281     }
01282 
01283     while (true)
01284     {
01285         SG_ASSERT(i < skipRaveUpdate.size());
01286         SG_ASSERT(i < sequence.size());
01287         
01288         SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]);
01289         if (! skipRaveUpdate[i])
01290         {
01291             SgMove mv = sequence[i];
01292             size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01293             if (i < first)
01294                 first = i;
01295             if (opp)
01296                 UpdateRaveValues(state, playout, invEval, i,
01297                                  firstPlayOpp, firstPlay);
01298             else
01299                 UpdateRaveValues(state, playout, eval, i,
01300                                  firstPlay, firstPlayOpp);
01301         }
01302         if (i == 0)
01303             break;
01304         --i;
01305         opp = ! opp;
01306     }
01307 }
01308 
01309 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01310                                    std::size_t playout, float eval,
01311                                    std::size_t i,
01312                                    const std::size_t firstPlay[],
01313                                    const std::size_t firstPlayOpp[])
01314 {
01315     SG_ASSERT(i < state.m_gameInfo.m_nodes.size());
01316     const SgUctNode* node = state.m_gameInfo.m_nodes[i];
01317     if (! node->HasChildren())
01318         return;
01319     std::size_t len = state.m_gameInfo.m_sequence[playout].size();
01320     for (SgUctChildIterator it(m_tree, *node); it; ++it)
01321     {
01322         const SgUctNode& child = *it;
01323         SgMove mv = child.Move();
01324         size_t first = firstPlay[mv];
01325         SG_ASSERT(first >= i);
01326         if (first == numeric_limits<size_t>::max())
01327             continue;
01328         if  (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first))
01329             continue;
01330         float weight;
01331         if (m_weightRaveUpdates)
01332             weight = 2.f - static_cast<float>(first - i) / (len - i);
01333         else
01334             weight = 1.f;
01335         m_tree.AddRaveValue(child, eval, weight);
01336     }
01337 }
01338 
01339 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info)
01340 {
01341     m_statistics.m_movesInTree.Add(
01342                             static_cast<float>(info.m_inTreeSequence.size()));
01343     for (size_t i = 0; i < m_numberPlayouts; ++i)
01344     {
01345         m_statistics.m_gameLength.Add(
01346                                static_cast<float>(info.m_sequence[i].size()));
01347         m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f);
01348     }
01349 }
01350 
01351 void SgUctSearch::UpdateTree(const SgUctGameInfo& info)
01352 {
01353     float eval = 0;
01354     for (size_t i = 0; i < m_numberPlayouts; ++i)
01355         eval += info.m_eval[i];
01356     eval /= m_numberPlayouts;
01357     float inverseEval = InverseEval(eval);
01358     const vector<const SgUctNode*>& nodes = info.m_nodes;
01359 
01360     for (size_t i = 0; i < nodes.size(); ++i)
01361     {
01362         const SgUctNode& node = *nodes[i];
01363         const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0);
01364         m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval, m_numberPlayouts);
01365     }
01366 }
01367 
01368 void SgUctSearch::WriteStatistics(ostream& out) const
01369 {
01370     out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n'
01371         << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n';
01372     if (!m_knowledgeThreshold.empty())
01373         out << SgWriteLabel("Knowledge") 
01374             << m_statistics.m_knowledge << " (" << fixed << setprecision(1) 
01375             << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount()
01376             << "%)\n";
01377     m_statistics.Write(out);
01378     m_mpiSynchronizer->WriteStatistics(out);
01379 }
01380 
01381