Lots of neural network stuff.

This commit is contained in:
2012-11-27 17:33:09 -05:00
parent 790b5666a8
commit 214bdcd032
55 changed files with 1507 additions and 821 deletions

View File

@@ -7,6 +7,5 @@
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/> <classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
<classpathentry kind="lib" path="lib/kgsGtp.jar"/> <classpathentry kind="lib" path="lib/kgsGtp.jar"/>
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/> <classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
<classpathentry kind="lib" path="lib/encog-java-core.jar" sourcepath="lib/encog-java-core-sources.jar"/>
<classpathentry kind="output" path="bin"/> <classpathentry kind="output" path="bin"/>
</classpath> </classpath>

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,21 +1,26 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.File; import java.io.InputStream;
import java.io.FileInputStream; import java.io.OutputStream;
import java.io.FileOutputStream; import java.util.List;
import java.io.IOException;
import org.encog.ml.data.MLData;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.PersistBasicNetwork;
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
protected BasicNetwork neuralNetwork; private final FeedforwardNetwork neuralNetwork;
protected int actualTrainingEpochs = 0; private final TrainingMethod trainingMethod;
protected int maxTrainingEpochs = 1000;
private double maxError;
private int actualTrainingEpochs = 0;
private int maxTrainingEpochs;
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
this.neuralNetwork = neuralNetwork;
this.trainingMethod = trainingMethod;
this.maxError = maxError;
this.maxTrainingEpochs = maxTrainingEpochs;
}
@Override @Override
public MLData compute(MLData input) { public NNData compute(NNDataPair input) {
return this.neuralNetwork.compute(input); return this.neuralNetwork.compute(input);
} }
@@ -23,35 +28,80 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
return actualTrainingEpochs; return actualTrainingEpochs;
} }
@Override
public int getInputSize() {
return 2;
}
public int getMaxTrainingEpochs() { public int getMaxTrainingEpochs() {
return maxTrainingEpochs; return maxTrainingEpochs;
} }
@Override protected FeedforwardNetwork getNeuralNetwork() {
public BasicNetwork getNeuralNetwork() {
return neuralNetwork; return neuralNetwork;
} }
public void load(String filename) throws IOException { @Override
FileInputStream fis = new FileInputStream(new File(filename)); public void learnPatterns(List<NNDataPair> trainingSet) {
neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis); actualTrainingEpochs = 0;
fis.close(); double error;
neuralNetwork.initWeights();
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
if (error <= maxError) {
System.out.println("Initial error: " + error);
return;
}
do {
trainingMethod.iteratePatterns(neuralNetwork,trainingSet);
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ error);
actualTrainingEpochs++;
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
} }
@Override @Override
public void reset() { public void learnSequences(List<List<NNDataPair>> trainingSet) {
neuralNetwork.reset(); actualTrainingEpochs = 0;
double error;
neuralNetwork.initWeights();
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
if (error <= maxError) {
System.out.println("Initial error: " + error);
return;
}
do {
trainingMethod.iterateSequences(neuralNetwork,trainingSet);
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
if (Double.isNaN(error)) {
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
}
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ error);
actualTrainingEpochs++;
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
} }
@Override @Override
public void reset(int seed) { public boolean load(InputStream input) {
neuralNetwork.reset(seed); return neuralNetwork.load(input);
} }
public void save(String filename) throws IOException { @Override
FileOutputStream fos = new FileOutputStream(new File(filename)); public boolean save(OutputStream output) {
new PersistBasicNetwork().save(fos, getNeuralNetwork()); return neuralNetwork.save(output);
fos.close(); }
public void setMaxError(double maxError) {
this.maxError = maxError;
} }
public void setMaxTrainingEpochs(int max) { public void setMaxTrainingEpochs(int max) {

View File

@@ -1,11 +1,11 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.ann2.math.ErrorFunction; import net.woodyfolsom.msproj.ann.math.ErrorFunction;
import net.woodyfolsom.msproj.ann2.math.MSSE; import net.woodyfolsom.msproj.ann.math.MSSE;
public class BackPropagation implements TrainingMethod { public class BackPropagation extends TrainingMethod {
private final ErrorFunction errorFunction; private final ErrorFunction errorFunction;
private final double learningRate; private final double learningRate;
private final double momentum; private final double momentum;
@@ -17,15 +17,13 @@ public class BackPropagation implements TrainingMethod {
} }
@Override @Override
public void iterate(FeedforwardNetwork neuralNetwork, public void iteratePatterns(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) { List<NNDataPair> trainingSet) {
System.out.println("Learningrate: " + learningRate); System.out.println("Learningrate: " + learningRate);
System.out.println("Momentum: " + momentum); System.out.println("Momentum: " + momentum);
//zeroErrors(neuralNetwork);
for (NNDataPair trainingPair : trainingSet) { for (NNDataPair trainingPair : trainingSet) {
zeroErrors(neuralNetwork); zeroGradients(neuralNetwork);
System.out.println("Training with: " + trainingPair.getInput()); System.out.println("Training with: " + trainingPair.getInput());
@@ -35,16 +33,15 @@ public class BackPropagation implements TrainingMethod {
System.out.println("Updating weights. Ideal Output: " + ideal); System.out.println("Updating weights. Ideal Output: " + ideal);
System.out.println("Actual Output: " + actual); System.out.println("Actual Output: " + actual);
updateErrors(neuralNetwork, ideal); //backpropagate the gradients w.r.t. output error
backPropagate(neuralNetwork, ideal);
updateWeights(neuralNetwork); updateWeights(neuralNetwork);
} }
//updateWeights(neuralNetwork);
} }
@Override @Override
public double computeError(FeedforwardNetwork neuralNetwork, public double computePatternError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) { List<NNDataPair> trainingSet) {
int numDataPairs = trainingSet.size(); int numDataPairs = trainingSet.size();
int outputSize = neuralNetwork.getOutput().length; int outputSize = neuralNetwork.getOutput().length;
@@ -67,15 +64,17 @@ public class BackPropagation implements TrainingMethod {
return MSSE; return MSSE;
} }
private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) { @Override
protected
void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons(); Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
double[] idealValues = ideal.getValues(); double[] idealValues = ideal.getValues();
for (int i = 0; i < idealValues.length; i++) { for (int i = 0; i < idealValues.length; i++) {
double output = outputNeurons[i].getOutput(); double input = outputNeurons[i].getInput();
double derivative = outputNeurons[i].getActivationFunction() double derivative = outputNeurons[i].getActivationFunction()
.derivative(output); .derivative(input);
outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output)); outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput()));
} }
// walking down the list of Neurons in reverse order, propagate the // walking down the list of Neurons in reverse order, propagate the
// error // error
@@ -84,19 +83,19 @@ public class BackPropagation implements TrainingMethod {
for (int n = neurons.length - 1; n >= 0; n--) { for (int n = neurons.length - 1; n >= 0; n--) {
Neuron neuron = neurons[n]; Neuron neuron = neurons[n];
double error = neuron.getError(); double error = neuron.getGradient();
Connection[] connectionsFromN = neuralNetwork Connection[] connectionsFromN = neuralNetwork
.getConnectionsFrom(neuron.getId()); .getConnectionsFrom(neuron.getId());
if (connectionsFromN.length > 0) { if (connectionsFromN.length > 0) {
double derivative = neuron.getActivationFunction().derivative( double derivative = neuron.getActivationFunction().derivative(
neuron.getOutput()); neuron.getInput());
for (Connection connection : connectionsFromN) { for (Connection connection : connectionsFromN) {
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError(); error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient();
} }
} }
neuron.setError(error); neuron.setGradient(error);
} }
} }
@@ -104,17 +103,30 @@ public class BackPropagation implements TrainingMethod {
for (Connection connection : neuralNetwork.getConnections()) { for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError(); double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient();
//TODO allow for momentum //TODO allow for momentum
//double lastDelta = connection.getLastDelta(); //double lastDelta = connection.getLastDelta();
connection.addDelta(delta); connection.addDelta(delta);
} }
} }
private void zeroErrors(FeedforwardNetwork neuralNetwork) { @Override
// Set output errors relative to ideals, all other errors to 0. public void iterateSequences(FeedforwardNetwork neuralNetwork,
for (Neuron neuron : neuralNetwork.getNeurons()) { List<List<NNDataPair>> trainingSet) {
neuron.setError(0.0); throw new UnsupportedOperationException();
} }
@Override
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
} }
@Override
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward) {
throw new UnsupportedOperationException();
}
} }

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlTransient; import javax.xml.bind.annotation.XmlTransient;
@@ -8,6 +8,7 @@ public class Connection {
private int dest; private int dest;
private double weight; private double weight;
private transient double lastDelta = 0.0; private transient double lastDelta = 0.0;
private transient double trace = 0.0;
public Connection() { public Connection() {
//no-arg constructor for JAXB //no-arg constructor for JAXB
@@ -20,6 +21,7 @@ public class Connection {
} }
public void addDelta(double delta) { public void addDelta(double delta) {
this.trace = delta;
this.weight += delta; this.weight += delta;
this.lastDelta = delta; this.lastDelta = delta;
} }
@@ -39,6 +41,10 @@ public class Connection {
return src; return src;
} }
public double getTrace() {
return trace;
}
@XmlAttribute @XmlAttribute
public double getWeight() { public double getWeight() {
return weight; return weight;
@@ -52,6 +58,11 @@ public class Connection {
this.src = src; this.src = src;
} }
@XmlTransient
public void setTrace(double trace) {
this.trace = trace;
}
public void setWeight(double weight) { public void setWeight(double weight) {
this.weight = weight; this.weight = weight;
} }

View File

@@ -1,12 +0,0 @@
package net.woodyfolsom.msproj.ann;
import org.encog.ml.data.basic.BasicMLData;
public class DoublePair extends BasicMLData {
private static final long serialVersionUID = 1L;
public DoublePair(double x, double y) {
super(new double[] { x, y });
}
}

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@@ -9,10 +9,11 @@ import java.util.Map;
import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlTransient;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction; import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Linear; import net.woodyfolsom.msproj.ann.math.Linear;
import net.woodyfolsom.msproj.ann2.math.Sigmoid; import net.woodyfolsom.msproj.ann.math.Sigmoid;
public abstract class FeedforwardNetwork { public abstract class FeedforwardNetwork {
private ActivationFunction activationFunction; private ActivationFunction activationFunction;
@@ -83,12 +84,12 @@ public abstract class FeedforwardNetwork {
* Adds a new neuron with a unique id to this FeedforwardNetwork. * Adds a new neuron with a unique id to this FeedforwardNetwork.
* @return * @return
*/ */
Neuron createNeuron(boolean input) { Neuron createNeuron(boolean input, ActivationFunction afunc) {
Neuron neuron; Neuron neuron;
if (input) { if (input) {
neuron = new Neuron(Linear.function, neurons.size()); neuron = new Neuron(Linear.function, neurons.size());
} else { } else {
neuron = new Neuron(activationFunction, neurons.size()); neuron = new Neuron(afunc, neurons.size());
} }
neurons.add(neuron); neurons.add(neuron);
return neuron; return neuron;
@@ -153,6 +154,10 @@ public abstract class FeedforwardNetwork {
return neurons.get(id); return neurons.get(id);
} }
public Connection getConnection(int index) {
return connections.get(index);
}
@XmlElement @XmlElement
protected Connection[] getConnections() { protected Connection[] getConnections() {
return connections.toArray(new Connection[connections.size()]); return connections.toArray(new Connection[connections.size()]);
@@ -178,6 +183,22 @@ public abstract class FeedforwardNetwork {
} }
} }
public double[] getGradients() {
double[] gradients = new double[neurons.size()];
for (int n = 0; n < gradients.length; n++) {
gradients[n] = neurons.get(n).getGradient();
}
return gradients;
}
public double[] getWeights() {
double[] weights = new double[connections.size()];
for (int i = 0; i < connections.size(); i++) {
weights[i] = connections.get(i).getWeight();
}
return weights;
}
@XmlAttribute @XmlAttribute
public boolean isBiased() { public boolean isBiased() {
return biased; return biased;
@@ -226,7 +247,7 @@ public abstract class FeedforwardNetwork {
this.biased = biased; this.biased = biased;
if (biased) { if (biased) {
Neuron biasNeuron = createNeuron(true); Neuron biasNeuron = createNeuron(true, activationFunction);
biasNeuron.setInput(1.0); biasNeuron.setInput(1.0);
biasNeuronId = biasNeuron.getId(); biasNeuronId = biasNeuron.getId();
} else { } else {
@@ -270,6 +291,7 @@ public abstract class FeedforwardNetwork {
} }
} }
@XmlTransient
public void setWeights(double[] weights) { public void setWeights(double[] weights) {
if (weights.length != connections.size()) { if (weights.length != connections.size()) {
throw new IllegalArgumentException("# of weights must == # of connections"); throw new IllegalArgumentException("# of weights must == # of connections");

View File

@@ -1,117 +0,0 @@
package net.woodyfolsom.msproj.ann;
import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.util.kmeans.Centroid;
public class GameStateMLDataPair implements MLDataPair {
private BasicMLDataPair mlDataPairDelegate;
private GameState gameState;
public GameStateMLDataPair(GameState gameState) {
this.gameState = gameState;
mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal()));
}
public GameStateMLDataPair(GameStateMLDataPair that) {
this.gameState = new GameState(that.gameState);
mlDataPairDelegate = new BasicMLDataPair(
that.mlDataPairDelegate.getInput(),
that.mlDataPairDelegate.getIdeal());
}
@Override
public MLDataPair clone() {
return new GameStateMLDataPair(this);
}
@Override
public Centroid<MLDataPair> createCentroid() {
return mlDataPairDelegate.createCentroid();
}
/**
* Creates a vector of normalized scores from GameState.
*
* @return
*/
private double[] createInput() {
GameResult result = gameState.getResult();
double maxScore = gameState.getGameConfig().getSize()
* gameState.getGameConfig().getSize();
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
return new double[] { blackScore, whiteScore };
}
/**
* Creates a vector of values indicating strength of black/white win output
* from network.
*
* @return
*/
private double[] createIdeal() {
GameResult result = gameState.getResult();
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
return new double[] { blackWinner, whiteWinner };
}
@Override
public MLData getIdeal() {
return mlDataPairDelegate.getIdeal();
}
@Override
public double[] getIdealArray() {
return mlDataPairDelegate.getIdealArray();
}
@Override
public MLData getInput() {
return mlDataPairDelegate.getInput();
}
@Override
public double[] getInputArray() {
return mlDataPairDelegate.getInputArray();
}
@Override
public double getSignificance() {
return mlDataPairDelegate.getSignificance();
}
@Override
public boolean isSupervised() {
return mlDataPairDelegate.isSupervised();
}
@Override
public void setIdealArray(double[] arg0) {
mlDataPairDelegate.setIdealArray(arg0);
}
@Override
public void setInputArray(double[] arg0) {
mlDataPairDelegate.setInputArray(arg0);
}
@Override
public void setSignificance(double arg0) {
mlDataPairDelegate.setSignificance(arg0);
}
}

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import java.util.Arrays; import java.util.Arrays;

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@@ -10,6 +10,10 @@ import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlRootElement; import javax.xml.bind.annotation.XmlRootElement;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann.math.Tanh;
@XmlRootElement @XmlRootElement
public class MultiLayerPerceptron extends FeedforwardNetwork { public class MultiLayerPerceptron extends FeedforwardNetwork {
private boolean biased; private boolean biased;
@@ -37,7 +41,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
throw new IllegalArgumentException("Layer size must be >= 1"); throw new IllegalArgumentException("Layer size must be >= 1");
} }
Layer newLayer = createNewLayer(layerIndex, layerSize);
Layer newLayer;
if (layerIndex == numLayers - 1) {
newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function);
} else {
newLayer = createNewLayer(layerIndex, layerSize, Tanh.function);
}
if (layerIndex > 0) { if (layerIndex > 0) {
Layer prevLayer = layers[layerIndex - 1]; Layer prevLayer = layers[layerIndex - 1];
@@ -54,11 +64,11 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
} }
} }
private Layer createNewLayer(int layerIndex, int layerSize) { private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) {
Layer layer = new Layer(layerSize); Layer layer = new Layer(layerSize);
layers[layerIndex] = layer; layers[layerIndex] = layer;
for (int n = 0; n < layerSize; n++) { for (int n = 0; n < layerSize; n++) {
Neuron neuron = createNeuron(layerIndex == 0); Neuron neuron = createNeuron(layerIndex == 0, afunc);
layer.setNeuronId(n, neuron.getId()); layer.setNeuronId(n, neuron.getId());
} }
return layer; return layer;
@@ -93,10 +103,15 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
protected void setInput(double[] input) { protected void setInput(double[] input) {
Layer inputLayer = layers[0]; Layer inputLayer = layers[0];
for (int n = 0; n < inputLayer.size(); n++) { for (int n = 0; n < inputLayer.size(); n++) {
try {
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]); getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
} catch (NullPointerException npe) {
npe.printStackTrace();
} }
} }
}
public void setLayers(Layer[] layers) { public void setLayers(Layer[] layers) {
this.layers = layers; this.layers = layers;
} }

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
public class NNData { public class NNData {
private final double[] values; private final double[] values;

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
public class NNDataPair { public class NNDataPair {
private final NNData input; private final NNData input;

View File

@@ -1,30 +1,29 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.IOException; import java.io.InputStream;
import java.io.OutputStream;
import java.util.List; import java.util.List;
import java.util.Set;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.BasicNetwork;
public interface NeuralNetFilter { public interface NeuralNetFilter {
BasicNetwork getNeuralNetwork();
int getActualTrainingEpochs(); int getActualTrainingEpochs();
int getInputSize(); int getInputSize();
int getMaxTrainingEpochs(); int getMaxTrainingEpochs();
int getOutputSize(); int getOutputSize();
void learn(MLDataSet trainingSet); boolean load(InputStream input);
void learn(Set<List<MLDataPair>> trainingSet);
boolean save(OutputStream output);
void load(String fileName) throws IOException;
void reset();
void reset(int seed);
void save(String fileName) throws IOException;
void setMaxTrainingEpochs(int max); void setMaxTrainingEpochs(int max);
MLData compute(MLData input); NNData compute(NNDataPair input);
//Due to Java type erasure, overloading a method
//simply named 'learn' which takes Lists would be problematic
void learnPatterns(List<NNDataPair> trainingSet);
void learnSequences(List<List<NNDataPair>> trainingSet);
} }

View File

@@ -1,17 +1,17 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlTransient; import javax.xml.bind.annotation.XmlTransient;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction; import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Sigmoid; import net.woodyfolsom.msproj.ann.math.Sigmoid;
public class Neuron { public class Neuron {
private ActivationFunction activationFunction; private ActivationFunction activationFunction;
private int id; private int id;
private transient double input = 0.0; private transient double input = 0.0;
private transient double error = 0.0; private transient double gradient = 0.0;
public Neuron() { public Neuron() {
//no-arg constructor for JAXB //no-arg constructor for JAXB
@@ -37,8 +37,8 @@ public class Neuron {
} }
@XmlTransient @XmlTransient
public double getError() { public double getGradient() {
return error; return gradient;
} }
@XmlTransient @XmlTransient
@@ -50,8 +50,8 @@ public class Neuron {
return activationFunction.calculate(input); return activationFunction.calculate(input);
} }
public void setError(double value) { public void setGradient(double value) {
this.error = value; this.gradient = value;
} }
public void setInput(double input) { public void setInput(double input) {
@@ -92,7 +92,7 @@ public class Neuron {
@Override @Override
public String toString() { public String toString() {
return "Neuron #" + id +", input: " + input + ", error: " + error; return "Neuron #" + id +", input: " + input + ", gradient: " + gradient;
} }
} }

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.ann;
public class ObjectiveFunction {
}

View File

@@ -0,0 +1,34 @@
package net.woodyfolsom.msproj.ann;
/**
* Based on sample code from http://neuroph.sourceforge.net
*
* @author Woody
*
*/
public class TTTFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter {
private static final int INPUT_SIZE = 9;
private static final int OUTPUT_SIZE = 1;
public TTTFilter() {
this(0.5,0.0, 1000);
}
public TTTFilter(double alpha, double lambda, int maxEpochs) {
super( new MultiLayerPerceptron(true, INPUT_SIZE, 5, OUTPUT_SIZE),
new TemporalDifference(0.5,0.0), maxEpochs, 0.05);
super.getNeuralNetwork().setName("XORFilter");
}
@Override
public int getInputSize() {
return INPUT_SIZE;
}
@Override
public int getOutputSize() {
return OUTPUT_SIZE;
}
}

View File

@@ -0,0 +1,187 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.tictactoe.Action;
import net.woodyfolsom.msproj.tictactoe.GameRecord;
import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy;
import net.woodyfolsom.msproj.tictactoe.Policy;
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
import net.woodyfolsom.msproj.tictactoe.State;
public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter
public static void main(String[] args) throws FileNotFoundException {
double alpha = 0.0;
double lambda = 0.9;
int maxGames = 15000;
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
}
public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException {
///
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1);
neuralNetwork.setName("TicTacToe");
neuralNetwork.initWeights();
TrainingMethod trainer = new TemporalDifference(0.5,0.5);
System.out.println("Playing untrained games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
}
System.out.println("Learning from " + maxGames + " games of random self-play");
int gamesPlayed = 0;
List<RESULT> results = new ArrayList<RESULT>();
do {
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer);
System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++;
results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames);
///
System.out.println("Learned network after " + maxGames + " training games.");
double[][] validationSet = new double[8][];
for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) {
System.out.println("" + (i+1) + ". " + results.get(i));
}
}
// empty board
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// center
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0 };
// top edge
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// left edge
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// corner
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// win
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
-1.0, 0.0 };
// loss
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
0.0, -1.0 };
// about to win
validationSet[7] = new double[] {
-1.0, 1.0, 1.0,
1.0, -1.0, 1.0,
-1.0, -1.0, 0.0 };
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
"12", "20", "21", "22" };
String[] outputNames = new String[] { "values" };
System.out.println("Output from eval set (learned network):");
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
System.out.println("Playing optimal games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
}
/*
File output = new File("ttt.net");
FileOutputStream fos = new FileOutputStream(output);
neuralNetwork.save(fos);*/
}
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
State state = gameRecord.getState();
do {
Action action;
State nextState;
action = neuralNetPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
//System.out.println("Next board state: " + nextState);
state = nextState;
} while (!state.isTerminal());
//finally, reinforce the actual reward
return gameRecord;
}
private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
GameRecord gameRecord = new GameRecord();
Policy randomPolicy = new RandomPolicy();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
//System.out.println("Playing epsilon-greedy game.");
State state = gameRecord.getState();
NNDataPair statePair;
Policy selectedPolicy;
trainer.zeroTraces(neuralNetwork);
do {
Action action;
State nextState;
if (Math.random() < epsilon) {
selectedPolicy = randomPolicy;
action = selectedPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
} else {
selectedPolicy = neuralNetPolicy;
action = selectedPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
statePair = NNDataSetFactory.createDataPair(state);
NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState);
trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal());
}
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
//System.out.println("Next board state: " + nextState);
state = nextState;
} while (!state.isTerminal());
//finally, reinforce the actual reward
statePair = NNDataSetFactory.createDataPair(state);
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
return gameRecord;
}
private void testNetwork(FeedforwardNetwork neuralNetwork,
double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames,
validationSet[valIndex]));
System.out.println(dp + " => " + neuralNetwork.compute(dp));
}
}
}

View File

@@ -1,30 +1,133 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import org.encog.ml.data.MLDataSet; import java.util.List;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
public class TemporalDifference extends Backpropagation { public class TemporalDifference extends TrainingMethod {
private final double alpha;
private final double gamma = 1.0;
private final double lambda; private final double lambda;
public TemporalDifference(ContainsFlat network, MLDataSet training, public TemporalDifference(double alpha, double lambda) {
double theLearnRate, double theMomentum, double lambda) { this.alpha = alpha;
super(network, training, theLearnRate, theMomentum);
this.lambda = lambda; this.lambda = lambda;
} }
public double getLamdba() { @Override
return lambda; public void iteratePatterns(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
throw new UnsupportedOperationException();
} }
@Override @Override
public double updateWeight(final double[] gradients, public double computePatternError(FeedforwardNetwork neuralNetwork,
final double[] lastGradient, final int index) { List<NNDataPair> trainingSet) {
double alpha = this.getLearningRate(); int numDataPairs = trainingSet.size();
int outputSize = neuralNetwork.getOutput().length;
int totalOutputSize = outputSize * numDataPairs;
//TODO fill in weight update for TD(lambda) double[] actuals = new double[totalOutputSize];
double[] ideals = new double[totalOutputSize];
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
NNDataPair nnDataPair = trainingSet.get(dataPair);
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
.getValues());
double[] ideal = nnDataPair.getIdeal().getValues();
int offset = dataPair * outputSize;
return 0.0; System.arraycopy(actual, 0, actuals, offset, outputSize);
System.arraycopy(ideal, 0, ideals, offset, outputSize);
} }
double MSSE = errorFunction.compute(ideals, actuals);
return MSSE;
}
@Override
protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
double[] idealValues = ideal.getValues();
for (int i = 0; i < idealValues.length; i++) {
double input = outputNeurons[i].getInput();
double derivative = outputNeurons[i].getActivationFunction()
.derivative(input);
outputNeurons[i].setGradient(outputNeurons[i].getGradient()
+ derivative
* (idealValues[i] - outputNeurons[i].getOutput()));
}
// walking down the list of Neurons in reverse order, propagate the
// error
Neuron[] neurons = neuralNetwork.getNeurons();
for (int n = neurons.length - 1; n >= 0; n--) {
Neuron neuron = neurons[n];
double error = neuron.getGradient();
Connection[] connectionsFromN = neuralNetwork
.getConnectionsFrom(neuron.getId());
if (connectionsFromN.length > 0) {
double derivative = neuron.getActivationFunction().derivative(
neuron.getInput());
for (Connection connection : connectionsFromN) {
error += derivative
* connection.getWeight()
* neuralNetwork.getNeuron(connection.getDest())
.getGradient();
}
}
neuron.setGradient(error);
}
}
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = alpha * srcNeuron.getOutput()
* destNeuron.getGradient() * predictionError + connection.getTrace() * lambda;
// TODO allow for momentum
// double lastDelta = connection.getLastDelta();
connection.addDelta(delta);
}
}
@Override
public void iterateSequences(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward) {
//System.out.println("Learningrate: " + alpha);
zeroGradients(neuralNetwork);
//System.out.println("Training with: " + statePair.getInput());
NNData ideal = nextReward;
NNData actual = neuralNetwork.compute(statePair);
//System.out.println("Updating weights. Ideal Output: " + ideal);
//System.out.println("Actual Output: " + actual);
// backpropagate the gradients w.r.t. output error
backPropagate(neuralNetwork, ideal);
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
+ actual.getValues()[0] - nextReward.getValues()[0];
updateWeights(neuralNetwork, predictionError);
}
} }

View File

@@ -0,0 +1,43 @@
package net.woodyfolsom.msproj.ann;
import java.util.List;
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
import net.woodyfolsom.msproj.ann.math.MSSE;
public abstract class TrainingMethod {
protected final ErrorFunction errorFunction;
public TrainingMethod() {
this.errorFunction = MSSE.function;
}
protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward);
protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet);
protected abstract double computePatternError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet);
protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet);
protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output);
protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet);
protected void zeroGradients(FeedforwardNetwork neuralNetwork) {
for (Neuron neuron : neuralNetwork.getNeurons()) {
neuron.setGradient(0.0);
}
}
protected void zeroTraces(FeedforwardNetwork neuralNetwork) {
for (Connection conn : neuralNetwork.getConnections()) {
conn.setTrace(0.0);
}
}
}

View File

@@ -1,105 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.util.List;
import java.util.Set;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
public class WinFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter {
public WinFilter() {
// create a neural network, without using a factory
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null, false, 2));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
network.getStructure().finalizeStructure();
network.reset();
this.neuralNetwork = network;
}
@Override
public void learn(MLDataSet trainingData) {
throw new UnsupportedOperationException(
"This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
}
/**
* Method is necessary because with temporal difference learning, some of
* the MLDataPairs are related by being a sequence of moves within a
* particular game.
*/
@Override
public void learn(Set<List<MLDataPair>> trainingSet) {
MLDataSet mlDataset = new BasicMLDataSet();
for (List<MLDataPair> gameRecord : trainingSet) {
for (int t = 0; t < gameRecord.size() - 1; t++) {
mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t)
.getInput()));
}
mlDataset.add(gameRecord.get(gameRecord.size() - 1));
}
// train the neural network
final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25);
//final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8);
actualTrainingEpochs = 0;
do {
if (actualTrainingEpochs > 0) {
int gameStateIndex = 0;
for (List<MLDataPair> gameRecord : trainingSet) {
for (int t = 0; t < gameRecord.size() - 1; t++) {
MLDataPair oldDataPair = mlDataset.get(gameStateIndex);
this.neuralNetwork.compute(oldDataPair.getInput());
gameStateIndex++;
}
gameStateIndex++;
}
}
train.iteration();
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ train.getError());
actualTrainingEpochs++;
} while (train.getError() > 0.01
&& actualTrainingEpochs <= maxTrainingEpochs);
}
@Override
public void reset() {
neuralNetwork.reset();
}
@Override
public void reset(int seed) {
neuralNetwork.reset(seed);
}
@Override
public BasicNetwork getNeuralNetwork() {
// TODO Auto-generated method stub
return null;
}
@Override
public int getInputSize() {
// TODO Auto-generated method stub
return 0;
}
@Override
public int getOutputSize() {
// TODO Auto-generated method stub
return 0;
}
}

View File

@@ -1,18 +1,5 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.util.List;
import java.util.Set;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
/** /**
* Based on sample code from http://neuroph.sourceforge.net * Based on sample code from http://neuroph.sourceforge.net
* *
@@ -22,54 +9,30 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
public class XORFilter extends AbstractNeuralNetFilter implements public class XORFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter { NeuralNetFilter {
public XORFilter() { private static final int INPUT_SIZE = 2;
// create a neural network, without using a factory private static final int OUTPUT_SIZE = 1;
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null, false, 2));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
network.getStructure().finalizeStructure();
network.reset();
this.neuralNetwork = network; public XORFilter() {
this(0.8,0.7);
}
public XORFilter(double learningRate, double momentum) {
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
new BackPropagation(learningRate, momentum), 1000, 0.01);
super.getNeuralNetwork().setName("XORFilter");
} }
public double compute(double x, double y) { public double compute(double x, double y) {
return compute(new BasicMLData(new double[]{x,y})).getData(0); return getNeuralNetwork().compute(new double[]{x,y})[0];
} }
@Override @Override
public int getInputSize() { public int getInputSize() {
return 2; return INPUT_SIZE;
} }
@Override @Override
public int getOutputSize() { public int getOutputSize() {
// TODO Auto-generated method stub return OUTPUT_SIZE;
return 1;
}
@Override
public void learn(MLDataSet trainingSet) {
// train the neural network
final MLTrain train = new Backpropagation(neuralNetwork, trainingSet,
0.7, 0.8);
actualTrainingEpochs = 0;
do {
train.iteration();
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ train.getError());
actualTrainingEpochs++;
} while (train.getError() > 0.01
&& actualTrainingEpochs <= maxTrainingEpochs);
}
@Override
public void learn(Set<List<MLDataPair>> trainingSet) {
throw new UnsupportedOperationException(
"This Filter learns an MLDataSet, not a Set<List<MLData>>.");
} }
} }

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2.math; package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlAttribute;

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2.math; package net.woodyfolsom.msproj.ann.math;
public interface ErrorFunction { public interface ErrorFunction {
double compute(double[] ideal, double[] actual); double compute(double[] ideal, double[] actual);

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2.math; package net.woodyfolsom.msproj.ann.math;
public class Linear extends ActivationFunction{ public class Linear extends ActivationFunction{
public static final Linear function = new Linear(); public static final Linear function = new Linear();

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2.math; package net.woodyfolsom.msproj.ann.math;
public class MSSE implements ErrorFunction{ public class MSSE implements ErrorFunction{
public static final ErrorFunction function = new MSSE(); public static final ErrorFunction function = new MSSE();

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2.math; package net.woodyfolsom.msproj.ann.math;
public class Sigmoid extends ActivationFunction{ public class Sigmoid extends ActivationFunction{
public static final Sigmoid function = new Sigmoid(); public static final Sigmoid function = new Sigmoid();
@@ -12,9 +12,9 @@ public class Sigmoid extends ActivationFunction{
} }
public double derivative(double arg) { public double derivative(double arg) {
//lol wth? //lol wth? oh, the next derivative formula is a function of s(x), not x.
//double eX = Math.exp(arg); double eX = Math.exp(arg);
//return eX / (Math.pow((1+eX), 2)); return eX / (Math.pow((1+eX), 2));
return arg - Math.pow(arg,2); //return arg - Math.pow(arg,2);
} }
} }

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2.math; package net.woodyfolsom.msproj.ann.math;
public class Tanh extends ActivationFunction{ public class Tanh extends ActivationFunction{
public static final Tanh function = new Tanh(); public static final Tanh function = new Tanh();

View File

@@ -1,84 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
private final FeedforwardNetwork neuralNetwork;
private final TrainingMethod trainingMethod;
private double maxError;
private int actualTrainingEpochs = 0;
private int maxTrainingEpochs;
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
this.neuralNetwork = neuralNetwork;
this.trainingMethod = trainingMethod;
this.maxError = maxError;
this.maxTrainingEpochs = maxTrainingEpochs;
}
@Override
public NNData compute(NNDataPair input) {
return this.neuralNetwork.compute(input);
}
public int getActualTrainingEpochs() {
return actualTrainingEpochs;
}
@Override
public int getInputSize() {
return 2;
}
public int getMaxTrainingEpochs() {
return maxTrainingEpochs;
}
protected FeedforwardNetwork getNeuralNetwork() {
return neuralNetwork;
}
@Override
public void learn(List<NNDataPair> trainingSet) {
actualTrainingEpochs = 0;
double error;
neuralNetwork.initWeights();
error = trainingMethod.computeError(neuralNetwork,trainingSet);
if (error <= maxError) {
System.out.println("Initial error: " + error);
return;
}
do {
trainingMethod.iterate(neuralNetwork,trainingSet);
error = trainingMethod.computeError(neuralNetwork,trainingSet);
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ error);
actualTrainingEpochs++;
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
}
@Override
public boolean load(InputStream input) {
return neuralNetwork.load(input);
}
@Override
public boolean save(OutputStream output) {
return neuralNetwork.save(output);
}
public void setMaxError(double maxError) {
this.maxError = maxError;
}
public void setMaxTrainingEpochs(int max) {
this.maxTrainingEpochs = max;
}
}

View File

@@ -1,25 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
public interface NeuralNetFilter {
int getActualTrainingEpochs();
int getInputSize();
int getMaxTrainingEpochs();
int getOutputSize();
boolean load(InputStream input);
boolean save(OutputStream output);
void setMaxTrainingEpochs(int max);
NNData compute(NNDataPair input);
void learn(List<NNDataPair> trainingSet);
}

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann2;
public class ObjectiveFunction {
}

View File

@@ -1,19 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import java.util.List;
public class TemporalDifference implements TrainingMethod {
@Override
public void iterate(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public double computeError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
throw new UnsupportedOperationException("Not implemented");
}
}

View File

@@ -1,10 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import java.util.List;
public interface TrainingMethod {
void iterate(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
double computeError(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
}

View File

@@ -1,44 +0,0 @@
package net.woodyfolsom.msproj.ann2;
/**
* Based on sample code from http://neuroph.sourceforge.net
*
* @author Woody
*
*/
public class XORFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter {
private static final int INPUT_SIZE = 2;
private static final int OUTPUT_SIZE = 1;
public XORFilter() {
this(0.8,0.7);
}
public XORFilter(double learningRate, double momentum) {
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
new BackPropagation(learningRate, momentum), 1000, 0.01);
super.getNeuralNetwork().setName("XORFilter");
//TODO remove
//getNeuralNetwork().setWeights(new double[] {
// 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
// -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
// -0.993423, 0.164732, 0.752621}); //output
}
public double compute(double x, double y) {
return getNeuralNetwork().compute(new double[]{x,y})[0];
}
@Override
public int getInputSize() {
return INPUT_SIZE;
}
@Override
public int getOutputSize() {
return OUTPUT_SIZE;
}
}

View File

@@ -0,0 +1,54 @@
package net.woodyfolsom.msproj.tictactoe;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class Action {
public static final Action NONE = new Action(PLAYER.NONE, -1, -1);
private Game.PLAYER player;
private int row;
private int column;
public static Action getInstance(PLAYER player, int row, int column) {
return new Action(player,row,column);
}
private Action(PLAYER player, int row, int column) {
this.player = player;
this.row = row;
this.column = column;
}
public Game.PLAYER getPlayer() {
return player;
}
public int getColumn() {
return column;
}
public int getRow() {
return row;
}
public boolean isNone() {
return this == Action.NONE;
}
public void setPlayer(Game.PLAYER player) {
this.player = player;
}
public void setRow(int row) {
this.row = row;
}
public void setColumn(int column) {
this.column = column;
}
@Override
public String toString() {
return player + "(" + row + ", " + column + ")";
}
}

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.tictactoe;
public class Game {
public enum PLAYER {X,O,NONE}
}

View File

@@ -0,0 +1,63 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class GameRecord {
public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS}
private List<Action> actions = new ArrayList<Action>();
private List<State> states = new ArrayList<State>();
private RESULT result = RESULT.IN_PROGRESS;
public GameRecord() {
actions.add(Action.NONE);
states.add(new State());
}
public void addState(State state) {
states.add(state);
}
public State apply(Action action) {
State nextState = getState().apply(action);
if (nextState.isValid()) {
states.add(nextState);
actions.add(action);
}
if (nextState.isTerminal()) {
if (nextState.isWinner(PLAYER.X)) {
result = RESULT.X_WINS;
} else if (nextState.isWinner(PLAYER.O)) {
result = RESULT.O_WINS;
} else {
result = RESULT.TIE_GAME;
}
}
return nextState;
}
public int getNumStates() {
return states.size();
}
public RESULT getResult() {
return result;
}
public void setResult(RESULT result) {
this.result = result;
}
public State getState() {
return states.get(states.size()-1);
}
public State getState(int index) {
return states.get(index);
}
}

View File

@@ -0,0 +1,20 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class MoveGenerator {
public List<Action> getValidActions(State state) {
PLAYER playerToMove = state.getPlayerToMove();
List<Action> validActions = new ArrayList<Action>();
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
if (state.isEmpty(i,j))
validActions.add(Action.getInstance(playerToMove, i, j));
}
}
return validActions;
}
}

View File

@@ -0,0 +1,81 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class NNDataSetFactory {
public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"};
public static final String[] TTT_OUTPUT_FIELDS = {"value"};
public static List<List<NNDataPair>> createDataSet(List<GameRecord> tttGames) {
List<List<NNDataPair>> nnDataSet = new ArrayList<List<NNDataPair>>();
for (GameRecord tttGame : tttGames) {
List<NNDataPair> gameData = createDataPairList(tttGame);
nnDataSet.add(gameData);
}
return nnDataSet;
}
public static List<NNDataPair> createDataPairList(GameRecord gameRecord) {
List<NNDataPair> gameData = new ArrayList<NNDataPair>();
for (int i = 0; i < gameRecord.getNumStates(); i++) {
gameData.add(createDataPair(gameRecord.getState(i)));
}
return gameData;
}
public static NNDataPair createDataPair(State tttState) {
double value;
if (tttState.isTerminal()) {
if (tttState.isWinner(PLAYER.X)) {
value = 1.0; // win for black
} else if (tttState.isWinner(PLAYER.O)) {
value = 0.0; // loss for black
//value = -1.0;
} else {
value = 0.5;
//value = 0.0; //tie
}
} else {
value = 0.0;
}
double[] inputValues = new double[9];
char[] boardCopy = tttState.getBoard();
inputValues[0] = getTicTacToeInput(boardCopy, 0, 0);
inputValues[1] = getTicTacToeInput(boardCopy, 0, 1);
inputValues[2] = getTicTacToeInput(boardCopy, 0, 2);
inputValues[3] = getTicTacToeInput(boardCopy, 1, 0);
inputValues[4] = getTicTacToeInput(boardCopy, 1, 1);
inputValues[5] = getTicTacToeInput(boardCopy, 1, 2);
inputValues[6] = getTicTacToeInput(boardCopy, 2, 0);
inputValues[7] = getTicTacToeInput(boardCopy, 2, 1);
inputValues[8] = getTicTacToeInput(boardCopy, 2, 2);
return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value}));
}
private static double getTicTacToeInput(char[] board, int row, int column) {
switch (board[row*3+column]) {
case 'X' :
return 1.0;
case 'O' :
return -1.0;
case '.' :
return 0.0;
default:
throw new RuntimeException("Invalid board symbol at " + row +", " + column);
}
}
}

View File

@@ -0,0 +1,67 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class NeuralNetPolicy extends Policy {
private FeedforwardNetwork neuralNet;
private MoveGenerator moveGenerator = new MoveGenerator();
public NeuralNetPolicy(FeedforwardNetwork neuralNet) {
super("NeuralNet-" + neuralNet.getName());
this.neuralNet = neuralNet;
}
@Override
public Action getAction(State state) {
List<Action> validMoves = moveGenerator.getValidActions(state);
Map<Action, Double> scores = new HashMap<Action, Double>();
for (Action action : validMoves) {
State nextState = state.apply(action);
NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
//estimated reward for X
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
}
PLAYER playerToMove = state.getPlayerToMove();
if (playerToMove == PLAYER.X) {
return returnMaxAction(scores);
} else if (playerToMove == PLAYER.O) {
return returnMinAction(scores);
} else {
throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove);
}
//return validMoves.get((int)(Math.random() * validMoves.size()));
}
private Action returnMaxAction(Map<Action,Double> scores) {
Action bestAction = null;
Double bestScore = Double.NEGATIVE_INFINITY;
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
if (entry.getValue() > bestScore) {
bestScore = entry.getValue();
bestAction = entry.getKey();
}
}
return bestAction;
}
private Action returnMinAction(Map<Action,Double> scores) {
Action bestAction = null;
Double bestScore = Double.POSITIVE_INFINITY;
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
if (entry.getValue() < bestScore) {
bestScore = entry.getValue();
bestAction = entry.getKey();
}
}
return bestAction;
}
}

View File

@@ -0,0 +1,15 @@
package net.woodyfolsom.msproj.tictactoe;
public abstract class Policy {
private String name;
protected Policy(String name) {
this.name = name;
}
public abstract Action getAction(State state);
public String getName() {
return name;
}
}

View File

@@ -0,0 +1,18 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.List;
public class RandomPolicy extends Policy {
private MoveGenerator moveGenerator = new MoveGenerator();
public RandomPolicy() {
super("Random");
}
@Override
public Action getAction(State state) {
List<Action> validMoves = moveGenerator.getValidActions(state);
return validMoves.get((int)(Math.random() * validMoves.size()));
}
}

View File

@@ -0,0 +1,43 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
public class Referee {
public static void main(String[] args) {
new Referee().play(50);
}
public List<GameRecord> play(int nGames) {
Policy policy = new RandomPolicy();
List<GameRecord> tournament = new ArrayList<GameRecord>();
for (int i = 0; i < nGames; i++) {
GameRecord gameRecord = new GameRecord();
System.out.println("Playing game #" +(i+1));
State state;
do {
Action action = policy.getAction(gameRecord.getState());
System.out.println("Action " + action + " selected by policy " + policy.getName());
state = gameRecord.apply(action);
System.out.println("Next board state:");
System.out.println(gameRecord.getState());
} while (!state.isTerminal());
System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult());
tournament.add(gameRecord);
}
System.out.println("Played " + tournament.size() + " random games.");
System.out.println("Results:");
for (int i = 0; i < tournament.size(); i++) {
GameRecord gameRecord = tournament.get(i);
System.out.println((i+1) + ". " + gameRecord.getResult());
}
return tournament;
}
}

View File

@@ -0,0 +1,116 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.Arrays;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class State {
public static final State INVALID = new State();
public static char EMPTY_SQUARE = '.';
private char[] board;
private PLAYER playerToMove;
public State() {
playerToMove = Game.PLAYER.X;
board = new char[9];
Arrays.fill(board,'.');
}
private State(State that) {
this.board = Arrays.copyOf(that.board, that.board.length);
this.playerToMove = that.playerToMove;
}
public State apply(Action action) {
if (action.getPlayer() != playerToMove) {
System.out.println("It is not " + action.getPlayer() +"'s turn.");
return State.INVALID;
}
State nextState = new State(this);
int row = action.getRow();
int column = action.getColumn();
int dest = row * 3 + column;
if (board[dest] != EMPTY_SQUARE) {
System.out.println("Invalid move " + action + ", coordinate not empty.");
return State.INVALID;
}
switch (playerToMove) {
case X : nextState.board[dest] = 'X';
break;
case O : nextState.board[dest] = 'O';
break;
default:
throw new RuntimeException("Invalid playerToMove");
}
if (playerToMove == PLAYER.X) {
nextState.playerToMove = PLAYER.O;
} else {
nextState.playerToMove = PLAYER.X;
}
return nextState;
}
public char[] getBoard() {
return Arrays.copyOf(board, board.length);
}
public PLAYER getPlayerToMove() {
return playerToMove;
}
public boolean isEmpty(int row, int column) {
return board[row*3+column] == EMPTY_SQUARE;
}
public boolean isFull(char mark1, char mark2, char mark3) {
return mark1 != '.' && mark2 != '.' && mark3 != '.';
}
public boolean isWinner(PLAYER player) {
return isWin(player,board[0],board[1],board[2]) ||
isWin(player,board[3],board[4],board[5]) ||
isWin(player,board[6],board[7],board[8]) ||
isWin(player,board[0],board[3],board[6]) ||
isWin(player,board[1],board[4],board[7]) ||
isWin(player,board[2],board[5],board[8]) ||
isWin(player,board[0],board[4],board[8]) ||
isWin(player,board[2],board[4],board[6]);
}
public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) {
if (isFull(mark1,mark2,mark3)) {
switch (player) {
case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X';
case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O';
default :
return false;
}
} else {
return false;
}
}
public boolean isTerminal() {
return isWinner(PLAYER.X) || isWinner(PLAYER.O) ||
(isFull(board[0],board[1], board[2]) &&
isFull(board[3],board[4], board[5]) &&
isFull(board[6],board[7], board[8]));
}
public boolean isValid() {
return this != INVALID;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n");
sb.append(""+board[0] + board[1] + board[2] + "\n");
sb.append(""+board[3] + board[4] + board[5] + "\n");
sb.append(""+board[6] + board[7] + board[8] + "\n");
return sb.toString();
}
}

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@@ -10,6 +10,12 @@ import java.io.IOException;
import javax.xml.bind.JAXBException; import javax.xml.bind.JAXBException;
import net.woodyfolsom.msproj.ann.Connection;
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;

View File

@@ -0,0 +1,100 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.IOException;
import java.util.List;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
import net.woodyfolsom.msproj.ann.TTTFilter;
import net.woodyfolsom.msproj.tictactoe.GameRecord;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
import net.woodyfolsom.msproj.tictactoe.Referee;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TTTFilterTest {
private static final String FILENAME = "tttPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearn() throws IOException {
double alpha = 0.5;
double lambda = 0.0;
int maxEpochs = 1000;
NeuralNetFilter nnLearner = new TTTFilter(alpha, lambda, maxEpochs);
// Create trainingSet from a tournament of random games.
// Future iterations will use Epsilon-greedy play from a policy based on
// this network to generate additional datasets.
List<GameRecord> tournament = new Referee().play(1);
List<List<NNDataPair>> trainingSet = NNDataSetFactory
.createDataSet(tournament);
System.out.println("Generated " + trainingSet.size()
+ " datasets from random self-play.");
nnLearner.learnSequences(trainingSet);
System.out.println("Learned network after "
+ nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[7][];
// empty board
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// center
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0 };
// top edge
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// left edge
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// corner
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// win
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
-1.0, 0.0 };
// loss
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
0.0, -1.0 };
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
"12", "20", "21", "22" };
String[] outputNames = new String[] { "values" };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner,
double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames,
validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
}

View File

@@ -1,64 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.Referee;
import org.antlr.runtime.RecognitionException;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.junit.Test;
public class WinFilterTest {
@Test
public void testLearnSaveLoad() throws IOException, RecognitionException {
File[] sgfFiles = new File("data/games/random_vs_random")
.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
return pathname.getName().endsWith(".sgf");
}
});
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
for (File file : sgfFiles) {
FileInputStream fis = new FileInputStream(file);
GameRecord gameRecord = Referee.replay(fis);
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
}
trainingData.add(gameData);
fis.close();
}
WinFilter winFilter = new WinFilter();
winFilter.learn(trainingData);
for (List<MLDataPair> trainingSequence : trainingData) {
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
continue;
}
MLData input = trainingSequence.get(stateIndex).getInput();
System.out.println("Turn " + stateIndex + ": " + input + " => "
+ winFilter.compute(input));
}
}
}
}

View File

@@ -1,10 +1,19 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.File; import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
import net.woodyfolsom.msproj.ann.XORFilter;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@@ -29,9 +38,8 @@ public class XORFilterTest {
} }
@Test @Test
public void testLearnSaveLoad() throws IOException { public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(); NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
// create training set (logical XOR function) // create training set (logical XOR function)
int size = 1; int size = 1;
@@ -49,9 +57,58 @@ public class XORFilterTest {
} }
// create training data // create training data
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput); List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.learn(trainingSet); nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learnPatterns(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 2;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(1);
nnLearner.learnPatterns(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2]; double[][] validationSet = new double[4][2];
@@ -61,18 +118,23 @@ public class XORFilterTest {
validationSet[3] = new double[] { 1, 1 }; validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):"); System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet); testNetwork(nnLearner, validationSet, inputNames, outputNames);
nnLearner.save(FILENAME); FileOutputStream fos = new FileOutputStream(FILENAME);
nnLearner.load(FILENAME); assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):"); System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet); testNetwork(nnLearner, validationSet, inputNames, outputNames);
} }
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) { private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]); NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp)); System.out.println(dp + " => " + nnLearner.compute(dp));
} }
} }

View File

@@ -1,11 +1,11 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction; import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Sigmoid; import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann2.math.Tanh; import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test; import org.junit.Test;

View File

@@ -1,10 +1,10 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction; import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Tanh; import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test; import org.junit.Test;

View File

@@ -1,136 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class XORFilterTest {
private static final String FILENAME = "xorPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 2;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(1);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
FileOutputStream fos = new FileOutputStream(FILENAME);
assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
}

View File

@@ -0,0 +1,73 @@
package net.woodyfolsom.msproj.tictactoe;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class GameRecordTest {
@Test
public void testGetResultXwins() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertTrue(finalState.isWinner(PLAYER.X));
assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
}
@Test
public void testGetResultOwins() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertTrue(finalState.isWinner(PLAYER.O));
assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
}
@Test
public void testGetResultTieGame() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertFalse(finalState.isWinner(PLAYER.X));
assertFalse(finalState.isWinner(PLAYER.O));
assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
}
}

View File

@@ -0,0 +1,12 @@
package net.woodyfolsom.msproj.tictactoe;
import org.junit.Test;
public class RefereeTest {
@Test
public void testPlay100Games() {
new Referee().play(100);
}
}

129
ttt.net Normal file
View File

@@ -0,0 +1,129 @@
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<multiLayerPerceptron biased="true" name="TicTacToe">
<activationFunction name="Sigmoid"/>
<connections dest="10" src="0" weight="0.5827629317852295"/>
<connections dest="10" src="1" weight="0.49198735902918994"/>
<connections dest="10" src="2" weight="-0.3019566272377494"/>
<connections dest="10" src="3" weight="0.42204442000472525"/>
<connections dest="10" src="4" weight="-0.26015075178733194"/>
<connections dest="10" src="5" weight="-0.001558299861060293"/>
<connections dest="10" src="6" weight="0.07987916348233416"/>
<connections dest="10" src="7" weight="0.07258122647153753"/>
<connections dest="10" src="8" weight="-0.691045501522254"/>
<connections dest="10" src="9" weight="0.7118463494749109"/>
<connections dest="11" src="0" weight="-1.8387878977128804"/>
<connections dest="11" src="1" weight="0.07066242812415906"/>
<connections dest="11" src="2" weight="-0.2141385079779094"/>
<connections dest="11" src="3" weight="0.02318115051417748"/>
<connections dest="11" src="4" weight="-0.4940158494633454"/>
<connections dest="11" src="5" weight="0.24951794707397953"/>
<connections dest="11" src="6" weight="-0.3422002057868113"/>
<connections dest="11" src="7" weight="-0.34896333718320666"/>
<connections dest="11" src="8" weight="0.18236809262087086"/>
<connections dest="11" src="9" weight="-0.39168932467050466"/>
<connections dest="12" src="0" weight="1.5206290139263101"/>
<connections dest="12" src="1" weight="-0.4806468102477885"/>
<connections dest="12" src="2" weight="0.21439697155823853"/>
<connections dest="12" src="3" weight="0.1226010537695569"/>
<connections dest="12" src="4" weight="-0.2957055657777683"/>
<connections dest="12" src="5" weight="0.6130228290778311"/>
<connections dest="12" src="6" weight="0.36875530286236485"/>
<connections dest="12" src="7" weight="-0.5171899914088294"/>
<connections dest="12" src="8" weight="0.10837708801339006"/>
<connections dest="12" src="9" weight="-0.7053746937035315"/>
<connections dest="13" src="0" weight="0.002913660858364482"/>
<connections dest="13" src="1" weight="-0.7651207747987173"/>
<connections dest="13" src="2" weight="0.9715970070491731"/>
<connections dest="13" src="3" weight="-0.9956453258174628"/>
<connections dest="13" src="4" weight="-0.9408358352747842"/>
<connections dest="13" src="5" weight="-1.008966493202113"/>
<connections dest="13" src="6" weight="-0.672355054680489"/>
<connections dest="13" src="7" weight="-0.3367206164565582"/>
<connections dest="13" src="8" weight="0.7588693137687637"/>
<connections dest="13" src="9" weight="-0.7196453490945308"/>
<connections dest="14" src="0" weight="-1.9439726796836931"/>
<connections dest="14" src="1" weight="-0.2894027034518325"/>
<connections dest="14" src="2" weight="0.2110335238178935"/>
<connections dest="14" src="3" weight="-0.009846640898758158"/>
<connections dest="14" src="4" weight="0.1568088381509006"/>
<connections dest="14" src="5" weight="-0.18073468038735682"/>
<connections dest="14" src="6" weight="0.3823096688264287"/>
<connections dest="14" src="7" weight="-0.21319807548539116"/>
<connections dest="14" src="8" weight="-0.3736851760400955"/>
<connections dest="14" src="9" weight="-0.10659568761110778"/>
<connections dest="15" src="0" weight="-3.5802003342217197"/>
<connections dest="15" src="10" weight="-0.520010988494904"/>
<connections dest="15" src="11" weight="2.0607479402794953"/>
<connections dest="15" src="12" weight="-1.3810086619100004"/>
<connections dest="15" src="13" weight="-0.024645797466295187"/>
<connections dest="15" src="14" weight="2.4372644169618125"/>
<neurons id="0">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="1">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="2">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="3">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="4">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="5">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="6">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="7">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="8">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="9">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="10">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="11">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="12">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="13">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="14">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="15">
<activationFunction name="Sigmoid"/>
</neurons>
<layers>
<neuronIds>1</neuronIds>
<neuronIds>2</neuronIds>
<neuronIds>3</neuronIds>
<neuronIds>4</neuronIds>
<neuronIds>5</neuronIds>
<neuronIds>6</neuronIds>
<neuronIds>7</neuronIds>
<neuronIds>8</neuronIds>
<neuronIds>9</neuronIds>
</layers>
<layers>
<neuronIds>10</neuronIds>
<neuronIds>11</neuronIds>
<neuronIds>12</neuronIds>
<neuronIds>13</neuronIds>
<neuronIds>14</neuronIds>
</layers>
<layers>
<neuronIds>15</neuronIds>
</layers>
</multiLayerPerceptron>