Fixed AMAF, SMAF algorithms.

This commit is contained in:
2012-12-02 18:37:51 -05:00
parent d24e7aee97
commit 28dc44b61e
16 changed files with 362 additions and 123 deletions

View File

@@ -1,4 +1,4 @@
PlayerOne=ROOT_PAR PlayerOne=SMAF
PlayerTwo=RANDOM PlayerTwo=RANDOM
GUIDelay=1000 //1 second GUIDelay=1000 //1 second
BoardSize=9 BoardSize=9
@@ -7,4 +7,4 @@ NumGames=1 //Games for each color per player
TurnTime=2000 //seconds per player per turn TurnTime=2000 //seconds per player per turn
SpectatorBoardShown=true SpectatorBoardShown=true
WhiteMoveLogged=false WhiteMoveLogged=false
BlackMoveLogged=false BlackMoveLogged=true

View File

@@ -91,7 +91,6 @@ public class Referee {
while (!gameRecord.isFinished()) { while (!gameRecord.isFinished()) {
GameState gameState = gameRecord.getGameState(gameRecord GameState gameState = gameRecord.getGameState(gameRecord
.getNumTurns()); .getNumTurns());
// System.out.println(gameState);
Player playerToMove = gameRecord.getPlayerToMove(); Player playerToMove = gameRecord.getPlayerToMove();
Policy policy = getPolicy(playerToMove); Policy policy = getPolicy(playerToMove);
@@ -108,6 +107,11 @@ public class Referee {
} else { } else {
System.out.println("Move rejected - try again."); System.out.println("Move rejected - try again.");
} }
if (policy.isLogging()) {
System.out.println(gameState);
}
} }
} catch (Exception ex) { } catch (Exception ex) {
System.out System.out

View File

@@ -13,6 +13,7 @@ import net.woodyfolsom.msproj.gui.Goban;
import net.woodyfolsom.msproj.policy.HumanGuiInput; import net.woodyfolsom.msproj.policy.HumanGuiInput;
import net.woodyfolsom.msproj.policy.HumanKeyboardInput; import net.woodyfolsom.msproj.policy.HumanKeyboardInput;
import net.woodyfolsom.msproj.policy.MonteCarloAMAF; import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
import net.woodyfolsom.msproj.policy.MonteCarloSMAF;
import net.woodyfolsom.msproj.policy.MonteCarloUCT; import net.woodyfolsom.msproj.policy.MonteCarloUCT;
import net.woodyfolsom.msproj.policy.Policy; import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.policy.RandomMovePolicy; import net.woodyfolsom.msproj.policy.RandomMovePolicy;
@@ -26,7 +27,7 @@ public class StandAloneGame {
private int gameNo = 0; private int gameNo = 0;
enum PLAYER_TYPE { enum PLAYER_TYPE {
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF
}; };
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
@@ -41,7 +42,8 @@ public class StandAloneGame {
gameSettings.getBoardSize(), gameSettings.getKomi(), gameSettings.getBoardSize(), gameSettings.getKomi(),
gameSettings.getNumGames(), gameSettings.getTurnTime(), gameSettings.getNumGames(), gameSettings.getTurnTime(),
gameSettings.isSpectatorBoardShown(), gameSettings.isSpectatorBoardShown(),
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged()); gameSettings.isBlackMoveLogged(),
gameSettings.isWhiteMoveLogged());
System.out.println("Press <Enter> or CTRL-C to exit"); System.out.println("Press <Enter> or CTRL-C to exit");
System.in.read(new byte[80]); System.in.read(new byte[80]);
} catch (IOException ioe) { } catch (IOException ioe) {
@@ -64,14 +66,17 @@ public class StandAloneGame {
return PLAYER_TYPE.RANDOM; return PLAYER_TYPE.RANDOM;
} else if ("RAVE".equalsIgnoreCase(playerTypeStr)) { } else if ("RAVE".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.RAVE; return PLAYER_TYPE.RAVE;
} else if ("SMAF".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.SMAF;
} else { } else {
throw new RuntimeException("Unknown player type: " + playerTypeStr); throw new RuntimeException("Unknown player type: " + playerTypeStr);
} }
} }
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2, public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard, int size, double komi, int rounds, long turnLength,
boolean blackMoveLogged, boolean whiteMoveLogged) { boolean showSpectatorBoard, boolean blackMoveLogged,
boolean whiteMoveLogged) {
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
@@ -79,28 +84,38 @@ public class StandAloneGame {
gameConfig.setKomi(komi); gameConfig.setKomi(komi);
Referee referee = new Referee(); Referee referee = new Referee();
referee.setPolicy(Player.BLACK, referee.setPolicy(
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); Player.BLACK,
referee.setPolicy(Player.WHITE, getPolicy(playerType1, gameConfig, Player.BLACK, turnLength,
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); blackMoveLogged));
referee.setPolicy(
Player.WHITE,
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength,
whiteMoveLogged));
List<GameResult> round1results = new ArrayList<GameResult>(); List<GameResult> round1results = new ArrayList<GameResult>();
boolean logGameRecords = rounds <= 50; boolean logGameRecords = rounds <= 50;
for (int round = 0; round < rounds; round++) { for (int round = 0; round < rounds; round++) {
gameNo++; gameNo++;
round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); round1results.add(referee.play(gameConfig, gameNo,
showSpectatorBoard, logGameRecords));
} }
List<GameResult> round2results = new ArrayList<GameResult>(); List<GameResult> round2results = new ArrayList<GameResult>();
referee.setPolicy(Player.BLACK, referee.setPolicy(
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); Player.BLACK,
referee.setPolicy(Player.WHITE, getPolicy(playerType2, gameConfig, Player.BLACK, turnLength,
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); blackMoveLogged));
referee.setPolicy(
Player.WHITE,
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength,
whiteMoveLogged));
for (int round = 0; round < rounds; round++) { for (int round = 0; round < rounds; round++) {
gameNo++; gameNo++;
round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); round2results.add(referee.play(gameConfig, gameNo,
showSpectatorBoard, logGameRecords));
} }
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
@@ -113,13 +128,14 @@ public class StandAloneGame {
try { try {
if (!logGameRecords) { if (!logGameRecords) {
System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output."); System.out
.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
} }
logResults(writer, round1results, playerType1.toString(), logResults(writer, round1results, playerType1.toString(),
playerType2.toString()); playerType2.toString());
logResults(writer, round2results, playerType2.toString(), logResults(writer, round2results, playerType2.toString(),
playerType1.toString()); playerType1.toString());
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0 writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
+ " seconds."); + " seconds.");
System.out.println("Game tournament saved as " System.out.println("Game tournament saved as "
@@ -157,25 +173,38 @@ public class StandAloneGame {
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
Player player, long turnLength, boolean moveLogged) { Player player, long turnLength, boolean moveLogged) {
Policy policy;
switch (playerType) { switch (playerType) {
case HUMAN: case HUMAN:
return new HumanKeyboardInput(); policy = new HumanKeyboardInput();
break;
case HUMAN_GUI: case HUMAN_GUI:
return new HumanGuiInput(new Goban(gameConfig, player,"")); policy = new HumanGuiInput(new Goban(gameConfig, player, ""));
break;
case ROOT_PAR: case ROOT_PAR:
return new RootParallelization(4, turnLength); policy = new RootParallelization(4, turnLength);
break;
case UCT: case UCT:
return new MonteCarloUCT(new RandomMovePolicy(), turnLength); policy = new MonteCarloUCT(new RandomMovePolicy(), turnLength);
break;
case SMAF:
policy = new MonteCarloSMAF(new RandomMovePolicy(), turnLength, 0);
break;
case RANDOM: case RANDOM:
RandomMovePolicy randomMovePolicy = new RandomMovePolicy(); policy = new RandomMovePolicy();
randomMovePolicy.setLogging(moveLogged); break;
return randomMovePolicy;
case RAVE: case RAVE:
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength); policy = new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
break;
default: default:
throw new IllegalArgumentException("Invalid PLAYER_TYPE: " throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
+ playerType); + playerType);
} }
policy.setLogging(moveLogged);
return policy;
} }
} }

View File

@@ -1,6 +1,9 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.FileNotFoundException; import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@@ -16,66 +19,99 @@ import net.woodyfolsom.msproj.tictactoe.State;
public class TTTFilterTrainer { // implements epsilon-greedy trainer? online public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
// version of NeuralNetFilter // version of NeuralNetFilter
public static void main(String[] args) throws FileNotFoundException { private boolean training = true;
double alpha = 0.15;
double lambda = .95; public static void main(String[] args) throws IOException {
int maxGames = 1000; double alpha = 0.50;
double lambda = 0.90;
int maxGames = 100000;
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames); new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
} }
public void trainNetwork(double alpha, double lambda, int maxGames) public void trainNetwork(double alpha, double lambda, int maxGames)
throws FileNotFoundException { throws IOException {
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9, 6, FeedforwardNetwork neuralNetwork;
1); if (training) {
neuralNetwork.setName("TicTacToe"); neuralNetwork = new MultiLayerPerceptron(true, 9, 9, 1);
neuralNetwork.initWeights(); neuralNetwork.setName("TicTacToe");
TrainingMethod trainer = new TemporalDifference(alpha, lambda); neuralNetwork.initWeights();
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
System.out.println("Playing untrained games."); System.out.println("Playing untrained games.");
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
System.out.println("" + (i + 1) + ". " System.out.println("" + (i + 1) + ". "
+ playOptimal(neuralNetwork).getResult()); + playOptimal(neuralNetwork).getResult());
}
System.out.println("Learning from " + maxGames
+ " games of random self-play");
int gamesPlayed = 0;
List<RESULT> results = new ArrayList<RESULT>();
do {
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork,
trainer);
System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++;
results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames);
System.out.println("Learned network after " + maxGames
+ " training games.");
for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) {
System.out.println("" + (i + 1) + ". " + results.get(i));
} }
System.out.println("Learning from " + maxGames
+ " games of random self-play");
int gamesPlayed = 0;
List<RESULT> results = new ArrayList<RESULT>();
do {
GameRecord gameRecord = playEpsilonGreedy(0.50, neuralNetwork,
trainer);
System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++;
results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames);
System.out.println("Results of every 10th training game:");
for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) {
System.out.println("" + (i + 1) + ". " + results.get(i));
}
}
System.out.println("Learned network after " + maxGames
+ " training games.");
} else {
System.out.println("Loading TicTacToe network from file.");
neuralNetwork = new MultiLayerPerceptron();
FileInputStream fis = new FileInputStream(new File("ttt.net"));
if (!new MultiLayerPerceptron().load(fis)) {
System.out.println("Error loading ttt.net from file.");
return;
}
fis.close();
} }
evalTestCases(neuralNetwork); evalTestCases(neuralNetwork);
System.out.println("Playing optimal games."); System.out.println("Playing optimal games.");
List<RESULT> gameResults = new ArrayList<RESULT>();
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
System.out.println("" + (i + 1) + ". " gameResults.add(playOptimal(neuralNetwork).getResult());
+ playOptimal(neuralNetwork).getResult());
} }
/* boolean suboptimalPlay = false;
* File output = new File("ttt.net"); System.out.println("Optimal game summary: ");
* for (int i = 0; i < gameResults.size(); i++) {
* FileOutputStream fos = new FileOutputStream(output); RESULT result = gameResults.get(i);
* System.out.println("" + (i + 1) + ". " + result);
* neuralNetwork.save(fos); if (result != RESULT.X_WINS) {
*/ suboptimalPlay = true;
}
}
File output = new File("ttt.net");
FileOutputStream fos = new FileOutputStream(output);
neuralNetwork.save(fos);
System.out.println("Playing optimal vs random games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i + 1) + ". "
+ playOptimalVsRandom(neuralNetwork).getResult());
}
if (suboptimalPlay) {
System.out.println("Suboptimal play detected!");
}
} }
private void evalTestCases(FeedforwardNetwork neuralNetwork) { private void evalTestCases(FeedforwardNetwork neuralNetwork) {
@@ -115,6 +151,32 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
testNetwork(neuralNetwork, validationSet, inputNames, outputNames); testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
} }
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
Policy randomPolicy = new RandomPolicy();
State state = gameRecord.getState();
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
int turnNo = 0;
do {
Action action;
State nextState;
action = policies[turnNo % 2].getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
System.out.println("Action " + action + " selected by policy "
+ policies[turnNo % 2].getName());
System.out.println("Next board state: " + nextState);
state = nextState;
turnNo++;
} while (!state.isTerminal());
return gameRecord;
}
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) { private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord(); GameRecord gameRecord = new GameRecord();
@@ -122,8 +184,6 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
State state = gameRecord.getState(); State state = gameRecord.getState();
System.out.println("Playing optimal game:");
do { do {
Action action; Action action;
State nextState; State nextState;
@@ -131,14 +191,12 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
action = neuralNetPolicy.getAction(gameRecord.getState()); action = neuralNetPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action); nextState = gameRecord.apply(action);
System.out.println("Action " + action + " selected by policy " + System.out.println("Action " + action + " selected by policy "
neuralNetPolicy.getName()); + neuralNetPolicy.getName());
System.out.println("Next board state: " + nextState); System.out.println("Next board state: " + nextState);
state = nextState; state = nextState;
} while (!state.isTerminal()); } while (!state.isTerminal());
// finally, reinforce the actual reward
return gameRecord; return gameRecord;
} }
@@ -196,7 +254,7 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames, NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames, validationSet[valIndex]), new NNData(outputNames,
new double[] {0.0})); new double[] { 0.0 }));
System.out.println(dp + " => " + neuralNetwork.compute(dp)); System.out.println(dp + " => " + neuralNetwork.compute(dp));
} }
} }

View File

@@ -4,7 +4,7 @@ import java.util.List;
public class TemporalDifference extends TrainingMethod { public class TemporalDifference extends TrainingMethod {
private final double alpha; private final double alpha;
private final double gamma = 1.0; // private final double gamma = 1.0;
private final double lambda; private final double lambda;
public TemporalDifference(double alpha, double lambda) { public TemporalDifference(double alpha, double lambda) {
@@ -81,23 +81,27 @@ public class TemporalDifference extends TrainingMethod {
} }
} }
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) { private void updateWeights(FeedforwardNetwork neuralNetwork,
double predictionError) {
for (Connection connection : neuralNetwork.getConnections()) { for (Connection connection : neuralNetwork.getConnections()) {
/*Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); /*
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); * Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
* Neuron destNeuron =
double delta = alpha * srcNeuron.getOutput() * neuralNetwork.getNeuron(connection.getDest());
* destNeuron.getGradient() * predictionError + connection.getTrace() * lambda; *
* double delta = alpha * srcNeuron.getOutput()
// TODO allow for momentum * destNeuron.getGradient() * predictionError +
// double lastDelta = connection.getLastDelta(); * connection.getTrace() * lambda;
connection.addDelta(delta);*/ *
* // TODO allow for momentum // double lastDelta =
* connection.getLastDelta(); connection.addDelta(delta);
*/
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = alpha * srcNeuron.getOutput() double delta = alpha * srcNeuron.getOutput()
* destNeuron.getGradient() + connection.getTrace() * lambda; * destNeuron.getGradient() + connection.getTrace() * lambda;
//TODO allow for momentum // TODO allow for momentum
//double lastDelta = connection.getLastDelta(); // double lastDelta = connection.getLastDelta();
connection.addDelta(delta); connection.addDelta(delta);
} }
} }
@@ -117,23 +121,23 @@ public class TemporalDifference extends TrainingMethod {
@Override @Override
protected void iteratePattern(FeedforwardNetwork neuralNetwork, protected void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward) { NNDataPair statePair, NNData nextReward) {
//System.out.println("Learningrate: " + alpha); // System.out.println("Learningrate: " + alpha);
zeroGradients(neuralNetwork); zeroGradients(neuralNetwork);
//System.out.println("Training with: " + statePair.getInput()); // System.out.println("Training with: " + statePair.getInput());
NNData ideal = nextReward; NNData ideal = nextReward;
NNData actual = neuralNetwork.compute(statePair); NNData actual = neuralNetwork.compute(statePair);
//System.out.println("Updating weights. Ideal Output: " + ideal); // System.out.println("Updating weights. Ideal Output: " + ideal);
//System.out.println("Actual Output: " + actual); // System.out.println("Actual Output: " + actual);
// backpropagate the gradients w.r.t. output error // backpropagate the gradients w.r.t. output error
backPropagate(neuralNetwork, ideal); backPropagate(neuralNetwork, ideal);
double predictionError = statePair.getIdeal().getValues()[0] // reward_t double predictionError = statePair.getIdeal().getValues()[0] // reward_t
+ actual.getValues()[0] - nextReward.getValues()[0]; + actual.getValues()[0] - nextReward.getValues()[0];
updateWeights(neuralNetwork, predictionError); updateWeights(neuralNetwork, predictionError);
} }

View File

@@ -16,6 +16,15 @@ public class AlphaBeta implements Policy {
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private int lookAhead; private int lookAhead;
private int numStateEvaluations = 0; private int numStateEvaluations = 0;

View File

@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.gui.Goban; import net.woodyfolsom.msproj.gui.Goban;
public class HumanGuiInput implements Policy { public class HumanGuiInput implements Policy {
private boolean logging;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private Goban goban; private Goban goban;
public HumanGuiInput(Goban goban) { public HumanGuiInput(Goban goban) {

View File

@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Player;
public class HumanKeyboardInput implements Policy { public class HumanKeyboardInput implements Policy {
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
@Override @Override
public Action getAction(GameConfig gameConfig, GameState gameState, public Action getAction(GameConfig gameConfig, GameState gameState,

View File

@@ -16,6 +16,15 @@ public class Minimax implements Policy {
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private int lookAhead; private int lookAhead;
private int numStateEvaluations = 0; private int numStateEvaluations = 0;

View File

@@ -15,6 +15,15 @@ import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public abstract class MonteCarlo implements Policy { public abstract class MonteCarlo implements Policy {
protected static final int ROLLOUT_DEPTH_LIMIT = 250; protected static final int ROLLOUT_DEPTH_LIMIT = 250;
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
protected int numStateEvaluations = 0; protected int numStateEvaluations = 0;
protected Policy movePolicy; protected Policy movePolicy;

View File

@@ -63,6 +63,43 @@ public class MonteCarloAMAF extends MonteCarloUCT {
rootGameState, new AMAFProperties()); rootGameState, new AMAFProperties());
} }
@Override
public Action getBestAction(GameTreeNode<MonteCarloProperties> node) {
Action bestAction = Action.NONE;
double bestScore = Double.NEGATIVE_INFINITY;
GameTreeNode<MonteCarloProperties> bestChild = null;
for (Action action : node.getActions()) {
GameTreeNode<MonteCarloProperties> childNode = node
.getChild(action);
AMAFProperties childProps = (AMAFProperties)childNode.getProperties();
double childScore = childProps.getAmafWins() / (double)childProps.getAmafVisits();
if (childScore >= bestScore) {
bestScore = childScore;
bestAction = action;
bestChild = childNode;
}
}
if (bestAction == Action.NONE) {
System.out
.println("MonteCarloUCT failed - no actions were found for the current game state (not even PASS).");
} else {
System.out.println("Action " + bestAction + " selected for "
+ node.getGameState().getPlayerToMove()
+ " with simulated win ratio of "
+ (bestScore * 100.0 + "%"));
System.out.println("It was visited "
+ bestChild.getProperties().getVisits() + " times out of "
+ node.getProperties().getVisits() + " rollouts among "
+ node.getNumChildren()
+ " valid actions from the current state.");
}
return bestAction;
}
@Override @Override
protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) { protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) {
//double nodeVisits = gameTreeNode.getParent().getProperties().getVisits(); //double nodeVisits = gameTreeNode.getParent().getProperties().getVisits();
@@ -72,16 +109,8 @@ public class MonteCarloAMAF extends MonteCarloUCT {
if (gameTreeNode.getGameState().isTerminal()) { if (gameTreeNode.getGameState().isTerminal()) {
nodeScore = 0.0; nodeScore = 0.0;
} else { } else {
/*
MonteCarloProperties properties = gameTreeNode.getProperties();
nodeScore = (double) (properties.getWins() / properties
.getVisits())
+ (TUNING_CONSTANT * Math.sqrt(Math.log(nodeVisits)
/ gameTreeNode.getProperties().getVisits()));
*
*/
AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties(); AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties();
nodeScore = (double) (properties.getAmafWins() / properties nodeScore = (properties.getAmafWins() / (double) properties
.getAmafVisits()) .getAmafVisits())
+ (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits) + (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits)
/ properties.getAmafVisits())); / properties.getAmafVisits()));

View File

@@ -0,0 +1,59 @@
package net.woodyfolsom.msproj.policy;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.AMAFProperties;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class MonteCarloSMAF extends MonteCarloAMAF {
private int horizon;
public MonteCarloSMAF(Policy movePolicy, long searchTimeLimit, int horizon) {
super(movePolicy, searchTimeLimit);
this.horizon = horizon;
}
@Override
public void update(GameTreeNode<MonteCarloProperties> node, Rollout rollout) {
GameTreeNode<MonteCarloProperties> currentNode = node;
//List<Action> subTreeActions = new ArrayList<Action>(rollout.getPlayout());
List<Action> playout = rollout.getPlayout();
int reward = rollout.getReward();
while (currentNode != null) {
AMAFProperties nodeProperties = (AMAFProperties)currentNode.getProperties();
//Always update props for the current node
nodeProperties.setWins(nodeProperties.getWins() + reward);
nodeProperties.setVisits(nodeProperties.getVisits() + 1);
nodeProperties.setAmafWins(nodeProperties.getAmafWins() + reward);
nodeProperties.setAmafVisits(nodeProperties.getAmafVisits() + 1);
GameTreeNode<MonteCarloProperties> parentNode = currentNode.getParent();
if (parentNode != null) {
Player playerToMove = parentNode.getGameState().getPlayerToMove();
for (Action actionFromParent : parentNode.getActions()) {
if (playout.subList(0, Math.max(horizon,playout.size())).contains(actionFromParent)) {
GameTreeNode<MonteCarloProperties> subTreeChild = parentNode.getChild(actionFromParent);
//Don't count AMAF properties for the current node twice
if (subTreeChild == currentNode) {
continue;
}
AMAFProperties siblingProperties = (AMAFProperties)subTreeChild.getProperties();
//Only update AMAF properties if the sibling is reached by the same action with the same player to move
if (rollout.hasPlay(playerToMove,actionFromParent)) {
siblingProperties.setAmafWins(siblingProperties.getAmafWins() + reward);
siblingProperties.setAmafVisits(siblingProperties.getAmafVisits() + 1);
}
}
}
}
currentNode = currentNode.getParent();
}
}
}

View File

@@ -90,11 +90,8 @@ public class MonteCarloUCT extends MonteCarlo {
GameTreeNode<MonteCarloProperties> childNode = node GameTreeNode<MonteCarloProperties> childNode = node
.getChild(action); .getChild(action);
//MonteCarloProperties properties = childNode.getProperties(); MonteCarloProperties childProps = childNode.getProperties();
//double childScore = (double) properties.getWins() double childScore = childProps.getWins() / (double)childProps.getVisits();
// / properties.getVisits();
double childScore = getNodeScore(childNode);
if (childScore >= bestScore) { if (childScore >= bestScore) {
bestScore = childScore; bestScore = childScore;

View File

@@ -17,4 +17,8 @@ public interface Policy {
public int getNumStateEvaluations(); public int getNumStateEvaluations();
public void setState(GameState gameState); public void setState(GameState gameState);
boolean isLogging();
void setLogging(boolean logging);
} }

View File

@@ -110,6 +110,7 @@ public class RandomMovePolicy implements Policy, ActionGenerator {
return randomAction; return randomAction;
} }
@Override
public boolean isLogging() { public boolean isLogging() {
return logging; return logging;
} }

View File

@@ -13,7 +13,16 @@ import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.MonteCarloProperties; import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class RootParallelization implements Policy { public class RootParallelization implements Policy {
private boolean logging = false;
private int numTrees = 1; private int numTrees = 1;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private long timeLimit = 1000L; private long timeLimit = 1000L;
public RootParallelization(int numTrees, long timeLimit) { public RootParallelization(int numTrees, long timeLimit) {