00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctTree.cpp 00003 See SgUctTree.h 00004 */ 00005 //---------------------------------------------------------------------------- 00006 00007 #include "SgSystem.h" 00008 #include "SgUctTree.h" 00009 00010 #include <boost/format.hpp> 00011 #include "SgDebug.h" 00012 #include "SgTimer.h" 00013 00014 using namespace std; 00015 using boost::format; 00016 using boost::shared_ptr; 00017 00018 //---------------------------------------------------------------------------- 00019 00020 SgUctAllocator::~SgUctAllocator() 00021 { 00022 if (m_start != 0) 00023 { 00024 Clear(); 00025 std::free(m_start); 00026 } 00027 } 00028 00029 bool SgUctAllocator::Contains(const SgUctNode& node) const 00030 { 00031 return (&node >= m_start && &node < m_finish); 00032 } 00033 00034 void SgUctAllocator::Swap(SgUctAllocator& allocator) 00035 { 00036 swap(m_start, allocator.m_start); 00037 swap(m_finish, allocator.m_finish); 00038 swap(m_endOfStorage, allocator.m_endOfStorage); 00039 } 00040 00041 void SgUctAllocator::SetMaxNodes(std::size_t maxNodes) 00042 { 00043 if (m_start != 0) 00044 { 00045 Clear(); 00046 std::free(m_start); 00047 } 00048 void* ptr = std::malloc(maxNodes * sizeof(SgUctNode)); 00049 if (ptr == 0) 00050 throw std::bad_alloc(); 00051 m_start = static_cast<SgUctNode*>(ptr); 00052 m_finish = m_start; 00053 m_endOfStorage = m_start + maxNodes; 00054 } 00055 00056 //---------------------------------------------------------------------------- 00057 00058 SgUctTree::SgUctTree() 00059 : m_maxNodes(0), 00060 m_root(SG_NULLMOVE) 00061 { 00062 } 00063 00064 void SgUctTree::AddVirtualLoss(const std::vector<const SgUctNode*>& nodes) 00065 { 00066 for (size_t i = 0; i < nodes.size(); ++i) 00067 { 00068 const SgUctNode* father = (i > 0 ? nodes[i-1] : 0); 00069 AddGameResult(*nodes[i], father, 1.0); // loss for us = win for child 00070 AddRaveValue(*nodes[i], 0.0, 1.0); // loss for us 00071 } 00072 } 00073 00074 void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node, 00075 const vector<SgMove>& rootFilter) 00076 { 00077 SG_ASSERT(Contains(node)); 00078 SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren())); 00079 if (! node.HasChildren()) 00080 return; 00081 00082 SgUctAllocator& allocator = Allocator(allocatorId); 00083 const SgUctNode* firstChild = allocator.Finish(); 00084 00085 int nuChildren = 0; 00086 for (SgUctChildIterator it(*this, node); it; ++it) 00087 { 00088 SgMove move = (*it).Move(); 00089 if (find(rootFilter.begin(), rootFilter.end(), move) 00090 == rootFilter.end()) 00091 { 00092 SgUctNode* child = allocator.CreateOne(move); 00093 child->CopyDataFrom(*it); 00094 int childNuChildren = (*it).NuChildren(); 00095 child->SetNuChildren(childNuChildren); 00096 if (childNuChildren > 0) 00097 child->SetFirstChild((*it).FirstChild()); 00098 ++nuChildren; 00099 } 00100 } 00101 00102 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00103 // Write order dependency: SgUctSearch in lock-free mode assumes that 00104 // m_firstChild is valid if m_nuChildren is greater zero 00105 SgSynchronizeThreadMemory(); 00106 nonConstNode.SetFirstChild(firstChild); 00107 SgSynchronizeThreadMemory(); 00108 nonConstNode.SetNuChildren(nuChildren); 00109 } 00110 00111 void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node, 00112 const vector<SgMove>& moves) 00113 { 00114 SG_ASSERT(Contains(node)); 00115 SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size())); 00116 SG_ASSERT(node.HasChildren()); 00117 00118 SgUctAllocator& allocator = Allocator(allocatorId); 00119 const SgUctNode* firstChild = allocator.Finish(); 00120 00121 int nuChildren = 0; 00122 for (size_t i = 0; i < moves.size(); ++i) 00123 { 00124 bool found = false; 00125 for (SgUctChildIterator it(*this, node); it; ++it) 00126 { 00127 SgMove move = (*it).Move(); 00128 if (move == moves[i]) 00129 { 00130 found = true; 00131 SgUctNode* child = allocator.CreateOne(move); 00132 child->CopyDataFrom(*it); 00133 int childNuChildren = (*it).NuChildren(); 00134 child->SetNuChildren(childNuChildren); 00135 if (childNuChildren > 0) 00136 child->SetFirstChild((*it).FirstChild()); 00137 ++nuChildren; 00138 break; 00139 } 00140 } 00141 if (! found) 00142 { 00143 allocator.CreateOne(moves[i]); 00144 ++nuChildren; 00145 } 00146 } 00147 SG_ASSERT((size_t)nuChildren == moves.size()); 00148 00149 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00150 // Write order dependency: SgUctSearch in lock-free mode assumes that 00151 // m_firstChild is valid if m_nuChildren is greater zero 00152 SgSynchronizeThreadMemory(); 00153 nonConstNode.SetFirstChild(firstChild); 00154 SgSynchronizeThreadMemory(); 00155 nonConstNode.SetNuChildren(nuChildren); 00156 } 00157 00158 void SgUctTree::CheckConsistency() const 00159 { 00160 for (SgUctTreeIterator it(*this); it; ++it) 00161 if (! Contains(*it)) 00162 ThrowConsistencyError(str(format("! Contains(%1%)") % &(*it))); 00163 } 00164 00165 void SgUctTree::Clear() 00166 { 00167 for (size_t i = 0; i < NuAllocators(); ++i) 00168 Allocator(i).Clear(); 00169 m_root = SgUctNode(SG_NULLMOVE); 00170 } 00171 00172 /** Check if node is in tree. 00173 Only used for assertions. May not be available in future implementations. 00174 */ 00175 bool SgUctTree::Contains(const SgUctNode& node) const 00176 { 00177 if (&node == &m_root) 00178 return true; 00179 for (size_t i = 0; i < NuAllocators(); ++i) 00180 if (Allocator(i).Contains(node)) 00181 return true; 00182 return false; 00183 } 00184 00185 void SgUctTree::CopyPruneLowCount(SgUctTree& target, std::size_t minCount, 00186 bool warnTruncate, double maxTime) const 00187 { 00188 size_t allocatorId = 0; 00189 SgTimer timer; 00190 bool abort = false; 00191 CopySubtree(target, target.m_root, m_root, minCount, allocatorId, 00192 warnTruncate, abort, timer, maxTime); 00193 } 00194 00195 /** Recursive function used by SgUctTree::ExtractSubtree and 00196 SgUctTree::CopyPruneLowCount. 00197 @param target The target tree. 00198 @param targetNode The target node; it is already created but the content 00199 not yet copied 00200 @param node The node in the source tree to be copied. 00201 @param minCount The minimum count (SgUctNode::MoveCount()) of a non-root 00202 node in the source tree to copy 00203 @param currentAllocatorId The current node allocator. Will be incremented 00204 in each call to CopySubtree to use node allocators of target tree evenly. 00205 @param warnTruncate Print warning to SgDebug() if tree was 00206 truncated (e.g due to reassigning nodes to different allocators) 00207 @param[in,out] abort Flag to abort copying. Must be initialized to false 00208 by top-level caller 00209 @param timer 00210 @param maxTime See ExtractSubtree() 00211 */ 00212 void SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode, 00213 const SgUctNode& node, std::size_t minCount, 00214 std::size_t& currentAllocatorId, 00215 bool warnTruncate, bool& abort, SgTimer& timer, 00216 double maxTime) const 00217 { 00218 SG_ASSERT(Contains(node)); 00219 SG_ASSERT(target.Contains(targetNode)); 00220 targetNode.CopyDataFrom(node); 00221 00222 if (! node.HasChildren() || node.MoveCount() < minCount) 00223 return; 00224 00225 SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId); 00226 int nuChildren = node.NuChildren(); 00227 if (! abort) 00228 { 00229 if (! targetAllocator.HasCapacity(nuChildren)) 00230 { 00231 // This can happen even if target tree has same maximum number of 00232 // nodes, because allocators are used differently. 00233 if (warnTruncate) 00234 SgDebug() << 00235 "SgUctTree::CopySubtree: Truncated (allocator capacity)\n"; 00236 abort = true; 00237 } 00238 if (timer.IsTimeOut(maxTime, 10000)) 00239 { 00240 if (warnTruncate) 00241 SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n"; 00242 abort = true; 00243 } 00244 if (SgUserAbort()) 00245 { 00246 if (warnTruncate) 00247 SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n"; 00248 abort = true; 00249 } 00250 } 00251 if (abort) 00252 { 00253 // Don't copy the children and set the pos count to zero (should 00254 // reflect the sum of children move counts) 00255 targetNode.SetPosCount(0); 00256 return; 00257 } 00258 00259 SgUctNode* firstTargetChild = targetAllocator.Finish(); 00260 targetNode.SetFirstChild(firstTargetChild); 00261 targetNode.SetNuChildren(nuChildren); 00262 00263 // Create target nodes first (must be contiguous in the target tree) 00264 targetAllocator.CreateN(nuChildren); 00265 00266 // Recurse 00267 SgUctNode* targetChild = firstTargetChild; 00268 for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild) 00269 { 00270 const SgUctNode& child = *it; 00271 ++currentAllocatorId; // Cycle to use allocators uniformly 00272 if (currentAllocatorId >= target.NuAllocators()) 00273 currentAllocatorId = 0; 00274 CopySubtree(target, *targetChild, child, minCount, currentAllocatorId, 00275 warnTruncate, abort, timer, maxTime); 00276 } 00277 } 00278 00279 void SgUctTree::CreateAllocators(std::size_t nuThreads) 00280 { 00281 Clear(); 00282 m_allocators.clear(); 00283 for (size_t i = 0; i < nuThreads; ++i) 00284 { 00285 boost::shared_ptr<SgUctAllocator> allocator(new SgUctAllocator()); 00286 m_allocators.push_back(allocator); 00287 } 00288 } 00289 00290 void SgUctTree::DumpDebugInfo(std::ostream& out) const 00291 { 00292 out << "Root " << &m_root << '\n'; 00293 for (size_t i = 0; i < NuAllocators(); ++i) 00294 out << "Allocator " << i 00295 << " size=" << Allocator(i).NuNodes() 00296 << " start=" << Allocator(i).Start() 00297 << " finish=" << Allocator(i).Finish() << '\n'; 00298 } 00299 00300 void SgUctTree::ExtractSubtree(SgUctTree& target, const SgUctNode& node, 00301 bool warnTruncate, double maxTime) const 00302 { 00303 SG_ASSERT(Contains(node)); 00304 SG_ASSERT(&target != this); 00305 SG_ASSERT(target.MaxNodes() == MaxNodes()); 00306 target.Clear(); 00307 size_t allocatorId = 0; 00308 SgTimer timer; 00309 bool abort = false; 00310 CopySubtree(target, target.m_root, node, 0, allocatorId, warnTruncate, 00311 abort, timer, maxTime); 00312 } 00313 00314 void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node, 00315 const std::vector<SgMoveInfo>& moves, 00316 bool deleteChildTrees) 00317 { 00318 SG_ASSERT(Contains(node)); 00319 // Parameters are const-references, because only the tree is allowed 00320 // to modify nodes 00321 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00322 size_t nuNewChildren = moves.size(); 00323 00324 if (nuNewChildren == 0) 00325 { 00326 // Write order dependency 00327 nonConstNode.SetNuChildren(0); 00328 SgSynchronizeThreadMemory(); 00329 nonConstNode.SetFirstChild(0); 00330 return; 00331 } 00332 00333 SgUctAllocator& allocator = Allocator(allocatorId); 00334 SG_ASSERT(allocator.HasCapacity(nuNewChildren)); 00335 00336 const SgUctNode* newFirstChild = allocator.Finish(); 00337 std::size_t parentCount = allocator.Create(moves); 00338 00339 // Update new children with data in old children 00340 for (std::size_t i = 0; i < moves.size(); ++i) 00341 { 00342 SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]); 00343 for (SgUctChildIterator it(*this, node); it; ++it) 00344 { 00345 const SgUctNode& oldChild = *it; 00346 if (oldChild.Move() == moves[i].m_move) 00347 { 00348 newChild->MergeResults(oldChild); 00349 newChild->SetKnowledgeCount(oldChild.KnowledgeCount()); 00350 if (! deleteChildTrees) 00351 { 00352 newChild->SetPosCount(oldChild.PosCount()); 00353 parentCount += oldChild.MoveCount(); 00354 if (oldChild.HasChildren()) 00355 { 00356 newChild->SetFirstChild(oldChild.FirstChild()); 00357 newChild->SetNuChildren(oldChild.NuChildren()); 00358 } 00359 } 00360 break; 00361 } 00362 } 00363 } 00364 nonConstNode.SetPosCount(parentCount); 00365 00366 // Write order dependency: We do not want an SgUctChildIterator to 00367 // run past the end of a node's children, which can happen if one 00368 // is created between the two statements below. We modify node in 00369 // such a way so as to avoid that. 00370 SgSynchronizeThreadMemory(); 00371 if (nonConstNode.NuChildren() < (int)nuNewChildren) 00372 { 00373 nonConstNode.SetFirstChild(newFirstChild); 00374 SgSynchronizeThreadMemory(); 00375 nonConstNode.SetNuChildren(nuNewChildren); 00376 } 00377 else 00378 { 00379 nonConstNode.SetNuChildren(nuNewChildren); 00380 SgSynchronizeThreadMemory(); 00381 nonConstNode.SetFirstChild(newFirstChild); 00382 } 00383 } 00384 00385 std::size_t SgUctTree::NuNodes() const 00386 { 00387 size_t nuNodes = 1; // Count root node 00388 for (size_t i = 0; i < NuAllocators(); ++i) 00389 nuNodes += Allocator(i).NuNodes(); 00390 return nuNodes; 00391 } 00392 00393 void SgUctTree::RemoveVirtualLoss(const std::vector<const SgUctNode*>& nodes) 00394 { 00395 for (size_t i = 0; i < nodes.size(); ++i) 00396 { 00397 const SgUctNode* father = (i > 0 ? nodes[i-1] : 0); 00398 RemoveGameResult(*nodes[i], father, 1.0); // see AddVirtualLoss() 00399 RemoveRaveValue(*nodes[i], 0.0, 1.0); 00400 } 00401 } 00402 00403 void SgUctTree::SetMaxNodes(std::size_t maxNodes) 00404 { 00405 Clear(); 00406 size_t nuAllocators = NuAllocators(); 00407 if (nuAllocators == 0) 00408 { 00409 SgDebug() << "SgUctTree::SetMaxNodes: no allocators registered\n"; 00410 SG_ASSERT(false); 00411 return; 00412 } 00413 m_maxNodes = maxNodes; 00414 size_t maxNodesPerAlloc = maxNodes / nuAllocators; 00415 for (size_t i = 0; i < NuAllocators(); ++i) 00416 Allocator(i).SetMaxNodes(maxNodesPerAlloc); 00417 } 00418 00419 void SgUctTree::Swap(SgUctTree& tree) 00420 { 00421 SG_ASSERT(MaxNodes() == tree.MaxNodes()); 00422 SG_ASSERT(NuAllocators() == tree.NuAllocators()); 00423 swap(m_root, tree.m_root); 00424 for (size_t i = 0; i < NuAllocators(); ++i) 00425 Allocator(i).Swap(tree.Allocator(i)); 00426 } 00427 00428 void SgUctTree::ThrowConsistencyError(const string& message) const 00429 { 00430 DumpDebugInfo(SgDebug()); 00431 throw SgException("SgUctTree::ThrowConsistencyError: " + message); 00432 } 00433 00434 //---------------------------------------------------------------------------- 00435 00436 SgUctTreeIterator::SgUctTreeIterator(const SgUctTree& tree) 00437 : m_tree(tree), 00438 m_current(&tree.Root()) 00439 { 00440 } 00441 00442 const SgUctNode& SgUctTreeIterator::operator*() const 00443 { 00444 return *m_current; 00445 } 00446 00447 void SgUctTreeIterator::operator++() 00448 { 00449 if (m_current->HasChildren()) 00450 { 00451 SgUctChildIterator* it = new SgUctChildIterator(m_tree, *m_current); 00452 m_stack.push(shared_ptr<SgUctChildIterator>(it)); 00453 m_current = &(**it); 00454 return; 00455 } 00456 while (! m_stack.empty()) 00457 { 00458 SgUctChildIterator& it = *m_stack.top(); 00459 SG_ASSERT(it); 00460 ++it; 00461 if (it) 00462 { 00463 m_current = &(*it); 00464 return; 00465 } 00466 else 00467 { 00468 m_stack.pop(); 00469 m_current = 0; 00470 } 00471 } 00472 m_current = 0; 00473 } 00474 00475 SgUctTreeIterator::operator bool() const 00476 { 00477 return (m_current != 0); 00478 } 00479 00480 //----------------------------------------------------------------------------