From d24e7aee97a6ce0db55343cb9faf0461ed90854a Mon Sep 17 00:00:00 2001 From: Woody Folsom Date: Wed, 28 Nov 2012 19:45:03 -0500 Subject: [PATCH] Bug in TDL seems to be fixed. Issue was incorrect calculation of next state reward (reuse of current state) in epsilon-greedy learner. --- .../woodyfolsom/msproj/ann/Connection.java | 14 +- .../woodyfolsom/msproj/ann/NNDataPair.java | 5 + src/net/woodyfolsom/msproj/ann/Neuron.java | 2 +- .../msproj/ann/TTTFilterTrainer.java | 194 ++++++++++-------- .../msproj/ann/TemporalDifference.java | 13 +- .../msproj/tictactoe/NeuralNetPolicy.java | 3 +- 6 files changed, 130 insertions(+), 101 deletions(-) diff --git a/src/net/woodyfolsom/msproj/ann/Connection.java b/src/net/woodyfolsom/msproj/ann/Connection.java index f2d2e5a..0cb4fd0 100644 --- a/src/net/woodyfolsom/msproj/ann/Connection.java +++ b/src/net/woodyfolsom/msproj/ann/Connection.java @@ -7,7 +7,7 @@ public class Connection { private int src; private int dest; private double weight; - private transient double lastDelta = 0.0; + //private transient double lastDelta = 0.0; private transient double trace = 0.0; public Connection() { @@ -23,7 +23,7 @@ public class Connection { public void addDelta(double delta) { this.trace = delta; this.weight += delta; - this.lastDelta = delta; + //this.lastDelta = delta; } @XmlAttribute @@ -31,10 +31,10 @@ public class Connection { return dest; } - @XmlTransient - public double getLastDelta() { - return lastDelta; - } + //@XmlTransient + //public double getLastDelta() { + // return lastDelta; + //} @XmlAttribute public int getSrc() { @@ -100,6 +100,6 @@ public class Connection { @Override public String toString() { - return "Connection(" + src + ", " + dest +"), weight: " + weight; + return "Connection(src: " + src + ",dest: " + dest + ", trace:" + trace +"), weight: " + weight; } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/NNDataPair.java b/src/net/woodyfolsom/msproj/ann/NNDataPair.java index 2b2921a..fe6f5ce 100644 --- a/src/net/woodyfolsom/msproj/ann/NNDataPair.java +++ b/src/net/woodyfolsom/msproj/ann/NNDataPair.java @@ -16,4 +16,9 @@ public class NNDataPair { public NNData getIdeal() { return ideal; } + + @Override + public String toString() { + return input.toString() + " => " + ideal.toString(); + } } diff --git a/src/net/woodyfolsom/msproj/ann/Neuron.java b/src/net/woodyfolsom/msproj/ann/Neuron.java index c5ad0c4..c7f9c70 100644 --- a/src/net/woodyfolsom/msproj/ann/Neuron.java +++ b/src/net/woodyfolsom/msproj/ann/Neuron.java @@ -57,7 +57,7 @@ public class Neuron { public void setInput(double input) { this.input = input; } - + @Override public int hashCode() { final int prime = 31; diff --git a/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java index 036e77a..429833e 100644 --- a/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java +++ b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java @@ -1,8 +1,6 @@ package net.woodyfolsom.msproj.ann; -import java.io.File; import java.io.FileNotFoundException; -import java.io.FileOutputStream; import java.util.ArrayList; import java.util.List; @@ -15,49 +13,74 @@ import net.woodyfolsom.msproj.tictactoe.Policy; import net.woodyfolsom.msproj.tictactoe.RandomPolicy; import net.woodyfolsom.msproj.tictactoe.State; -public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter +public class TTTFilterTrainer { // implements epsilon-greedy trainer? online + // version of NeuralNetFilter public static void main(String[] args) throws FileNotFoundException { - double alpha = 0.0; - double lambda = 0.9; - int maxGames = 15000; - + double alpha = 0.15; + double lambda = .95; + int maxGames = 1000; + new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames); } - - public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException { - /// - FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1); + + public void trainNetwork(double alpha, double lambda, int maxGames) + throws FileNotFoundException { + + FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9, 6, + 1); neuralNetwork.setName("TicTacToe"); neuralNetwork.initWeights(); - TrainingMethod trainer = new TemporalDifference(0.5,0.5); - + TrainingMethod trainer = new TemporalDifference(alpha, lambda); + System.out.println("Playing untrained games."); for (int i = 0; i < 10; i++) { - System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult()); + System.out.println("" + (i + 1) + ". " + + playOptimal(neuralNetwork).getResult()); } - - System.out.println("Learning from " + maxGames + " games of random self-play"); - + + System.out.println("Learning from " + maxGames + + " games of random self-play"); + int gamesPlayed = 0; List results = new ArrayList(); do { - GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer); + GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, + trainer); System.out.println("Winner: " + gameRecord.getResult()); gamesPlayed++; results.add(gameRecord.getResult()); } while (gamesPlayed < maxGames); - /// - - System.out.println("Learned network after " + maxGames + " training games."); - double[][] validationSet = new double[8][]; - + System.out.println("Learned network after " + maxGames + + " training games."); + for (int i = 0; i < results.size(); i++) { if (i % 10 == 0) { - System.out.println("" + (i+1) + ". " + results.get(i)); + System.out.println("" + (i + 1) + ". " + results.get(i)); } } + + evalTestCases(neuralNetwork); + + System.out.println("Playing optimal games."); + for (int i = 0; i < 10; i++) { + System.out.println("" + (i + 1) + ". " + + playOptimal(neuralNetwork).getResult()); + } + + /* + * File output = new File("ttt.net"); + * + * FileOutputStream fos = new FileOutputStream(output); + * + * neuralNetwork.save(fos); + */ + } + + private void evalTestCases(FeedforwardNetwork neuralNetwork) { + double[][] validationSet = new double[8][]; + // empty board validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 }; @@ -81,73 +104,63 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers 0.0, -1.0 }; // about to win - validationSet[7] = new double[] { - -1.0, 1.0, 1.0, - 1.0, -1.0, 1.0, - -1.0, -1.0, 0.0 }; - + validationSet[7] = new double[] { -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, + -1.0, 0.0 }; + String[] inputNames = new String[] { "00", "01", "02", "10", "11", "12", "20", "21", "22" }; String[] outputNames = new String[] { "values" }; System.out.println("Output from eval set (learned network):"); testNetwork(neuralNetwork, validationSet, inputNames, outputNames); - - System.out.println("Playing optimal games."); - for (int i = 0; i < 10; i++) { - System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult()); - } - - /* - File output = new File("ttt.net"); - - FileOutputStream fos = new FileOutputStream(output); - - neuralNetwork.save(fos);*/ - } - - private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) { - GameRecord gameRecord = new GameRecord(); - - Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); - - State state = gameRecord.getState(); - - do { - Action action; - State nextState; - - action = neuralNetPolicy.getAction(gameRecord.getState()); - - nextState = gameRecord.apply(action); - //System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName()); - //System.out.println("Next board state: " + nextState); - state = nextState; - } while (!state.isTerminal()); - - //finally, reinforce the actual reward - - return gameRecord; } - private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) { + private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) { GameRecord gameRecord = new GameRecord(); - - Policy randomPolicy = new RandomPolicy(); + Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); - - //System.out.println("Playing epsilon-greedy game."); - + State state = gameRecord.getState(); - NNDataPair statePair; + + System.out.println("Playing optimal game:"); - Policy selectedPolicy; - trainer.zeroTraces(neuralNetwork); - - do { + do { Action action; State nextState; - + + action = neuralNetPolicy.getAction(gameRecord.getState()); + + nextState = gameRecord.apply(action); + System.out.println("Action " + action + " selected by policy " + + neuralNetPolicy.getName()); + System.out.println("Next board state: " + nextState); + state = nextState; + } while (!state.isTerminal()); + + // finally, reinforce the actual reward + + return gameRecord; + } + + private GameRecord playEpsilonGreedy(double epsilon, + FeedforwardNetwork neuralNetwork, TrainingMethod trainer) { + GameRecord gameRecord = new GameRecord(); + + Policy randomPolicy = new RandomPolicy(); + Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); + + // System.out.println("Playing epsilon-greedy game."); + + State state = gameRecord.getState(); + NNDataPair statePair; + + Policy selectedPolicy; + trainer.zeroTraces(neuralNetwork); + + do { + Action action; + State nextState; + if (Math.random() < epsilon) { selectedPolicy = randomPolicy; action = selectedPolicy.getAction(gameRecord.getState()); @@ -155,32 +168,35 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers } else { selectedPolicy = neuralNetPolicy; action = selectedPolicy.getAction(gameRecord.getState()); - + nextState = gameRecord.apply(action); statePair = NNDataSetFactory.createDataPair(state); - NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState); - trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal()); + NNDataPair nextStatePair = NNDataSetFactory + .createDataPair(nextState); + trainer.iteratePattern(neuralNetwork, statePair, + nextStatePair.getIdeal()); } - //System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName()); - - //System.out.println("Next board state: " + nextState); - + // System.out.println("Action " + action + " selected by policy " + + // selectedPolicy.getName()); + + // System.out.println("Next board state: " + nextState); + state = nextState; } while (!state.isTerminal()); - - //finally, reinforce the actual reward + + // finally, reinforce the actual reward statePair = NNDataSetFactory.createDataPair(state); trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal()); - + return gameRecord; } - + private void testNetwork(FeedforwardNetwork neuralNetwork, double[][] validationSet, String[] inputNames, String[] outputNames) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { NNDataPair dp = new NNDataPair(new NNData(inputNames, validationSet[valIndex]), new NNData(outputNames, - validationSet[valIndex])); + new double[] {0.0})); System.out.println(dp + " => " + neuralNetwork.compute(dp)); } } diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java index 853fb9f..cc69ebc 100644 --- a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java +++ b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java @@ -83,15 +83,22 @@ public class TemporalDifference extends TrainingMethod { private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) { for (Connection connection : neuralNetwork.getConnections()) { - Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); + /*Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); - + double delta = alpha * srcNeuron.getOutput() * destNeuron.getGradient() * predictionError + connection.getTrace() * lambda; // TODO allow for momentum // double lastDelta = connection.getLastDelta(); - connection.addDelta(delta); + connection.addDelta(delta);*/ + Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); + Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); + double delta = alpha * srcNeuron.getOutput() + * destNeuron.getGradient() + connection.getTrace() * lambda; + //TODO allow for momentum + //double lastDelta = connection.getLastDelta(); + connection.addDelta(delta); } } diff --git a/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java b/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java index 4bb0fb8..51fde90 100644 --- a/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java +++ b/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java @@ -24,7 +24,8 @@ public class NeuralNetPolicy extends Policy { for (Action action : validMoves) { State nextState = state.apply(action); - NNDataPair dataPair = NNDataSetFactory.createDataPair(state); + //NNDataPair dataPair = NNDataSetFactory.createDataPair(state); + NNDataPair dataPair = NNDataSetFactory.createDataPair(nextState); //estimated reward for X scores.put(action, neuralNet.compute(dataPair).getValues()[0]); }