00001
00002
00003
00004
00005
00006
00007 #include "SgSystem.h"
00008 #include "SgUctTree.h"
00009
00010 #include <boost/format.hpp>
00011 #include "SgDebug.h"
00012 #include "SgTimer.h"
00013
00014 using namespace std;
00015 using boost::format;
00016 using boost::shared_ptr;
00017
00018
00019
00020 SgUctAllocator::~SgUctAllocator()
00021 {
00022 if (m_start != 0)
00023 {
00024 Clear();
00025 std::free(m_start);
00026 }
00027 }
00028
00029 bool SgUctAllocator::Contains(const SgUctNode& node) const
00030 {
00031 return (&node >= m_start && &node < m_finish);
00032 }
00033
00034 void SgUctAllocator::Swap(SgUctAllocator& allocator)
00035 {
00036 swap(m_start, allocator.m_start);
00037 swap(m_finish, allocator.m_finish);
00038 swap(m_endOfStorage, allocator.m_endOfStorage);
00039 }
00040
00041 void SgUctAllocator::SetMaxNodes(std::size_t maxNodes)
00042 {
00043 if (m_start != 0)
00044 {
00045 Clear();
00046 std::free(m_start);
00047 }
00048 void* ptr = std::malloc(maxNodes * sizeof(SgUctNode));
00049 if (ptr == 0)
00050 throw std::bad_alloc();
00051 m_start = static_cast<SgUctNode*>(ptr);
00052 m_finish = m_start;
00053 m_endOfStorage = m_start + maxNodes;
00054 }
00055
00056
00057
00058 SgUctTree::SgUctTree()
00059 : m_maxNodes(0),
00060 m_root(SG_NULLMOVE)
00061 {
00062 }
00063
00064 void SgUctTree::AddVirtualLoss(const std::vector<const SgUctNode*>& nodes)
00065 {
00066 for (size_t i = 0; i < nodes.size(); ++i)
00067 {
00068 const SgUctNode* father = (i > 0 ? nodes[i-1] : 0);
00069 AddGameResult(*nodes[i], father, 1.0);
00070 AddRaveValue(*nodes[i], 0.0, 1.0);
00071 }
00072 }
00073
00074 void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node,
00075 const vector<SgMove>& rootFilter)
00076 {
00077 SG_ASSERT(Contains(node));
00078 SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren()));
00079 if (! node.HasChildren())
00080 return;
00081
00082 SgUctAllocator& allocator = Allocator(allocatorId);
00083 const SgUctNode* firstChild = allocator.Finish();
00084
00085 int nuChildren = 0;
00086 for (SgUctChildIterator it(*this, node); it; ++it)
00087 {
00088 SgMove move = (*it).Move();
00089 if (find(rootFilter.begin(), rootFilter.end(), move)
00090 == rootFilter.end())
00091 {
00092 SgUctNode* child = allocator.CreateOne(move);
00093 child->CopyDataFrom(*it);
00094 int childNuChildren = (*it).NuChildren();
00095 child->SetNuChildren(childNuChildren);
00096 if (childNuChildren > 0)
00097 child->SetFirstChild((*it).FirstChild());
00098 ++nuChildren;
00099 }
00100 }
00101
00102 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00103
00104
00105 SgSynchronizeThreadMemory();
00106 nonConstNode.SetFirstChild(firstChild);
00107 SgSynchronizeThreadMemory();
00108 nonConstNode.SetNuChildren(nuChildren);
00109 }
00110
00111 void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node,
00112 const vector<SgMove>& moves)
00113 {
00114 SG_ASSERT(Contains(node));
00115 SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size()));
00116 SG_ASSERT(node.HasChildren());
00117
00118 SgUctAllocator& allocator = Allocator(allocatorId);
00119 const SgUctNode* firstChild = allocator.Finish();
00120
00121 int nuChildren = 0;
00122 for (size_t i = 0; i < moves.size(); ++i)
00123 {
00124 bool found = false;
00125 for (SgUctChildIterator it(*this, node); it; ++it)
00126 {
00127 SgMove move = (*it).Move();
00128 if (move == moves[i])
00129 {
00130 found = true;
00131 SgUctNode* child = allocator.CreateOne(move);
00132 child->CopyDataFrom(*it);
00133 int childNuChildren = (*it).NuChildren();
00134 child->SetNuChildren(childNuChildren);
00135 if (childNuChildren > 0)
00136 child->SetFirstChild((*it).FirstChild());
00137 ++nuChildren;
00138 break;
00139 }
00140 }
00141 if (! found)
00142 {
00143 allocator.CreateOne(moves[i]);
00144 ++nuChildren;
00145 }
00146 }
00147 SG_ASSERT((size_t)nuChildren == moves.size());
00148
00149 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00150
00151
00152 SgSynchronizeThreadMemory();
00153 nonConstNode.SetFirstChild(firstChild);
00154 SgSynchronizeThreadMemory();
00155 nonConstNode.SetNuChildren(nuChildren);
00156 }
00157
00158 void SgUctTree::CheckConsistency() const
00159 {
00160 for (SgUctTreeIterator it(*this); it; ++it)
00161 if (! Contains(*it))
00162 ThrowConsistencyError(str(format("! Contains(%1%)") % &(*it)));
00163 }
00164
00165 void SgUctTree::Clear()
00166 {
00167 for (size_t i = 0; i < NuAllocators(); ++i)
00168 Allocator(i).Clear();
00169 m_root = SgUctNode(SG_NULLMOVE);
00170 }
00171
00172
00173
00174
00175 bool SgUctTree::Contains(const SgUctNode& node) const
00176 {
00177 if (&node == &m_root)
00178 return true;
00179 for (size_t i = 0; i < NuAllocators(); ++i)
00180 if (Allocator(i).Contains(node))
00181 return true;
00182 return false;
00183 }
00184
00185 void SgUctTree::CopyPruneLowCount(SgUctTree& target, std::size_t minCount,
00186 bool warnTruncate, double maxTime) const
00187 {
00188 size_t allocatorId = 0;
00189 SgTimer timer;
00190 bool abort = false;
00191 CopySubtree(target, target.m_root, m_root, minCount, allocatorId,
00192 warnTruncate, abort, timer, maxTime);
00193 }
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212 void SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode,
00213 const SgUctNode& node, std::size_t minCount,
00214 std::size_t& currentAllocatorId,
00215 bool warnTruncate, bool& abort, SgTimer& timer,
00216 double maxTime) const
00217 {
00218 SG_ASSERT(Contains(node));
00219 SG_ASSERT(target.Contains(targetNode));
00220 targetNode.CopyDataFrom(node);
00221
00222 if (! node.HasChildren() || node.MoveCount() < minCount)
00223 return;
00224
00225 SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId);
00226 int nuChildren = node.NuChildren();
00227 if (! abort)
00228 {
00229 if (! targetAllocator.HasCapacity(nuChildren))
00230 {
00231
00232
00233 if (warnTruncate)
00234 SgDebug() <<
00235 "SgUctTree::CopySubtree: Truncated (allocator capacity)\n";
00236 abort = true;
00237 }
00238 if (timer.IsTimeOut(maxTime, 10000))
00239 {
00240 if (warnTruncate)
00241 SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n";
00242 abort = true;
00243 }
00244 if (SgUserAbort())
00245 {
00246 if (warnTruncate)
00247 SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n";
00248 abort = true;
00249 }
00250 }
00251 if (abort)
00252 {
00253
00254
00255 targetNode.SetPosCount(0);
00256 return;
00257 }
00258
00259 SgUctNode* firstTargetChild = targetAllocator.Finish();
00260 targetNode.SetFirstChild(firstTargetChild);
00261 targetNode.SetNuChildren(nuChildren);
00262
00263
00264 targetAllocator.CreateN(nuChildren);
00265
00266
00267 SgUctNode* targetChild = firstTargetChild;
00268 for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild)
00269 {
00270 const SgUctNode& child = *it;
00271 ++currentAllocatorId;
00272 if (currentAllocatorId >= target.NuAllocators())
00273 currentAllocatorId = 0;
00274 CopySubtree(target, *targetChild, child, minCount, currentAllocatorId,
00275 warnTruncate, abort, timer, maxTime);
00276 }
00277 }
00278
00279 void SgUctTree::CreateAllocators(std::size_t nuThreads)
00280 {
00281 Clear();
00282 m_allocators.clear();
00283 for (size_t i = 0; i < nuThreads; ++i)
00284 {
00285 boost::shared_ptr<SgUctAllocator> allocator(new SgUctAllocator());
00286 m_allocators.push_back(allocator);
00287 }
00288 }
00289
00290 void SgUctTree::DumpDebugInfo(std::ostream& out) const
00291 {
00292 out << "Root " << &m_root << '\n';
00293 for (size_t i = 0; i < NuAllocators(); ++i)
00294 out << "Allocator " << i
00295 << " size=" << Allocator(i).NuNodes()
00296 << " start=" << Allocator(i).Start()
00297 << " finish=" << Allocator(i).Finish() << '\n';
00298 }
00299
00300 void SgUctTree::ExtractSubtree(SgUctTree& target, const SgUctNode& node,
00301 bool warnTruncate, double maxTime) const
00302 {
00303 SG_ASSERT(Contains(node));
00304 SG_ASSERT(&target != this);
00305 SG_ASSERT(target.MaxNodes() == MaxNodes());
00306 target.Clear();
00307 size_t allocatorId = 0;
00308 SgTimer timer;
00309 bool abort = false;
00310 CopySubtree(target, target.m_root, node, 0, allocatorId, warnTruncate,
00311 abort, timer, maxTime);
00312 }
00313
00314 void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node,
00315 const std::vector<SgMoveInfo>& moves,
00316 bool deleteChildTrees)
00317 {
00318 SG_ASSERT(Contains(node));
00319
00320
00321 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00322 size_t nuNewChildren = moves.size();
00323
00324 if (nuNewChildren == 0)
00325 {
00326
00327 nonConstNode.SetNuChildren(0);
00328 SgSynchronizeThreadMemory();
00329 nonConstNode.SetFirstChild(0);
00330 return;
00331 }
00332
00333 SgUctAllocator& allocator = Allocator(allocatorId);
00334 SG_ASSERT(allocator.HasCapacity(nuNewChildren));
00335
00336 const SgUctNode* newFirstChild = allocator.Finish();
00337 std::size_t parentCount = allocator.Create(moves);
00338
00339
00340 for (std::size_t i = 0; i < moves.size(); ++i)
00341 {
00342 SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]);
00343 for (SgUctChildIterator it(*this, node); it; ++it)
00344 {
00345 const SgUctNode& oldChild = *it;
00346 if (oldChild.Move() == moves[i].m_move)
00347 {
00348 newChild->MergeResults(oldChild);
00349 newChild->SetKnowledgeCount(oldChild.KnowledgeCount());
00350 if (! deleteChildTrees)
00351 {
00352 newChild->SetPosCount(oldChild.PosCount());
00353 parentCount += oldChild.MoveCount();
00354 if (oldChild.HasChildren())
00355 {
00356 newChild->SetFirstChild(oldChild.FirstChild());
00357 newChild->SetNuChildren(oldChild.NuChildren());
00358 }
00359 }
00360 break;
00361 }
00362 }
00363 }
00364 nonConstNode.SetPosCount(parentCount);
00365
00366
00367
00368
00369
00370 SgSynchronizeThreadMemory();
00371 if (nonConstNode.NuChildren() < (int)nuNewChildren)
00372 {
00373 nonConstNode.SetFirstChild(newFirstChild);
00374 SgSynchronizeThreadMemory();
00375 nonConstNode.SetNuChildren(nuNewChildren);
00376 }
00377 else
00378 {
00379 nonConstNode.SetNuChildren(nuNewChildren);
00380 SgSynchronizeThreadMemory();
00381 nonConstNode.SetFirstChild(newFirstChild);
00382 }
00383 }
00384
00385 std::size_t SgUctTree::NuNodes() const
00386 {
00387 size_t nuNodes = 1;
00388 for (size_t i = 0; i < NuAllocators(); ++i)
00389 nuNodes += Allocator(i).NuNodes();
00390 return nuNodes;
00391 }
00392
00393 void SgUctTree::RemoveVirtualLoss(const std::vector<const SgUctNode*>& nodes)
00394 {
00395 for (size_t i = 0; i < nodes.size(); ++i)
00396 {
00397 const SgUctNode* father = (i > 0 ? nodes[i-1] : 0);
00398 RemoveGameResult(*nodes[i], father, 1.0);
00399 RemoveRaveValue(*nodes[i], 0.0, 1.0);
00400 }
00401 }
00402
00403 void SgUctTree::SetMaxNodes(std::size_t maxNodes)
00404 {
00405 Clear();
00406 size_t nuAllocators = NuAllocators();
00407 if (nuAllocators == 0)
00408 {
00409 SgDebug() << "SgUctTree::SetMaxNodes: no allocators registered\n";
00410 SG_ASSERT(false);
00411 return;
00412 }
00413 m_maxNodes = maxNodes;
00414 size_t maxNodesPerAlloc = maxNodes / nuAllocators;
00415 for (size_t i = 0; i < NuAllocators(); ++i)
00416 Allocator(i).SetMaxNodes(maxNodesPerAlloc);
00417 }
00418
00419 void SgUctTree::Swap(SgUctTree& tree)
00420 {
00421 SG_ASSERT(MaxNodes() == tree.MaxNodes());
00422 SG_ASSERT(NuAllocators() == tree.NuAllocators());
00423 swap(m_root, tree.m_root);
00424 for (size_t i = 0; i < NuAllocators(); ++i)
00425 Allocator(i).Swap(tree.Allocator(i));
00426 }
00427
00428 void SgUctTree::ThrowConsistencyError(const string& message) const
00429 {
00430 DumpDebugInfo(SgDebug());
00431 throw SgException("SgUctTree::ThrowConsistencyError: " + message);
00432 }
00433
00434
00435
00436 SgUctTreeIterator::SgUctTreeIterator(const SgUctTree& tree)
00437 : m_tree(tree),
00438 m_current(&tree.Root())
00439 {
00440 }
00441
00442 const SgUctNode& SgUctTreeIterator::operator*() const
00443 {
00444 return *m_current;
00445 }
00446
00447 void SgUctTreeIterator::operator++()
00448 {
00449 if (m_current->HasChildren())
00450 {
00451 SgUctChildIterator* it = new SgUctChildIterator(m_tree, *m_current);
00452 m_stack.push(shared_ptr<SgUctChildIterator>(it));
00453 m_current = &(**it);
00454 return;
00455 }
00456 while (! m_stack.empty())
00457 {
00458 SgUctChildIterator& it = *m_stack.top();
00459 SG_ASSERT(it);
00460 ++it;
00461 if (it)
00462 {
00463 m_current = &(*it);
00464 return;
00465 }
00466 else
00467 {
00468 m_stack.pop();
00469 m_current = 0;
00470 }
00471 }
00472 m_current = 0;
00473 }
00474
00475 SgUctTreeIterator::operator bool() const
00476 {
00477 return (m_current != 0);
00478 }
00479
00480