diff --git a/data/gogame.cfg b/data/gogame.cfg index 26cb933..abcbf5c 100644 --- a/data/gogame.cfg +++ b/data/gogame.cfg @@ -1,4 +1,4 @@ -PlayerOne=ROOT_PAR +PlayerOne=SMAF PlayerTwo=RANDOM GUIDelay=1000 //1 second BoardSize=9 @@ -7,4 +7,4 @@ NumGames=1 //Games for each color per player TurnTime=2000 //seconds per player per turn SpectatorBoardShown=true WhiteMoveLogged=false -BlackMoveLogged=false \ No newline at end of file +BlackMoveLogged=true \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/Referee.java b/src/net/woodyfolsom/msproj/Referee.java index 6299d53..5aacde3 100644 --- a/src/net/woodyfolsom/msproj/Referee.java +++ b/src/net/woodyfolsom/msproj/Referee.java @@ -91,8 +91,7 @@ public class Referee { while (!gameRecord.isFinished()) { GameState gameState = gameRecord.getGameState(gameRecord .getNumTurns()); - // System.out.println(gameState); - + Player playerToMove = gameRecord.getPlayerToMove(); Policy policy = getPolicy(playerToMove); Action action = policy.getAction(gameConfig, gameState, @@ -108,6 +107,11 @@ public class Referee { } else { System.out.println("Move rejected - try again."); } + + if (policy.isLogging()) { + System.out.println(gameState); + } + } } catch (Exception ex) { System.out diff --git a/src/net/woodyfolsom/msproj/StandAloneGame.java b/src/net/woodyfolsom/msproj/StandAloneGame.java index 66ff681..cba3eb2 100644 --- a/src/net/woodyfolsom/msproj/StandAloneGame.java +++ b/src/net/woodyfolsom/msproj/StandAloneGame.java @@ -13,6 +13,7 @@ import net.woodyfolsom.msproj.gui.Goban; import net.woodyfolsom.msproj.policy.HumanGuiInput; import net.woodyfolsom.msproj.policy.HumanKeyboardInput; import net.woodyfolsom.msproj.policy.MonteCarloAMAF; +import net.woodyfolsom.msproj.policy.MonteCarloSMAF; import net.woodyfolsom.msproj.policy.MonteCarloUCT; import net.woodyfolsom.msproj.policy.Policy; import net.woodyfolsom.msproj.policy.RandomMovePolicy; @@ -22,11 +23,11 @@ public class StandAloneGame { public static final int EXIT_USER_QUIT = 1; public static final int EXIT_NOMINAL = 0; public static final int EXIT_IO_EXCEPTION = -1; - + private int gameNo = 0; - + enum PLAYER_TYPE { - HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE + HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF }; public static void main(String[] args) throws IOException { @@ -41,7 +42,8 @@ public class StandAloneGame { gameSettings.getBoardSize(), gameSettings.getKomi(), gameSettings.getNumGames(), gameSettings.getTurnTime(), gameSettings.isSpectatorBoardShown(), - gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged()); + gameSettings.isBlackMoveLogged(), + gameSettings.isWhiteMoveLogged()); System.out.println("Press or CTRL-C to exit"); System.in.read(new byte[80]); } catch (IOException ioe) { @@ -64,14 +66,17 @@ public class StandAloneGame { return PLAYER_TYPE.RANDOM; } else if ("RAVE".equalsIgnoreCase(playerTypeStr)) { return PLAYER_TYPE.RAVE; + } else if ("SMAF".equalsIgnoreCase(playerTypeStr)) { + return PLAYER_TYPE.SMAF; } else { throw new RuntimeException("Unknown player type: " + playerTypeStr); } } public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2, - int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard, - boolean blackMoveLogged, boolean whiteMoveLogged) { + int size, double komi, int rounds, long turnLength, + boolean showSpectatorBoard, boolean blackMoveLogged, + boolean whiteMoveLogged) { long startTime = System.currentTimeMillis(); @@ -79,28 +84,38 @@ public class StandAloneGame { gameConfig.setKomi(komi); Referee referee = new Referee(); - referee.setPolicy(Player.BLACK, - getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); - referee.setPolicy(Player.WHITE, - getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); + referee.setPolicy( + Player.BLACK, + getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, + blackMoveLogged)); + referee.setPolicy( + Player.WHITE, + getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, + whiteMoveLogged)); List round1results = new ArrayList(); - + boolean logGameRecords = rounds <= 50; for (int round = 0; round < rounds; round++) { gameNo++; - round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); + round1results.add(referee.play(gameConfig, gameNo, + showSpectatorBoard, logGameRecords)); } List round2results = new ArrayList(); - referee.setPolicy(Player.BLACK, - getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); - referee.setPolicy(Player.WHITE, - getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); + referee.setPolicy( + Player.BLACK, + getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, + blackMoveLogged)); + referee.setPolicy( + Player.WHITE, + getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, + whiteMoveLogged)); for (int round = 0; round < rounds; round++) { gameNo++; - round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); + round2results.add(referee.play(gameConfig, gameNo, + showSpectatorBoard, logGameRecords)); } long endTime = System.currentTimeMillis(); @@ -113,13 +128,14 @@ public class StandAloneGame { try { if (!logGameRecords) { - System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output."); + System.out + .println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output."); } - + logResults(writer, round1results, playerType1.toString(), - playerType2.toString()); + playerType2.toString()); logResults(writer, round2results, playerType2.toString(), - playerType1.toString()); + playerType1.toString()); writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0 + " seconds."); System.out.println("Game tournament saved as " @@ -157,25 +173,38 @@ public class StandAloneGame { private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, Player player, long turnLength, boolean moveLogged) { + + Policy policy; + switch (playerType) { case HUMAN: - return new HumanKeyboardInput(); + policy = new HumanKeyboardInput(); + break; case HUMAN_GUI: - return new HumanGuiInput(new Goban(gameConfig, player,"")); + policy = new HumanGuiInput(new Goban(gameConfig, player, "")); + break; case ROOT_PAR: - return new RootParallelization(4, turnLength); + policy = new RootParallelization(4, turnLength); + break; case UCT: - return new MonteCarloUCT(new RandomMovePolicy(), turnLength); + policy = new MonteCarloUCT(new RandomMovePolicy(), turnLength); + break; + case SMAF: + policy = new MonteCarloSMAF(new RandomMovePolicy(), turnLength, 0); + break; case RANDOM: - RandomMovePolicy randomMovePolicy = new RandomMovePolicy(); - randomMovePolicy.setLogging(moveLogged); - return randomMovePolicy; + policy = new RandomMovePolicy(); + break; case RAVE: - return new MonteCarloAMAF(new RandomMovePolicy(), turnLength); + policy = new MonteCarloAMAF(new RandomMovePolicy(), turnLength); + break; default: throw new IllegalArgumentException("Invalid PLAYER_TYPE: " + playerType); } + + policy.setLogging(moveLogged); + return policy; } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java index 429833e..7c39344 100644 --- a/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java +++ b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java @@ -1,6 +1,9 @@ package net.woodyfolsom.msproj.ann; -import java.io.FileNotFoundException; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -16,66 +19,99 @@ import net.woodyfolsom.msproj.tictactoe.State; public class TTTFilterTrainer { // implements epsilon-greedy trainer? online // version of NeuralNetFilter - public static void main(String[] args) throws FileNotFoundException { - double alpha = 0.15; - double lambda = .95; - int maxGames = 1000; + private boolean training = true; + + public static void main(String[] args) throws IOException { + double alpha = 0.50; + double lambda = 0.90; + int maxGames = 100000; new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames); } public void trainNetwork(double alpha, double lambda, int maxGames) - throws FileNotFoundException { + throws IOException { - FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9, 6, - 1); - neuralNetwork.setName("TicTacToe"); - neuralNetwork.initWeights(); - TrainingMethod trainer = new TemporalDifference(alpha, lambda); + FeedforwardNetwork neuralNetwork; + if (training) { + neuralNetwork = new MultiLayerPerceptron(true, 9, 9, 1); + neuralNetwork.setName("TicTacToe"); + neuralNetwork.initWeights(); + TrainingMethod trainer = new TemporalDifference(alpha, lambda); - System.out.println("Playing untrained games."); - for (int i = 0; i < 10; i++) { - System.out.println("" + (i + 1) + ". " - + playOptimal(neuralNetwork).getResult()); - } - - System.out.println("Learning from " + maxGames - + " games of random self-play"); - - int gamesPlayed = 0; - List results = new ArrayList(); - do { - GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, - trainer); - System.out.println("Winner: " + gameRecord.getResult()); - gamesPlayed++; - results.add(gameRecord.getResult()); - } while (gamesPlayed < maxGames); - - System.out.println("Learned network after " + maxGames - + " training games."); - - for (int i = 0; i < results.size(); i++) { - if (i % 10 == 0) { - System.out.println("" + (i + 1) + ". " + results.get(i)); + System.out.println("Playing untrained games."); + for (int i = 0; i < 10; i++) { + System.out.println("" + (i + 1) + ". " + + playOptimal(neuralNetwork).getResult()); } + + System.out.println("Learning from " + maxGames + + " games of random self-play"); + + int gamesPlayed = 0; + List results = new ArrayList(); + do { + GameRecord gameRecord = playEpsilonGreedy(0.50, neuralNetwork, + trainer); + System.out.println("Winner: " + gameRecord.getResult()); + gamesPlayed++; + results.add(gameRecord.getResult()); + } while (gamesPlayed < maxGames); + + System.out.println("Results of every 10th training game:"); + + for (int i = 0; i < results.size(); i++) { + if (i % 10 == 0) { + System.out.println("" + (i + 1) + ". " + results.get(i)); + } + } + + System.out.println("Learned network after " + maxGames + + " training games."); + } else { + System.out.println("Loading TicTacToe network from file."); + neuralNetwork = new MultiLayerPerceptron(); + FileInputStream fis = new FileInputStream(new File("ttt.net")); + if (!new MultiLayerPerceptron().load(fis)) { + System.out.println("Error loading ttt.net from file."); + return; + } + fis.close(); } - + evalTestCases(neuralNetwork); System.out.println("Playing optimal games."); + List gameResults = new ArrayList(); for (int i = 0; i < 10; i++) { - System.out.println("" + (i + 1) + ". " - + playOptimal(neuralNetwork).getResult()); + gameResults.add(playOptimal(neuralNetwork).getResult()); } - /* - * File output = new File("ttt.net"); - * - * FileOutputStream fos = new FileOutputStream(output); - * - * neuralNetwork.save(fos); - */ + boolean suboptimalPlay = false; + System.out.println("Optimal game summary: "); + for (int i = 0; i < gameResults.size(); i++) { + RESULT result = gameResults.get(i); + System.out.println("" + (i + 1) + ". " + result); + if (result != RESULT.X_WINS) { + suboptimalPlay = true; + } + } + + File output = new File("ttt.net"); + + FileOutputStream fos = new FileOutputStream(output); + + neuralNetwork.save(fos); + + System.out.println("Playing optimal vs random games."); + for (int i = 0; i < 10; i++) { + System.out.println("" + (i + 1) + ". " + + playOptimalVsRandom(neuralNetwork).getResult()); + } + + if (suboptimalPlay) { + System.out.println("Suboptimal play detected!"); + } } private void evalTestCases(FeedforwardNetwork neuralNetwork) { @@ -114,7 +150,33 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online System.out.println("Output from eval set (learned network):"); testNetwork(neuralNetwork, validationSet, inputNames, outputNames); } - + + private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork) { + GameRecord gameRecord = new GameRecord(); + + Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); + Policy randomPolicy = new RandomPolicy(); + + State state = gameRecord.getState(); + + Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy }; + int turnNo = 0; + do { + Action action; + State nextState; + + action = policies[turnNo % 2].getAction(gameRecord.getState()); + + nextState = gameRecord.apply(action); + System.out.println("Action " + action + " selected by policy " + + policies[turnNo % 2].getName()); + System.out.println("Next board state: " + nextState); + state = nextState; + turnNo++; + } while (!state.isTerminal()); + return gameRecord; + } + private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) { GameRecord gameRecord = new GameRecord(); @@ -122,8 +184,6 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online State state = gameRecord.getState(); - System.out.println("Playing optimal game:"); - do { Action action; State nextState; @@ -131,14 +191,12 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online action = neuralNetPolicy.getAction(gameRecord.getState()); nextState = gameRecord.apply(action); - System.out.println("Action " + action + " selected by policy " + - neuralNetPolicy.getName()); + System.out.println("Action " + action + " selected by policy " + + neuralNetPolicy.getName()); System.out.println("Next board state: " + nextState); state = nextState; } while (!state.isTerminal()); - // finally, reinforce the actual reward - return gameRecord; } @@ -196,7 +254,7 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { NNDataPair dp = new NNDataPair(new NNData(inputNames, validationSet[valIndex]), new NNData(outputNames, - new double[] {0.0})); + new double[] { 0.0 })); System.out.println(dp + " => " + neuralNetwork.compute(dp)); } } diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java index cc69ebc..e0587e7 100644 --- a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java +++ b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java @@ -4,7 +4,7 @@ import java.util.List; public class TemporalDifference extends TrainingMethod { private final double alpha; - private final double gamma = 1.0; + // private final double gamma = 1.0; private final double lambda; public TemporalDifference(double alpha, double lambda) { @@ -81,27 +81,31 @@ public class TemporalDifference extends TrainingMethod { } } - private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) { + private void updateWeights(FeedforwardNetwork neuralNetwork, + double predictionError) { for (Connection connection : neuralNetwork.getConnections()) { - /*Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); - Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); - - double delta = alpha * srcNeuron.getOutput() - * destNeuron.getGradient() * predictionError + connection.getTrace() * lambda; - - // TODO allow for momentum - // double lastDelta = connection.getLastDelta(); - connection.addDelta(delta);*/ + /* + * Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); + * Neuron destNeuron = + * neuralNetwork.getNeuron(connection.getDest()); + * + * double delta = alpha * srcNeuron.getOutput() + * destNeuron.getGradient() * predictionError + + * connection.getTrace() * lambda; + * + * // TODO allow for momentum // double lastDelta = + * connection.getLastDelta(); connection.addDelta(delta); + */ Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); - double delta = alpha * srcNeuron.getOutput() + double delta = alpha * srcNeuron.getOutput() * destNeuron.getGradient() + connection.getTrace() * lambda; - //TODO allow for momentum - //double lastDelta = connection.getLastDelta(); - connection.addDelta(delta); + // TODO allow for momentum + // double lastDelta = connection.getLastDelta(); + connection.addDelta(delta); } } - + @Override public void iterateSequences(FeedforwardNetwork neuralNetwork, List> trainingSet) { @@ -117,24 +121,24 @@ public class TemporalDifference extends TrainingMethod { @Override protected void iteratePattern(FeedforwardNetwork neuralNetwork, NNDataPair statePair, NNData nextReward) { - //System.out.println("Learningrate: " + alpha); + // System.out.println("Learningrate: " + alpha); zeroGradients(neuralNetwork); - //System.out.println("Training with: " + statePair.getInput()); + // System.out.println("Training with: " + statePair.getInput()); NNData ideal = nextReward; NNData actual = neuralNetwork.compute(statePair); - //System.out.println("Updating weights. Ideal Output: " + ideal); - //System.out.println("Actual Output: " + actual); + // System.out.println("Updating weights. Ideal Output: " + ideal); + // System.out.println("Actual Output: " + actual); // backpropagate the gradients w.r.t. output error backPropagate(neuralNetwork, ideal); double predictionError = statePair.getIdeal().getValues()[0] // reward_t - + actual.getValues()[0] - nextReward.getValues()[0]; - + + actual.getValues()[0] - nextReward.getValues()[0]; + updateWeights(neuralNetwork, predictionError); } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/policy/AlphaBeta.java b/src/net/woodyfolsom/msproj/policy/AlphaBeta.java index b79b018..1d7f542 100644 --- a/src/net/woodyfolsom/msproj/policy/AlphaBeta.java +++ b/src/net/woodyfolsom/msproj/policy/AlphaBeta.java @@ -16,6 +16,15 @@ public class AlphaBeta implements Policy { private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); + private boolean logging = false; + public boolean isLogging() { + return logging; + } + + public void setLogging(boolean logging) { + this.logging = logging; + } + private int lookAhead; private int numStateEvaluations = 0; diff --git a/src/net/woodyfolsom/msproj/policy/HumanGuiInput.java b/src/net/woodyfolsom/msproj/policy/HumanGuiInput.java index 08c159f..043cbca 100644 --- a/src/net/woodyfolsom/msproj/policy/HumanGuiInput.java +++ b/src/net/woodyfolsom/msproj/policy/HumanGuiInput.java @@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.gui.Goban; public class HumanGuiInput implements Policy { + private boolean logging; + public boolean isLogging() { + return logging; + } + + public void setLogging(boolean logging) { + this.logging = logging; + } + private Goban goban; public HumanGuiInput(Goban goban) { diff --git a/src/net/woodyfolsom/msproj/policy/HumanKeyboardInput.java b/src/net/woodyfolsom/msproj/policy/HumanKeyboardInput.java index b6cfe9d..f6bb765 100644 --- a/src/net/woodyfolsom/msproj/policy/HumanKeyboardInput.java +++ b/src/net/woodyfolsom/msproj/policy/HumanKeyboardInput.java @@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.GameState; import net.woodyfolsom.msproj.Player; public class HumanKeyboardInput implements Policy { + private boolean logging = false; + + public boolean isLogging() { + return logging; + } + + public void setLogging(boolean logging) { + this.logging = logging; + } @Override public Action getAction(GameConfig gameConfig, GameState gameState, diff --git a/src/net/woodyfolsom/msproj/policy/Minimax.java b/src/net/woodyfolsom/msproj/policy/Minimax.java index ca3e79a..0e8fd0b 100644 --- a/src/net/woodyfolsom/msproj/policy/Minimax.java +++ b/src/net/woodyfolsom/msproj/policy/Minimax.java @@ -16,6 +16,15 @@ public class Minimax implements Policy { private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); + private boolean logging = false; + public boolean isLogging() { + return logging; + } + + public void setLogging(boolean logging) { + this.logging = logging; + } + private int lookAhead; private int numStateEvaluations = 0; diff --git a/src/net/woodyfolsom/msproj/policy/MonteCarlo.java b/src/net/woodyfolsom/msproj/policy/MonteCarlo.java index fb4d35a..3183393 100644 --- a/src/net/woodyfolsom/msproj/policy/MonteCarlo.java +++ b/src/net/woodyfolsom/msproj/policy/MonteCarlo.java @@ -15,6 +15,15 @@ import net.woodyfolsom.msproj.tree.MonteCarloProperties; public abstract class MonteCarlo implements Policy { protected static final int ROLLOUT_DEPTH_LIMIT = 250; + private boolean logging = false; + public boolean isLogging() { + return logging; + } + + public void setLogging(boolean logging) { + this.logging = logging; + } + protected int numStateEvaluations = 0; protected Policy movePolicy; diff --git a/src/net/woodyfolsom/msproj/policy/MonteCarloAMAF.java b/src/net/woodyfolsom/msproj/policy/MonteCarloAMAF.java index 3ade6e1..9d2246f 100644 --- a/src/net/woodyfolsom/msproj/policy/MonteCarloAMAF.java +++ b/src/net/woodyfolsom/msproj/policy/MonteCarloAMAF.java @@ -63,6 +63,43 @@ public class MonteCarloAMAF extends MonteCarloUCT { rootGameState, new AMAFProperties()); } + @Override + public Action getBestAction(GameTreeNode node) { + Action bestAction = Action.NONE; + double bestScore = Double.NEGATIVE_INFINITY; + GameTreeNode bestChild = null; + + for (Action action : node.getActions()) { + GameTreeNode childNode = node + .getChild(action); + + AMAFProperties childProps = (AMAFProperties)childNode.getProperties(); + double childScore = childProps.getAmafWins() / (double)childProps.getAmafVisits(); + + if (childScore >= bestScore) { + bestScore = childScore; + bestAction = action; + bestChild = childNode; + } + } + + if (bestAction == Action.NONE) { + System.out + .println("MonteCarloUCT failed - no actions were found for the current game state (not even PASS)."); + } else { + System.out.println("Action " + bestAction + " selected for " + + node.getGameState().getPlayerToMove() + + " with simulated win ratio of " + + (bestScore * 100.0 + "%")); + System.out.println("It was visited " + + bestChild.getProperties().getVisits() + " times out of " + + node.getProperties().getVisits() + " rollouts among " + + node.getNumChildren() + + " valid actions from the current state."); + } + return bestAction; + } + @Override protected double getNodeScore(GameTreeNode gameTreeNode) { //double nodeVisits = gameTreeNode.getParent().getProperties().getVisits(); @@ -72,16 +109,8 @@ public class MonteCarloAMAF extends MonteCarloUCT { if (gameTreeNode.getGameState().isTerminal()) { nodeScore = 0.0; } else { - /* - MonteCarloProperties properties = gameTreeNode.getProperties(); - nodeScore = (double) (properties.getWins() / properties - .getVisits()) - + (TUNING_CONSTANT * Math.sqrt(Math.log(nodeVisits) - / gameTreeNode.getProperties().getVisits())); - * - */ AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties(); - nodeScore = (double) (properties.getAmafWins() / properties + nodeScore = (properties.getAmafWins() / (double) properties .getAmafVisits()) + (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits) / properties.getAmafVisits())); diff --git a/src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java b/src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java new file mode 100644 index 0000000..dfa8cf9 --- /dev/null +++ b/src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java @@ -0,0 +1,59 @@ +package net.woodyfolsom.msproj.policy; + +import java.util.List; + +import net.woodyfolsom.msproj.Action; +import net.woodyfolsom.msproj.Player; +import net.woodyfolsom.msproj.tree.AMAFProperties; +import net.woodyfolsom.msproj.tree.GameTreeNode; +import net.woodyfolsom.msproj.tree.MonteCarloProperties; + +public class MonteCarloSMAF extends MonteCarloAMAF { + private int horizon; + + public MonteCarloSMAF(Policy movePolicy, long searchTimeLimit, int horizon) { + super(movePolicy, searchTimeLimit); + this.horizon = horizon; + } + + @Override + public void update(GameTreeNode node, Rollout rollout) { + GameTreeNode currentNode = node; + //List subTreeActions = new ArrayList(rollout.getPlayout()); + + List playout = rollout.getPlayout(); + int reward = rollout.getReward(); + while (currentNode != null) { + AMAFProperties nodeProperties = (AMAFProperties)currentNode.getProperties(); + + //Always update props for the current node + nodeProperties.setWins(nodeProperties.getWins() + reward); + nodeProperties.setVisits(nodeProperties.getVisits() + 1); + nodeProperties.setAmafWins(nodeProperties.getAmafWins() + reward); + nodeProperties.setAmafVisits(nodeProperties.getAmafVisits() + 1); + + GameTreeNode parentNode = currentNode.getParent(); + if (parentNode != null) { + Player playerToMove = parentNode.getGameState().getPlayerToMove(); + for (Action actionFromParent : parentNode.getActions()) { + if (playout.subList(0, Math.max(horizon,playout.size())).contains(actionFromParent)) { + GameTreeNode subTreeChild = parentNode.getChild(actionFromParent); + //Don't count AMAF properties for the current node twice + if (subTreeChild == currentNode) { + continue; + } + + AMAFProperties siblingProperties = (AMAFProperties)subTreeChild.getProperties(); + //Only update AMAF properties if the sibling is reached by the same action with the same player to move + if (rollout.hasPlay(playerToMove,actionFromParent)) { + siblingProperties.setAmafWins(siblingProperties.getAmafWins() + reward); + siblingProperties.setAmafVisits(siblingProperties.getAmafVisits() + 1); + } + } + } + } + + currentNode = currentNode.getParent(); + } + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/policy/MonteCarloUCT.java b/src/net/woodyfolsom/msproj/policy/MonteCarloUCT.java index d0e6b80..b4a622d 100644 --- a/src/net/woodyfolsom/msproj/policy/MonteCarloUCT.java +++ b/src/net/woodyfolsom/msproj/policy/MonteCarloUCT.java @@ -90,11 +90,8 @@ public class MonteCarloUCT extends MonteCarlo { GameTreeNode childNode = node .getChild(action); - //MonteCarloProperties properties = childNode.getProperties(); - //double childScore = (double) properties.getWins() - // / properties.getVisits(); - - double childScore = getNodeScore(childNode); + MonteCarloProperties childProps = childNode.getProperties(); + double childScore = childProps.getWins() / (double)childProps.getVisits(); if (childScore >= bestScore) { bestScore = childScore; diff --git a/src/net/woodyfolsom/msproj/policy/Policy.java b/src/net/woodyfolsom/msproj/policy/Policy.java index 262aeb7..cc04120 100644 --- a/src/net/woodyfolsom/msproj/policy/Policy.java +++ b/src/net/woodyfolsom/msproj/policy/Policy.java @@ -17,4 +17,8 @@ public interface Policy { public int getNumStateEvaluations(); public void setState(GameState gameState); + + boolean isLogging(); + + void setLogging(boolean logging); } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/policy/RandomMovePolicy.java b/src/net/woodyfolsom/msproj/policy/RandomMovePolicy.java index 0677eb0..45209d7 100644 --- a/src/net/woodyfolsom/msproj/policy/RandomMovePolicy.java +++ b/src/net/woodyfolsom/msproj/policy/RandomMovePolicy.java @@ -110,6 +110,7 @@ public class RandomMovePolicy implements Policy, ActionGenerator { return randomAction; } + @Override public boolean isLogging() { return logging; } diff --git a/src/net/woodyfolsom/msproj/policy/RootParallelization.java b/src/net/woodyfolsom/msproj/policy/RootParallelization.java index c9c9e27..a05936a 100644 --- a/src/net/woodyfolsom/msproj/policy/RootParallelization.java +++ b/src/net/woodyfolsom/msproj/policy/RootParallelization.java @@ -13,7 +13,16 @@ import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.tree.MonteCarloProperties; public class RootParallelization implements Policy { + private boolean logging = false; private int numTrees = 1; + public boolean isLogging() { + return logging; + } + + public void setLogging(boolean logging) { + this.logging = logging; + } + private long timeLimit = 1000L; public RootParallelization(int numTrees, long timeLimit) {