Index   Main   Namespaces   Classes   Hierarchy   Annotated   Files   Compound   Global   Pages  

SgUctTree.cpp

Go to the documentation of this file.
00001 //----------------------------------------------------------------------------
00002 /** @file SgUctTree.cpp
00003     See SgUctTree.h
00004 */
00005 //----------------------------------------------------------------------------
00006 
00007 #include "SgSystem.h"
00008 #include "SgUctTree.h"
00009 
00010 #include <boost/format.hpp>
00011 #include "SgDebug.h"
00012 #include "SgTimer.h"
00013 
00014 using namespace std;
00015 using boost::format;
00016 using boost::shared_ptr;
00017 
00018 //----------------------------------------------------------------------------
00019 
00020 SgUctAllocator::~SgUctAllocator()
00021 {
00022     if (m_start != 0)
00023     {
00024         Clear();
00025         std::free(m_start);
00026     }
00027 }
00028 
00029 bool SgUctAllocator::Contains(const SgUctNode& node) const
00030 {
00031     return (&node >= m_start && &node < m_finish);
00032 }
00033 
00034 void SgUctAllocator::Swap(SgUctAllocator& allocator)
00035 {
00036     swap(m_start, allocator.m_start);
00037     swap(m_finish, allocator.m_finish);
00038     swap(m_endOfStorage, allocator.m_endOfStorage);
00039 }
00040 
00041 void SgUctAllocator::SetMaxNodes(std::size_t maxNodes)
00042 {
00043     if (m_start != 0)
00044     {
00045         Clear();
00046         std::free(m_start);
00047     }
00048     void* ptr = std::malloc(maxNodes * sizeof(SgUctNode));
00049     if (ptr == 0)
00050         throw std::bad_alloc();
00051     m_start = static_cast<SgUctNode*>(ptr);
00052     m_finish = m_start;
00053     m_endOfStorage = m_start + maxNodes;
00054 }
00055 
00056 //----------------------------------------------------------------------------
00057 
00058 SgUctTree::SgUctTree()
00059     : m_maxNodes(0),
00060       m_root(SG_NULLMOVE)
00061 {
00062 }
00063 
00064 void SgUctTree::AddVirtualLoss(const std::vector<const SgUctNode*>& nodes)
00065 {
00066     for (size_t i = 0; i < nodes.size(); ++i)
00067     {
00068         const SgUctNode* father = (i > 0 ? nodes[i-1] : 0);
00069         AddGameResult(*nodes[i], father, 1.0); // loss for us = win for child
00070         AddRaveValue(*nodes[i], 0.0, 1.0); // loss for us
00071     }
00072 }
00073 
00074 void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node,
00075                             const vector<SgMove>& rootFilter)
00076 {
00077     SG_ASSERT(Contains(node));
00078     SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren()));
00079     if (! node.HasChildren())
00080         return;
00081 
00082     SgUctAllocator& allocator = Allocator(allocatorId);
00083     const SgUctNode* firstChild = allocator.Finish();
00084 
00085     int nuChildren = 0;
00086     for (SgUctChildIterator it(*this, node); it; ++it)
00087     {
00088         SgMove move = (*it).Move();
00089         if (find(rootFilter.begin(), rootFilter.end(), move)
00090             == rootFilter.end())
00091         {
00092             SgUctNode* child = allocator.CreateOne(move);
00093             child->CopyDataFrom(*it);
00094             int childNuChildren = (*it).NuChildren();
00095             child->SetNuChildren(childNuChildren);
00096             if (childNuChildren > 0)
00097                 child->SetFirstChild((*it).FirstChild());
00098             ++nuChildren;
00099         }
00100     }
00101 
00102     SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00103     // Write order dependency: SgUctSearch in lock-free mode assumes that
00104     // m_firstChild is valid if m_nuChildren is greater zero
00105     SgSynchronizeThreadMemory();
00106     nonConstNode.SetFirstChild(firstChild);
00107     SgSynchronizeThreadMemory();
00108     nonConstNode.SetNuChildren(nuChildren);
00109 }
00110 
00111 void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node,
00112                             const vector<SgMove>& moves)
00113 {
00114     SG_ASSERT(Contains(node));
00115     SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size()));
00116     SG_ASSERT(node.HasChildren());
00117 
00118     SgUctAllocator& allocator = Allocator(allocatorId);
00119     const SgUctNode* firstChild = allocator.Finish();
00120 
00121     int nuChildren = 0;
00122     for (size_t i = 0; i < moves.size(); ++i)
00123     {
00124         bool found = false;
00125         for (SgUctChildIterator it(*this, node); it; ++it)
00126         {
00127             SgMove move = (*it).Move();
00128             if (move == moves[i])
00129             {
00130                 found = true;
00131                 SgUctNode* child = allocator.CreateOne(move);
00132                 child->CopyDataFrom(*it);
00133                 int childNuChildren = (*it).NuChildren();
00134                 child->SetNuChildren(childNuChildren);
00135                 if (childNuChildren > 0)
00136                     child->SetFirstChild((*it).FirstChild());
00137                 ++nuChildren;
00138                 break;
00139             }
00140         }
00141         if (! found)
00142         {
00143             allocator.CreateOne(moves[i]);
00144             ++nuChildren;
00145         }
00146     }
00147     SG_ASSERT((size_t)nuChildren == moves.size());
00148 
00149     SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00150     // Write order dependency: SgUctSearch in lock-free mode assumes that
00151     // m_firstChild is valid if m_nuChildren is greater zero
00152     SgSynchronizeThreadMemory();
00153     nonConstNode.SetFirstChild(firstChild);
00154     SgSynchronizeThreadMemory();
00155     nonConstNode.SetNuChildren(nuChildren);
00156 }
00157 
00158 void SgUctTree::CheckConsistency() const
00159 {
00160     for (SgUctTreeIterator it(*this); it; ++it)
00161         if (! Contains(*it))
00162             ThrowConsistencyError(str(format("! Contains(%1%)") % &(*it)));
00163 }
00164 
00165 void SgUctTree::Clear()
00166 {
00167     for (size_t i = 0; i < NuAllocators(); ++i)
00168         Allocator(i).Clear();
00169     m_root = SgUctNode(SG_NULLMOVE);
00170 }
00171 
00172 /** Check if node is in tree.
00173     Only used for assertions. May not be available in future implementations.
00174 */
00175 bool SgUctTree::Contains(const SgUctNode& node) const
00176 {
00177     if (&node == &m_root)
00178         return true;
00179     for (size_t i = 0; i < NuAllocators(); ++i)
00180         if (Allocator(i).Contains(node))
00181             return true;
00182     return false;
00183 }
00184 
00185 void SgUctTree::CopyPruneLowCount(SgUctTree& target, std::size_t minCount,
00186                                   bool warnTruncate, double maxTime) const
00187 {
00188     size_t allocatorId = 0;
00189     SgTimer timer;
00190     bool abort = false;
00191     CopySubtree(target, target.m_root, m_root, minCount, allocatorId,
00192                 warnTruncate, abort, timer, maxTime);
00193 }
00194 
00195 /** Recursive function used by SgUctTree::ExtractSubtree and
00196     SgUctTree::CopyPruneLowCount.
00197     @param target The target tree.
00198     @param targetNode The target node; it is already created but the content
00199     not yet copied
00200     @param node The node in the source tree to be copied.
00201     @param minCount The minimum count (SgUctNode::MoveCount()) of a non-root
00202     node in the source tree to copy
00203     @param currentAllocatorId The current node allocator. Will be incremented
00204     in each call to CopySubtree to use node allocators of target tree evenly.
00205     @param warnTruncate Print warning to SgDebug() if tree was
00206     truncated (e.g due to reassigning nodes to different allocators)
00207     @param[in,out] abort Flag to abort copying. Must be initialized to false
00208     by top-level caller
00209     @param timer
00210     @param maxTime See ExtractSubtree()
00211 */
00212 void SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode,
00213                             const SgUctNode& node, std::size_t minCount,
00214                             std::size_t& currentAllocatorId,
00215                             bool warnTruncate, bool& abort, SgTimer& timer,
00216                             double maxTime) const
00217 {
00218     SG_ASSERT(Contains(node));
00219     SG_ASSERT(target.Contains(targetNode));
00220     targetNode.CopyDataFrom(node);
00221 
00222     if (! node.HasChildren() || node.MoveCount() < minCount)
00223         return;
00224 
00225     SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId);
00226     int nuChildren = node.NuChildren();
00227     if (! abort)
00228     {
00229         if (! targetAllocator.HasCapacity(nuChildren))
00230         {
00231             // This can happen even if target tree has same maximum number of
00232             // nodes, because allocators are used differently.
00233             if (warnTruncate)
00234                 SgDebug() <<
00235                 "SgUctTree::CopySubtree: Truncated (allocator capacity)\n";
00236             abort = true;
00237         }
00238         if (timer.IsTimeOut(maxTime, 10000))
00239         {
00240             if (warnTruncate)
00241                 SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n";
00242             abort = true;
00243         }
00244         if (SgUserAbort())
00245         {
00246             if (warnTruncate)
00247                 SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n";
00248             abort = true;
00249         }
00250     }
00251     if (abort)
00252     {
00253         // Don't copy the children and set the pos count to zero (should
00254         // reflect the sum of children move counts)
00255         targetNode.SetPosCount(0);
00256         return;
00257     }
00258 
00259     SgUctNode* firstTargetChild = targetAllocator.Finish();
00260     targetNode.SetFirstChild(firstTargetChild);
00261     targetNode.SetNuChildren(nuChildren);
00262 
00263     // Create target nodes first (must be contiguous in the target tree)
00264     targetAllocator.CreateN(nuChildren);
00265 
00266     // Recurse
00267     SgUctNode* targetChild = firstTargetChild;
00268     for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild)
00269     {
00270         const SgUctNode& child = *it;
00271         ++currentAllocatorId; // Cycle to use allocators uniformly
00272         if (currentAllocatorId >= target.NuAllocators())
00273             currentAllocatorId = 0;
00274         CopySubtree(target, *targetChild, child, minCount, currentAllocatorId,
00275                     warnTruncate, abort, timer, maxTime);
00276     }
00277 }
00278 
00279 void SgUctTree::CreateAllocators(std::size_t nuThreads)
00280 {
00281     Clear();
00282     m_allocators.clear();
00283     for (size_t i = 0; i < nuThreads; ++i)
00284     {
00285         boost::shared_ptr<SgUctAllocator> allocator(new SgUctAllocator());
00286         m_allocators.push_back(allocator);
00287     }
00288 }
00289 
00290 void SgUctTree::DumpDebugInfo(std::ostream& out) const
00291 {
00292     out << "Root " << &m_root << '\n';
00293     for (size_t i = 0; i < NuAllocators(); ++i)
00294         out << "Allocator " << i
00295             << " size=" << Allocator(i).NuNodes()
00296             << " start=" << Allocator(i).Start()
00297             << " finish=" << Allocator(i).Finish() << '\n';
00298 }
00299 
00300 void SgUctTree::ExtractSubtree(SgUctTree& target, const SgUctNode& node,
00301                                bool warnTruncate, double maxTime) const
00302 {
00303     SG_ASSERT(Contains(node));
00304     SG_ASSERT(&target != this);
00305     SG_ASSERT(target.MaxNodes() == MaxNodes());
00306     target.Clear();
00307     size_t allocatorId = 0;
00308     SgTimer timer;
00309     bool abort = false;
00310     CopySubtree(target, target.m_root, node, 0, allocatorId, warnTruncate,
00311                 abort, timer, maxTime);
00312 }
00313 
00314 void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node,
00315                               const std::vector<SgMoveInfo>& moves,
00316                               bool deleteChildTrees)
00317 {
00318     SG_ASSERT(Contains(node));
00319     // Parameters are const-references, because only the tree is allowed
00320     // to modify nodes
00321     SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00322     size_t nuNewChildren = moves.size();
00323 
00324     if (nuNewChildren == 0)
00325     {
00326         // Write order dependency
00327         nonConstNode.SetNuChildren(0);
00328         SgSynchronizeThreadMemory();
00329         nonConstNode.SetFirstChild(0);
00330         return;
00331     }
00332 
00333     SgUctAllocator& allocator = Allocator(allocatorId);
00334     SG_ASSERT(allocator.HasCapacity(nuNewChildren));
00335 
00336     const SgUctNode* newFirstChild = allocator.Finish();
00337     std::size_t parentCount = allocator.Create(moves);
00338     
00339     // Update new children with data in old children
00340     for (std::size_t i = 0; i < moves.size(); ++i) 
00341     {
00342         SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]);
00343         for (SgUctChildIterator it(*this, node); it; ++it)
00344         {
00345             const SgUctNode& oldChild = *it;
00346             if (oldChild.Move() == moves[i].m_move)
00347             {
00348                 newChild->MergeResults(oldChild);
00349                 newChild->SetKnowledgeCount(oldChild.KnowledgeCount());
00350                 if (! deleteChildTrees)
00351                 {
00352                     newChild->SetPosCount(oldChild.PosCount());
00353                     parentCount += oldChild.MoveCount();
00354                     if (oldChild.HasChildren())
00355                     {
00356                         newChild->SetFirstChild(oldChild.FirstChild());
00357                         newChild->SetNuChildren(oldChild.NuChildren());
00358                     }
00359                 }
00360                 break;
00361             }
00362         }
00363     }
00364     nonConstNode.SetPosCount(parentCount);
00365 
00366     // Write order dependency: We do not want an SgUctChildIterator to
00367     // run past the end of a node's children, which can happen if one
00368     // is created between the two statements below. We modify node in
00369     // such a way so as to avoid that.
00370     SgSynchronizeThreadMemory();
00371     if (nonConstNode.NuChildren() < (int)nuNewChildren)
00372     {
00373         nonConstNode.SetFirstChild(newFirstChild);
00374         SgSynchronizeThreadMemory();
00375         nonConstNode.SetNuChildren(nuNewChildren);
00376     }
00377     else
00378     {
00379         nonConstNode.SetNuChildren(nuNewChildren);
00380         SgSynchronizeThreadMemory();
00381         nonConstNode.SetFirstChild(newFirstChild);
00382     }
00383 }
00384 
00385 std::size_t SgUctTree::NuNodes() const
00386 {
00387     size_t nuNodes = 1; // Count root node
00388     for (size_t i = 0; i < NuAllocators(); ++i)
00389         nuNodes += Allocator(i).NuNodes();
00390     return nuNodes;
00391 }
00392 
00393 void SgUctTree::RemoveVirtualLoss(const std::vector<const SgUctNode*>& nodes)
00394 {
00395     for (size_t i = 0; i < nodes.size(); ++i)
00396     {
00397         const SgUctNode* father = (i > 0 ? nodes[i-1] : 0);
00398         RemoveGameResult(*nodes[i], father, 1.0); // see AddVirtualLoss()
00399         RemoveRaveValue(*nodes[i], 0.0, 1.0);
00400     }
00401 }
00402 
00403 void SgUctTree::SetMaxNodes(std::size_t maxNodes)
00404 {
00405     Clear();
00406     size_t nuAllocators = NuAllocators();
00407     if (nuAllocators == 0)
00408     {
00409         SgDebug() << "SgUctTree::SetMaxNodes: no allocators registered\n";
00410         SG_ASSERT(false);
00411         return;
00412     }
00413     m_maxNodes = maxNodes;
00414     size_t maxNodesPerAlloc = maxNodes / nuAllocators;
00415     for (size_t i = 0; i < NuAllocators(); ++i)
00416         Allocator(i).SetMaxNodes(maxNodesPerAlloc);
00417 }
00418 
00419 void SgUctTree::Swap(SgUctTree& tree)
00420 {
00421     SG_ASSERT(MaxNodes() == tree.MaxNodes());
00422     SG_ASSERT(NuAllocators() == tree.NuAllocators());
00423     swap(m_root, tree.m_root);
00424     for (size_t i = 0; i < NuAllocators(); ++i)
00425         Allocator(i).Swap(tree.Allocator(i));
00426 }
00427 
00428 void SgUctTree::ThrowConsistencyError(const string& message) const
00429 {
00430     DumpDebugInfo(SgDebug());
00431     throw SgException("SgUctTree::ThrowConsistencyError: " + message);
00432 }
00433 
00434 //----------------------------------------------------------------------------
00435 
00436 SgUctTreeIterator::SgUctTreeIterator(const SgUctTree& tree)
00437     : m_tree(tree),
00438       m_current(&tree.Root())
00439 {
00440 }
00441 
00442 const SgUctNode& SgUctTreeIterator::operator*() const
00443 {
00444     return *m_current;
00445 }
00446 
00447 void SgUctTreeIterator::operator++()
00448 {
00449     if (m_current->HasChildren())
00450     {
00451         SgUctChildIterator* it = new SgUctChildIterator(m_tree, *m_current);
00452         m_stack.push(shared_ptr<SgUctChildIterator>(it));
00453         m_current = &(**it);
00454         return;
00455     }
00456     while (! m_stack.empty())
00457     {
00458         SgUctChildIterator& it = *m_stack.top();
00459         SG_ASSERT(it);
00460         ++it;
00461         if (it)
00462         {
00463             m_current = &(*it);
00464             return;
00465         }
00466         else
00467         {
00468             m_stack.pop();
00469             m_current = 0;
00470         }
00471     }
00472     m_current = 0;
00473 }
00474 
00475 SgUctTreeIterator::operator bool() const
00476 {
00477     return (m_current != 0);
00478 }
00479 
00480 //----------------------------------------------------------------------------


17 Jun 2010 Doxygen 1.4.7