Bug in TDL seems to be fixed. Issue was incorrect calculation of next state reward (reuse of current state) in epsilon-greedy learner.

This commit is contained in:
2012-11-28 19:45:03 -05:00
parent 214bdcd032
commit d24e7aee97
6 changed files with 130 additions and 101 deletions

View File

@@ -7,7 +7,7 @@ public class Connection {
private int src; private int src;
private int dest; private int dest;
private double weight; private double weight;
private transient double lastDelta = 0.0; //private transient double lastDelta = 0.0;
private transient double trace = 0.0; private transient double trace = 0.0;
public Connection() { public Connection() {
@@ -23,7 +23,7 @@ public class Connection {
public void addDelta(double delta) { public void addDelta(double delta) {
this.trace = delta; this.trace = delta;
this.weight += delta; this.weight += delta;
this.lastDelta = delta; //this.lastDelta = delta;
} }
@XmlAttribute @XmlAttribute
@@ -31,10 +31,10 @@ public class Connection {
return dest; return dest;
} }
@XmlTransient //@XmlTransient
public double getLastDelta() { //public double getLastDelta() {
return lastDelta; // return lastDelta;
} //}
@XmlAttribute @XmlAttribute
public int getSrc() { public int getSrc() {
@@ -100,6 +100,6 @@ public class Connection {
@Override @Override
public String toString() { public String toString() {
return "Connection(" + src + ", " + dest +"), weight: " + weight; return "Connection(src: " + src + ",dest: " + dest + ", trace:" + trace +"), weight: " + weight;
} }
} }

View File

@@ -16,4 +16,9 @@ public class NNDataPair {
public NNData getIdeal() { public NNData getIdeal() {
return ideal; return ideal;
} }
@Override
public String toString() {
return input.toString() + " => " + ideal.toString();
}
} }

View File

@@ -57,7 +57,7 @@ public class Neuron {
public void setInput(double input) { public void setInput(double input) {
this.input = input; this.input = input;
} }
@Override @Override
public int hashCode() { public int hashCode() {
final int prime = 31; final int prime = 31;

View File

@@ -1,8 +1,6 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@@ -15,49 +13,74 @@ import net.woodyfolsom.msproj.tictactoe.Policy;
import net.woodyfolsom.msproj.tictactoe.RandomPolicy; import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
import net.woodyfolsom.msproj.tictactoe.State; import net.woodyfolsom.msproj.tictactoe.State;
public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
// version of NeuralNetFilter
public static void main(String[] args) throws FileNotFoundException { public static void main(String[] args) throws FileNotFoundException {
double alpha = 0.0; double alpha = 0.15;
double lambda = 0.9; double lambda = .95;
int maxGames = 15000; int maxGames = 1000;
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames); new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
} }
public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException { public void trainNetwork(double alpha, double lambda, int maxGames)
/// throws FileNotFoundException {
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1);
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9, 6,
1);
neuralNetwork.setName("TicTacToe"); neuralNetwork.setName("TicTacToe");
neuralNetwork.initWeights(); neuralNetwork.initWeights();
TrainingMethod trainer = new TemporalDifference(0.5,0.5); TrainingMethod trainer = new TemporalDifference(alpha, lambda);
System.out.println("Playing untrained games."); System.out.println("Playing untrained games.");
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult()); System.out.println("" + (i + 1) + ". "
+ playOptimal(neuralNetwork).getResult());
} }
System.out.println("Learning from " + maxGames + " games of random self-play"); System.out.println("Learning from " + maxGames
+ " games of random self-play");
int gamesPlayed = 0; int gamesPlayed = 0;
List<RESULT> results = new ArrayList<RESULT>(); List<RESULT> results = new ArrayList<RESULT>();
do { do {
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer); GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork,
trainer);
System.out.println("Winner: " + gameRecord.getResult()); System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++; gamesPlayed++;
results.add(gameRecord.getResult()); results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames); } while (gamesPlayed < maxGames);
///
System.out.println("Learned network after " + maxGames + " training games.");
double[][] validationSet = new double[8][]; System.out.println("Learned network after " + maxGames
+ " training games.");
for (int i = 0; i < results.size(); i++) { for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) { if (i % 10 == 0) {
System.out.println("" + (i+1) + ". " + results.get(i)); System.out.println("" + (i + 1) + ". " + results.get(i));
} }
} }
evalTestCases(neuralNetwork);
System.out.println("Playing optimal games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i + 1) + ". "
+ playOptimal(neuralNetwork).getResult());
}
/*
* File output = new File("ttt.net");
*
* FileOutputStream fos = new FileOutputStream(output);
*
* neuralNetwork.save(fos);
*/
}
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
double[][] validationSet = new double[8][];
// empty board // empty board
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 }; 0.0, 0.0 };
@@ -81,73 +104,63 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
0.0, -1.0 }; 0.0, -1.0 };
// about to win // about to win
validationSet[7] = new double[] { validationSet[7] = new double[] { -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0,
-1.0, 1.0, 1.0, -1.0, 0.0 };
1.0, -1.0, 1.0,
-1.0, -1.0, 0.0 };
String[] inputNames = new String[] { "00", "01", "02", "10", "11", String[] inputNames = new String[] { "00", "01", "02", "10", "11",
"12", "20", "21", "22" }; "12", "20", "21", "22" };
String[] outputNames = new String[] { "values" }; String[] outputNames = new String[] { "values" };
System.out.println("Output from eval set (learned network):"); System.out.println("Output from eval set (learned network):");
testNetwork(neuralNetwork, validationSet, inputNames, outputNames); testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
System.out.println("Playing optimal games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
}
/*
File output = new File("ttt.net");
FileOutputStream fos = new FileOutputStream(output);
neuralNetwork.save(fos);*/
}
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
State state = gameRecord.getState();
do {
Action action;
State nextState;
action = neuralNetPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
//System.out.println("Next board state: " + nextState);
state = nextState;
} while (!state.isTerminal());
//finally, reinforce the actual reward
return gameRecord;
} }
private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) { private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord(); GameRecord gameRecord = new GameRecord();
Policy randomPolicy = new RandomPolicy();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
//System.out.println("Playing epsilon-greedy game.");
State state = gameRecord.getState(); State state = gameRecord.getState();
NNDataPair statePair;
System.out.println("Playing optimal game:");
Policy selectedPolicy; do {
trainer.zeroTraces(neuralNetwork);
do {
Action action; Action action;
State nextState; State nextState;
action = neuralNetPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
System.out.println("Action " + action + " selected by policy " +
neuralNetPolicy.getName());
System.out.println("Next board state: " + nextState);
state = nextState;
} while (!state.isTerminal());
// finally, reinforce the actual reward
return gameRecord;
}
private GameRecord playEpsilonGreedy(double epsilon,
FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
GameRecord gameRecord = new GameRecord();
Policy randomPolicy = new RandomPolicy();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
// System.out.println("Playing epsilon-greedy game.");
State state = gameRecord.getState();
NNDataPair statePair;
Policy selectedPolicy;
trainer.zeroTraces(neuralNetwork);
do {
Action action;
State nextState;
if (Math.random() < epsilon) { if (Math.random() < epsilon) {
selectedPolicy = randomPolicy; selectedPolicy = randomPolicy;
action = selectedPolicy.getAction(gameRecord.getState()); action = selectedPolicy.getAction(gameRecord.getState());
@@ -155,32 +168,35 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
} else { } else {
selectedPolicy = neuralNetPolicy; selectedPolicy = neuralNetPolicy;
action = selectedPolicy.getAction(gameRecord.getState()); action = selectedPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action); nextState = gameRecord.apply(action);
statePair = NNDataSetFactory.createDataPair(state); statePair = NNDataSetFactory.createDataPair(state);
NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState); NNDataPair nextStatePair = NNDataSetFactory
trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal()); .createDataPair(nextState);
trainer.iteratePattern(neuralNetwork, statePair,
nextStatePair.getIdeal());
} }
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName()); // System.out.println("Action " + action + " selected by policy " +
// selectedPolicy.getName());
//System.out.println("Next board state: " + nextState);
// System.out.println("Next board state: " + nextState);
state = nextState; state = nextState;
} while (!state.isTerminal()); } while (!state.isTerminal());
//finally, reinforce the actual reward // finally, reinforce the actual reward
statePair = NNDataSetFactory.createDataPair(state); statePair = NNDataSetFactory.createDataPair(state);
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal()); trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
return gameRecord; return gameRecord;
} }
private void testNetwork(FeedforwardNetwork neuralNetwork, private void testNetwork(FeedforwardNetwork neuralNetwork,
double[][] validationSet, String[] inputNames, String[] outputNames) { double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames, NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames, validationSet[valIndex]), new NNData(outputNames,
validationSet[valIndex])); new double[] {0.0}));
System.out.println(dp + " => " + neuralNetwork.compute(dp)); System.out.println(dp + " => " + neuralNetwork.compute(dp));
} }
} }

View File

@@ -83,15 +83,22 @@ public class TemporalDifference extends TrainingMethod {
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) { private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
for (Connection connection : neuralNetwork.getConnections()) { for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); /*Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = alpha * srcNeuron.getOutput() double delta = alpha * srcNeuron.getOutput()
* destNeuron.getGradient() * predictionError + connection.getTrace() * lambda; * destNeuron.getGradient() * predictionError + connection.getTrace() * lambda;
// TODO allow for momentum // TODO allow for momentum
// double lastDelta = connection.getLastDelta(); // double lastDelta = connection.getLastDelta();
connection.addDelta(delta); connection.addDelta(delta);*/
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = alpha * srcNeuron.getOutput()
* destNeuron.getGradient() + connection.getTrace() * lambda;
//TODO allow for momentum
//double lastDelta = connection.getLastDelta();
connection.addDelta(delta);
} }
} }

View File

@@ -24,7 +24,8 @@ public class NeuralNetPolicy extends Policy {
for (Action action : validMoves) { for (Action action : validMoves) {
State nextState = state.apply(action); State nextState = state.apply(action);
NNDataPair dataPair = NNDataSetFactory.createDataPair(state); //NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
NNDataPair dataPair = NNDataSetFactory.createDataPair(nextState);
//estimated reward for X //estimated reward for X
scores.put(action, neuralNet.compute(dataPair).getValues()[0]); scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
} }