Bug in TDL seems to be fixed. Issue was incorrect calculation of next state reward (reuse of current state) in epsilon-greedy learner.
This commit is contained in:
@@ -7,7 +7,7 @@ public class Connection {
|
|||||||
private int src;
|
private int src;
|
||||||
private int dest;
|
private int dest;
|
||||||
private double weight;
|
private double weight;
|
||||||
private transient double lastDelta = 0.0;
|
//private transient double lastDelta = 0.0;
|
||||||
private transient double trace = 0.0;
|
private transient double trace = 0.0;
|
||||||
|
|
||||||
public Connection() {
|
public Connection() {
|
||||||
@@ -23,7 +23,7 @@ public class Connection {
|
|||||||
public void addDelta(double delta) {
|
public void addDelta(double delta) {
|
||||||
this.trace = delta;
|
this.trace = delta;
|
||||||
this.weight += delta;
|
this.weight += delta;
|
||||||
this.lastDelta = delta;
|
//this.lastDelta = delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlAttribute
|
@XmlAttribute
|
||||||
@@ -31,10 +31,10 @@ public class Connection {
|
|||||||
return dest;
|
return dest;
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlTransient
|
//@XmlTransient
|
||||||
public double getLastDelta() {
|
//public double getLastDelta() {
|
||||||
return lastDelta;
|
// return lastDelta;
|
||||||
}
|
//}
|
||||||
|
|
||||||
@XmlAttribute
|
@XmlAttribute
|
||||||
public int getSrc() {
|
public int getSrc() {
|
||||||
@@ -100,6 +100,6 @@ public class Connection {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "Connection(" + src + ", " + dest +"), weight: " + weight;
|
return "Connection(src: " + src + ",dest: " + dest + ", trace:" + trace +"), weight: " + weight;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -16,4 +16,9 @@ public class NNDataPair {
|
|||||||
public NNData getIdeal() {
|
public NNData getIdeal() {
|
||||||
return ideal;
|
return ideal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return input.toString() + " => " + ideal.toString();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -15,49 +13,74 @@ import net.woodyfolsom.msproj.tictactoe.Policy;
|
|||||||
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
|
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
|
||||||
import net.woodyfolsom.msproj.tictactoe.State;
|
import net.woodyfolsom.msproj.tictactoe.State;
|
||||||
|
|
||||||
public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter
|
public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
||||||
|
// version of NeuralNetFilter
|
||||||
|
|
||||||
public static void main(String[] args) throws FileNotFoundException {
|
public static void main(String[] args) throws FileNotFoundException {
|
||||||
double alpha = 0.0;
|
double alpha = 0.15;
|
||||||
double lambda = 0.9;
|
double lambda = .95;
|
||||||
int maxGames = 15000;
|
int maxGames = 1000;
|
||||||
|
|
||||||
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException {
|
public void trainNetwork(double alpha, double lambda, int maxGames)
|
||||||
///
|
throws FileNotFoundException {
|
||||||
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1);
|
|
||||||
|
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9, 6,
|
||||||
|
1);
|
||||||
neuralNetwork.setName("TicTacToe");
|
neuralNetwork.setName("TicTacToe");
|
||||||
neuralNetwork.initWeights();
|
neuralNetwork.initWeights();
|
||||||
TrainingMethod trainer = new TemporalDifference(0.5,0.5);
|
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
|
||||||
|
|
||||||
System.out.println("Playing untrained games.");
|
System.out.println("Playing untrained games.");
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
|
System.out.println("" + (i + 1) + ". "
|
||||||
|
+ playOptimal(neuralNetwork).getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Learning from " + maxGames + " games of random self-play");
|
System.out.println("Learning from " + maxGames
|
||||||
|
+ " games of random self-play");
|
||||||
|
|
||||||
int gamesPlayed = 0;
|
int gamesPlayed = 0;
|
||||||
List<RESULT> results = new ArrayList<RESULT>();
|
List<RESULT> results = new ArrayList<RESULT>();
|
||||||
do {
|
do {
|
||||||
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer);
|
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork,
|
||||||
|
trainer);
|
||||||
System.out.println("Winner: " + gameRecord.getResult());
|
System.out.println("Winner: " + gameRecord.getResult());
|
||||||
gamesPlayed++;
|
gamesPlayed++;
|
||||||
results.add(gameRecord.getResult());
|
results.add(gameRecord.getResult());
|
||||||
} while (gamesPlayed < maxGames);
|
} while (gamesPlayed < maxGames);
|
||||||
///
|
|
||||||
|
|
||||||
System.out.println("Learned network after " + maxGames + " training games.");
|
System.out.println("Learned network after " + maxGames
|
||||||
|
+ " training games.");
|
||||||
double[][] validationSet = new double[8][];
|
|
||||||
|
|
||||||
for (int i = 0; i < results.size(); i++) {
|
for (int i = 0; i < results.size(); i++) {
|
||||||
if (i % 10 == 0) {
|
if (i % 10 == 0) {
|
||||||
System.out.println("" + (i + 1) + ". " + results.get(i));
|
System.out.println("" + (i + 1) + ". " + results.get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
evalTestCases(neuralNetwork);
|
||||||
|
|
||||||
|
System.out.println("Playing optimal games.");
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
System.out.println("" + (i + 1) + ". "
|
||||||
|
+ playOptimal(neuralNetwork).getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* File output = new File("ttt.net");
|
||||||
|
*
|
||||||
|
* FileOutputStream fos = new FileOutputStream(output);
|
||||||
|
*
|
||||||
|
* neuralNetwork.save(fos);
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
|
||||||
|
double[][] validationSet = new double[8][];
|
||||||
|
|
||||||
// empty board
|
// empty board
|
||||||
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
0.0, 0.0 };
|
0.0, 0.0 };
|
||||||
@@ -81,10 +104,8 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
0.0, -1.0 };
|
0.0, -1.0 };
|
||||||
|
|
||||||
// about to win
|
// about to win
|
||||||
validationSet[7] = new double[] {
|
validationSet[7] = new double[] { -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0,
|
||||||
-1.0, 1.0, 1.0,
|
-1.0, 0.0 };
|
||||||
1.0, -1.0, 1.0,
|
|
||||||
-1.0, -1.0, 0.0 };
|
|
||||||
|
|
||||||
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
||||||
"12", "20", "21", "22" };
|
"12", "20", "21", "22" };
|
||||||
@@ -92,18 +113,6 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
|
|
||||||
System.out.println("Output from eval set (learned network):");
|
System.out.println("Output from eval set (learned network):");
|
||||||
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||||
|
|
||||||
System.out.println("Playing optimal games.");
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
File output = new File("ttt.net");
|
|
||||||
|
|
||||||
FileOutputStream fos = new FileOutputStream(output);
|
|
||||||
|
|
||||||
neuralNetwork.save(fos);*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
||||||
@@ -113,6 +122,8 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
|
|
||||||
State state = gameRecord.getState();
|
State state = gameRecord.getState();
|
||||||
|
|
||||||
|
System.out.println("Playing optimal game:");
|
||||||
|
|
||||||
do {
|
do {
|
||||||
Action action;
|
Action action;
|
||||||
State nextState;
|
State nextState;
|
||||||
@@ -120,8 +131,9 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
action = neuralNetPolicy.getAction(gameRecord.getState());
|
action = neuralNetPolicy.getAction(gameRecord.getState());
|
||||||
|
|
||||||
nextState = gameRecord.apply(action);
|
nextState = gameRecord.apply(action);
|
||||||
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
|
System.out.println("Action " + action + " selected by policy " +
|
||||||
//System.out.println("Next board state: " + nextState);
|
neuralNetPolicy.getName());
|
||||||
|
System.out.println("Next board state: " + nextState);
|
||||||
state = nextState;
|
state = nextState;
|
||||||
} while (!state.isTerminal());
|
} while (!state.isTerminal());
|
||||||
|
|
||||||
@@ -130,7 +142,8 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
return gameRecord;
|
return gameRecord;
|
||||||
}
|
}
|
||||||
|
|
||||||
private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
|
private GameRecord playEpsilonGreedy(double epsilon,
|
||||||
|
FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
|
||||||
GameRecord gameRecord = new GameRecord();
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
|
||||||
Policy randomPolicy = new RandomPolicy();
|
Policy randomPolicy = new RandomPolicy();
|
||||||
@@ -158,10 +171,13 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
|
|
||||||
nextState = gameRecord.apply(action);
|
nextState = gameRecord.apply(action);
|
||||||
statePair = NNDataSetFactory.createDataPair(state);
|
statePair = NNDataSetFactory.createDataPair(state);
|
||||||
NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState);
|
NNDataPair nextStatePair = NNDataSetFactory
|
||||||
trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal());
|
.createDataPair(nextState);
|
||||||
|
trainer.iteratePattern(neuralNetwork, statePair,
|
||||||
|
nextStatePair.getIdeal());
|
||||||
}
|
}
|
||||||
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
|
// System.out.println("Action " + action + " selected by policy " +
|
||||||
|
// selectedPolicy.getName());
|
||||||
|
|
||||||
// System.out.println("Next board state: " + nextState);
|
// System.out.println("Next board state: " + nextState);
|
||||||
|
|
||||||
@@ -180,7 +196,7 @@ public class TTTFilterTrainer { //implements epsilon-greedy trainer? online vers
|
|||||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||||
validationSet[valIndex]), new NNData(outputNames,
|
validationSet[valIndex]), new NNData(outputNames,
|
||||||
validationSet[valIndex]));
|
new double[] {0.0}));
|
||||||
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ public class TemporalDifference extends TrainingMethod {
|
|||||||
|
|
||||||
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
|
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
|
||||||
for (Connection connection : neuralNetwork.getConnections()) {
|
for (Connection connection : neuralNetwork.getConnections()) {
|
||||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
/*Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||||
|
|
||||||
double delta = alpha * srcNeuron.getOutput()
|
double delta = alpha * srcNeuron.getOutput()
|
||||||
@@ -91,6 +91,13 @@ public class TemporalDifference extends TrainingMethod {
|
|||||||
|
|
||||||
// TODO allow for momentum
|
// TODO allow for momentum
|
||||||
// double lastDelta = connection.getLastDelta();
|
// double lastDelta = connection.getLastDelta();
|
||||||
|
connection.addDelta(delta);*/
|
||||||
|
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
|
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||||
|
double delta = alpha * srcNeuron.getOutput()
|
||||||
|
* destNeuron.getGradient() + connection.getTrace() * lambda;
|
||||||
|
//TODO allow for momentum
|
||||||
|
//double lastDelta = connection.getLastDelta();
|
||||||
connection.addDelta(delta);
|
connection.addDelta(delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ public class NeuralNetPolicy extends Policy {
|
|||||||
|
|
||||||
for (Action action : validMoves) {
|
for (Action action : validMoves) {
|
||||||
State nextState = state.apply(action);
|
State nextState = state.apply(action);
|
||||||
NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
|
//NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
|
||||||
|
NNDataPair dataPair = NNDataSetFactory.createDataPair(nextState);
|
||||||
//estimated reward for X
|
//estimated reward for X
|
||||||
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
|
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user