00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctTree.h 00003 Class SgUctTree and strongly related classes. 00004 */ 00005 //---------------------------------------------------------------------------- 00006 00007 #ifndef SG_UCTTREE_H 00008 #define SG_UCTTREE_H 00009 00010 #include <stack> 00011 #include <boost/shared_ptr.hpp> 00012 #include "SgMove.h" 00013 #include "SgStatistics.h" 00014 00015 class SgTimer; 00016 00017 //---------------------------------------------------------------------------- 00018 00019 typedef SgStatisticsBase<float,std::size_t> SgUctStatisticsBase; 00020 00021 typedef SgStatisticsBase<volatile float,volatile std::size_t> 00022 SgUctStatisticsBaseVolatile; 00023 00024 //---------------------------------------------------------------------------- 00025 00026 /** Used for node creation. */ 00027 struct SgMoveInfo 00028 { 00029 /** Move for the child. */ 00030 SgMove m_move; 00031 00032 /** Value of node after node is created. 00033 Value is from child's perspective, so the value stored here 00034 must be the inverse of the evaluation from the parent's 00035 perspective. 00036 */ 00037 float m_value; 00038 00039 /** Count of node after node is created. */ 00040 std::size_t m_count; 00041 00042 /** Rave value of move after node is created from viewpoint of 00043 parent node. 00044 Value should not be inverted to child's perspective. 00045 */ 00046 float m_raveValue; 00047 00048 /** Rave count of move after node is created. */ 00049 float m_raveCount; 00050 00051 SgMoveInfo(); 00052 00053 SgMoveInfo(SgMove move); 00054 00055 SgMoveInfo(SgMove move, float value, std::size_t count, 00056 float raveValue, float raveCount); 00057 }; 00058 00059 inline SgMoveInfo::SgMoveInfo() 00060 : m_value(0.0), 00061 m_count(0), 00062 m_raveValue(0.0), 00063 m_raveCount(0.0) 00064 { 00065 } 00066 00067 inline SgMoveInfo::SgMoveInfo(SgMove move) 00068 : m_move(move), 00069 m_value(0.0), 00070 m_count(0), 00071 m_raveValue(0.0), 00072 m_raveCount(0.0) 00073 { 00074 } 00075 00076 inline SgMoveInfo::SgMoveInfo(SgMove move, float value, std::size_t count, 00077 float raveValue, float raveCount) 00078 : m_move(move), 00079 m_value(value), 00080 m_count(count), 00081 m_raveValue(raveValue), 00082 m_raveCount(raveCount) 00083 { 00084 } 00085 00086 //---------------------------------------------------------------------------- 00087 00088 /** Types of proven nodes. */ 00089 typedef enum 00090 { 00091 /** Node is not a proven win or loss. */ 00092 SG_NOT_PROVEN, 00093 00094 /** Node is a proven win. */ 00095 SG_PROVEN_WIN, 00096 00097 /** Node is a proven loss. */ 00098 SG_PROVEN_LOSS 00099 00100 } SgProvenNodeType; 00101 00102 //---------------------------------------------------------------------------- 00103 00104 /** Node used in SgUctTree. 00105 All data members are declared as volatile to avoid that the compiler 00106 re-orders writes, which can break assumptions made by SgUctSearch in 00107 lock-free mode (see @ref sguctsearchlockfree). For example, the search 00108 relies on the fact that m_firstChild is valid, if m_nuChildren is greater 00109 zero or that the mean value of the move and RAVE value statistics is valid 00110 if the corresponding count is greater zero. 00111 @ingroup sguctgroup 00112 */ 00113 class SgUctNode 00114 { 00115 public: 00116 /** Initializes node with given move, value and count. */ 00117 SgUctNode(const SgMoveInfo& info); 00118 00119 /** Add game result. 00120 @param eval The game result (e.g. score or 0/1 for win loss) 00121 */ 00122 void AddGameResult(float eval); 00123 00124 /** Adds a game result count times. */ 00125 void AddGameResults(float eval, std::size_t count); 00126 00127 /** Add other nodes results to this node's. */ 00128 void MergeResults(const SgUctNode& node); 00129 00130 /** Removes a game result. 00131 @param eval The game result (e.g. score or 0/1 for win loss) 00132 */ 00133 void RemoveGameResult(float eval); 00134 00135 /** Removes a game result count times. */ 00136 void RemoveGameResults(float eval, std::size_t count); 00137 00138 /** Number of times this node was visited. 00139 This corresponds to the sum of MoveCount() of all children. 00140 It can be different from MoveCount() of this position, if prior 00141 knowledge initialization of the children is used. 00142 */ 00143 std::size_t PosCount() const; 00144 00145 /** Number of times the move leading to this position was chosen. 00146 This count will be different from PosCount(), if prior knowledge 00147 initialization is used. 00148 */ 00149 std::size_t MoveCount() const; 00150 00151 /** Get first child. 00152 @note This information is an implementation detail of how SgUctTree 00153 manages nodes. Use SgUctChildIterator to access children nodes. 00154 */ 00155 const SgUctNode* FirstChild() const; 00156 00157 /** Does the node have at least one child? */ 00158 bool HasChildren() const; 00159 00160 /** Average game result. 00161 Requires: HasMean() 00162 */ 00163 float Mean() const; 00164 00165 /** True, if mean value is defined (move count not zero) */ 00166 bool HasMean() const; 00167 00168 /** Get number of children. 00169 @note This information is an implementation detail of how SgUctTree 00170 manages nodes. Use SgUctChildIterator to access children nodes. 00171 */ 00172 int NuChildren() const; 00173 00174 /** See FirstChild() */ 00175 void SetFirstChild(const SgUctNode* child); 00176 00177 /** See NuChildren() */ 00178 void SetNuChildren(int nuChildren); 00179 00180 /** Increment the position count. 00181 See PosCount() 00182 */ 00183 void IncPosCount(); 00184 00185 /** Decrement the position count. 00186 See PosCount() 00187 */ 00188 void DecPosCount(); 00189 00190 void SetPosCount(std::size_t value); 00191 00192 /** Initialize value with prior knowledge. */ 00193 void InitializeValue(float value, std::size_t count); 00194 00195 /** Copy data from other node. 00196 Copies all data, apart from the children information (first child 00197 and number of children). 00198 */ 00199 void CopyDataFrom(const SgUctNode& node); 00200 00201 /** Get move. 00202 Requires: Node has a move (is not root node) 00203 */ 00204 SgMove Move() const; 00205 00206 /** Get RAVE count. 00207 @see SgUctSearch::Rave(). 00208 */ 00209 float RaveCount() const; 00210 00211 /** Get RAVE mean value. 00212 Requires: HasRaveValue() 00213 @see SgUctSearch::Rave(). 00214 */ 00215 float RaveValue() const; 00216 00217 bool HasRaveValue() const; 00218 00219 /** Add a game result value to the RAVE value. 00220 @see SgUctSearch::Rave(). 00221 */ 00222 void AddRaveValue(float value, float weight); 00223 00224 /** Removes a rave result. */ 00225 void RemoveRaveValue(float value); 00226 00227 void RemoveRaveValue(float value, float weight); 00228 00229 /** Initialize RAVE value with prior knowledge. */ 00230 void InitializeRaveValue(float value, float count); 00231 00232 /** Returns the last time knowledge was computed. */ 00233 std::size_t KnowledgeCount() const; 00234 00235 /** Set that knowledge has been computed at count. */ 00236 void SetKnowledgeCount(std::size_t count); 00237 00238 /** Returns true if node is a proven node. */ 00239 bool IsProven() const; 00240 00241 bool IsProvenWin() const; 00242 00243 bool IsProvenLoss() const; 00244 00245 SgProvenNodeType ProvenNodeType() const; 00246 00247 void SetProvenNodeType(SgProvenNodeType type); 00248 00249 private: 00250 SgStatisticsBase<volatile float,volatile std::size_t> m_statistics; 00251 00252 const SgUctNode* volatile m_firstChild; 00253 00254 volatile int m_nuChildren; 00255 00256 volatile SgMove m_move; 00257 00258 /** RAVE statistics. 00259 Uses double for count to allow adding fractional values if RAVE 00260 updates are weighted. 00261 */ 00262 SgStatisticsBase<volatile float,volatile float> m_raveValue; 00263 00264 volatile std::size_t m_posCount; 00265 00266 volatile std::size_t m_knowledgeCount; 00267 00268 volatile SgProvenNodeType m_provenType; 00269 }; 00270 00271 inline SgUctNode::SgUctNode(const SgMoveInfo& info) 00272 : m_statistics(info.m_value, info.m_count), 00273 m_nuChildren(0), 00274 m_move(info.m_move), 00275 m_raveValue(info.m_raveValue, info.m_raveCount), 00276 m_posCount(0), 00277 m_knowledgeCount(0), 00278 m_provenType(SG_NOT_PROVEN) 00279 { 00280 // m_firstChild is not initialized, only defined if m_nuChildren > 0 00281 } 00282 00283 inline void SgUctNode::AddGameResult(float eval) 00284 { 00285 m_statistics.Add(eval); 00286 } 00287 00288 inline void SgUctNode::AddGameResults(float eval, std::size_t count) 00289 { 00290 m_statistics.Add(eval, count); 00291 } 00292 00293 inline void SgUctNode::MergeResults(const SgUctNode& node) 00294 { 00295 if (node.m_statistics.IsDefined()) 00296 m_statistics.Add(node.m_statistics.Mean(), node.m_statistics.Count()); 00297 if (node.m_raveValue.IsDefined()) 00298 m_raveValue.Add(node.m_raveValue.Mean(), node.m_raveValue.Count()); 00299 } 00300 00301 inline void SgUctNode::RemoveGameResult(float eval) 00302 { 00303 m_statistics.Remove(eval); 00304 } 00305 00306 inline void SgUctNode::RemoveGameResults(float eval, std::size_t count) 00307 { 00308 m_statistics.Remove(eval, count); 00309 } 00310 00311 inline void SgUctNode::AddRaveValue(float value, float weight) 00312 { 00313 m_raveValue.Add(value, weight); 00314 } 00315 00316 inline void SgUctNode::RemoveRaveValue(float value) 00317 { 00318 m_raveValue.Remove(value); 00319 } 00320 00321 inline void SgUctNode::RemoveRaveValue(float value, float weight) 00322 { 00323 m_raveValue.Remove(value, weight); 00324 } 00325 00326 inline void SgUctNode::CopyDataFrom(const SgUctNode& node) 00327 { 00328 m_statistics = node.m_statistics; 00329 m_move = node.m_move; 00330 m_raveValue = node.m_raveValue; 00331 m_posCount = node.m_posCount; 00332 m_knowledgeCount = node.m_knowledgeCount; 00333 m_provenType = node.m_provenType; 00334 } 00335 00336 inline const SgUctNode* SgUctNode::FirstChild() const 00337 { 00338 SG_ASSERT(HasChildren()); // Otherwise m_firstChild is undefined 00339 return m_firstChild; 00340 } 00341 00342 inline bool SgUctNode::HasChildren() const 00343 { 00344 return (m_nuChildren > 0); 00345 } 00346 00347 inline bool SgUctNode::HasMean() const 00348 { 00349 return m_statistics.IsDefined(); 00350 } 00351 00352 inline bool SgUctNode::HasRaveValue() const 00353 { 00354 return m_raveValue.IsDefined(); 00355 } 00356 00357 inline void SgUctNode::IncPosCount() 00358 { 00359 ++m_posCount; 00360 } 00361 00362 inline void SgUctNode::DecPosCount() 00363 { 00364 --m_posCount; 00365 } 00366 00367 inline void SgUctNode::InitializeValue(float value, std::size_t count) 00368 { 00369 m_statistics.Initialize(value, count); 00370 } 00371 00372 inline void SgUctNode::InitializeRaveValue(float value, float count) 00373 { 00374 m_raveValue.Initialize(value, count); 00375 } 00376 00377 inline float SgUctNode::Mean() const 00378 { 00379 return m_statistics.Mean(); 00380 } 00381 00382 inline SgMove SgUctNode::Move() const 00383 { 00384 SG_ASSERT(m_move != SG_NULLMOVE); 00385 return m_move; 00386 } 00387 00388 inline std::size_t SgUctNode::MoveCount() const 00389 { 00390 return m_statistics.Count(); 00391 } 00392 00393 inline int SgUctNode::NuChildren() const 00394 { 00395 return m_nuChildren; 00396 } 00397 00398 inline std::size_t SgUctNode::PosCount() const 00399 { 00400 return m_posCount; 00401 } 00402 00403 inline float SgUctNode::RaveCount() const 00404 { 00405 return m_raveValue.Count(); 00406 } 00407 00408 inline float SgUctNode::RaveValue() const 00409 { 00410 return m_raveValue.Mean(); 00411 } 00412 00413 inline void SgUctNode::SetFirstChild(const SgUctNode* child) 00414 { 00415 m_firstChild = child; 00416 } 00417 00418 inline void SgUctNode::SetNuChildren(int nuChildren) 00419 { 00420 SG_ASSERT(nuChildren >= 0); 00421 m_nuChildren = nuChildren; 00422 } 00423 00424 inline void SgUctNode::SetPosCount(std::size_t value) 00425 { 00426 m_posCount = value; 00427 } 00428 00429 inline std::size_t SgUctNode::KnowledgeCount() const 00430 { 00431 return m_knowledgeCount; 00432 } 00433 00434 inline void SgUctNode::SetKnowledgeCount(std::size_t count) 00435 { 00436 m_knowledgeCount = count; 00437 } 00438 00439 inline bool SgUctNode::IsProven() const 00440 { 00441 return m_provenType != SG_NOT_PROVEN; 00442 } 00443 00444 inline bool SgUctNode::IsProvenWin() const 00445 { 00446 return m_provenType == SG_PROVEN_WIN; 00447 } 00448 00449 inline bool SgUctNode::IsProvenLoss() const 00450 { 00451 return m_provenType == SG_PROVEN_LOSS; 00452 } 00453 00454 inline SgProvenNodeType SgUctNode::ProvenNodeType() const 00455 { 00456 return m_provenType; 00457 } 00458 00459 inline void SgUctNode::SetProvenNodeType(SgProvenNodeType type) 00460 { 00461 m_provenType = type; 00462 } 00463 00464 //---------------------------------------------------------------------------- 00465 00466 /** Allocater for nodes used in the implementation of SgUctTree. 00467 Each thread has its own node allocator to allow lock-free usage of 00468 SgUctTree. 00469 @ingroup sguctgroup 00470 */ 00471 class SgUctAllocator 00472 { 00473 public: 00474 SgUctAllocator(); 00475 00476 ~SgUctAllocator(); 00477 00478 void Clear(); 00479 00480 /** Does the allocator have the capacity for n more nodes? */ 00481 bool HasCapacity(std::size_t n) const; 00482 00483 std::size_t NuNodes() const; 00484 00485 std::size_t MaxNodes() const; 00486 00487 void SetMaxNodes(std::size_t maxNodes); 00488 00489 /** Check if allocator contains node. 00490 This function uses pointer comparisons. Since the result of 00491 comparisons for pointers to elements in different containers 00492 is platform-dependent, it is only guaranteed that it returns true, 00493 if not node belongs to the allocator, but not that it returns false 00494 for nodes not in the allocator. 00495 */ 00496 bool Contains(const SgUctNode& node) const; 00497 00498 const SgUctNode* Start() const; 00499 00500 SgUctNode* Finish(); 00501 00502 const SgUctNode* Finish() const; 00503 00504 /** Create a new node at the end of the storage. 00505 REQUIRES: HasCapacity(1) 00506 @param move The constructor argument. 00507 @return A pointer to new newly created node. 00508 */ 00509 SgUctNode* CreateOne(SgMove move); 00510 00511 /** Create a number of new nodes with a given list of moves at the end of 00512 the storage. Returns the sum of counts of moves. 00513 REQUIRES: HasCapacity(moves.size()) 00514 @param moves The list of moves. 00515 */ 00516 std::size_t Create(const std::vector<SgMoveInfo>& moves); 00517 00518 /** Create a number of new nodes at the end of the storage. 00519 REQUIRES: HasCapacity(n) 00520 @param n The number of nodes to create. 00521 */ 00522 void CreateN(std::size_t n); 00523 00524 void Swap(SgUctAllocator& allocator); 00525 00526 private: 00527 SgUctNode* m_start; 00528 00529 SgUctNode* m_finish; 00530 00531 SgUctNode* m_endOfStorage; 00532 00533 /** Not implemented. 00534 Cannot be copied because array contains pointers to elements. 00535 Use Swap() instead. 00536 */ 00537 SgUctAllocator& operator=(const SgUctAllocator& tree); 00538 }; 00539 00540 inline SgUctAllocator::SgUctAllocator() 00541 { 00542 m_start = 0; 00543 } 00544 00545 inline void SgUctAllocator::Clear() 00546 { 00547 if (m_start != 0) 00548 { 00549 for (SgUctNode* it = m_start; it != m_finish; ++it) 00550 it->~SgUctNode(); 00551 m_finish = m_start; 00552 } 00553 } 00554 00555 inline SgUctNode* SgUctAllocator::CreateOne(SgMove move) 00556 { 00557 SG_ASSERT(HasCapacity(1)); 00558 new(m_finish) SgUctNode(move); 00559 return (m_finish++); 00560 } 00561 00562 inline std::size_t SgUctAllocator::Create( 00563 const std::vector<SgMoveInfo>& moves) 00564 { 00565 SG_ASSERT(HasCapacity(moves.size())); 00566 std::size_t count = 0; 00567 for (std::vector<SgMoveInfo>::const_iterator it = moves.begin(); 00568 it != moves.end(); ++it, ++m_finish) 00569 { 00570 new(m_finish) SgUctNode(*it); 00571 count += it->m_count; 00572 } 00573 return count; 00574 } 00575 00576 inline void SgUctAllocator::CreateN(std::size_t n) 00577 { 00578 SG_ASSERT(HasCapacity(n)); 00579 SgUctNode* newFinish = m_finish + n; 00580 for ( ; m_finish != newFinish; ++m_finish) 00581 new(m_finish) SgUctNode(SG_NULLMOVE); 00582 } 00583 00584 inline SgUctNode* SgUctAllocator::Finish() 00585 { 00586 return m_finish; 00587 } 00588 00589 inline const SgUctNode* SgUctAllocator::Finish() const 00590 { 00591 return m_finish; 00592 } 00593 00594 inline bool SgUctAllocator::HasCapacity(std::size_t n) const 00595 { 00596 return (m_finish + n <= m_endOfStorage); 00597 } 00598 00599 inline std::size_t SgUctAllocator::MaxNodes() const 00600 { 00601 return m_endOfStorage - m_start; 00602 } 00603 00604 inline std::size_t SgUctAllocator::NuNodes() const 00605 { 00606 return m_finish - m_start; 00607 } 00608 00609 inline const SgUctNode* SgUctAllocator::Start() const 00610 { 00611 return m_start; 00612 } 00613 00614 //---------------------------------------------------------------------------- 00615 00616 /** Tree used in SgUctSearch. 00617 The nodes can be accessed only by getting non-const references or modified 00618 through accessor functions of SgUctTree, therefore SgUctTree can guarantee 00619 the integrity of the tree structure. 00620 The tree can be used in a lock-free way during a search (see 00621 @ref sguctsearchlockfree). 00622 @ingroup sguctgroup 00623 */ 00624 class SgUctTree 00625 { 00626 public: 00627 friend class SgUctChildIterator; 00628 00629 /** Constructor. 00630 Construct a tree. Before using the tree, CreateAllocators() and 00631 SetMaxNodes() must be called (in this order). 00632 */ 00633 SgUctTree(); 00634 00635 /** Create node allocators for threads. */ 00636 void CreateAllocators(std::size_t nuThreads); 00637 00638 /** Add a game result. 00639 @param node The node. 00640 @param father The father (if not root) to update the position count. 00641 @param eval 00642 */ 00643 void AddGameResult(const SgUctNode& node, const SgUctNode* father, 00644 float eval); 00645 00646 /** Adds a game result count times. */ 00647 void AddGameResults(const SgUctNode& node, const SgUctNode* father, 00648 float eval, std::size_t count); 00649 00650 /** Removes a game result. 00651 @param node The node. 00652 @param father The father (if not root) to update the position count. 00653 @param eval 00654 */ 00655 void RemoveGameResult(const SgUctNode& node, const SgUctNode* father, 00656 float eval); 00657 00658 /** Removes a game result count times. */ 00659 void RemoveGameResults(const SgUctNode& node, const SgUctNode* father, 00660 float eval, std::size_t count); 00661 00662 /** Adds a virtual loss to the given nodes. */ 00663 void AddVirtualLoss(const std::vector<const SgUctNode*>& nodes); 00664 00665 /** Removes a virtual loss to the given nodes. */ 00666 void RemoveVirtualLoss(const std::vector<const SgUctNode*>& nodes); 00667 00668 void Clear(); 00669 00670 /** Return the current maximum number of nodes. 00671 This returns the maximum number of nodes as set by SetMaxNodes(). 00672 See SetMaxNodes() why the real maximum number of nodes can be higher 00673 or lower. 00674 */ 00675 std::size_t MaxNodes() const; 00676 00677 /** Change maximum number of nodes. 00678 Also clears the tree. This will call SetMaxNodes() at each registered 00679 allocator with maxNodes / numberAllocators as an argument. The real 00680 maximum number of nodes can be higher (because the root node is 00681 owned by this class, not an allocator) or lower (if maxNodes is not 00682 a multiple of the number of allocators). 00683 @param maxNodes Maximum number of nodes 00684 */ 00685 void SetMaxNodes(std::size_t maxNodes); 00686 00687 /** Swap content with another tree. 00688 The other tree must have the same number of allocators and 00689 the same maximum number of nodes. 00690 */ 00691 void Swap(SgUctTree& tree); 00692 00693 bool HasCapacity(std::size_t allocatorId, std::size_t n) const; 00694 00695 /** Create children nodes. 00696 Requires: Allocator(allocatorId).HasCapacity(moves.size()) 00697 */ 00698 void CreateChildren(std::size_t allocatorId, const SgUctNode& node, 00699 const std::vector<SgMoveInfo>& moves); 00700 00701 /** Merge new children with old. 00702 Requires: Allocator(allocatorId).HasCapacity(moves.size()) 00703 */ 00704 void MergeChildren(std::size_t allocatorId, const SgUctNode& node, 00705 const std::vector<SgMoveInfo>& moves, 00706 bool deleteChildTrees); 00707 00708 /** Extract subtree to a different tree. 00709 The tree will be truncated if one of the allocators overflows (can 00710 happen due to reassigning nodes to different allocators), the given 00711 max time is exceeded or on SgUserAbort(). 00712 @param[out] target The resulting subtree. Must have the same maximum 00713 number of nodes. Will be cleared before using. 00714 @param node The start node of the subtree. 00715 @param warnTruncate Print warning to SgDebug() if tree was truncated 00716 @param maxTime Truncate the tree, if the extraction takes longer than 00717 the given time 00718 */ 00719 void ExtractSubtree(SgUctTree& target, const SgUctNode& node, 00720 bool warnTruncate, 00721 double maxTime = std::numeric_limits<double>::max()) const; 00722 00723 /** Get a copy of the tree with low count nodes pruned. 00724 The tree will be truncated if one of the allocators overflows (can 00725 happen due to reassigning nodes to different allocators), the given 00726 max time is exceeded or on SgUserAbort(). 00727 @param[out] target The resulting tree. Must have the same maximum 00728 number of nodes. Will be cleared before using. 00729 @param minCount The minimum count (SgUctNode::MoveCount()) 00730 @param warnTruncate Print warning to SgDebug() if tree was truncated 00731 @param maxTime Truncate the tree, if the extraction takes longer than 00732 the given time 00733 */ 00734 void CopyPruneLowCount(SgUctTree& target, std::size_t minCount, 00735 bool warnTruncate, 00736 double maxTime = std::numeric_limits<double>::max()) const; 00737 00738 const SgUctNode& Root() const; 00739 00740 std::size_t NuAllocators() const; 00741 00742 /** Total number of nodes. 00743 Includes the sum of nodes in all allocators plus the root node. 00744 */ 00745 std::size_t NuNodes() const; 00746 00747 /** Number of nodes in one of the allocators. */ 00748 std::size_t NuNodes(std::size_t allocatorId) const; 00749 00750 /** Add a game result value to the RAVE value of a node. 00751 @param node The node with the move 00752 @param value 00753 @param weight 00754 @see SgUctSearch::Rave(). 00755 */ 00756 void AddRaveValue(const SgUctNode& node, float value, float weight); 00757 00758 /** Remove a game result from the RAVE value of a node. 00759 @param node The node with the move 00760 @param value 00761 @param weight 00762 @see SgUctSearch::Rave(). 00763 */ 00764 void RemoveRaveValue(const SgUctNode& node, float value, float weight); 00765 00766 /** Initialize the value and count of a node. */ 00767 void InitializeValue(const SgUctNode& node, float value, 00768 std::size_t count); 00769 00770 void SetPosCount(const SgUctNode& node, std::size_t posCount); 00771 00772 /** Initialize the rave value and count of a move node with prior 00773 knowledge. 00774 */ 00775 void InitializeRaveValue(const SgUctNode& node, float value, float count); 00776 00777 /** Remove some children of a node according to a list of filtered moves. 00778 Requires: Allocator(allocatorId).HasCapacity(node.NuChildren()) <br> 00779 For efficiency, no reorganization of the tree is done to remove 00780 the dead subtrees (and NuNodes() will not report the real number of 00781 nodes in the tree). This function can be used in lock-free mode. 00782 */ 00783 void ApplyFilter(std::size_t allocatorId, const SgUctNode& node, 00784 const std::vector<SgMove>& rootFilter); 00785 00786 /** Sets the children under node to be exactly those in moves, 00787 reusing the old children if possible. Children not in moves 00788 are pruned, children missing from moves are added as leaves. 00789 Requires: Allocator(allocatorId).HasCapacity(moves.size()) 00790 */ 00791 void SetChildren(std::size_t allocatorId, const SgUctNode& node, 00792 const vector<SgMove>& moves); 00793 00794 /** @name Functions for debugging */ 00795 // @{ 00796 00797 /** Do some consistency checks. 00798 @throws SgException if inconsistencies are detected. 00799 */ 00800 void CheckConsistency() const; 00801 00802 /** Check if tree contains node. 00803 This function uses pointer comparisons. Since the result of 00804 comparisons for pointers to elements in different containers 00805 is platform-dependent, it is only guaranteed that it returns true, 00806 if not node belongs to the allocator, but not that it returns false 00807 for nodes not in the tree. 00808 */ 00809 bool Contains(const SgUctNode& node) const; 00810 00811 void DumpDebugInfo(std::ostream& out) const; 00812 00813 // @} // @name 00814 00815 private: 00816 std::size_t m_maxNodes; 00817 00818 SgUctNode m_root; 00819 00820 /** Allocators. 00821 The elements are owned by the vector (shared_ptr is only used because 00822 auto_ptr should not be used with standard containers) 00823 */ 00824 std::vector<boost::shared_ptr<SgUctAllocator> > m_allocators; 00825 00826 /** Not implemented. 00827 Cannot be copied because allocators contain pointers to elements. 00828 Use SgUctTree::Swap instead. 00829 */ 00830 SgUctTree& operator=(const SgUctTree& tree); 00831 00832 SgUctAllocator& Allocator(std::size_t i); 00833 00834 const SgUctAllocator& Allocator(std::size_t i) const; 00835 00836 void CopySubtree(SgUctTree& target, SgUctNode& targetNode, 00837 const SgUctNode& node, std::size_t minCount, 00838 std::size_t& currentAllocatorId, bool warnTruncate, 00839 bool& abort, SgTimer& timer, double maxTime) const; 00840 00841 void ThrowConsistencyError(const std::string& message) const; 00842 }; 00843 00844 inline void SgUctTree::AddGameResult(const SgUctNode& node, 00845 const SgUctNode* father, float eval) 00846 { 00847 SG_ASSERT(Contains(node)); 00848 // Parameters are const-references, because only the tree is allowed 00849 // to modify nodes 00850 if (father != 0) 00851 const_cast<SgUctNode*>(father)->IncPosCount(); 00852 const_cast<SgUctNode&>(node).AddGameResult(eval); 00853 } 00854 00855 inline void SgUctTree::AddGameResults(const SgUctNode& node, 00856 const SgUctNode* father, float eval, 00857 std::size_t count) 00858 { 00859 00860 SG_ASSERT(Contains(node)); 00861 // Parameters are const-references, because only the tree is allowed 00862 // to modify nodes 00863 if (father != 0) 00864 const_cast<SgUctNode*>(father)->SetPosCount(father->PosCount() 00865 + count); 00866 const_cast<SgUctNode&>(node).AddGameResults(eval, count); 00867 } 00868 00869 inline void SgUctTree::CreateChildren(std::size_t allocatorId, 00870 const SgUctNode& node, 00871 const std::vector<SgMoveInfo>& moves) 00872 { 00873 SG_ASSERT(Contains(node)); 00874 // Parameters are const-references, because only the tree is allowed 00875 // to modify nodes 00876 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00877 size_t nuChildren = moves.size(); 00878 SG_ASSERT(nuChildren > 0); 00879 SgUctAllocator& allocator = Allocator(allocatorId); 00880 SG_ASSERT(allocator.HasCapacity(nuChildren)); 00881 00882 // In lock-free multi-threading, a node can be expanded multiple times 00883 // (the later thread overwrites the children information of the previous 00884 // thread) 00885 SG_ASSERT(NuAllocators() > 1 || ! node.HasChildren()); 00886 00887 const SgUctNode* firstChild = allocator.Finish(); 00888 00889 std::size_t parentCount = allocator.Create(moves); 00890 00891 // Write order dependency: SgUctSearch in lock-free mode assumes that 00892 // m_firstChild is valid if m_nuChildren is greater zero 00893 nonConstNode.SetPosCount(parentCount); 00894 SgSynchronizeThreadMemory(); 00895 nonConstNode.SetFirstChild(firstChild); 00896 SgSynchronizeThreadMemory(); 00897 nonConstNode.SetNuChildren(nuChildren); 00898 } 00899 00900 inline void SgUctTree::RemoveGameResult(const SgUctNode& node, 00901 const SgUctNode* father, float eval) 00902 { 00903 SG_ASSERT(Contains(node)); 00904 // Parameters are const-references, because only the tree is allowed 00905 // to modify nodes 00906 if (father != 0) 00907 const_cast<SgUctNode*>(father)->DecPosCount(); 00908 const_cast<SgUctNode&>(node).RemoveGameResult(eval); 00909 } 00910 00911 inline void SgUctTree::RemoveGameResults(const SgUctNode& node, 00912 const SgUctNode* father, float eval, 00913 std::size_t count) 00914 { 00915 SG_ASSERT(Contains(node)); 00916 // Parameters are const-references, because only the tree is allowed 00917 // to modify nodes 00918 if (father != 0) 00919 const_cast<SgUctNode*>(father)->SetPosCount(father->PosCount() 00920 - count); 00921 const_cast<SgUctNode&>(node).RemoveGameResults(eval, count); 00922 } 00923 00924 inline void SgUctTree::AddRaveValue(const SgUctNode& node, float value, 00925 float weight) 00926 { 00927 SG_ASSERT(Contains(node)); 00928 // Parameters are const-references, because only the tree is allowed 00929 // to modify nodes 00930 const_cast<SgUctNode&>(node).AddRaveValue(value, weight); 00931 } 00932 00933 inline void SgUctTree::RemoveRaveValue(const SgUctNode& node, float value, 00934 float weight) 00935 { 00936 SG_UNUSED(weight); 00937 SG_ASSERT(Contains(node)); 00938 // Parameters are const-references, because only the tree is allowed 00939 // to modify nodes 00940 const_cast<SgUctNode&>(node).RemoveRaveValue(value, weight); 00941 } 00942 00943 inline SgUctAllocator& SgUctTree::Allocator(std::size_t i) 00944 { 00945 SG_ASSERT(i < m_allocators.size()); 00946 return *m_allocators[i]; 00947 } 00948 00949 inline const SgUctAllocator& SgUctTree::Allocator(std::size_t i) const 00950 { 00951 SG_ASSERT(i < m_allocators.size()); 00952 return *m_allocators[i]; 00953 } 00954 00955 inline bool SgUctTree::HasCapacity(std::size_t allocatorId, 00956 std::size_t n) const 00957 { 00958 return Allocator(allocatorId).HasCapacity(n); 00959 } 00960 00961 inline void SgUctTree::InitializeValue(const SgUctNode& node, 00962 float value, std::size_t count) 00963 { 00964 SG_ASSERT(Contains(node)); 00965 // Parameter is const-reference, because only the tree is allowed 00966 // to modify nodes 00967 const_cast<SgUctNode&>(node).InitializeValue(value, count); 00968 } 00969 00970 inline void SgUctTree::InitializeRaveValue(const SgUctNode& node, 00971 float value, float count) 00972 { 00973 SG_ASSERT(Contains(node)); 00974 // Parameters are const-references, because only the tree is allowed 00975 // to modify nodes 00976 const_cast<SgUctNode&>(node).InitializeRaveValue(value, count); 00977 } 00978 00979 inline std::size_t SgUctTree::MaxNodes() const 00980 { 00981 return m_maxNodes; 00982 } 00983 00984 inline std::size_t SgUctTree::NuAllocators() const 00985 { 00986 return m_allocators.size(); 00987 } 00988 00989 inline std::size_t SgUctTree::NuNodes(std::size_t allocatorId) const 00990 { 00991 return Allocator(allocatorId).NuNodes(); 00992 } 00993 00994 inline const SgUctNode& SgUctTree::Root() const 00995 { 00996 return m_root; 00997 } 00998 00999 inline void SgUctTree::SetPosCount(const SgUctNode& node, 01000 std::size_t posCount) 01001 { 01002 SG_ASSERT(Contains(node)); 01003 // Parameters are const-references, because only the tree is allowed 01004 // to modify nodes 01005 const_cast<SgUctNode&>(node).SetPosCount(posCount); 01006 } 01007 01008 //---------------------------------------------------------------------------- 01009 01010 /** Iterator over all children of a node. 01011 It was intentionally implemented to be used only, if at least one child 01012 exists (checked with an assertion), since in many use cases, the case 01013 of no children needs to be handled specially and should be checked 01014 before doing a loop over all children. 01015 @ingroup sguctgroup 01016 */ 01017 class SgUctChildIterator 01018 { 01019 public: 01020 /** Constructor. 01021 Requires: node.HasChildren() 01022 */ 01023 SgUctChildIterator(const SgUctTree& tree, const SgUctNode& node); 01024 01025 const SgUctNode& operator*() const; 01026 01027 void operator++(); 01028 01029 operator bool() const; 01030 01031 private: 01032 const SgUctNode* m_current; 01033 01034 const SgUctNode* m_last; 01035 }; 01036 01037 inline SgUctChildIterator::SgUctChildIterator(const SgUctTree& tree, 01038 const SgUctNode& node) 01039 { 01040 SG_DEBUG_ONLY(tree); 01041 SG_ASSERT(tree.Contains(node)); 01042 SG_ASSERT(node.HasChildren()); 01043 m_current = node.FirstChild(); 01044 m_last = m_current + node.NuChildren(); 01045 } 01046 01047 inline const SgUctNode& SgUctChildIterator::operator*() const 01048 { 01049 return *m_current; 01050 } 01051 01052 inline void SgUctChildIterator::operator++() 01053 { 01054 ++m_current; 01055 } 01056 01057 inline SgUctChildIterator::operator bool() const 01058 { 01059 return (m_current < m_last); 01060 } 01061 01062 //---------------------------------------------------------------------------- 01063 01064 /** Iterator for traversing a tree depth-first. 01065 @ingroup sguctgroup 01066 */ 01067 class SgUctTreeIterator 01068 { 01069 public: 01070 SgUctTreeIterator(const SgUctTree& tree); 01071 01072 const SgUctNode& operator*() const; 01073 01074 void operator++(); 01075 01076 operator bool() const; 01077 01078 private: 01079 const SgUctTree& m_tree; 01080 01081 const SgUctNode* m_current; 01082 01083 /** Stack of child iterators. 01084 The elements are owned by the stack (shared_ptr is only used because 01085 auto_ptr should not be used with standard containers) 01086 */ 01087 std::stack<boost::shared_ptr<SgUctChildIterator> > m_stack; 01088 }; 01089 01090 //---------------------------------------------------------------------------- 01091 01092 #endif // SG_UCTTREE_H