00001
00002
00003
00004
00005
00006 #include "SgSystem.h"
00007 #include "GoUctCommands.h"
00008
00009 #include <fstream>
00010 #include <boost/format.hpp>
00011 #include "GoEyeUtil.h"
00012 #include "GoGtpCommandUtil.h"
00013 #include "GoBoardUtil.h"
00014 #include "GoSafetySolver.h"
00015 #include "GoUctDefaultPriorKnowledge.h"
00016 #include "GoUctDefaultRootFilter.h"
00017 #include "GoUctEstimatorStat.h"
00018 #include "GoUctGlobalSearch.h"
00019 #include "GoUctPatterns.h"
00020 #include "GoUctPlayer.h"
00021 #include "GoUctPlayoutPolicy.h"
00022 #include "GoUctUtil.h"
00023 #include "GoUtil.h"
00024 #include "SgException.h"
00025 #include "SgPointSetUtil.h"
00026 #include "SgRestorer.h"
00027 #include "SgUctTreeUtil.h"
00028 #include "SgWrite.h"
00029
00030 using namespace std;
00031 using boost::format;
00032 using GoGtpCommandUtil::BlackWhiteArg;
00033 using GoGtpCommandUtil::EmptyPointArg;
00034 using GoGtpCommandUtil::PointArg;
00035
00036 typedef GoUctPlayer<GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
00037 GoUctPlayoutPolicyFactory<GoUctBoard> >,
00038 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> > >
00039 GoUctPlayerType;
00040
00041
00042
00043
00044 namespace {
00045
00046 GoUctLiveGfx LiveGfxArg(const GtpCommand& cmd, size_t number)
00047 {
00048 string arg = cmd.ArgToLower(number);
00049 if (arg == "none")
00050 return GOUCT_LIVEGFX_NONE;
00051 if (arg == "counts")
00052 return GOUCT_LIVEGFX_COUNTS;
00053 if (arg == "sequence")
00054 return GOUCT_LIVEGFX_SEQUENCE;
00055 throw GtpFailure() << "unknown live-gfx argument \"" << arg << '"';
00056 }
00057
00058 string LiveGfxToString(GoUctLiveGfx mode)
00059 {
00060 switch (mode)
00061 {
00062 case GOUCT_LIVEGFX_NONE:
00063 return "none";
00064 case GOUCT_LIVEGFX_COUNTS:
00065 return "counts";
00066 case GOUCT_LIVEGFX_SEQUENCE:
00067 return "sequence";
00068 default:
00069 SG_ASSERT(false);
00070 return "?";
00071 }
00072 }
00073
00074 SgUctMoveSelect MoveSelectArg(const GtpCommand& cmd, size_t number)
00075 {
00076 string arg = cmd.ArgToLower(number);
00077 if (arg == "value")
00078 return SG_UCTMOVESELECT_VALUE;
00079 if (arg == "count")
00080 return SG_UCTMOVESELECT_COUNT;
00081 if (arg == "bound")
00082 return SG_UCTMOVESELECT_BOUND;
00083 if (arg == "estimate")
00084 return SG_UCTMOVESELECT_ESTIMATE;
00085 throw GtpFailure() << "unknown move select argument \"" << arg << '"';
00086 }
00087
00088 string MoveSelectToString(SgUctMoveSelect moveSelect)
00089 {
00090 switch (moveSelect)
00091 {
00092 case SG_UCTMOVESELECT_VALUE:
00093 return "value";
00094 case SG_UCTMOVESELECT_COUNT:
00095 return "count";
00096 case SG_UCTMOVESELECT_BOUND:
00097 return "bound";
00098 case SG_UCTMOVESELECT_ESTIMATE:
00099 return "estimate";
00100 default:
00101 SG_ASSERT(false);
00102 return "?";
00103 }
00104 }
00105
00106 GoUctGlobalSearchMode SearchModeArg(const GtpCommand& cmd, size_t number)
00107 {
00108 string arg = cmd.ArgToLower(number);
00109 if (arg == "playout_policy")
00110 return GOUCT_SEARCHMODE_PLAYOUTPOLICY;
00111 if (arg == "uct")
00112 return GOUCT_SEARCHMODE_UCT;
00113 if (arg == "one_ply")
00114 return GOUCT_SEARCHMODE_ONEPLY;
00115 throw GtpFailure() << "unknown search mode argument \"" << arg << '"';
00116 }
00117
00118 string SearchModeToString(GoUctGlobalSearchMode mode)
00119 {
00120 switch (mode)
00121 {
00122 case GOUCT_SEARCHMODE_PLAYOUTPOLICY:
00123 return "playout_policy";
00124 case GOUCT_SEARCHMODE_UCT:
00125 return "uct";
00126 case GOUCT_SEARCHMODE_ONEPLY:
00127 return "one_ply";
00128 default:
00129 SG_ASSERT(false);
00130 return "?";
00131 }
00132 }
00133
00134 string KnowledgeThresholdToString(const std::vector<std::size_t>& t)
00135 {
00136 if (t.empty())
00137 return "0";
00138 std::ostringstream os;
00139 os << '\"';
00140 for (std::size_t i = 0; i < t.size(); ++i)
00141 {
00142 if (i > 0)
00143 os << ' ';
00144 os << t[i];
00145 }
00146 os << '\"';
00147 return os.str();
00148 }
00149
00150 std::vector<std::size_t> KnowledgeThresholdFromString(const std::string& val)
00151 {
00152 std::vector<std::size_t> v;
00153 std::istringstream is(val);
00154 std::size_t t;
00155 while (is >> t)
00156 v.push_back(t);
00157 if (v.size() == 1 && v[0] == 0)
00158 v.clear();
00159 return v;
00160 }
00161
00162 }
00163
00164
00165
00166 GoUctCommands::GoUctCommands(GoBoard& bd, GoPlayer*& player)
00167 : m_bd(bd),
00168 m_player(player)
00169 {
00170 }
00171
00172 void GoUctCommands::AddGoGuiAnalyzeCommands(GtpCommand& cmd)
00173 {
00174 cmd <<
00175 "gfx/Uct Bounds/uct_bounds\n"
00176 "plist/Uct Default Policy/uct_default_policy\n"
00177 "gfx/Uct Gfx/uct_gfx\n"
00178 "none/Uct Max Memory/uct_max_memory %s\n"
00179 "plist/Uct Moves/uct_moves\n"
00180 "param/Uct Param GlobalSearch/uct_param_globalsearch\n"
00181 "param/Uct Param Policy/uct_param_policy\n"
00182 "param/Uct Param Player/uct_param_player\n"
00183 "param/Uct Param RootFilter/uct_param_rootfilter\n"
00184 "param/Uct Param Search/uct_param_search\n"
00185 "plist/Uct Patterns/uct_patterns\n"
00186 "pstring/Uct Policy Moves/uct_policy_moves\n"
00187 "gfx/Uct Prior Knowledge/uct_prior_knowledge\n"
00188 "sboard/Uct Rave Values/uct_rave_values\n"
00189 "plist/Uct Root Filter/uct_root_filter\n"
00190 "none/Uct SaveGames/uct_savegames %w\n"
00191 "none/Uct SaveTree/uct_savetree %w\n"
00192 "gfx/Uct Sequence/uct_sequence\n"
00193 "hstring/Uct Stat Player/uct_stat_player\n"
00194 "none/Uct Stat Player Clear/uct_stat_player_clear\n"
00195 "hstring/Uct Stat Policy/uct_stat_policy\n"
00196 "none/Uct Stat Policy Clear/uct_stat_policy_clear\n"
00197 "hstring/Uct Stat Search/uct_stat_search\n"
00198 "dboard/Uct Stat Territory/uct_stat_territory\n";
00199 }
00200
00201
00202
00203
00204
00205
00206
00207 void GoUctCommands::CmdBounds(GtpCommand& cmd)
00208 {
00209 cmd.CheckArgNone();
00210 const GoUctSearch& search = Search();
00211 const SgUctTree& tree = search.Tree();
00212 const SgUctNode& root = tree.Root();
00213 bool hasPass = false;
00214 float passBound = 0;
00215 cmd << "LABEL";
00216 for (SgUctChildIterator it(tree, root); it; ++it)
00217 {
00218 const SgUctNode& child = *it;
00219 SgPoint move = child.Move();
00220 float bound = search.GetBound(search.Rave(), root, child);
00221 if (move == SG_PASS)
00222 {
00223 hasPass = true;
00224 passBound = bound;
00225 }
00226 else
00227 cmd << ' ' << SgWritePoint(move) << ' ' << fixed
00228 << setprecision(2) << bound;
00229 }
00230 cmd << '\n';
00231 if (hasPass)
00232 cmd << "TEXT PASS=" << fixed << setprecision(2) << passBound << '\n';
00233 }
00234
00235
00236 void GoUctCommands::CmdDefaultPolicy(GtpCommand& cmd)
00237 {
00238 cmd.CheckArgNone();
00239 GoUctDefaultPriorKnowledge knowledge(m_bd, GoUctPlayoutPolicyParam());
00240 SgPointSet pattern;
00241 SgPointSet atari;
00242 GoPointList empty;
00243 knowledge.FindGlobalPatternAndAtariMoves(pattern, atari, empty);
00244 cmd << SgWritePointSet(atari, "", false) << '\n';
00245 }
00246
00247
00248
00249
00250
00251 void GoUctCommands::CmdEstimatorStat(GtpCommand& cmd)
00252 {
00253 cmd.CheckNuArg(4);
00254 size_t trueValueMaxGames = cmd.SizeTypeArg(0);
00255 size_t maxGames = cmd.SizeTypeArg(1);
00256 size_t stepSize = cmd.SizeTypeArg(2);
00257 string fileName = cmd.Arg(3);
00258 GoUctEstimatorStat::Compute(Search(), trueValueMaxGames, maxGames,
00259 stepSize, fileName);
00260 }
00261
00262
00263
00264
00265
00266 void GoUctCommands::CmdFinalScore(GtpCommand& cmd)
00267 {
00268 cmd.CheckArgNone();
00269 SgPointSet deadStones = DoFinalStatusSearch();
00270 float score;
00271 if (! GoBoardUtil::ScorePosition(m_bd, deadStones, score))
00272 throw GtpFailure("cannot score");
00273 cmd << GoUtil::ScoreToString(score);
00274 }
00275
00276
00277
00278
00279
00280
00281
00282
00283 void GoUctCommands::CmdFinalStatusList(GtpCommand& cmd)
00284 {
00285 cmd.CheckNuArg(1);
00286 string arg = cmd.Arg(0);
00287 if (arg == "seki")
00288 return;
00289 bool getDead;
00290 if (arg == "alive")
00291 getDead = false;
00292 else if (arg == "dead")
00293 getDead = true;
00294 else
00295 throw GtpFailure("invalid final status argument");
00296 SgPointSet deadPoints = DoFinalStatusSearch();
00297
00298
00299 for (GoBlockIterator it(m_bd); it; ++it)
00300 {
00301 if ((getDead && deadPoints.Contains(*it))
00302 || (! getDead && ! deadPoints.Contains(*it)))
00303 {
00304 for (GoBoard::StoneIterator it2(m_bd, *it); it2; ++it2)
00305 cmd << SgWritePoint(*it2) << ' ';
00306 cmd << '\n';
00307 }
00308 }
00309 }
00310
00311
00312
00313
00314
00315 void GoUctCommands::CmdGfx(GtpCommand& cmd)
00316 {
00317 cmd.CheckArgNone();
00318 const GoUctSearch& s = Search();
00319 SgBlackWhite toPlay = s.ToPlay();
00320 GoUctUtil::GfxBestMove(s, toPlay, cmd);
00321 GoUctUtil::GfxMoveValues(s, toPlay, cmd);
00322 GoUctUtil::GfxCounts(s.Tree(), cmd);
00323 GoUctUtil::GfxStatus(s, cmd);
00324 }
00325
00326
00327
00328
00329
00330
00331 void GoUctCommands::CmdMaxMemory(GtpCommand& cmd)
00332 {
00333 cmd.CheckNuArgLessEqual(1);
00334 if (cmd.NuArg() == 0)
00335 cmd << Search().MaxNodes() * 2 * sizeof(SgUctNode);
00336 else
00337 {
00338 std::size_t memory = cmd.SizeTypeArg(0, 2*sizeof(SgUctNode));
00339 Search().SetMaxNodes(memory / 2 / sizeof(SgUctNode));
00340 }
00341 }
00342
00343
00344
00345
00346
00347
00348 void GoUctCommands::CmdMoves(GtpCommand& cmd)
00349 {
00350 cmd.CheckArgNone();
00351 vector<SgMoveInfo> moves;
00352 Search().GenerateAllMoves(moves);
00353 for (std::size_t i = 0; i < moves.size(); ++i)
00354 cmd << SgWritePoint(moves[i].m_move) << ' ';
00355 cmd << '\n';
00356 }
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371 void GoUctCommands::CmdParamGlobalSearch(GtpCommand& cmd)
00372 {
00373 cmd.CheckNuArgLessEqual(2);
00374 GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
00375 GoUctPlayoutPolicyFactory<GoUctBoard> >&
00376 s = GlobalSearch();
00377 GoUctGlobalSearchStateParam& p = s.m_param;
00378 if (cmd.NuArg() == 0)
00379 {
00380
00381
00382 cmd << "[bool] live_gfx " << s.GlobalSearchLiveGfx() << '\n'
00383 << "[bool] mercy_rule " << p.m_mercyRule << '\n'
00384 << "[bool] territory_statistics " << p.m_territoryStatistics
00385 << '\n'
00386 << "[string] length_modification " << p.m_lengthModification
00387 << '\n'
00388 << "[string] score_modification " << p.m_scoreModification
00389 << '\n';
00390 }
00391 else if (cmd.NuArg() == 2)
00392 {
00393 string name = cmd.Arg(0);
00394 if (name == "live_gfx")
00395 s.SetGlobalSearchLiveGfx(cmd.BoolArg(1));
00396 else if (name == "mercy_rule")
00397 p.m_mercyRule = cmd.BoolArg(1);
00398 else if (name == "territory_statistics")
00399 p.m_territoryStatistics = cmd.BoolArg(1);
00400 else if (name == "length_modification")
00401 p.m_lengthModification = cmd.FloatArg(1);
00402 else if (name == "score_modification")
00403 p.m_scoreModification = cmd.FloatArg(1);
00404 else
00405 throw GtpFailure() << "unknown parameter: " << name;
00406 }
00407 else
00408 throw GtpFailure() << "need 0 or 2 arguments";
00409 }
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426 void GoUctCommands::CmdParamPlayer(GtpCommand& cmd)
00427 {
00428 cmd.CheckNuArgLessEqual(2);
00429 GoUctPlayerType& p = Player();
00430 if (cmd.NuArg() == 0)
00431 {
00432
00433
00434 cmd << "[bool] auto_param " << p.AutoParam() << '\n'
00435 << "[bool] early_pass " << p.EarlyPass() << '\n'
00436 << "[bool] forced_opening_moves " << p.ForcedOpeningMoves() << '\n'
00437 << "[bool] ignore_clock " << p.IgnoreClock() << '\n'
00438 << "[bool] ponder " << p.EnablePonder() << '\n'
00439 << "[bool] reuse_subtree " << p.ReuseSubtree() << '\n'
00440 << "[bool] use_root_filter " << p.UseRootFilter() << '\n'
00441 << "[string] max_games " << p.MaxGames() << '\n'
00442 << "[string] resign_min_games " << p.ResignMinGames() << '\n'
00443 << "[string] resign_threshold " << p.ResignThreshold() << '\n'
00444 << "[list/playout_policy/uct/one_ply] search_mode "
00445 << SearchModeToString(p.SearchMode()) << '\n';
00446 }
00447 else if (cmd.NuArg() >= 1 && cmd.NuArg() <= 2)
00448 {
00449 string name = cmd.Arg(0);
00450 if (name == "auto_param")
00451 p.SetAutoParam(cmd.BoolArg(1));
00452 else if (name == "early_pass")
00453 p.SetEarlyPass(cmd.BoolArg(1));
00454 else if (name == "forced_opening_moves")
00455 p.SetForcedOpeningMoves(cmd.BoolArg(1));
00456 else if (name == "ignore_clock")
00457 p.SetIgnoreClock(cmd.BoolArg(1));
00458 else if (name == "ponder")
00459 p.SetEnablePonder(cmd.BoolArg(1));
00460 else if (name == "reuse_subtree")
00461 p.SetReuseSubtree(cmd.BoolArg(1));
00462 else if (name == "use_root_filter")
00463 p.SetUseRootFilter(cmd.BoolArg(1));
00464 else if (name == "max_games")
00465 p.SetMaxGames(cmd.SizeTypeArg(1, 1));
00466 else if (name == "resign_min_games")
00467 p.SetResignMinGames(cmd.SizeTypeArg(1));
00468 else if (name == "resign_threshold")
00469 p.SetResignThreshold(cmd.FloatArg(1));
00470 else if (name == "search_mode")
00471 p.SetSearchMode(SearchModeArg(cmd, 1));
00472 else
00473 throw GtpFailure() << "unknown parameter: " << name;
00474 }
00475 else
00476 throw GtpFailure() << "need 0 or 2 arguments";
00477 }
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490 void GoUctCommands::CmdParamPolicy(GtpCommand& cmd)
00491 {
00492 cmd.CheckNuArgLessEqual(2);
00493 GoUctPlayoutPolicyParam& p = Player().m_playoutPolicyParam;
00494 if (cmd.NuArg() == 0)
00495 {
00496
00497
00498 cmd << "[bool] nakade_heuristic " << p.m_useNakadeHeuristic << '\n'
00499 << "[bool] statistics_enabled " << p.m_statisticsEnabled << '\n'
00500 << "fillboard_tries " << p.m_fillboardTries << '\n';
00501 }
00502 else if (cmd.NuArg() == 2)
00503 {
00504 string name = cmd.Arg(0);
00505 if (name == "nakade_heuristic")
00506 p.m_useNakadeHeuristic = cmd.BoolArg(1);
00507 else if (name == "statistics_enabled")
00508 p.m_statisticsEnabled = cmd.BoolArg(1);
00509 else if (name == "fillboard_tries")
00510 p.m_fillboardTries = cmd.IntArg(1);
00511 else
00512 throw GtpFailure() << "unknown parameter: " << name;
00513 }
00514 else
00515 throw GtpFailure() << "need 0 or 2 arguments";
00516 }
00517
00518
00519
00520
00521
00522
00523
00524 void GoUctCommands::CmdParamRootFilter(GtpCommand& cmd)
00525 {
00526 cmd.CheckNuArgLessEqual(2);
00527 GoUctDefaultRootFilter* f =
00528 dynamic_cast<GoUctDefaultRootFilter*>(&Player().RootFilter());
00529 if (f == 0)
00530 throw GtpFailure("root filter is not GoUctDefaultRootFilter");
00531 if (cmd.NuArg() == 0)
00532 {
00533
00534
00535 cmd << "[bool] check_ladders " << f->CheckLadders() << '\n';
00536 }
00537 else if (cmd.NuArg() == 2)
00538 {
00539 string name = cmd.Arg(0);
00540 if (name == "check_ladders")
00541 f->SetCheckLadders(cmd.BoolArg(1));
00542 else
00543 throw GtpFailure() << "unknown parameter: " << name;
00544 }
00545 else
00546 throw GtpFailure() << "need 0 or 2 arguments";
00547 }
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570
00571
00572
00573 void GoUctCommands::CmdParamSearch(GtpCommand& cmd)
00574 {
00575 cmd.CheckNuArgLessEqual(2);
00576 GoUctSearch& s = Search();
00577 if (cmd.NuArg() == 0)
00578 {
00579
00580
00581 cmd << "[bool] keep_games " << s.KeepGames() << '\n'
00582 << "[bool] lock_free " << s.LockFree() << '\n'
00583 << "[bool] log_games " << s.LogGames() << '\n'
00584 << "[bool] prune_full_tree " << s.PruneFullTree() << '\n'
00585 << "[bool] rave " << s.Rave() << '\n'
00586 << "[bool] virtual_loss " << s.VirtualLoss() << '\n'
00587 << "[bool] weight_rave_updates " << s.WeightRaveUpdates() << '\n'
00588 << "[string] bias_term_constant " << s.BiasTermConstant() << '\n'
00589 << "[string] expand_threshold " << s.ExpandThreshold() << '\n'
00590 << "[string] first_play_urgency " << s.FirstPlayUrgency() << '\n'
00591 << "[string] knowledge_threshold "
00592 << KnowledgeThresholdToString(s.KnowledgeThreshold()) << '\n'
00593 << "[list/none/counts/sequence] live_gfx "
00594 << LiveGfxToString(s.LiveGfx()) << '\n'
00595 << "[string] live_gfx_interval " << s.LiveGfxInterval() << '\n'
00596 << "[string] max_nodes " << s.MaxNodes() << '\n'
00597 << "[list/value/count/bound/estimate] move_select "
00598 << MoveSelectToString(s.MoveSelect()) << '\n'
00599 << "[string] number_threads " << s.NumberThreads() << '\n'
00600 << "[string] number_playouts " << s.NumberPlayouts() << '\n'
00601 << "[string] prune_min_count " << s.PruneMinCount() << '\n'
00602 << "[string] randomize_rave_frequency "
00603 << s.RandomizeRaveFrequency() << '\n'
00604 << "[string] rave_weight_final " << s.RaveWeightFinal() << '\n'
00605 << "[string] rave_weight_initial "
00606 << s.RaveWeightInitial() << '\n';
00607
00608 }
00609 else if (cmd.NuArg() == 2)
00610 {
00611 string name = cmd.Arg(0);
00612 if (name == "keep_games")
00613 s.SetKeepGames(cmd.BoolArg(1));
00614 else if (name == "knowledge_threshold")
00615 s.SetKnowledgeThreshold(KnowledgeThresholdFromString(cmd.Arg(1)));
00616 else if (name == "lock_free")
00617 s.SetLockFree(cmd.BoolArg(1));
00618 else if (name == "log_games")
00619 s.SetLogGames(cmd.BoolArg(1));
00620 else if (name == "prune_full_tree")
00621 s.SetPruneFullTree(cmd.BoolArg(1));
00622 else if (name == "randomize_rave_frequency")
00623 s.SetRandomizeRaveFrequency(cmd.IntArg(1, 0));
00624 else if (name == "rave")
00625 s.SetRave(cmd.BoolArg(1));
00626 else if (name == "weight_rave_updates")
00627 s.SetWeightRaveUpdates(cmd.BoolArg(1));
00628 else if (name == "virtual_loss")
00629 s.SetVirtualLoss(cmd.BoolArg(1));
00630 else if (name == "bias_term_constant")
00631 s.SetBiasTermConstant(cmd.FloatArg(1));
00632 else if (name == "expand_threshold")
00633 s.SetExpandThreshold(cmd.SizeTypeArg(1, 1));
00634 else if (name == "first_play_urgency")
00635 s.SetFirstPlayUrgency(cmd.FloatArg(1));
00636 else if (name == "live_gfx")
00637 s.SetLiveGfx(LiveGfxArg(cmd, 1));
00638 else if (name == "live_gfx_interval")
00639 s.SetLiveGfxInterval(cmd.IntArg(1, 1));
00640 else if (name == "max_nodes")
00641 s.SetMaxNodes(cmd.SizeTypeArg(1, 1));
00642 else if (name == "move_select")
00643 s.SetMoveSelect(MoveSelectArg(cmd, 1));
00644 else if (name == "number_threads")
00645 s.SetNumberThreads(cmd.SizeTypeArg(1, 1));
00646 else if (name == "number_playouts")
00647 s.SetNumberPlayouts(cmd.IntArg(1, 1));
00648 else if (name == "prune_min_count")
00649 s.SetPruneMinCount(cmd.SizeTypeArg(1, 1));
00650 else if (name == "rave_weight_final")
00651 s.SetRaveWeightFinal(cmd.FloatArg(1));
00652 else if (name == "rave_weight_initial")
00653 s.SetRaveWeightInitial(cmd.FloatArg(1));
00654 else
00655 throw GtpFailure() << "unknown parameter: " << name;
00656 }
00657 else
00658 throw GtpFailure() << "need 0 or 2 arguments";
00659 }
00660
00661
00662
00663
00664
00665 void GoUctCommands::CmdPatterns(GtpCommand& cmd)
00666 {
00667 cmd.CheckArgNone();
00668 GoUctPatterns<GoBoard> patterns(m_bd);
00669 for (GoBoard::Iterator it(m_bd); it; ++it)
00670 if (m_bd.IsEmpty(*it) && patterns.MatchAny(*it))
00671 cmd << SgWritePoint(*it) << ' ';
00672 }
00673
00674
00675
00676
00677
00678
00679 void GoUctCommands::CmdPolicyMoves(GtpCommand& cmd)
00680 {
00681 cmd.CheckArgNone();
00682 GoUctPlayoutPolicy<GoBoard> policy(m_bd, Player().m_playoutPolicyParam);
00683 policy.StartPlayout();
00684 policy.GenerateMove();
00685 cmd << GoUctPlayoutPolicyTypeStr(policy.MoveType());
00686 GoPointList moves = policy.GetEquivalentBestMoves();
00687
00688
00689
00690
00691 moves.Sort();
00692 for (int i = 0; i < moves.Length(); ++i)
00693 cmd << ' ' << SgWritePoint(moves[i]);
00694 }
00695
00696
00697
00698
00699
00700
00701 void GoUctCommands::CmdPriorKnowledge(GtpCommand& cmd)
00702 {
00703 cmd.CheckNuArgLessEqual(1);
00704 size_t count = 0;
00705 if (cmd.NuArg() == 1)
00706 count = cmd.SizeTypeArg(0, 0);
00707 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> >& state
00708 = ThreadState(0);
00709 state.StartSearch();
00710 vector<SgMoveInfo> moves;
00711 SgProvenNodeType provenType;
00712 state.GenerateAllMoves(count, moves, provenType);
00713
00714 cmd << "INFLUENCE ";
00715 for (size_t i = 0; i < moves.size(); ++i)
00716 {
00717 SgMove move = moves[i].m_move;
00718 float value = SgUctSearch::InverseEval(moves[i].m_value);
00719
00720 size_t count = moves[i].m_count;
00721 if (count > 0)
00722 {
00723 float scaledValue = (value * 2 - 1);
00724 if (m_bd.ToPlay() != SG_BLACK)
00725 scaledValue *= -1;
00726 cmd << ' ' << SgWritePoint(move) << ' ' << scaledValue;
00727 }
00728 }
00729 cmd << "\nLABEL ";
00730 for (size_t i = 0; i < moves.size(); ++i)
00731 {
00732 SgMove move = moves[i].m_move;
00733 size_t count = moves[i].m_count;
00734 if (count > 0)
00735 cmd << ' ' << SgWritePoint(move) << ' ' << count;
00736 }
00737 cmd << '\n';
00738 }
00739
00740
00741
00742
00743
00744
00745 void GoUctCommands::CmdRaveValues(GtpCommand& cmd)
00746 {
00747 cmd.CheckArgNone();
00748 const GoUctSearch& search = Search();
00749 if (! search.Rave())
00750 throw GtpFailure("RAVE not enabled");
00751 SgPointArray<string> array("\"\"");
00752 const SgUctTree& tree = search.Tree();
00753 for (SgUctChildIterator it(tree, tree.Root()); it; ++it)
00754 {
00755 const SgUctNode& child = *it;
00756 SgPoint p = child.Move();
00757 if (p == SG_PASS || ! child.HasRaveValue())
00758 continue;
00759 ostringstream out;
00760 out << fixed << setprecision(2) << child.RaveValue();
00761 array[p] = out.str();
00762 }
00763 cmd << '\n'
00764 << SgWritePointArray<string>(array, m_bd.Size());
00765 }
00766
00767
00768
00769
00770 void GoUctCommands::CmdRootFilter(GtpCommand& cmd)
00771 {
00772 cmd.CheckArgNone();
00773 cmd << SgWritePointList(Player().RootFilter().Get(), "", false);
00774 }
00775
00776
00777
00778
00779
00780
00781
00782
00783 void GoUctCommands::CmdSaveTree(GtpCommand& cmd)
00784 {
00785 if (Search().MpiSynchronizer()->IsRootProcess())
00786 {
00787 cmd.CheckNuArgLessEqual(2);
00788 string fileName = cmd.Arg(0);
00789 int maxDepth = -1;
00790 if (cmd.NuArg() == 2)
00791 maxDepth = cmd.IntArg(1, 0);
00792 ofstream out(fileName.c_str());
00793 if (! out)
00794 throw GtpFailure() << "Could not open " << fileName;
00795 Search().SaveTree(out, maxDepth);
00796 }
00797 }
00798
00799
00800
00801
00802
00803 void GoUctCommands::CmdSaveGames(GtpCommand& cmd)
00804 {
00805 cmd.CheckNuArg(1);
00806 string fileName = cmd.Arg(0);
00807 try
00808 {
00809 Search().SaveGames(fileName);
00810 }
00811 catch (const SgException& e)
00812 {
00813 throw GtpFailure(e.what());
00814 }
00815 }
00816
00817
00818
00819
00820
00821
00822 void GoUctCommands::CmdScore(GtpCommand& cmd)
00823 {
00824 cmd.CheckArgNone();
00825 try
00826 {
00827 float komi = m_bd.Rules().Komi().ToFloat();
00828 cmd << GoBoardUtil::ScoreSimpleEndPosition(m_bd, komi);
00829 }
00830 catch (const SgException& e)
00831 {
00832 throw GtpFailure(e.what());
00833 }
00834 }
00835
00836
00837
00838
00839
00840
00841
00842
00843
00844 void GoUctCommands::CmdSequence(GtpCommand& cmd)
00845 {
00846 cmd.CheckArgNone();
00847 GoUctUtil::GfxSequence(Search(), Search().ToPlay(), cmd);
00848 }
00849
00850
00851
00852
00853
00854 void GoUctCommands::CmdStatPlayer(GtpCommand& cmd)
00855 {
00856 cmd.CheckArgNone();
00857 Player().GetStatistics().Write(cmd);
00858 }
00859
00860
00861
00862
00863
00864 void GoUctCommands::CmdStatPlayerClear(GtpCommand& cmd)
00865 {
00866 cmd.CheckArgNone();
00867 Player().ClearStatistics();
00868 }
00869
00870
00871
00872
00873
00874
00875
00876
00877 void GoUctCommands::CmdStatPolicy(GtpCommand& cmd)
00878 {
00879 cmd.CheckArgNone();
00880 if (! Player().m_playoutPolicyParam.m_statisticsEnabled)
00881 SgWarning() << "statistics not enabled in policy parameters\n";
00882 Policy(0).Statistics().Write(cmd);
00883 }
00884
00885
00886
00887
00888
00889
00890 void GoUctCommands::CmdStatPolicyClear(GtpCommand& cmd)
00891 {
00892 cmd.CheckArgNone();
00893 Policy(0).ClearStatistics();
00894 }
00895
00896
00897
00898
00899
00900 void GoUctCommands::CmdStatSearch(GtpCommand& cmd)
00901 {
00902 cmd.CheckArgNone();
00903 const GoUctSearch& search = Search();
00904 SgUctTreeStatistics treeStatistics;
00905 treeStatistics.Compute(search.Tree());
00906 cmd << "SearchStatistics:\n";
00907 search.WriteStatistics(cmd);
00908 cmd << "TreeStatistics:\n"
00909 << treeStatistics;
00910 }
00911
00912
00913
00914
00915
00916
00917
00918
00919 void GoUctCommands::CmdStatTerritory(GtpCommand& cmd)
00920 {
00921 cmd.CheckArgNone();
00922 SgPointArray<SgUctStatistics> territoryStatistics
00923 = ThreadState(0).m_territoryStatistics;
00924 SgPointArray<float> array;
00925 for (GoBoard::Iterator it(m_bd); it; ++it)
00926 {
00927 if (territoryStatistics[*it].Count() == 0)
00928 throw GtpFailure("no statistics available");
00929 array[*it] = territoryStatistics[*it].Mean() * 2 - 1;
00930 }
00931 cmd << '\n'
00932 << SgWritePointArrayFloat<float>(array, m_bd.Size(), true, 3);
00933 }
00934
00935
00936
00937
00938 void GoUctCommands::CmdValue(GtpCommand& cmd)
00939 {
00940 cmd.CheckArgNone();
00941 cmd << Search().Tree().Root().Mean();
00942 }
00943
00944
00945
00946
00947 void GoUctCommands::CmdValueBlack(GtpCommand& cmd)
00948 {
00949 cmd.CheckArgNone();
00950 float value = Search().Tree().Root().Mean();
00951 if (Search().ToPlay() == SG_WHITE)
00952 value = SgUctSearch::InverseEval(value);
00953 cmd << value;
00954 }
00955
00956
00957
00958
00959
00960 SgPointSet GoUctCommands::DoFinalStatusSearch()
00961 {
00962 SgPointSet deadStones;
00963 if (GoBoardUtil::TwoPasses(m_bd) && m_bd.Rules().CaptureDead())
00964
00965 return deadStones;
00966
00967 const size_t MAX_GAMES = 10000;
00968 SgDebug() << "GoUctCommands::DoFinalStatusSearch: doing a search with "
00969 << MAX_GAMES << " games to determine final status\n";
00970 GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
00971 GoUctPlayoutPolicyFactory<GoUctBoard> >&
00972 search = GlobalSearch();
00973 SgRestorer<bool> restorer(&search.m_param.m_territoryStatistics);
00974 search.m_param.m_territoryStatistics = true;
00975
00976
00977 int nuUndoPass = 0;
00978 SgBlackWhite toPlay = m_bd.ToPlay();
00979 while (m_bd.GetLastMove() == SG_PASS)
00980 {
00981 m_bd.Undo();
00982 toPlay = SgOppBW(toPlay);
00983 ++nuUndoPass;
00984 }
00985 m_player->UpdateSubscriber();
00986 if (nuUndoPass > 0)
00987 SgDebug() << "Undoing " << nuUndoPass << " passes\n";
00988 vector<SgMove> sequence;
00989 search.Search(MAX_GAMES, numeric_limits<double>::max(), sequence);
00990 SgDebug() << SgWriteLabel("Sequence")
00991 << SgWritePointList(sequence, "", false);
00992 for (int i = 0; i < nuUndoPass; ++i)
00993 {
00994 m_bd.Play(SG_PASS, toPlay);
00995 toPlay = SgOppBW(toPlay);
00996 }
00997 m_player->UpdateSubscriber();
00998
00999 SgPointArray<SgUctStatistics> territoryStatistics =
01000 ThreadState(0).m_territoryStatistics;
01001 GoSafetySolver safetySolver(m_bd);
01002 SgBWSet safe;
01003 safetySolver.FindSafePoints(&safe);
01004 for (GoBlockIterator it(m_bd); it; ++it)
01005 {
01006 SgBlackWhite c = m_bd.GetStone(*it);
01007 bool isDead = safe[SgOppBW(c)].Contains(*it);
01008 if (! isDead && ! safe[c].Contains(*it))
01009 {
01010 SgStatistics<float,int> averageStatus;
01011 for (GoBoard::StoneIterator it2(m_bd, *it); it2; ++it2)
01012 {
01013 if (territoryStatistics[*it2].Count() == 0)
01014
01015
01016 return deadStones;
01017 averageStatus.Add(territoryStatistics[*it2].Mean());
01018 }
01019 const float threshold = 0.3;
01020 isDead =
01021 ((c == SG_BLACK && averageStatus.Mean() < threshold)
01022 || (c == SG_WHITE && averageStatus.Mean() > 1 - threshold));
01023 }
01024 if (isDead)
01025 for (GoBoard::StoneIterator it2(m_bd, *it); it2; ++it2)
01026 deadStones.Include(*it2);
01027 }
01028 return deadStones;
01029 }
01030
01031 GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
01032 GoUctPlayoutPolicyFactory<GoUctBoard> >&
01033 GoUctCommands::GlobalSearch()
01034 {
01035 return Player().GlobalSearch();
01036 }
01037
01038 GoUctPlayerType& GoUctCommands::Player()
01039 {
01040 if (m_player == 0)
01041 throw GtpFailure("player not GoUctPlayer");
01042 try
01043 {
01044 return dynamic_cast<GoUctPlayerType&>(*m_player);
01045 }
01046 catch (const bad_cast&)
01047 {
01048 throw GtpFailure("player not GoUctPlayer");
01049 }
01050 }
01051
01052 GoUctPlayoutPolicy<GoUctBoard>&
01053 GoUctCommands::Policy(std::size_t threadId)
01054 {
01055 GoUctPlayoutPolicy<GoUctBoard>* policy =
01056 dynamic_cast<GoUctPlayoutPolicy<GoUctBoard>*>(
01057 ThreadState(threadId).Policy());
01058 if (policy == 0)
01059 throw GtpFailure("player has no GoUctPlayoutPolicy");
01060 return *policy;
01061 }
01062
01063 void GoUctCommands::Register(GtpEngine& e)
01064 {
01065 Register(e, "final_score", &GoUctCommands::CmdFinalScore);
01066 Register(e, "final_status_list", &GoUctCommands::CmdFinalStatusList);
01067 Register(e, "uct_bounds", &GoUctCommands::CmdBounds);
01068 Register(e, "uct_default_policy", &GoUctCommands::CmdDefaultPolicy);
01069 Register(e, "uct_estimator_stat", &GoUctCommands::CmdEstimatorStat);
01070 Register(e, "uct_gfx", &GoUctCommands::CmdGfx);
01071 Register(e, "uct_max_memory", &GoUctCommands::CmdMaxMemory);
01072 Register(e, "uct_moves", &GoUctCommands::CmdMoves);
01073 Register(e, "uct_param_globalsearch",
01074 &GoUctCommands::CmdParamGlobalSearch);
01075 Register(e, "uct_param_policy", &GoUctCommands::CmdParamPolicy);
01076 Register(e, "uct_param_player", &GoUctCommands::CmdParamPlayer);
01077 Register(e, "uct_param_rootfilter", &GoUctCommands::CmdParamRootFilter);
01078 Register(e, "uct_param_search", &GoUctCommands::CmdParamSearch);
01079 Register(e, "uct_patterns", &GoUctCommands::CmdPatterns);
01080 Register(e, "uct_policy_moves", &GoUctCommands::CmdPolicyMoves);
01081 Register(e, "uct_prior_knowledge", &GoUctCommands::CmdPriorKnowledge);
01082 Register(e, "uct_rave_values", &GoUctCommands::CmdRaveValues);
01083 Register(e, "uct_root_filter", &GoUctCommands::CmdRootFilter);
01084 Register(e, "uct_savegames", &GoUctCommands::CmdSaveGames);
01085 Register(e, "uct_savetree", &GoUctCommands::CmdSaveTree);
01086 Register(e, "uct_sequence", &GoUctCommands::CmdSequence);
01087 Register(e, "uct_score", &GoUctCommands::CmdScore);
01088 Register(e, "uct_stat_player", &GoUctCommands::CmdStatPlayer);
01089 Register(e, "uct_stat_player_clear", &GoUctCommands::CmdStatPlayerClear);
01090 Register(e, "uct_stat_policy", &GoUctCommands::CmdStatPolicy);
01091 Register(e, "uct_stat_policy_clear", &GoUctCommands::CmdStatPolicyClear);
01092 Register(e, "uct_stat_search", &GoUctCommands::CmdStatSearch);
01093 Register(e, "uct_stat_territory", &GoUctCommands::CmdStatTerritory);
01094 Register(e, "uct_value", &GoUctCommands::CmdValue);
01095 Register(e, "uct_value_black", &GoUctCommands::CmdValueBlack);
01096 }
01097
01098 void GoUctCommands::Register(GtpEngine& engine, const std::string& command,
01099 GtpCallback<GoUctCommands>::Method method)
01100 {
01101 engine.Register(command, new GtpCallback<GoUctCommands>(this, method));
01102 }
01103
01104 GoUctSearch& GoUctCommands::Search()
01105 {
01106 try
01107 {
01108 GoUctObjectWithSearch& object =
01109 dynamic_cast<GoUctObjectWithSearch&>(*m_player);
01110 return object.Search();
01111 }
01112 catch (const bad_cast&)
01113 {
01114 throw GtpFailure("player is not a GoUctObjectWithSearch");
01115 }
01116 }
01117
01118
01119
01120
01121
01122 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> >&
01123 GoUctCommands::ThreadState(std::size_t threadId)
01124 {
01125 GoUctSearch& search = Search();
01126 if (! search.ThreadsCreated())
01127 search.CreateThreads();
01128 try
01129 {
01130 return dynamic_cast<
01131 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> >&>(
01132 search.ThreadState(threadId));
01133 }
01134 catch (const bad_cast&)
01135 {
01136 throw GtpFailure("player has no GoUctGlobalSearchState");
01137 }
01138 }
01139
01140