00001 //---------------------------------------------------------------------------- 00002 /** @file GoUctSearch.cpp 00003 */ 00004 //---------------------------------------------------------------------------- 00005 00006 #include "SgSystem.h" 00007 #include "GoUctSearch.h" 00008 00009 #include <fstream> 00010 #include <iostream> 00011 #include "GoBoardUtil.h" 00012 #include "GoNodeUtil.h" 00013 #include "GoUctUtil.h" 00014 #include "SgDebug.h" 00015 #include "SgGameWriter.h" 00016 #include "SgNode.h" 00017 #include "SgUctTreeUtil.h" 00018 00019 using namespace std; 00020 00021 //---------------------------------------------------------------------------- 00022 00023 namespace { 00024 00025 //---------------------------------------------------------------------------- 00026 00027 const int MOVERANGE = SG_PASS + 1; 00028 00029 SgNode* AppendChild(SgNode* node, const string& comment) 00030 { 00031 SgNode* child = node->NewRightMostSon(); 00032 child->AddComment(comment); 00033 return child; 00034 } 00035 00036 SgNode* AppendChild(SgNode* node, SgBlackWhite color, SgPoint move) 00037 { 00038 SgNode* child = node->NewRightMostSon(); 00039 SgPropID propId = 00040 (color == SG_BLACK ? SG_PROP_MOVE_BLACK : SG_PROP_MOVE_WHITE); 00041 child->Add(new SgPropMove(propId, move)); 00042 return child; 00043 } 00044 00045 /** Append game to saved simulations (used if m_keepGames is true) */ 00046 void AppendGame(SgNode* node, size_t gameNumber, int threadId, 00047 SgBlackWhite toPlay, const SgUctGameInfo& info) 00048 { 00049 SG_ASSERT(node != 0); 00050 { 00051 ostringstream comment; 00052 comment << "Thread " << threadId << '\n' 00053 << "Game " << gameNumber << '\n'; 00054 node = AppendChild(node, comment.str()); 00055 } 00056 size_t nuMovesInTree = info.m_inTreeSequence.size(); 00057 for (size_t i = 0; i < nuMovesInTree; ++i) 00058 { 00059 node = AppendChild(node, toPlay, info.m_inTreeSequence[i]); 00060 toPlay = SgOppBW(toPlay); 00061 } 00062 SgNode* lastInTreeNode = node; 00063 SgBlackWhite lastInTreeToPlay = toPlay; 00064 for (size_t i = 0; i < info.m_eval.size(); ++i) 00065 { 00066 node = lastInTreeNode; 00067 toPlay = lastInTreeToPlay; 00068 ostringstream comment; 00069 comment << "Playout " << i << '\n' 00070 << "Eval " << info.m_eval[i] << '\n' 00071 << "Aborted " << info.m_aborted[i] << '\n'; 00072 node = AppendChild(node, comment.str()); 00073 for (size_t j = nuMovesInTree; j < info.m_sequence[i].size(); ++j) 00074 { 00075 node = AppendChild(node, toPlay, info.m_sequence[i][j]); 00076 toPlay = SgOppBW(toPlay); 00077 } 00078 } 00079 } 00080 00081 //---------------------------------------------------------------------------- 00082 00083 } // namespace 00084 00085 //---------------------------------------------------------------------------- 00086 00087 GoUctState::AssertionHandler::AssertionHandler(const GoUctState& state) 00088 : m_state(state) 00089 { 00090 } 00091 00092 void GoUctState::AssertionHandler::Run() 00093 { 00094 m_state.Dump(SgDebug()); 00095 } 00096 00097 //---------------------------------------------------------------------------- 00098 00099 GoUctState::GoUctState(std::size_t threadId, const GoBoard& bd) 00100 : SgUctThreadState(threadId, MOVERANGE), 00101 m_assertionHandler(*this), 00102 m_uctBd(bd), 00103 m_synchronizer(bd) 00104 { 00105 m_synchronizer.SetSubscriber(m_bd); 00106 m_isInPlayout = false; 00107 } 00108 00109 void GoUctState::Dump(ostream& out) const 00110 { 00111 out << "GoUctState[" << m_threadId << "] "; 00112 if (m_isInPlayout) 00113 out << "playout board:\n" << m_uctBd; 00114 else 00115 out << "board:\n" << m_bd; 00116 } 00117 00118 void GoUctState::Execute(SgMove move) 00119 { 00120 SG_ASSERT(! m_isInPlayout); 00121 SG_ASSERT(move == SG_PASS || ! m_bd.Occupied(move)); 00122 // Temporarily switch ko rule to SIMPLEKO to avoid slow full board 00123 // repetition test in GoBoard::Play() 00124 GoRestoreKoRule restoreKoRule(m_bd); 00125 m_bd.Rules().SetKoRule(GoRules::SIMPLEKO); 00126 m_bd.Play(move); 00127 SG_ASSERT(! m_bd.LastMoveInfo(GO_MOVEFLAG_ILLEGAL)); 00128 ++m_gameLength; 00129 } 00130 00131 void GoUctState::ExecutePlayout(SgMove move) 00132 { 00133 SG_ASSERT(m_isInPlayout); 00134 SG_ASSERT(move == SG_PASS || ! m_uctBd.Occupied(move)); 00135 m_uctBd.Play(move); 00136 ++m_gameLength; 00137 } 00138 00139 void GoUctState::GameStart() 00140 { 00141 m_isInPlayout = false; 00142 m_gameLength = 0; 00143 } 00144 00145 void GoUctState::StartPlayout() 00146 { 00147 m_uctBd.Init(m_bd); 00148 } 00149 00150 void GoUctState::StartPlayouts() 00151 { 00152 m_isInPlayout = true; 00153 } 00154 00155 void GoUctState::StartSearch() 00156 { 00157 m_synchronizer.UpdateSubscriber(); 00158 } 00159 00160 void GoUctState::TakeBackInTree(std::size_t nuMoves) 00161 { 00162 for (size_t i = 0; i < nuMoves; ++i) 00163 m_bd.Undo(); 00164 } 00165 00166 void GoUctState::TakeBackPlayout(std::size_t nuMoves) 00167 { 00168 m_gameLength -= nuMoves; 00169 } 00170 00171 //---------------------------------------------------------------------------- 00172 00173 GoUctSearch::GoUctSearch(GoBoard& bd, SgUctThreadStateFactory* factory) 00174 : SgUctSearch(factory, MOVERANGE), 00175 m_keepGames(false), 00176 m_liveGfxInterval(5000), 00177 m_toPlay(SG_BLACK), 00178 m_bd(bd), 00179 m_root(0), 00180 m_liveGfx(GOUCT_LIVEGFX_NONE) 00181 { 00182 SetRaveCheckSame(true); 00183 } 00184 00185 GoUctSearch::~GoUctSearch() 00186 { 00187 if (m_root != 0) 00188 m_root->DeleteTree(); 00189 m_root = 0; 00190 } 00191 00192 std::string GoUctSearch::MoveString(SgMove move) const 00193 { 00194 return SgPointUtil::PointToString(move); 00195 } 00196 00197 void GoUctSearch::OnSearchIteration(std::size_t gameNumber, int threadId, 00198 const SgUctGameInfo& info) 00199 { 00200 SgUctSearch::OnSearchIteration(gameNumber, threadId, info); 00201 00202 if (m_liveGfx != GOUCT_LIVEGFX_NONE && threadId == 0 00203 && gameNumber % m_liveGfxInterval == 0) 00204 { 00205 SgDebug() << "gogui-gfx:\n"; 00206 switch (m_liveGfx) 00207 { 00208 case GOUCT_LIVEGFX_COUNTS: 00209 GoUctUtil::GfxBestMove(*this, m_toPlay, SgDebug()); 00210 GoUctUtil::GfxMoveValues(*this, m_toPlay, SgDebug()); 00211 GoUctUtil::GfxCounts(Tree(), SgDebug()); 00212 GoUctUtil::GfxStatus(*this, SgDebug()); 00213 break; 00214 case GOUCT_LIVEGFX_SEQUENCE: 00215 GoUctUtil::GfxSequence(*this, m_toPlay, SgDebug()); 00216 GoUctUtil::GfxStatus(*this, SgDebug()); 00217 break; 00218 case GOUCT_LIVEGFX_NONE: 00219 SG_ASSERT(false); // Already checked above 00220 break; 00221 } 00222 SgDebug() << '\n'; 00223 } 00224 if (! LockFree() && m_root != 0) 00225 AppendGame(m_root, gameNumber, threadId, m_toPlay, info); 00226 } 00227 00228 void GoUctSearch::OnStartSearch() 00229 { 00230 SgUctSearch::OnStartSearch(); 00231 00232 if (m_root != 0) 00233 { 00234 m_root->DeleteTree(); 00235 m_root = 0; 00236 } 00237 if (m_keepGames) 00238 { 00239 m_root = GoNodeUtil::CreateRoot(m_bd); 00240 if (LockFree()) 00241 SgWarning() << 00242 "GoUctSearch: keep games will be ignored" 00243 " in lock free search\n"; 00244 } 00245 m_toPlay = m_bd.ToPlay(); // Not needed if SetToPlay() was called 00246 for (SgBWIterator it; it; ++it) 00247 m_stones[*it] = m_bd.All(*it); 00248 int size = m_bd.Size(); 00249 // Limit to avoid very long games if m_simpleKo 00250 int maxGameLength = min(3 * size * size, 00251 GO_MAX_NUM_MOVES - m_bd.MoveNumber()); 00252 SetMaxGameLength(maxGameLength); 00253 m_boardHistory.SetFromBoard(m_bd); 00254 } 00255 00256 void GoUctSearch::SaveGames(const string& fileName) const 00257 { 00258 if (MpiSynchronizer()->IsRootProcess()) 00259 { 00260 if (m_root == 0) 00261 throw SgException("No games to save"); 00262 ofstream out(fileName.c_str()); 00263 SgGameWriter writer(out); 00264 writer.WriteGame(*m_root, true, 0, "", 1, 19); 00265 } 00266 } 00267 00268 void GoUctSearch::SaveTree(std::ostream& out, int maxDepth) const 00269 { 00270 GoUctUtil::SaveTree(Tree(), m_bd.Size(), m_stones, m_toPlay, out, 00271 maxDepth); 00272 } 00273 00274 SgBlackWhite GoUctSearch::ToPlay() const 00275 { 00276 return m_toPlay; 00277 } 00278 00279 //---------------------------------------------------------------------------- 00280 00281 SgPoint GoUctSearchUtil::TrompTaylorPassCheck(SgPoint move, 00282 const GoUctSearch& search) 00283 { 00284 const GoBoard& bd = search.Board(); 00285 bool isFirstPass = (bd.GetLastMove() != SG_PASS); 00286 bool isTrompTaylorRules = bd.Rules().CaptureDead(); 00287 if (move != SG_PASS || ! isTrompTaylorRules || ! isFirstPass) 00288 return move; 00289 float komi = bd.Rules().Komi().ToFloat(); 00290 float trompTaylorScore = GoBoardUtil::TrompTaylorScore(bd, komi); 00291 if (search.ToPlay() != SG_BLACK) 00292 trompTaylorScore *= -1; 00293 const SgUctTree& tree = search.Tree(); 00294 const SgUctNode& root = tree.Root(); 00295 float value = root.Mean(); 00296 float trompTaylorWinValue = (trompTaylorScore > 0 ? 1 : 0); 00297 if (value < trompTaylorWinValue) 00298 return move; 00299 SgDebug() << "GoUctSearchUtil::TrompTaylorPassCheck: bad pass move value=" 00300 << value << " trompTaylorScore=" << trompTaylorScore << '\n'; 00301 vector<SgMove> excludeMoves; 00302 excludeMoves.push_back(SG_PASS); 00303 const SgUctNode* bestChild = search.FindBestChild(root, &excludeMoves); 00304 if (bestChild == 0) 00305 { 00306 SgDebug() << 00307 "GoUctSearchUtil::TrompTaylorPassCheck: " 00308 "(no second best move found)\n"; 00309 return move; 00310 } 00311 return bestChild->Move(); 00312 } 00313 00314 //----------------------------------------------------------------------------