Index   Main   Namespaces   Classes   Hierarchy   Annotated   Files   Compound   Global   Pages  

SgUctSearch.cpp

Go to the documentation of this file.
00001 //----------------------------------------------------------------------------
00002 /** @file SgUctSearch.cpp
00003 */
00004 //----------------------------------------------------------------------------
00005 
00006 #include "SgSystem.h"
00007 #include "SgUctSearch.h"
00008 
00009 #include <algorithm>
00010 #include <cmath>
00011 #include <iomanip>
00012 #include <boost/format.hpp>
00013 #include <boost/io/ios_state.hpp>
00014 #include <boost/version.hpp>
00015 #include "SgDebug.h"
00016 #include "SgHashTable.h"
00017 #include "SgMath.h"
00018 #include "SgWrite.h"
00019 
00020 using namespace std;
00021 using boost::barrier;
00022 using boost::condition;
00023 using boost::format;
00024 using boost::mutex;
00025 using boost::shared_ptr;
00026 using boost::io::ios_all_saver;
00027 
00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000)
00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000)
00030 
00031 //----------------------------------------------------------------------------
00032 
00033 namespace {
00034 
00035 const bool DEBUG_THREADS = false;
00036 
00037 void Notify(mutex& aMutex, condition& aCondition)
00038 {
00039     mutex::scoped_lock lock(aMutex);
00040     aCondition.notify_all();
00041 }
00042 
00043 } // namespace
00044 
00045 //----------------------------------------------------------------------------
00046 
00047 void SgUctGameInfo::Clear(std::size_t numberPlayouts)
00048 {
00049     m_nodes.clear();
00050     m_inTreeSequence.clear();
00051     if (numberPlayouts != m_sequence.size())
00052     {
00053         m_sequence.resize(numberPlayouts);
00054         m_skipRaveUpdate.resize(numberPlayouts);
00055         m_eval.resize(numberPlayouts);
00056         m_aborted.resize(numberPlayouts);
00057     }
00058     for (size_t i = 0; i < numberPlayouts; ++i)
00059     {
00060         m_sequence[i].clear();
00061         m_skipRaveUpdate[i].clear();
00062     }
00063 }
00064 
00065 //----------------------------------------------------------------------------
00066 
00067 SgUctThreadState::SgUctThreadState(size_t threadId, int moveRange)
00068     : m_threadId(threadId),
00069       m_isSearchInitialized(false),
00070       m_isTreeOutOfMem(false),
00071       m_randomizeCounter(0)
00072 {
00073     if (moveRange > 0)
00074     {
00075         m_firstPlay.reset(new size_t[moveRange]);
00076         m_firstPlayOpp.reset(new size_t[moveRange]);
00077     }
00078 }
00079 
00080 SgUctThreadState::~SgUctThreadState()
00081 {
00082 }
00083 
00084 void SgUctThreadState::EndPlayout()
00085 {
00086     // Default implementation does nothing
00087 }
00088 
00089 void SgUctThreadState::GameStart()
00090 {
00091     // Default implementation does nothing
00092 }
00093 
00094 void SgUctThreadState::StartPlayout()
00095 {
00096     // Default implementation does nothing
00097 }
00098 
00099 void SgUctThreadState::StartPlayouts()
00100 {
00101     // Default implementation does nothing
00102 }
00103 
00104 //----------------------------------------------------------------------------
00105 
00106 SgUctThreadStateFactory::~SgUctThreadStateFactory()
00107 {
00108 }
00109 
00110 //----------------------------------------------------------------------------
00111 
00112 SgUctSearch::Thread::Function::Function(Thread& thread)
00113     : m_thread(thread)
00114 {
00115 }
00116 
00117 void SgUctSearch::Thread::Function::operator()()
00118 {
00119     m_thread();
00120 }
00121 
00122 SgUctSearch::Thread::Thread(SgUctSearch& search,
00123                             auto_ptr<SgUctThreadState> state)
00124     : m_state(state),
00125       m_search(search),
00126       m_quit(false),
00127       m_threadReady(2),
00128       m_playFinishedLock(m_playFinishedMutex),
00129 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34
00130       m_globalLock(search.m_globalMutex, false),
00131 #else
00132       m_globalLock(search.m_globalMutex, boost::defer_lock),
00133 #endif
00134       m_thread(Function(*this))
00135 {
00136     m_threadReady.wait();
00137 }
00138 
00139 SgUctSearch::Thread::~Thread()
00140 {
00141     m_quit = true;
00142     StartPlay();
00143     m_thread.join();
00144 }
00145 
00146 void SgUctSearch::Thread::operator()()
00147 {
00148     if (DEBUG_THREADS)
00149         SgDebug() << "SgUctSearch::Thread: starting thread "
00150                   << m_state->m_threadId << '\n';
00151     mutex::scoped_lock lock(m_startPlayMutex);
00152     m_threadReady.wait();
00153     while (true)
00154     {
00155         m_startPlay.wait(lock);
00156         if (m_quit)
00157             break;
00158         m_search.SearchLoop(*m_state, &m_globalLock);
00159         Notify(m_playFinishedMutex, m_playFinished);
00160     }
00161     if (DEBUG_THREADS)
00162         SgDebug() << "SgUctSearch::Thread: finishing thread "
00163                   << m_state->m_threadId << '\n';
00164 }
00165 
00166 void SgUctSearch::Thread::StartPlay()
00167 {
00168     Notify(m_startPlayMutex, m_startPlay);
00169 }
00170 
00171 void SgUctSearch::Thread::WaitPlayFinished()
00172 {
00173     m_playFinished.wait(m_playFinishedLock);
00174 }
00175 
00176 //----------------------------------------------------------------------------
00177 
00178 void SgUctSearchStat::Clear()
00179 {
00180     m_time = 0;
00181     m_knowledge = 0;
00182     m_gamesPerSecond = 0;
00183     m_gameLength.Clear();
00184     m_movesInTree.Clear();
00185     m_aborted.Clear();
00186 }
00187 
00188 void SgUctSearchStat::Write(std::ostream& out) const
00189 {
00190     ios_all_saver saver(out);
00191     out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n'
00192         << SgWriteLabel("GameLen") << fixed << setprecision(1);
00193     m_gameLength.Write(out);
00194     out << '\n'
00195         << SgWriteLabel("InTree");
00196     m_movesInTree.Write(out);
00197     out << '\n'
00198         << SgWriteLabel("Aborted")
00199         << static_cast<int>(100 * m_aborted.Mean()) << "%\n"
00200         << SgWriteLabel("Games/s") << fixed << setprecision(1)
00201         << m_gamesPerSecond << '\n';
00202 }
00203 
00204 //----------------------------------------------------------------------------
00205 
00206 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory,
00207                          int moveRange)
00208     : m_threadStateFactory(threadStateFactory),
00209       m_logGames(false),
00210       m_rave(false),
00211       m_knowledgeThreshold(),
00212       m_moveSelect(SG_UCTMOVESELECT_COUNT),
00213       m_raveCheckSame(false),
00214       m_randomizeRaveFrequency(20),
00215       m_lockFree(false),
00216       m_weightRaveUpdates(true),
00217       m_pruneFullTree(true),
00218       m_numberThreads(1),
00219       m_numberPlayouts(1),
00220       m_maxNodes(4000000),
00221       m_pruneMinCount(16),
00222       m_moveRange(moveRange),
00223       m_maxGameLength(numeric_limits<size_t>::max()),
00224       m_expandThreshold(1),
00225       m_biasTermConstant(0.7f),
00226       m_firstPlayUrgency(10000),
00227       m_raveWeightInitial(0.9f),
00228       m_raveWeightFinal(20000),
00229       m_virtualLoss(false),
00230       m_logFileName("uctsearch.log"),
00231       m_fastLog(10),
00232       m_mpiSynchronizer(SgMpiNullSynchronizer::Create())
00233 {
00234     // Don't create thread states here, because the factory passes the search
00235     // (which is not fully constructed here, because the subclass constructors
00236     // are not called yet) as an argument to the Create() function
00237 }
00238 
00239 SgUctSearch::~SgUctSearch()
00240 {
00241     DeleteThreads();
00242 }
00243 
00244 void SgUctSearch::ApplyRootFilter(vector<SgMoveInfo>& moves)
00245 {
00246     // Filter without changing the order of the unfiltered moves
00247     vector<SgMoveInfo> filteredMoves;
00248     for (vector<SgMoveInfo>::const_iterator it = moves.begin();
00249          it != moves.end(); ++it)
00250         if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move)
00251             == m_rootFilter.end())
00252             filteredMoves.push_back(*it);
00253     moves = filteredMoves;
00254 }
00255 
00256 size_t SgUctSearch::GamesPlayed() const
00257 {
00258     return m_tree.Root().MoveCount() - m_startRootMoveCount;
00259 }
00260 
00261 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state)
00262 {
00263     size_t gamesPlayed = GamesPlayed();
00264     if (SgUserAbort())
00265     {
00266         Debug(state, "SgUctSearch: abort flag");
00267         return true;
00268     }
00269     const bool isEarlyAbort = CheckEarlyAbort();
00270     if (gamesPlayed >= m_maxGames)
00271     {
00272         Debug(state, "SgUctSearch: max games reached");
00273         return true;
00274     }
00275     if (   isEarlyAbort
00276         && m_earlyAbort->m_reductionFactor * gamesPlayed >= m_maxGames
00277        )
00278     {
00279         Debug(state, "SgUctSearch: max games reached (early abort)");
00280         m_wasEarlyAbort = true;
00281         return true;
00282     }
00283     if (m_numberGames % m_checkTimeInterval == 0)
00284     {
00285         double time = m_timer.GetTime();
00286         if (time > m_maxTime)
00287         {
00288             Debug(state, "SgUctSearch: max time reached");
00289             return true;
00290         }
00291         if (isEarlyAbort
00292             && m_earlyAbort->m_reductionFactor * time > m_maxTime)
00293         {
00294             Debug(state, "SgUctSearch: max time reached (early abort)");
00295             m_wasEarlyAbort = true;
00296             return true;
00297         }
00298         UpdateCheckTimeInterval(time);
00299         if (m_moveSelect == SG_UCTMOVESELECT_COUNT)
00300         {
00301             double remainingGamesDouble = m_maxGames - gamesPlayed - 1;
00302             // Use time based count abort, only if time > 1, otherwise
00303             // m_gamesPerSecond is unreliable
00304             if (time > 1.)
00305             {
00306                 double remainingTime = m_maxTime - time;
00307                 remainingGamesDouble =
00308                     min(remainingGamesDouble,
00309                         remainingTime * m_statistics.m_gamesPerSecond);
00310             }
00311             size_t sizeTypeMax = numeric_limits<size_t>::max();
00312             size_t remainingGames;
00313             if (remainingGamesDouble > static_cast<double>(sizeTypeMax - 1))
00314                 remainingGames = sizeTypeMax;
00315             else
00316                 remainingGames = static_cast<size_t>(remainingGamesDouble);
00317             if (CheckCountAbort(state, remainingGames))
00318             {
00319                 Debug(state, "SgUctSearch: move cannot change anymore");
00320                 return true;
00321             }
00322         }
00323     }
00324     return false;
00325 }
00326 
00327 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state,
00328                                   std::size_t remainingGames) const
00329 {
00330     const SgUctNode& root = m_tree.Root();
00331     const SgUctNode* bestChild = FindBestChild(root);
00332     if (bestChild == 0)
00333         return false;
00334     size_t bestCount = bestChild->MoveCount();
00335     vector<SgMove>& excludeMoves = state.m_excludeMoves;
00336     excludeMoves.clear();
00337     excludeMoves.push_back(bestChild->Move());
00338     const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves);
00339     if (secondBestChild == 0)
00340         return false;
00341     std::size_t secondBestCount = secondBestChild->MoveCount();
00342     SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1);
00343     return (secondBestCount + remainingGames <= bestCount);
00344 }
00345 
00346 bool SgUctSearch::CheckEarlyAbort() const
00347 {
00348     const SgUctNode& root = m_tree.Root();
00349     return   m_earlyAbort.get() != 0
00350           && root.HasMean()
00351           && root.MoveCount() > m_earlyAbort->m_minGames
00352           && root.Mean() > m_earlyAbort->m_threshold;
00353 }
00354 
00355 void SgUctSearch::CreateThreads()
00356 {
00357     DeleteThreads();
00358     for (size_t i = 0; i < m_numberThreads; ++i)
00359     {
00360         auto_ptr<SgUctThreadState> state(
00361                                       m_threadStateFactory->Create(i, *this));
00362         shared_ptr<Thread> thread(new Thread(*this, state));
00363         m_threads.push_back(thread);
00364     }
00365     m_tree.CreateAllocators(m_numberThreads);
00366     m_tree.SetMaxNodes(m_maxNodes);
00367 
00368     m_searchLoopFinished.reset(new barrier(m_numberThreads));
00369 }
00370 
00371 /** Write a debugging line of text from within a thread.
00372     Prepends the line with the thread number if number of threads is greater
00373     than one. Also ensures that the line is written as a single string to
00374     avoid intermingling of text lines from different threads.
00375     @param state The state of the thread (only used for state.m_threadId)
00376     @param textLine The line of text without trailing newline character.
00377 */
00378 void SgUctSearch::Debug(const SgUctThreadState& state,
00379                         const std::string& textLine)
00380 {
00381     if (m_numberThreads > 1)
00382     {
00383         // SgDebug() is not necessarily thread-safe
00384         GlobalLock lock(m_globalMutex);
00385         SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine);
00386     }
00387     else
00388         SgDebug() << (format("%1%\n") % textLine);
00389 }
00390 
00391 void SgUctSearch::DeleteThreads()
00392 {
00393     m_threads.clear();
00394 }
00395 
00396 /** Expand a node.
00397     @param state The thread state with state.m_moves already computed.
00398     @param node The node to expand.
00399 */
00400 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node)
00401 {
00402     size_t threadId = state.m_threadId;
00403     if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00404     {
00405         Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00406                          % m_tree.MaxNodes()));
00407         state.m_isTreeOutOfMem = true;
00408         m_isTreeOutOfMemory = true;
00409         SgSynchronizeThreadMemory();
00410         return;
00411     }
00412     m_tree.CreateChildren(threadId, node, state.m_moves);
00413 }
00414 
00415 const SgUctNode*
00416 SgUctSearch::FindBestChild(const SgUctNode& node,
00417                            const vector<SgMove>* excludeMoves) const
00418 {
00419     if (! node.HasChildren())
00420         return 0;
00421     const SgUctNode* bestChild = 0;
00422     float bestValue = 0;
00423     for (SgUctChildIterator it(m_tree, node); it; ++it)
00424     {
00425         const SgUctNode& child = *it;
00426         if (excludeMoves != 0)
00427         {
00428             vector<SgMove>::const_iterator begin = excludeMoves->begin();
00429             vector<SgMove>::const_iterator end = excludeMoves->end();
00430             if (find(begin, end, child.Move()) != end)
00431                 continue;
00432         }
00433         if (  ! child.HasMean()
00434            && ! (  (  m_moveSelect == SG_UCTMOVESELECT_BOUND
00435                    || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE
00436                    )
00437                 && m_rave
00438                 && child.HasRaveValue()
00439                 )
00440             )
00441             continue;
00442         float moveValue = InverseEval(child.Mean());
00443         size_t moveCount = child.MoveCount();
00444         float value;
00445         switch (m_moveSelect)
00446         {
00447         case SG_UCTMOVESELECT_VALUE:
00448             value = moveValue;
00449             break;
00450         case SG_UCTMOVESELECT_COUNT:
00451             value = moveCount;
00452             break;
00453         case SG_UCTMOVESELECT_BOUND:
00454             value = GetBound(m_rave, node, child);
00455             break;
00456         case SG_UCTMOVESELECT_ESTIMATE:
00457             value = GetValueEstimate(m_rave, child);
00458             break;
00459         default:
00460             SG_ASSERT(false);
00461             value = SG_UCTMOVESELECT_VALUE;
00462         }
00463         if (bestChild == 0 || value > bestValue)
00464         {
00465             bestChild = &child;
00466             bestValue = value;
00467         }
00468     }
00469     return bestChild;
00470 }
00471 
00472 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const
00473 {
00474     sequence.clear();
00475     const SgUctNode* current = &m_tree.Root();
00476     while (true)
00477     {
00478         current = FindBestChild(*current);
00479         if (current == 0)
00480             break;
00481         sequence.push_back(current->Move());
00482         if (! current->HasChildren())
00483             break;
00484     }
00485 }
00486 
00487 void SgUctSearch::GenerateAllMoves(std::vector<SgMoveInfo>& moves)
00488 {
00489     if (m_threads.size() == 0)
00490         CreateThreads();
00491     moves.clear();
00492     OnStartSearch();
00493     SgUctThreadState& state = ThreadState(0);
00494     state.StartSearch();
00495     SgProvenNodeType type;
00496     state.GenerateAllMoves(0, moves, type);
00497 }
00498 
00499 float SgUctSearch::GetBound(bool useRave, const SgUctNode& node,
00500                             const SgUctNode& child) const
00501 {
00502     size_t posCount = node.PosCount();
00503     return GetBound(useRave, Log(posCount), child);
00504 }
00505 
00506 float SgUctSearch::GetBound(bool useRave, float logPosCount, 
00507                             const SgUctNode& child) const
00508 {
00509     float value;
00510     if (useRave)
00511         value = GetValueEstimateRave(child);
00512     else
00513         value = GetValueEstimate(false, child);
00514     if (m_biasTermConstant == 0.0)
00515         return value;
00516     else
00517     {
00518         float moveCount = static_cast<float>(child.MoveCount());
00519         float bound =
00520             value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
00521         return bound;
00522     }
00523 }
00524 
00525 SgUctTree& SgUctSearch::GetTempTree()
00526 {
00527     m_tempTree.Clear();
00528     // Use NumberThreads() (not m_tree.NuAllocators()) and MaxNodes() (not
00529     // m_tree.MaxNodes()), because of the delayed thread (and thereby
00530     // allocator) creation in SgUctSearch
00531     if (m_tempTree.NuAllocators() != NumberThreads())
00532     {
00533         m_tempTree.CreateAllocators(NumberThreads());
00534         m_tempTree.SetMaxNodes(MaxNodes());
00535     }
00536     else if (m_tempTree.MaxNodes() != MaxNodes())
00537     {
00538         m_tempTree.SetMaxNodes(MaxNodes());
00539     }
00540     return m_tempTree;
00541 }
00542 
00543 float SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
00544 {
00545     float value = 0.f;
00546     float weightSum = 0.f;
00547     bool hasValue = false;
00548     if (child.HasMean())
00549     {
00550         float weight = static_cast<float>(child.MoveCount());
00551         value += weight * InverseEval(child.Mean());
00552         weightSum += weight;
00553         hasValue = true;
00554     }
00555     if (useRave && child.HasRaveValue())
00556     {
00557         float raveCount = child.RaveCount();
00558         float weight =
00559                 raveCount
00560               / (  m_raveWeightParam1
00561                  + m_raveWeightParam2 * raveCount
00562                 );
00563         value += weight * child.RaveValue();
00564         weightSum += weight;
00565         hasValue = true;
00566     }
00567     if (hasValue)
00568         return value / weightSum;
00569     else
00570         return m_firstPlayUrgency;
00571 }
00572 
00573 /** Optimized version of GetValueEstimate() if RAVE and not other
00574     estimators are used.
00575     Previously there were more estimators than move value and RAVE value,
00576     and in the future there may be again. GetValueEstimate() is easier to
00577     extend, this function is more optimized for the special case.
00578 */
00579 float SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
00580 {
00581     SG_ASSERT(m_rave);
00582     bool hasRave = child.HasRaveValue();
00583     float value;
00584     if (child.HasMean())
00585     {
00586         float moveValue = InverseEval(child.Mean());
00587         if (hasRave)
00588         {
00589             float moveCount = child.MoveCount();
00590             float raveCount = child.RaveCount();
00591             float weight =
00592                 raveCount
00593                 / (moveCount
00594                    * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
00595                    + raveCount);
00596             value = weight * child.RaveValue() + (1.f - weight) * moveValue;
00597         }
00598         else
00599         {
00600             // This can happen only in lock-free multi-threading. Normally,
00601             // each move played in a position should also cause a RAVE value
00602             // to be added. But in lock-free multi-threading it can happen
00603             // that the move value was already updated but the RAVE value not
00604             SG_ASSERT(m_numberThreads > 1 && m_lockFree);
00605             value = moveValue;
00606         }
00607     }
00608     else if (hasRave)
00609         value = child.RaveValue();
00610     else
00611         value = m_firstPlayUrgency;
00612     SG_ASSERT(m_numberThreads > 1
00613               || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/);
00614     return value;
00615 }
00616 
00617 string SgUctSearch::LastGameSummaryLine() const
00618 {
00619     return SummaryLine(LastGameInfo());
00620 }
00621 
00622 float SgUctSearch::Log(float x) const
00623 {
00624     // TODO: can we speed up the computation of the logarithm by taking
00625     // advantage of the fact that the argument is an integer type?
00626     // Maybe round result to integer (then it is simple the position of the
00627     // highest bit
00628 #if SG_UCTFASTLOG
00629     return m_fastLog.Log(x);
00630 #else
00631     return log(x);
00632 #endif
00633 }
00634 
00635 /** Creates the children with the given moves and merges with existing
00636     children in the tree. */
00637 void SgUctSearch::CreateChildren(SgUctThreadState& state, 
00638                                  const SgUctNode& node,
00639                                  bool deleteChildTrees)
00640 {
00641     size_t threadId = state.m_threadId;
00642     if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00643     {
00644         Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00645                          % m_tree.MaxNodes()));
00646         state.m_isTreeOutOfMem = true;
00647         m_isTreeOutOfMemory = true;
00648         SgSynchronizeThreadMemory();
00649         return;
00650     }
00651     m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees);
00652 }
00653 
00654 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current)
00655 {
00656     if (m_knowledgeThreshold.empty())
00657         return false;
00658     for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i)
00659     {
00660         const std::size_t threshold = m_knowledgeThreshold[i];
00661         if (current->KnowledgeCount() < threshold)
00662         {
00663             if (current->MoveCount() >= threshold)
00664             {
00665                 // Mark knowledge computed immediately so other
00666                 // threads fall through and do not waste time
00667                 // re-computing this knowledge.
00668                 SgUctNode* node = const_cast<SgUctNode*>(current);
00669                 node->SetKnowledgeCount(threshold);
00670                 SG_ASSERT(current->MoveCount());
00671                 return true;
00672             }
00673             return false;
00674         }
00675     }
00676     return false;
00677 }
00678 
00679 void SgUctSearch::OnStartSearch()
00680 {
00681     m_mpiSynchronizer->OnStartSearch(*this);
00682 }
00683 
00684 void SgUctSearch::OnEndSearch()
00685 {
00686     m_mpiSynchronizer->OnEndSearch(*this);
00687 }
00688 
00689 /** Print time, mean, nodes searched, and PV */
00690 void SgUctSearch::PrintSearchProgress(double currTime) const
00691 {
00692     const int MAX_SEQ_PRINT_LENGTH = 15;
00693     const size_t MIN_MOVE_COUNT = 10;
00694     size_t rootMoveCount = m_tree.Root().MoveCount();
00695     float rootMean = m_tree.Root().Mean();
00696     ostringstream out;
00697     const SgUctNode* current = &m_tree.Root();
00698     out << fixed << setprecision(0)
00699         << SgTime::Format(currTime, true) << " | "
00700         << fixed << setprecision(3) << rootMean << " "
00701         << "| " << rootMoveCount << " ";
00702     for(int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i)
00703     {
00704         current = FindBestChild(*current);
00705         if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT)
00706             break;
00707         if (i == 0)
00708             out << "|";
00709         if (i < MAX_SEQ_PRINT_LENGTH)
00710             out << " " << SgWritePoint(current->Move());
00711         else
00712             out << " *";
00713     }
00714     SgDebug() << out.str() << endl;
00715 }
00716 
00717 void SgUctSearch::OnSearchIteration(std::size_t gameNumber, int threadId,
00718                                     const SgUctGameInfo& info)
00719 {
00720     const int DISPLAY_INTERVAL = 5;
00721 
00722     m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info);
00723     double currTime = m_timer.GetTime();
00724 
00725     if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL)
00726     {
00727         PrintSearchProgress(currTime);
00728         m_lastScoreDisplayTime = currTime;
00729     }
00730 }
00731 
00732 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock)
00733 {
00734     state.m_isTreeOutOfMem = false;
00735     state.GameStart();
00736     SgUctGameInfo& info = state.m_gameInfo;
00737     info.Clear(m_numberPlayouts);
00738     bool isTerminal;
00739     bool abortInTree = ! PlayInTree(state, isTerminal);
00740 
00741     // add a virtual loss to all nodes in path
00742     if (m_virtualLoss)
00743         m_tree.AddVirtualLoss(info.m_nodes);
00744 
00745     // The playout phase is always unlocked
00746     if (lock != 0)
00747         lock->unlock();
00748 
00749     size_t nuMovesInTree = info.m_inTreeSequence.size();
00750 
00751     // Play some "fake" playouts if node is a proven node
00752     if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven())
00753     {
00754         for (size_t i = 0; i < m_numberPlayouts; ++i)
00755         {
00756             info.m_sequence[i] = info.m_inTreeSequence;
00757             info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00758             float eval = info.m_nodes.back()->IsProvenWin() ? 1.0 : 0.0;
00759             size_t nuMoves = info.m_sequence[i].size();
00760             if (nuMoves % 2 != 0)
00761                 eval = InverseEval(eval);
00762             info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem;
00763             info.m_eval[i] = eval;
00764         }
00765     }
00766     else 
00767     {
00768         state.StartPlayouts();
00769         for (size_t i = 0; i < m_numberPlayouts; ++i)
00770         {
00771             state.StartPlayout();
00772             info.m_sequence[i] = info.m_inTreeSequence;
00773             // skipRaveUpdate only used in playout phase
00774             info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00775             bool abort = abortInTree || state.m_isTreeOutOfMem;
00776             if (! abort && ! isTerminal)
00777                 abort = ! PlayoutGame(state, i);
00778             float eval;
00779             if (abort)
00780                 eval = UnknownEval();
00781             else
00782                 eval = state.Evaluate();
00783             size_t nuMoves = info.m_sequence[i].size();
00784             if (nuMoves % 2 != 0)
00785                 eval = InverseEval(eval);
00786             info.m_aborted[i] = abort;
00787             info.m_eval[i] = eval;
00788             state.EndPlayout();
00789             state.TakeBackPlayout(nuMoves - nuMovesInTree);
00790         }
00791     }
00792     state.TakeBackInTree(nuMovesInTree);
00793 
00794     // End of unlocked part if ! m_lockFree
00795     if (lock != 0)
00796         lock->lock();
00797 
00798     // Remove the virtual loss
00799     if (m_virtualLoss)
00800         m_tree.RemoveVirtualLoss(info.m_nodes);
00801 
00802     UpdateTree(info);
00803     if (m_rave)
00804         UpdateRaveValues(state);
00805     UpdateStatistics(info);
00806 }
00807 
00808 /** Play game until it leaves the tree.
00809     @param state
00810     @param[out] isTerminal Was the sequence terminated because of a real
00811     terminal position (GenerateAllMoves() returned an empty list)?
00812     @return @c false, if game was aborted due to maximum length
00813  */
00814 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal)
00815 {
00816     vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence;
00817     vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes;
00818     const SgUctNode* root = &m_tree.Root();
00819     const SgUctNode* current = root;
00820     nodes.push_back(current);
00821     bool breakAfterSelect = false;
00822     isTerminal = false;
00823     bool deepenTree = false;
00824     while (true)
00825     {
00826         if (sequence.size() == m_maxGameLength)
00827             return false;
00828         if (current->IsProven())
00829             break;
00830         if (! current->HasChildren())
00831         {
00832             state.m_moves.clear();
00833             SgProvenNodeType provenType = SG_NOT_PROVEN;
00834             state.GenerateAllMoves(0, state.m_moves, provenType);
00835             if (current == root)
00836                 ApplyRootFilter(state.m_moves);
00837             if (provenType != SG_NOT_PROVEN)
00838             {
00839                 SgUctNode* node = const_cast<SgUctNode*>(current);
00840                 node->SetProvenNodeType(provenType);
00841                 break;
00842             }
00843             if (state.m_moves.empty())
00844             {
00845                 isTerminal = true;
00846                 break;
00847             }
00848             if (  deepenTree
00849                || current->MoveCount() >= m_expandThreshold
00850                )
00851             {
00852                 deepenTree = false;
00853                 ExpandNode(state, *current);
00854                 if (state.m_isTreeOutOfMem)
00855                     return true;
00856                 if (! deepenTree)
00857                     breakAfterSelect = true;
00858             }
00859             else
00860                 break;
00861         }
00862         else if (NeedToComputeKnowledge(current))
00863         {
00864             m_statistics.m_knowledge++;
00865             deepenTree = false;
00866             SgProvenNodeType provenType = SG_NOT_PROVEN;
00867             bool truncate = state.GenerateAllMoves(current->KnowledgeCount(), 
00868                                                    state.m_moves,
00869                                                    provenType);
00870             if (current == root)
00871                 ApplyRootFilter(state.m_moves);
00872             CreateChildren(state, *current, truncate);
00873             if (provenType != SG_NOT_PROVEN)
00874             {
00875                 SgUctNode* node = const_cast<SgUctNode*>(current);
00876                 node->SetProvenNodeType(provenType);
00877                 break;
00878             }
00879             if (state.m_moves.empty())
00880             {
00881                 isTerminal = true;
00882                 break;
00883             }
00884             if (state.m_isTreeOutOfMem)
00885                 return true;
00886             if (! deepenTree)
00887                 breakAfterSelect = true;
00888         }
00889         current = &SelectChild(state.m_randomizeCounter, *current);
00890         nodes.push_back(current);
00891         SgMove move = current->Move();
00892         state.Execute(move);
00893         sequence.push_back(move);
00894         if (breakAfterSelect)
00895             break;
00896     }
00897     return true;
00898 }
00899 
00900 /** Finish the game using GeneratePlayoutMove().
00901     @param state The thread state.
00902     @param playout The number of the playout.
00903     @return @c false if game was aborted
00904 */
00905 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout)
00906 {
00907     SgUctGameInfo& info = state.m_gameInfo;
00908     vector<SgMove>& sequence = info.m_sequence[playout];
00909     vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
00910     while (true)
00911     {
00912         if (sequence.size() == m_maxGameLength)
00913             return false;
00914         bool skipRave = false;
00915         SgMove move = state.GeneratePlayoutMove(skipRave);
00916         if (move == SG_NULLMOVE)
00917             break;
00918         state.ExecutePlayout(move);
00919         sequence.push_back(move);
00920         skipRaveUpdate.push_back(skipRave);
00921     }
00922     return true;
00923 }
00924 
00925 float SgUctSearch::Search(std::size_t maxGames, double maxTime,
00926                           vector<SgMove>& sequence,
00927                           const vector<SgMove>& rootFilter,
00928                           SgUctTree* initTree,
00929                           SgUctEarlyAbortParam* earlyAbort)
00930 {
00931     m_timer.Start();
00932     m_rootFilter = rootFilter;
00933     if (m_logGames)
00934     {
00935         m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str());
00936         m_log << "StartSearch maxGames=" << maxGames << '\n';
00937     }
00938     m_maxGames = maxGames;
00939     m_maxTime = maxTime;
00940     m_earlyAbort.reset(0);
00941     if (earlyAbort != 0)
00942         m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort));
00943 
00944     for (size_t i = 0; i < m_threads.size(); ++i)
00945     {
00946         m_threads[i]->m_state->m_isSearchInitialized = false;
00947     }
00948     StartSearch(rootFilter, initTree);
00949     size_t pruneMinCount = m_pruneMinCount;
00950     while (true)
00951     {
00952         m_isTreeOutOfMemory = false;
00953         SgSynchronizeThreadMemory();
00954         for (size_t i = 0; i < m_threads.size(); ++i)
00955             m_threads[i]->StartPlay();
00956         for (size_t i = 0; i < m_threads.size(); ++i)
00957             m_threads[i]->WaitPlayFinished();
00958         if (m_aborted || ! m_pruneFullTree)
00959             break;
00960         else
00961         {
00962             double startPruneTime = m_timer.GetTime();
00963             SgDebug() << "SgUctSearch: pruning nodes with count < "
00964                   << pruneMinCount << " (at time " << fixed << setprecision(1)
00965                   << startPruneTime << ")\n";
00966             SgUctTree& tempTree = GetTempTree();
00967             m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true);
00968             int prunedSizePercentage =
00969                 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes());
00970             SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes()
00971                       << " (" << prunedSizePercentage << "%) time: "
00972                       << (m_timer.GetTime() - startPruneTime) << "\n";
00973             if (prunedSizePercentage > 50)
00974                 pruneMinCount *= 2;
00975             else
00976                 pruneMinCount = m_pruneMinCount; 
00977             m_tree.Swap(tempTree);
00978         }
00979     }
00980     EndSearch();
00981     m_statistics.m_time = m_timer.GetTime();
00982     if (m_statistics.m_time > numeric_limits<double>::epsilon())
00983         m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time;
00984     if (m_logGames)
00985         m_log.close();
00986     FindBestSequence(sequence);
00987     return (m_tree.Root().MoveCount() > 0) ? m_tree.Root().Mean() : 0.5;
00988 }
00989 
00990 /** Loop invoked by each thread for playing games. */
00991 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock)
00992 {
00993     if (! state.m_isSearchInitialized)
00994     {
00995         OnThreadStartSearch(state);
00996         state.m_isSearchInitialized = true;
00997     }
00998 
00999     if (NumberThreads() == 1 || m_lockFree)
01000         lock = 0;
01001     if (lock != 0)
01002         lock->lock();
01003     state.m_isTreeOutOfMem = false;
01004     while (! state.m_isTreeOutOfMem)
01005     {
01006         PlayGame(state, lock);
01007         OnSearchIteration(m_numberGames + 1, state.m_threadId,
01008                           state.m_gameInfo);
01009         if (m_logGames)
01010             m_log << SummaryLine(state.m_gameInfo) << '\n';
01011         ++m_numberGames;
01012         if (m_isTreeOutOfMemory)
01013             break;
01014         if (m_aborted || CheckAbortSearch(state))
01015         {
01016             m_aborted = true;
01017             SgSynchronizeThreadMemory();
01018             break;
01019         }
01020     }
01021     if (lock != 0)
01022         lock->unlock();
01023 
01024     m_searchLoopFinished->wait();
01025     if (m_aborted || ! m_pruneFullTree)
01026         OnThreadEndSearch(state);
01027 }
01028 
01029 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state)
01030 {
01031     m_mpiSynchronizer->OnThreadStartSearch(*this, state);
01032 }
01033 
01034 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state)
01035 {
01036     m_mpiSynchronizer->OnThreadEndSearch(*this, state);
01037 }
01038 
01039 SgPoint SgUctSearch::SearchOnePly(size_t maxGames, double maxTime,
01040                                   float& value)
01041 {
01042     if (m_threads.size() == 0)
01043         CreateThreads();
01044     OnStartSearch();
01045     // SearchOnePly is not multi-threaded.
01046     // It uses the state of the first thread.
01047     SgUctThreadState& state = ThreadState(0);
01048     state.StartSearch();
01049     vector<SgMoveInfo> moves;
01050     SgProvenNodeType provenType;
01051     state.GameStart();
01052     state.GenerateAllMoves(0, moves, provenType);
01053     vector<SgUctStatisticsBase> statistics(moves.size());
01054     size_t games = 0;
01055     m_timer.Start();
01056     SgUctGameInfo& info = state.m_gameInfo;
01057     while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort())
01058     {
01059         for (size_t i = 0; i < moves.size(); ++i)
01060         {
01061             state.GameStart();
01062             info.Clear(1);
01063             SgMove move = moves[i].m_move;
01064             state.Execute(move);
01065             info.m_inTreeSequence.push_back(move);
01066             info.m_sequence[0].push_back(move);
01067             info.m_skipRaveUpdate[0].push_back(false);
01068             state.StartPlayouts();
01069             state.StartPlayout();
01070             bool abortGame = ! PlayoutGame(state, 0);
01071             float eval;
01072             if (abortGame)
01073                 eval = UnknownEval();
01074             else
01075                 eval = state.Evaluate();
01076             state.EndPlayout();
01077             state.TakeBackPlayout(info.m_sequence[0].size() - 1);
01078             state.TakeBackInTree(1);
01079             statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ?
01080                               eval : InverseEval(eval));
01081             OnSearchIteration(games + 1, 0, info);
01082             ++games;
01083         }
01084     }
01085     SgMove bestMove = SG_NULLMOVE;
01086     for (size_t i = 0; i < moves.size(); ++i)
01087     {
01088         SgDebug() << SgWritePoint(moves[i].m_move) 
01089                   << ' ' << statistics[i].Mean() << '\n';
01090         if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value)
01091         {
01092             bestMove = moves[i].m_move;
01093             value = statistics[i].Mean();
01094         }
01095     }
01096     return bestMove;
01097 }
01098 
01099 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter, 
01100                                           const SgUctNode& node)
01101 {
01102     bool useRave = m_rave;
01103     if (m_randomizeRaveFrequency > 0)
01104     {
01105         ++randomizeCounter;
01106         if (randomizeCounter % m_randomizeRaveFrequency == 0)
01107             useRave = false;
01108     }
01109     SG_ASSERT(node.HasChildren());
01110     size_t posCount = node.PosCount();
01111     if (posCount == 0)
01112         // If position count is zero, return first child
01113         return *SgUctChildIterator(m_tree, node);
01114     float logPosCount = Log(posCount);
01115     const SgUctNode* bestChild = 0;
01116     float bestUpperBound = 0;
01117     for (SgUctChildIterator it(m_tree, node); it; ++it)
01118     {
01119         const SgUctNode& child = *it;
01120         float bound = GetBound(useRave, logPosCount, child);
01121         if (bestChild == 0 || bound > bestUpperBound)
01122         {
01123             bestChild = &child;
01124             bestUpperBound = bound;
01125         }
01126     }
01127     SG_ASSERT(bestChild != 0);
01128     return *bestChild;
01129 }
01130 
01131 void SgUctSearch::SetNumberThreads(std::size_t n)
01132 {
01133     SG_ASSERT(n >= 1);
01134     if (m_numberThreads == n)
01135         return;
01136     m_numberThreads = n;
01137     CreateThreads();
01138 }
01139 
01140 void SgUctSearch::SetRave(bool enable)
01141 {
01142     if (enable && m_moveRange <= 0)
01143         throw SgException("RAVE not supported for this game");
01144     m_rave = enable;
01145 }
01146 
01147 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory)
01148 {
01149     SG_ASSERT(m_threadStateFactory.get() == 0);
01150     m_threadStateFactory.reset(factory);
01151     DeleteThreads();
01152     // Don't create states here, because this function could be called in the
01153     // constructor of the subclass, and the factory passes the search (which
01154     // is not fully constructed) as an argument to the Create() function
01155 }
01156 
01157 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter,
01158                               SgUctTree* initTree)
01159 {
01160     if (m_threads.size() == 0)
01161         CreateThreads();
01162     if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU)
01163         // Using CPU time with multiple threads makes the measured time
01164         // and games/sec not very meaningful; the total cputime is not equal
01165         // to the total real time, even if there is no other load on the
01166         // machine, because the time, while threads are waiting for a lock
01167         // does not contribute to the cputime.
01168         SgWarning() << "SgUctSearch: using cpu time with multiple threads\n";
01169     m_raveWeightParam1 = 1.f / m_raveWeightInitial;
01170     m_raveWeightParam2 = 1.f / m_raveWeightFinal;
01171     if (initTree == 0)
01172         m_tree.Clear();
01173     else
01174     {
01175         m_tree.Swap(*initTree);
01176         if (m_tree.HasCapacity(0, m_tree.Root().NuChildren()))
01177             m_tree.ApplyFilter(0, m_tree.Root(), rootFilter);
01178         else
01179             SgWarning() <<
01180                 "SgUctSearch: "
01181                 "root filter not applied (tree reached maximum size)\n";
01182     }
01183     m_statistics.Clear();
01184     m_aborted = false;
01185     m_wasEarlyAbort = false;
01186     m_checkTimeInterval = 1;
01187     m_numberGames = 0;
01188     m_lastScoreDisplayTime = m_timer.GetTime();
01189     OnStartSearch();
01190     m_startRootMoveCount = m_tree.Root().MoveCount();
01191     for (size_t i = 0; i < m_threads.size(); ++i)
01192         ThreadState(i).StartSearch();
01193 }
01194 
01195 void SgUctSearch::EndSearch()
01196 {
01197     OnEndSearch();
01198 }
01199 
01200 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const
01201 {
01202     ostringstream buffer;
01203     const vector<const SgUctNode*>& nodes = info.m_nodes;
01204     for (size_t i = 1; i < nodes.size(); ++i)
01205     {
01206         const SgUctNode* node = nodes[i];
01207         SgMove move = node->Move();
01208         buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2)
01209                << node->Mean() << ',' << node->MoveCount() << ')';
01210     }
01211     for (size_t i = 0; i < info.m_eval.size(); ++i)
01212         buffer << ' ' << fixed << setprecision(2) << info.m_eval[i];
01213     return buffer.str();
01214 }
01215 
01216 void SgUctSearch::UpdateCheckTimeInterval(double time)
01217 {
01218     if (time < numeric_limits<double>::epsilon())
01219         return;
01220     // Dynamically update m_checkTimeInterval (see comment at definition of
01221     // m_checkTimeInterval)
01222     float wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime);
01223     if (time < wantedTimeDiff / 10)
01224     {
01225         // Computing games per second might be unreliable for small times
01226         m_checkTimeInterval *= 2;
01227         return;
01228     }
01229     m_statistics.m_gamesPerSecond = GamesPlayed() / time;
01230     double gamesPerSecondPerThread =
01231         m_statistics.m_gamesPerSecond / m_numberThreads;
01232     m_checkTimeInterval =
01233         static_cast<size_t>(wantedTimeDiff * gamesPerSecondPerThread);
01234     if (m_checkTimeInterval == 0)
01235         m_checkTimeInterval = 1;
01236 }
01237 
01238 /** Update the RAVE values in the tree for both players after a game was
01239     played.
01240     @see SgUctSearch::Rave()
01241 */
01242 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state)
01243 {
01244     for (size_t i = 0; i < m_numberPlayouts; ++i)
01245         UpdateRaveValues(state, i);
01246 }
01247 
01248 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01249                                    std::size_t playout)
01250 {
01251     SgUctGameInfo& info = state.m_gameInfo;
01252     const vector<SgMove>& sequence = info.m_sequence[playout];
01253     if (sequence.size() == 0)
01254         return;
01255     SG_ASSERT(m_moveRange > 0);
01256     size_t* firstPlay = state.m_firstPlay.get();
01257     size_t* firstPlayOpp = state.m_firstPlayOpp.get();
01258     fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max());
01259     fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max());
01260     const vector<const SgUctNode*>& nodes = info.m_nodes;
01261     const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01262     float eval = info.m_eval[playout];
01263     float invEval = InverseEval(eval);
01264     size_t nuNodes = nodes.size();
01265     size_t i = sequence.size() - 1;
01266     bool opp = (i % 2 != 0);
01267 
01268     // Update firstPlay, firstPlayOpp arrays using playout moves
01269     for ( ; i >= nuNodes; --i)
01270     {
01271         SG_ASSERT(i < skipRaveUpdate.size());
01272         SG_ASSERT(i < sequence.size());
01273         if (! skipRaveUpdate[i])
01274         {
01275             SgMove mv = sequence[i];
01276             size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01277             if (i < first)
01278                 first = i;
01279         }
01280         opp = ! opp;
01281     }
01282 
01283     while (true)
01284     {
01285         SG_ASSERT(i < skipRaveUpdate.size());
01286         SG_ASSERT(i < sequence.size());
01287         // skipRaveUpdate currently not used in in-tree phase
01288         SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]);
01289         if (! skipRaveUpdate[i])
01290         {
01291             SgMove mv = sequence[i];
01292             size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01293             if (i < first)
01294                 first = i;
01295             if (opp)
01296                 UpdateRaveValues(state, playout, invEval, i,
01297                                  firstPlayOpp, firstPlay);
01298             else
01299                 UpdateRaveValues(state, playout, eval, i,
01300                                  firstPlay, firstPlayOpp);
01301         }
01302         if (i == 0)
01303             break;
01304         --i;
01305         opp = ! opp;
01306     }
01307 }
01308 
01309 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01310                                    std::size_t playout, float eval,
01311                                    std::size_t i,
01312                                    const std::size_t firstPlay[],
01313                                    const std::size_t firstPlayOpp[])
01314 {
01315     SG_ASSERT(i < state.m_gameInfo.m_nodes.size());
01316     const SgUctNode* node = state.m_gameInfo.m_nodes[i];
01317     if (! node->HasChildren())
01318         return;
01319     std::size_t len = state.m_gameInfo.m_sequence[playout].size();
01320     for (SgUctChildIterator it(m_tree, *node); it; ++it)
01321     {
01322         const SgUctNode& child = *it;
01323         SgMove mv = child.Move();
01324         size_t first = firstPlay[mv];
01325         SG_ASSERT(first >= i);
01326         if (first == numeric_limits<size_t>::max())
01327             continue;
01328         if  (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first))
01329             continue;
01330         float weight;
01331         if (m_weightRaveUpdates)
01332             weight = 2.f - static_cast<float>(first - i) / (len - i);
01333         else
01334             weight = 1.f;
01335         m_tree.AddRaveValue(child, eval, weight);
01336     }
01337 }
01338 
01339 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info)
01340 {
01341     m_statistics.m_movesInTree.Add(
01342                             static_cast<float>(info.m_inTreeSequence.size()));
01343     for (size_t i = 0; i < m_numberPlayouts; ++i)
01344     {
01345         m_statistics.m_gameLength.Add(
01346                                static_cast<float>(info.m_sequence[i].size()));
01347         m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f);
01348     }
01349 }
01350 
01351 void SgUctSearch::UpdateTree(const SgUctGameInfo& info)
01352 {
01353     float eval = 0;
01354     for (size_t i = 0; i < m_numberPlayouts; ++i)
01355         eval += info.m_eval[i];
01356     eval /= m_numberPlayouts;
01357     float inverseEval = InverseEval(eval);
01358     const vector<const SgUctNode*>& nodes = info.m_nodes;
01359 
01360     for (size_t i = 0; i < nodes.size(); ++i)
01361     {
01362         const SgUctNode& node = *nodes[i];
01363         const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0);
01364         m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval, m_numberPlayouts);
01365     }
01366 }
01367 
01368 void SgUctSearch::WriteStatistics(ostream& out) const
01369 {
01370     out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n'
01371         << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n';
01372     if (!m_knowledgeThreshold.empty())
01373         out << SgWriteLabel("Knowledge") 
01374             << m_statistics.m_knowledge << " (" << fixed << setprecision(1) 
01375             << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount()
01376             << "%)\n";
01377     m_statistics.Write(out);
01378     m_mpiSynchronizer->WriteStatistics(out);
01379 }
01380 
01381 //----------------------------------------------------------------------------


17 Jun 2010 Doxygen 1.4.7