00001
00002
00003
00004
00005
00006 #ifndef GOUCT_GLOBALSEARCH_H
00007 #define GOUCT_GLOBALSEARCH_H
00008
00009 #include <boost/scoped_ptr.hpp>
00010 #include "GoBoard.h"
00011 #include "GoBoardUtil.h"
00012 #include "GoEyeUtil.h"
00013 #include "GoRegionBoard.h"
00014 #include "GoSafetySolver.h"
00015 #include "GoUctDefaultPriorKnowledge.h"
00016 #include "GoUctSearch.h"
00017 #include "GoUctUtil.h"
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 const bool GOUCT_USE_SAFETY_SOLVER = false;
00028
00029
00030
00031
00032 struct GoUctGlobalSearchStateParam
00033 {
00034
00035
00036
00037
00038 bool m_mercyRule;
00039
00040
00041 bool m_territoryStatistics;
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054 float m_lengthModification;
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068 float m_scoreModification;
00069
00070 GoUctGlobalSearchStateParam();
00071 };
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104 template<class POLICY>
00105 class GoUctGlobalSearchState
00106 : public GoUctState
00107 {
00108 public:
00109 const SgBWSet& m_safe;
00110
00111 const SgPointArray<bool>& m_allSafe;
00112
00113
00114
00115
00116 SgPointArray<SgUctStatistics> m_territoryStatistics;
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131 GoUctGlobalSearchState(std::size_t threadId, const GoBoard& bd,
00132 POLICY* policy,
00133 const GoUctGlobalSearchStateParam& param,
00134 const GoUctPlayoutPolicyParam& policyParam,
00135 const SgBWSet& safe,
00136 const SgPointArray<bool>& allSafe);
00137
00138 ~GoUctGlobalSearchState();
00139
00140 float Evaluate();
00141
00142 bool GenerateAllMoves(std::size_t count, std::vector<SgMoveInfo>& moves,
00143 SgProvenNodeType& provenType);
00144
00145 SgMove GeneratePlayoutMove(bool& skipRaveUpdate);
00146
00147 void ExecutePlayout(SgMove move);
00148
00149 void GameStart();
00150
00151 void EndPlayout();
00152
00153 void StartPlayout();
00154
00155 void StartPlayouts();
00156
00157 void StartSearch();
00158
00159 POLICY* Policy();
00160
00161
00162
00163
00164 void SetPolicy(POLICY* policy);
00165
00166 void ClearTerritoryStatistics();
00167
00168 private:
00169 const GoUctGlobalSearchStateParam& m_param;
00170
00171 const GoUctPlayoutPolicyParam& m_policyParam;
00172
00173
00174 bool m_mercyRuleTriggered;
00175
00176
00177 int m_passMovesPlayoutPhase;
00178
00179
00180 int m_mercyRuleThreshold;
00181
00182
00183
00184
00185 int m_stoneDiff;
00186
00187
00188 int m_initialMoveNumber;
00189
00190
00191 GoPointList m_area;
00192
00193
00194 float m_mercyRuleResult;
00195
00196
00197
00198
00199 float m_invMaxScore;
00200
00201 SgRandom m_random;
00202
00203 GoUctDefaultPriorKnowledge m_priorKnowledge;
00204
00205 boost::scoped_ptr<POLICY> m_policy;
00206
00207
00208 GoUctGlobalSearchState(const GoUctGlobalSearchState& search);
00209
00210
00211 GoUctGlobalSearchState& operator=(const GoUctGlobalSearchState& search);
00212
00213 bool CheckMercyRule();
00214
00215 template<class BOARD>
00216 float EvaluateBoard(const BOARD& bd, float komi);
00217
00218
00219 void GenerateLegalMoves(std::vector<SgMoveInfo>& moves);
00220
00221 float GetKomi() const;
00222 };
00223
00224 template<class POLICY>
00225 GoUctGlobalSearchState<POLICY>::GoUctGlobalSearchState(std::size_t threadId,
00226 const GoBoard& bd, POLICY* policy,
00227 const GoUctGlobalSearchStateParam& param,
00228 const GoUctPlayoutPolicyParam& policyParam,
00229 const SgBWSet& safe, const SgPointArray<bool>& allSafe)
00230 : GoUctState(threadId, bd),
00231 m_safe(safe),
00232 m_allSafe(allSafe),
00233 m_param(param),
00234 m_policyParam(policyParam),
00235 m_priorKnowledge(Board(), m_policyParam),
00236 m_policy(policy)
00237 {
00238 ClearTerritoryStatistics();
00239 }
00240
00241 template<class POLICY>
00242 GoUctGlobalSearchState<POLICY>::~GoUctGlobalSearchState()
00243 {
00244 }
00245
00246
00247 template<class POLICY>
00248 bool GoUctGlobalSearchState<POLICY>::CheckMercyRule()
00249 {
00250 SG_ASSERT(m_param.m_mercyRule);
00251
00252 SG_ASSERT(IsInPlayout());
00253 if (m_stoneDiff >= m_mercyRuleThreshold)
00254 {
00255 m_mercyRuleTriggered = true;
00256 m_mercyRuleResult = (UctBoard().ToPlay() == SG_BLACK ? 1 : 0);
00257 }
00258 else if (m_stoneDiff <= -m_mercyRuleThreshold)
00259 {
00260 m_mercyRuleTriggered = true;
00261 m_mercyRuleResult = (UctBoard().ToPlay() == SG_WHITE ? 1 : 0);
00262 }
00263 else
00264 SG_ASSERT(! m_mercyRuleTriggered);
00265 return m_mercyRuleTriggered;
00266 }
00267
00268 template<class POLICY>
00269 void GoUctGlobalSearchState<POLICY>::ClearTerritoryStatistics()
00270 {
00271 for (SgPointArray<SgUctStatistics>::NonConstIterator
00272 it(m_territoryStatistics); it; ++it)
00273 (*it).Clear();
00274 }
00275
00276 template<class POLICY>
00277 void GoUctGlobalSearchState<POLICY>::EndPlayout()
00278 {
00279 GoUctState::EndPlayout();
00280 m_policy->EndPlayout();
00281 }
00282
00283 template<class POLICY>
00284 float GoUctGlobalSearchState<POLICY>::Evaluate()
00285 {
00286 float komi = GetKomi();
00287 if (IsInPlayout())
00288 return EvaluateBoard(UctBoard(), komi);
00289 else
00290 return EvaluateBoard(Board(), komi);
00291 }
00292
00293 template<class POLICY>
00294 template<class BOARD>
00295 float GoUctGlobalSearchState<POLICY>::EvaluateBoard(const BOARD& bd,
00296 float komi)
00297 {
00298 float score;
00299 SgPointArray<SgEmptyBlackWhite> scoreBoard;
00300 SgPointArray<SgEmptyBlackWhite>* scoreBoardPtr;
00301 if (m_param.m_territoryStatistics)
00302 scoreBoardPtr = &scoreBoard;
00303 else
00304 scoreBoardPtr = 0;
00305 if (m_passMovesPlayoutPhase < 2)
00306
00307 score = GoBoardUtil::TrompTaylorScore(bd, komi, scoreBoardPtr);
00308 else
00309 {
00310 if (m_param.m_mercyRule && m_mercyRuleTriggered)
00311 return m_mercyRuleResult;
00312 score = GoBoardUtil::ScoreSimpleEndPosition(bd, komi, m_safe,
00313 false, scoreBoardPtr);
00314 }
00315 if (m_param.m_territoryStatistics)
00316 for (typename BOARD::Iterator it(bd); it; ++it)
00317 switch (scoreBoard[*it])
00318 {
00319 case SG_BLACK:
00320 m_territoryStatistics[*it].Add(1);
00321 break;
00322 case SG_WHITE:
00323 m_territoryStatistics[*it].Add(0);
00324 break;
00325 case SG_EMPTY:
00326 m_territoryStatistics[*it].Add(0.5);
00327 break;
00328 }
00329 if (bd.ToPlay() != SG_BLACK)
00330 score *= -1;
00331 float lengthMod = min(GameLength() * m_param.m_lengthModification, 0.5f);
00332 if (score > std::numeric_limits<float>::epsilon())
00333 return
00334 (1 - m_param.m_scoreModification)
00335 + m_param.m_scoreModification * score * m_invMaxScore
00336 - lengthMod;
00337 else if (score < -std::numeric_limits<float>::epsilon())
00338 return
00339 m_param.m_scoreModification
00340 + m_param.m_scoreModification * score * m_invMaxScore
00341 + lengthMod;
00342 else
00343
00344 return 0;
00345 }
00346
00347 template<class POLICY>
00348 void GoUctGlobalSearchState<POLICY>::ExecutePlayout(SgMove move)
00349 {
00350 GoUctState::ExecutePlayout(move);
00351 const GoUctBoard& bd = UctBoard();
00352 if (bd.ToPlay() == SG_BLACK)
00353 m_stoneDiff -= bd.NuCapturedStones();
00354 else
00355 m_stoneDiff += bd.NuCapturedStones();
00356 m_policy->OnPlay();
00357 }
00358
00359 template<class POLICY>
00360 void GoUctGlobalSearchState<POLICY>::GameStart()
00361 {
00362 GoUctState::GameStart();
00363 }
00364
00365 template<class POLICY>
00366 void GoUctGlobalSearchState<POLICY>::GenerateLegalMoves(
00367 std::vector<SgMoveInfo>& moves)
00368 {
00369
00370 const GoBoard& bd = Board();
00371 SG_ASSERT(! bd.Rules().AllowSuicide());
00372
00373 if (GoBoardUtil::TwoPasses(bd))
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383 if (bd.Rules().CaptureDead()
00384 || bd.MoveNumber() - m_initialMoveNumber >= 2)
00385 return;
00386
00387 SgBlackWhite toPlay = bd.ToPlay();
00388 for (GoBoard::Iterator it(bd); it; ++it)
00389 {
00390 SgPoint p = *it;
00391 if (bd.IsEmpty(p)
00392 && ! GoEyeUtil::IsSimpleEye(bd, p, toPlay)
00393 && ! m_allSafe[p]
00394 && bd.IsLegal(p, toPlay))
00395 moves.push_back(SgMoveInfo(p));
00396 }
00397
00398
00399
00400
00401
00402
00403 if (moves.size() > 1)
00404 std::swap(moves[0], moves[m_random.Int(moves.size())]);
00405 moves.push_back(SgMoveInfo(SG_PASS));
00406 }
00407
00408 template<class POLICY>
00409 bool GoUctGlobalSearchState<POLICY>::GenerateAllMoves(std::size_t count,
00410 std::vector<SgMoveInfo>& moves,
00411 SgProvenNodeType& provenType)
00412 {
00413 provenType = SG_NOT_PROVEN;
00414 moves.clear();
00415 GenerateLegalMoves(moves);
00416 if (! moves.empty())
00417 {
00418 if (count == 0)
00419 m_priorKnowledge.ProcessPosition(moves);
00420 }
00421 return false;
00422 }
00423
00424 template<class POLICY>
00425 SgMove GoUctGlobalSearchState<POLICY>::GeneratePlayoutMove(
00426 bool& skipRaveUpdate)
00427 {
00428 SG_ASSERT(IsInPlayout());
00429 if (m_param.m_mercyRule && CheckMercyRule())
00430 return SG_NULLMOVE;
00431 SgPoint move = m_policy->GenerateMove();
00432 SG_ASSERT(move != SG_NULLMOVE);
00433 #ifndef NDEBUG
00434
00435
00436 if (move == SG_PASS)
00437 {
00438 const GoUctBoard& bd = UctBoard();
00439 SgBalancer balancer(100);
00440 for (GoUctBoard::Iterator it(bd); it; ++it)
00441 SG_ASSERT( bd.Occupied(*it)
00442 || m_safe.OneContains(*it)
00443 || GoBoardUtil::SelfAtari(bd, *it)
00444 || ! GoUctUtil::GeneratePoint(bd, balancer,
00445 *it, bd.ToPlay())
00446 );
00447 }
00448 else
00449 SG_ASSERT(! m_safe.OneContains(move));
00450 #endif
00451
00452
00453
00454
00455 if (move == SG_PASS)
00456 {
00457 skipRaveUpdate = true;
00458 if (m_passMovesPlayoutPhase < 2)
00459 ++m_passMovesPlayoutPhase;
00460 else
00461 return SG_NULLMOVE;
00462 }
00463 else
00464 m_passMovesPlayoutPhase = 0;
00465 return move;
00466 }
00467
00468
00469 template<class POLICY>
00470 float GoUctGlobalSearchState<POLICY>::GetKomi() const
00471 {
00472 const GoRules& rules = Board().Rules();
00473 float komi = rules.Komi().ToFloat();
00474 if (rules.ExtraHandicapKomi())
00475 komi += rules.Handicap();
00476 return komi;
00477 }
00478
00479 template<class POLICY>
00480 inline POLICY* GoUctGlobalSearchState<POLICY>::Policy()
00481 {
00482 return m_policy.get();
00483 }
00484
00485 template<class POLICY>
00486 void GoUctGlobalSearchState<POLICY>::SetPolicy(POLICY* policy)
00487 {
00488 m_policy.reset(policy);
00489 }
00490
00491 template<class POLICY>
00492 void GoUctGlobalSearchState<POLICY>::StartPlayout()
00493 {
00494 GoUctState::StartPlayout();
00495 m_passMovesPlayoutPhase = 0;
00496 m_mercyRuleTriggered = false;
00497 const GoBoard& bd = Board();
00498 m_stoneDiff = bd.All(SG_BLACK).Size() - bd.All(SG_WHITE).Size();
00499 m_policy->StartPlayout();
00500 }
00501
00502 template<class POLICY>
00503 void GoUctGlobalSearchState<POLICY>::StartPlayouts()
00504 {
00505 GoUctState::StartPlayouts();
00506 }
00507
00508 template<class POLICY>
00509 void GoUctGlobalSearchState<POLICY>::StartSearch()
00510 {
00511 GoUctState::StartSearch();
00512 const GoBoard& bd = Board();
00513 int size = bd.Size();
00514 float maxScore = size * size + GetKomi();
00515 m_invMaxScore = 1 / maxScore;
00516 m_initialMoveNumber = bd.MoveNumber();
00517 m_mercyRuleThreshold = static_cast<int>(0.3 * size * size);
00518 ClearTerritoryStatistics();
00519 }
00520
00521
00522
00523
00524
00525
00526
00527 template<class POLICY, class FACTORY>
00528 class GoUctGlobalSearchStateFactory
00529 : public SgUctThreadStateFactory
00530 {
00531 public:
00532
00533
00534
00535
00536
00537
00538
00539
00540 GoUctGlobalSearchStateFactory(GoBoard& bd,
00541 FACTORY& playoutPolicyFactory,
00542 const GoUctPlayoutPolicyParam& policyParam,
00543 const SgBWSet& safe,
00544 const SgPointArray<bool>& allSafe);
00545
00546 SgUctThreadState* Create(std::size_t threadId, const SgUctSearch& search);
00547
00548 private:
00549 GoBoard& m_bd;
00550
00551 FACTORY& m_playoutPolicyFactory;
00552
00553 const GoUctPlayoutPolicyParam& m_policyParam;
00554
00555 const SgBWSet& m_safe;
00556
00557 const SgPointArray<bool>& m_allSafe;
00558 };
00559
00560 template<class POLICY, class FACTORY>
00561 GoUctGlobalSearchStateFactory<POLICY,FACTORY>
00562 ::GoUctGlobalSearchStateFactory(GoBoard& bd,
00563 FACTORY& playoutPolicyFactory,
00564 const GoUctPlayoutPolicyParam& policyParam,
00565 const SgBWSet& safe,
00566 const SgPointArray<bool>& allSafe)
00567 : m_bd(bd),
00568 m_playoutPolicyFactory(playoutPolicyFactory),
00569 m_policyParam(policyParam),
00570 m_safe(safe),
00571 m_allSafe(allSafe)
00572 {
00573 }
00574
00575
00576
00577
00578
00579
00580
00581 template<class POLICY, class FACTORY>
00582 class GoUctGlobalSearch
00583 : public GoUctSearch
00584 {
00585 public:
00586 GoUctGlobalSearchStateParam m_param;
00587
00588
00589
00590
00591
00592
00593
00594
00595 GoUctGlobalSearch(GoBoard& bd,
00596 FACTORY* playoutPolicyFactory,
00597 const GoUctPlayoutPolicyParam& policyParam);
00598
00599
00600
00601
00602 float UnknownEval() const;
00603
00604
00605
00606
00607
00608
00609
00610 void OnStartSearch();
00611
00612 void OnSearchIteration(std::size_t gameNumber, int threadId,
00613 const SgUctGameInfo& info);
00614
00615
00616
00617
00618
00619 void SetDefaultParameters(int boardSize);
00620
00621
00622
00623
00624
00625
00626
00627
00628 bool GlobalSearchLiveGfx() const;
00629
00630
00631 void SetGlobalSearchLiveGfx(bool enable);
00632
00633 private:
00634 SgBWSet m_safe;
00635
00636 SgPointArray<bool> m_allSafe;
00637
00638 boost::scoped_ptr<FACTORY> m_playoutPolicyFactory;
00639
00640 GoRegionBoard m_regions;
00641
00642
00643 bool m_globalSearchLiveGfx;
00644 };
00645
00646 template<class POLICY, class FACTORY>
00647 GoUctGlobalSearch<POLICY,FACTORY>::GoUctGlobalSearch(GoBoard& bd,
00648 FACTORY* playoutFactory,
00649 const GoUctPlayoutPolicyParam& policyParam)
00650 : GoUctSearch(bd, 0),
00651 m_playoutPolicyFactory(playoutFactory),
00652 m_regions(bd),
00653 m_globalSearchLiveGfx(GOUCT_LIVEGFX_NONE)
00654 {
00655 SgUctThreadStateFactory* stateFactory =
00656 new GoUctGlobalSearchStateFactory<POLICY,FACTORY>(bd,
00657 *playoutFactory,
00658 policyParam,
00659 m_safe, m_allSafe);
00660 SetThreadStateFactory(stateFactory);
00661 SetDefaultParameters(bd.Size());
00662 }
00663
00664 template<class POLICY, class FACTORY>
00665 inline bool GoUctGlobalSearch<POLICY,FACTORY>::GlobalSearchLiveGfx() const
00666 {
00667 return m_globalSearchLiveGfx;
00668 }
00669
00670 template<class POLICY, class FACTORY>
00671 void GoUctGlobalSearch<POLICY,FACTORY>::OnSearchIteration(
00672 std::size_t gameNumber,
00673 int threadId,
00674 const SgUctGameInfo& info)
00675 {
00676 GoUctSearch::OnSearchIteration(gameNumber, threadId, info);
00677 if (m_globalSearchLiveGfx && threadId == 0
00678 && gameNumber % LiveGfxInterval() == 0)
00679 {
00680 const GoUctGlobalSearchState<POLICY>& state =
00681 dynamic_cast<GoUctGlobalSearchState<POLICY>&>(ThreadState(0));
00682 SgDebug() << "gogui-gfx:\n";
00683 GoUctUtil::GfxBestMove(*this, ToPlay(), SgDebug());
00684 GoUctUtil::GfxTerritoryStatistics(state.m_territoryStatistics,
00685 Board(), SgDebug());
00686 GoUctUtil::GfxStatus(*this, SgDebug());
00687 SgDebug() << '\n';
00688 }
00689 }
00690
00691 template<class POLICY, class FACTORY>
00692 void GoUctGlobalSearch<POLICY,FACTORY>::OnStartSearch()
00693 {
00694 GoUctSearch::OnStartSearch();
00695 m_safe.Clear();
00696 m_allSafe.Fill(false);
00697 if (GOUCT_USE_SAFETY_SOLVER)
00698 {
00699 GoBoard& bd = Board();
00700 GoSafetySolver solver(bd, &m_regions);
00701 solver.FindSafePoints(&m_safe);
00702 for (GoBoard::Iterator it(bd); it; ++it)
00703 m_allSafe[*it] = m_safe.OneContains(*it);
00704 }
00705 if (m_globalSearchLiveGfx && ! m_param.m_territoryStatistics)
00706 SgWarning() <<
00707 "GoUctGlobalSearch: "
00708 "live graphics need territory statistics enabled\n";
00709 }
00710
00711 template<class POLICY, class FACTORY>
00712 void GoUctGlobalSearch<POLICY,FACTORY>::SetDefaultParameters(int boardSize)
00713 {
00714 SetFirstPlayUrgency(1);
00715 SetMoveSelect(SG_UCTMOVESELECT_COUNT);
00716 SetRave(true);
00717 SetExpandThreshold(1);
00718 SetVirtualLoss(true);
00719 SetBiasTermConstant(0.0);
00720 if (boardSize < 15)
00721 {
00722
00723
00724 SetRaveWeightInitial(1.0);
00725 SetRaveWeightFinal(5000);
00726 m_param.m_lengthModification = 0;
00727 }
00728 else
00729 {
00730
00731
00732 SetRaveWeightInitial(0.9);
00733 SetRaveWeightFinal(5000);
00734 m_param.m_lengthModification = 0.00028;
00735 }
00736 }
00737
00738 template<class POLICY, class FACTORY>
00739 inline void GoUctGlobalSearch<POLICY,FACTORY>::SetGlobalSearchLiveGfx(
00740 bool enable)
00741 {
00742 m_globalSearchLiveGfx = enable;
00743 }
00744
00745 template<class POLICY, class FACTORY>
00746 float GoUctGlobalSearch<POLICY,FACTORY>::UnknownEval() const
00747 {
00748
00749
00750 return 0.5;
00751 }
00752
00753
00754
00755 template<class POLICY, class FACTORY>
00756 SgUctThreadState* GoUctGlobalSearchStateFactory<POLICY,FACTORY>::Create(
00757 std::size_t threadId, const SgUctSearch& search)
00758 {
00759 const GoUctGlobalSearch<POLICY,FACTORY>& globalSearch =
00760 dynamic_cast<const GoUctGlobalSearch<POLICY,FACTORY>&>(search);
00761 GoUctGlobalSearchState<POLICY>* state =
00762 new GoUctGlobalSearchState<POLICY>(threadId, globalSearch.Board(), 0,
00763 globalSearch.m_param,
00764 m_policyParam,
00765 m_safe, m_allSafe);
00766 POLICY* policy = m_playoutPolicyFactory.Create(state->UctBoard());
00767 state->SetPolicy(policy);
00768 return state;
00769 }
00770
00771
00772
00773 #endif // GOUCT_GLOBALSEARCH_H