diff --git a/.classpath b/.classpath
index 6f91090..7addda2 100644
--- a/.classpath
+++ b/.classpath
@@ -7,6 +7,5 @@
-
diff --git a/lib/encog-java-core-javadoc.jar b/lib/encog-java-core-javadoc.jar
deleted file mode 100644
index 2cf8d93..0000000
Binary files a/lib/encog-java-core-javadoc.jar and /dev/null differ
diff --git a/lib/encog-java-core-sources.jar b/lib/encog-java-core-sources.jar
deleted file mode 100644
index 0cb6b50..0000000
Binary files a/lib/encog-java-core-sources.jar and /dev/null differ
diff --git a/lib/encog-java-core.jar b/lib/encog-java-core.jar
deleted file mode 100644
index f846d91..0000000
Binary files a/lib/encog-java-core.jar and /dev/null differ
diff --git a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
index 865d6e9..8ecde7b 100644
--- a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
+++ b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
@@ -1,21 +1,26 @@
package net.woodyfolsom.msproj.ann;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-
-import org.encog.ml.data.MLData;
-import org.encog.neural.networks.BasicNetwork;
-import org.encog.neural.networks.PersistBasicNetwork;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
- protected BasicNetwork neuralNetwork;
- protected int actualTrainingEpochs = 0;
- protected int maxTrainingEpochs = 1000;
+ private final FeedforwardNetwork neuralNetwork;
+ private final TrainingMethod trainingMethod;
+
+ private double maxError;
+ private int actualTrainingEpochs = 0;
+ private int maxTrainingEpochs;
+
+ AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
+ this.neuralNetwork = neuralNetwork;
+ this.trainingMethod = trainingMethod;
+ this.maxError = maxError;
+ this.maxTrainingEpochs = maxTrainingEpochs;
+ }
@Override
- public MLData compute(MLData input) {
+ public NNData compute(NNDataPair input) {
return this.neuralNetwork.compute(input);
}
@@ -23,38 +28,83 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
return actualTrainingEpochs;
}
+ @Override
+ public int getInputSize() {
+ return 2;
+ }
+
public int getMaxTrainingEpochs() {
return maxTrainingEpochs;
}
- @Override
- public BasicNetwork getNeuralNetwork() {
+ protected FeedforwardNetwork getNeuralNetwork() {
return neuralNetwork;
}
- public void load(String filename) throws IOException {
- FileInputStream fis = new FileInputStream(new File(filename));
- neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis);
- fis.close();
+ @Override
+ public void learnPatterns(List trainingSet) {
+ actualTrainingEpochs = 0;
+ double error;
+ neuralNetwork.initWeights();
+
+ error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
+
+ if (error <= maxError) {
+ System.out.println("Initial error: " + error);
+ return;
+ }
+
+ do {
+ trainingMethod.iteratePatterns(neuralNetwork,trainingSet);
+ error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
+ System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ + error);
+ actualTrainingEpochs++;
+ System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
+ } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
}
@Override
- public void reset() {
- neuralNetwork.reset();
+ public void learnSequences(List> trainingSet) {
+ actualTrainingEpochs = 0;
+ double error;
+ neuralNetwork.initWeights();
+
+ error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
+
+ if (error <= maxError) {
+ System.out.println("Initial error: " + error);
+ return;
+ }
+
+ do {
+ trainingMethod.iterateSequences(neuralNetwork,trainingSet);
+ error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
+ if (Double.isNaN(error)) {
+ error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
+ }
+ System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ + error);
+ actualTrainingEpochs++;
+ System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
+ } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
}
@Override
- public void reset(int seed) {
- neuralNetwork.reset(seed);
+ public boolean load(InputStream input) {
+ return neuralNetwork.load(input);
+ }
+
+ @Override
+ public boolean save(OutputStream output) {
+ return neuralNetwork.save(output);
}
- public void save(String filename) throws IOException {
- FileOutputStream fos = new FileOutputStream(new File(filename));
- new PersistBasicNetwork().save(fos, getNeuralNetwork());
- fos.close();
+ public void setMaxError(double maxError) {
+ this.maxError = maxError;
}
public void setMaxTrainingEpochs(int max) {
this.maxTrainingEpochs = max;
}
-}
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/BackPropagation.java b/src/net/woodyfolsom/msproj/ann/BackPropagation.java
similarity index 65%
rename from src/net/woodyfolsom/msproj/ann2/BackPropagation.java
rename to src/net/woodyfolsom/msproj/ann/BackPropagation.java
index 8b40e1e..6be81cd 100644
--- a/src/net/woodyfolsom/msproj/ann2/BackPropagation.java
+++ b/src/net/woodyfolsom/msproj/ann/BackPropagation.java
@@ -1,11 +1,11 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import java.util.List;
-import net.woodyfolsom.msproj.ann2.math.ErrorFunction;
-import net.woodyfolsom.msproj.ann2.math.MSSE;
+import net.woodyfolsom.msproj.ann.math.ErrorFunction;
+import net.woodyfolsom.msproj.ann.math.MSSE;
-public class BackPropagation implements TrainingMethod {
+public class BackPropagation extends TrainingMethod {
private final ErrorFunction errorFunction;
private final double learningRate;
private final double momentum;
@@ -17,15 +17,13 @@ public class BackPropagation implements TrainingMethod {
}
@Override
- public void iterate(FeedforwardNetwork neuralNetwork,
+ public void iteratePatterns(FeedforwardNetwork neuralNetwork,
List trainingSet) {
System.out.println("Learningrate: " + learningRate);
System.out.println("Momentum: " + momentum);
- //zeroErrors(neuralNetwork);
-
for (NNDataPair trainingPair : trainingSet) {
- zeroErrors(neuralNetwork);
+ zeroGradients(neuralNetwork);
System.out.println("Training with: " + trainingPair.getInput());
@@ -35,16 +33,15 @@ public class BackPropagation implements TrainingMethod {
System.out.println("Updating weights. Ideal Output: " + ideal);
System.out.println("Actual Output: " + actual);
- updateErrors(neuralNetwork, ideal);
+ //backpropagate the gradients w.r.t. output error
+ backPropagate(neuralNetwork, ideal);
updateWeights(neuralNetwork);
}
-
- //updateWeights(neuralNetwork);
}
@Override
- public double computeError(FeedforwardNetwork neuralNetwork,
+ public double computePatternError(FeedforwardNetwork neuralNetwork,
List trainingSet) {
int numDataPairs = trainingSet.size();
int outputSize = neuralNetwork.getOutput().length;
@@ -67,15 +64,17 @@ public class BackPropagation implements TrainingMethod {
return MSSE;
}
- private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) {
+ @Override
+ protected
+ void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
double[] idealValues = ideal.getValues();
for (int i = 0; i < idealValues.length; i++) {
- double output = outputNeurons[i].getOutput();
+ double input = outputNeurons[i].getInput();
double derivative = outputNeurons[i].getActivationFunction()
- .derivative(output);
- outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output));
+ .derivative(input);
+ outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput()));
}
// walking down the list of Neurons in reverse order, propagate the
// error
@@ -84,19 +83,19 @@ public class BackPropagation implements TrainingMethod {
for (int n = neurons.length - 1; n >= 0; n--) {
Neuron neuron = neurons[n];
- double error = neuron.getError();
+ double error = neuron.getGradient();
Connection[] connectionsFromN = neuralNetwork
.getConnectionsFrom(neuron.getId());
if (connectionsFromN.length > 0) {
double derivative = neuron.getActivationFunction().derivative(
- neuron.getOutput());
+ neuron.getInput());
for (Connection connection : connectionsFromN) {
- error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError();
+ error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient();
}
}
- neuron.setError(error);
+ neuron.setGradient(error);
}
}
@@ -104,17 +103,30 @@ public class BackPropagation implements TrainingMethod {
for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
- double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError();
+ double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient();
//TODO allow for momentum
//double lastDelta = connection.getLastDelta();
connection.addDelta(delta);
}
}
-
- private void zeroErrors(FeedforwardNetwork neuralNetwork) {
- // Set output errors relative to ideals, all other errors to 0.
- for (Neuron neuron : neuralNetwork.getNeurons()) {
- neuron.setError(0.0);
- }
+
+ @Override
+ public void iterateSequences(FeedforwardNetwork neuralNetwork,
+ List> trainingSet) {
+ throw new UnsupportedOperationException();
}
+
+ @Override
+ public double computeSequenceError(FeedforwardNetwork neuralNetwork,
+ List> trainingSet) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected void iteratePattern(FeedforwardNetwork neuralNetwork,
+ NNDataPair statePair, NNData nextReward) {
+ throw new UnsupportedOperationException();
+ }
+
+
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/Connection.java b/src/net/woodyfolsom/msproj/ann/Connection.java
similarity index 84%
rename from src/net/woodyfolsom/msproj/ann2/Connection.java
rename to src/net/woodyfolsom/msproj/ann/Connection.java
index 8b38b1b..f2d2e5a 100644
--- a/src/net/woodyfolsom/msproj/ann2/Connection.java
+++ b/src/net/woodyfolsom/msproj/ann/Connection.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlTransient;
@@ -8,6 +8,7 @@ public class Connection {
private int dest;
private double weight;
private transient double lastDelta = 0.0;
+ private transient double trace = 0.0;
public Connection() {
//no-arg constructor for JAXB
@@ -20,6 +21,7 @@ public class Connection {
}
public void addDelta(double delta) {
+ this.trace = delta;
this.weight += delta;
this.lastDelta = delta;
}
@@ -39,6 +41,10 @@ public class Connection {
return src;
}
+ public double getTrace() {
+ return trace;
+ }
+
@XmlAttribute
public double getWeight() {
return weight;
@@ -52,6 +58,11 @@ public class Connection {
this.src = src;
}
+ @XmlTransient
+ public void setTrace(double trace) {
+ this.trace = trace;
+ }
+
public void setWeight(double weight) {
this.weight = weight;
}
diff --git a/src/net/woodyfolsom/msproj/ann/DoublePair.java b/src/net/woodyfolsom/msproj/ann/DoublePair.java
deleted file mode 100644
index cd0a00f..0000000
--- a/src/net/woodyfolsom/msproj/ann/DoublePair.java
+++ /dev/null
@@ -1,12 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import org.encog.ml.data.basic.BasicMLData;
-
-public class DoublePair extends BasicMLData {
-
- private static final long serialVersionUID = 1L;
-
- public DoublePair(double x, double y) {
- super(new double[] { x, y });
- }
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java b/src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java
similarity index 85%
rename from src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java
rename to src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java
index 29e1f10..c89eef9 100644
--- a/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java
+++ b/src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import java.io.InputStream;
import java.io.OutputStream;
@@ -9,10 +9,11 @@ import java.util.Map;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlTransient;
-import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
-import net.woodyfolsom.msproj.ann2.math.Linear;
-import net.woodyfolsom.msproj.ann2.math.Sigmoid;
+import net.woodyfolsom.msproj.ann.math.ActivationFunction;
+import net.woodyfolsom.msproj.ann.math.Linear;
+import net.woodyfolsom.msproj.ann.math.Sigmoid;
public abstract class FeedforwardNetwork {
private ActivationFunction activationFunction;
@@ -83,12 +84,12 @@ public abstract class FeedforwardNetwork {
* Adds a new neuron with a unique id to this FeedforwardNetwork.
* @return
*/
- Neuron createNeuron(boolean input) {
+ Neuron createNeuron(boolean input, ActivationFunction afunc) {
Neuron neuron;
if (input) {
neuron = new Neuron(Linear.function, neurons.size());
} else {
- neuron = new Neuron(activationFunction, neurons.size());
+ neuron = new Neuron(afunc, neurons.size());
}
neurons.add(neuron);
return neuron;
@@ -153,6 +154,10 @@ public abstract class FeedforwardNetwork {
return neurons.get(id);
}
+ public Connection getConnection(int index) {
+ return connections.get(index);
+ }
+
@XmlElement
protected Connection[] getConnections() {
return connections.toArray(new Connection[connections.size()]);
@@ -178,6 +183,22 @@ public abstract class FeedforwardNetwork {
}
}
+ public double[] getGradients() {
+ double[] gradients = new double[neurons.size()];
+ for (int n = 0; n < gradients.length; n++) {
+ gradients[n] = neurons.get(n).getGradient();
+ }
+ return gradients;
+ }
+
+ public double[] getWeights() {
+ double[] weights = new double[connections.size()];
+ for (int i = 0; i < connections.size(); i++) {
+ weights[i] = connections.get(i).getWeight();
+ }
+ return weights;
+ }
+
@XmlAttribute
public boolean isBiased() {
return biased;
@@ -226,7 +247,7 @@ public abstract class FeedforwardNetwork {
this.biased = biased;
if (biased) {
- Neuron biasNeuron = createNeuron(true);
+ Neuron biasNeuron = createNeuron(true, activationFunction);
biasNeuron.setInput(1.0);
biasNeuronId = biasNeuron.getId();
} else {
@@ -270,6 +291,7 @@ public abstract class FeedforwardNetwork {
}
}
+ @XmlTransient
public void setWeights(double[] weights) {
if (weights.length != connections.size()) {
throw new IllegalArgumentException("# of weights must == # of connections");
diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java
deleted file mode 100644
index d7c1316..0000000
--- a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java
+++ /dev/null
@@ -1,117 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import net.woodyfolsom.msproj.GameResult;
-import net.woodyfolsom.msproj.GameState;
-import net.woodyfolsom.msproj.Player;
-
-import org.encog.ml.data.MLData;
-import org.encog.ml.data.MLDataPair;
-import org.encog.ml.data.basic.BasicMLData;
-import org.encog.ml.data.basic.BasicMLDataPair;
-import org.encog.util.kmeans.Centroid;
-
-public class GameStateMLDataPair implements MLDataPair {
- private BasicMLDataPair mlDataPairDelegate;
- private GameState gameState;
-
- public GameStateMLDataPair(GameState gameState) {
- this.gameState = gameState;
- mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal()));
- }
-
- public GameStateMLDataPair(GameStateMLDataPair that) {
- this.gameState = new GameState(that.gameState);
- mlDataPairDelegate = new BasicMLDataPair(
- that.mlDataPairDelegate.getInput(),
- that.mlDataPairDelegate.getIdeal());
- }
-
- @Override
- public MLDataPair clone() {
- return new GameStateMLDataPair(this);
- }
-
- @Override
- public Centroid createCentroid() {
- return mlDataPairDelegate.createCentroid();
- }
-
- /**
- * Creates a vector of normalized scores from GameState.
- *
- * @return
- */
- private double[] createInput() {
-
- GameResult result = gameState.getResult();
-
- double maxScore = gameState.getGameConfig().getSize()
- * gameState.getGameConfig().getSize();
-
- double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
- double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
-
- return new double[] { blackScore, whiteScore };
- }
-
- /**
- * Creates a vector of values indicating strength of black/white win output
- * from network.
- *
- * @return
- */
- private double[] createIdeal() {
- GameResult result = gameState.getResult();
-
- double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
- double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
-
- return new double[] { blackWinner, whiteWinner };
- }
-
- @Override
- public MLData getIdeal() {
- return mlDataPairDelegate.getIdeal();
- }
-
- @Override
- public double[] getIdealArray() {
- return mlDataPairDelegate.getIdealArray();
- }
-
- @Override
- public MLData getInput() {
- return mlDataPairDelegate.getInput();
- }
-
- @Override
- public double[] getInputArray() {
- return mlDataPairDelegate.getInputArray();
- }
-
- @Override
- public double getSignificance() {
- return mlDataPairDelegate.getSignificance();
- }
-
- @Override
- public boolean isSupervised() {
- return mlDataPairDelegate.isSupervised();
- }
-
- @Override
- public void setIdealArray(double[] arg0) {
- mlDataPairDelegate.setIdealArray(arg0);
- }
-
- @Override
- public void setInputArray(double[] arg0) {
- mlDataPairDelegate.setInputArray(arg0);
- }
-
- @Override
- public void setSignificance(double arg0) {
- mlDataPairDelegate.setSignificance(arg0);
- }
-
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/Layer.java b/src/net/woodyfolsom/msproj/ann/Layer.java
similarity index 92%
rename from src/net/woodyfolsom/msproj/ann2/Layer.java
rename to src/net/woodyfolsom/msproj/ann/Layer.java
index 88cb178..dc2866f 100644
--- a/src/net/woodyfolsom/msproj/ann2/Layer.java
+++ b/src/net/woodyfolsom/msproj/ann/Layer.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import java.util.Arrays;
diff --git a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java b/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java
similarity index 79%
rename from src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java
rename to src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java
index 23f3dca..6f9c8f4 100644
--- a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java
+++ b/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import java.io.InputStream;
import java.io.OutputStream;
@@ -10,6 +10,10 @@ import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlRootElement;
+import net.woodyfolsom.msproj.ann.math.ActivationFunction;
+import net.woodyfolsom.msproj.ann.math.Sigmoid;
+import net.woodyfolsom.msproj.ann.math.Tanh;
+
@XmlRootElement
public class MultiLayerPerceptron extends FeedforwardNetwork {
private boolean biased;
@@ -37,7 +41,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
throw new IllegalArgumentException("Layer size must be >= 1");
}
- Layer newLayer = createNewLayer(layerIndex, layerSize);
+
+ Layer newLayer;
+ if (layerIndex == numLayers - 1) {
+ newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function);
+ } else {
+ newLayer = createNewLayer(layerIndex, layerSize, Tanh.function);
+ }
if (layerIndex > 0) {
Layer prevLayer = layers[layerIndex - 1];
@@ -54,11 +64,11 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
}
}
- private Layer createNewLayer(int layerIndex, int layerSize) {
+ private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) {
Layer layer = new Layer(layerSize);
layers[layerIndex] = layer;
for (int n = 0; n < layerSize; n++) {
- Neuron neuron = createNeuron(layerIndex == 0);
+ Neuron neuron = createNeuron(layerIndex == 0, afunc);
layer.setNeuronId(n, neuron.getId());
}
return layer;
@@ -93,8 +103,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
protected void setInput(double[] input) {
Layer inputLayer = layers[0];
for (int n = 0; n < inputLayer.size(); n++) {
- getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
+ try {
+ getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
+ } catch (NullPointerException npe) {
+ npe.printStackTrace();
+ }
}
+
}
public void setLayers(Layer[] layers) {
diff --git a/src/net/woodyfolsom/msproj/ann2/NNData.java b/src/net/woodyfolsom/msproj/ann/NNData.java
similarity index 89%
rename from src/net/woodyfolsom/msproj/ann2/NNData.java
rename to src/net/woodyfolsom/msproj/ann/NNData.java
index 12f2150..f5200f7 100644
--- a/src/net/woodyfolsom/msproj/ann2/NNData.java
+++ b/src/net/woodyfolsom/msproj/ann/NNData.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
public class NNData {
private final double[] values;
diff --git a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java b/src/net/woodyfolsom/msproj/ann/NNDataPair.java
similarity index 83%
rename from src/net/woodyfolsom/msproj/ann2/NNDataPair.java
rename to src/net/woodyfolsom/msproj/ann/NNDataPair.java
index 53d501d..2b2921a 100644
--- a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java
+++ b/src/net/woodyfolsom/msproj/ann/NNDataPair.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
public class NNDataPair {
private final NNData input;
diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
index 73ce084..08a5c76 100644
--- a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
+++ b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
@@ -1,30 +1,29 @@
package net.woodyfolsom.msproj.ann;
-import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import java.util.List;
-import java.util.Set;
-
-import org.encog.ml.data.MLData;
-import org.encog.ml.data.MLDataPair;
-import org.encog.ml.data.MLDataSet;
-import org.encog.neural.networks.BasicNetwork;
public interface NeuralNetFilter {
- BasicNetwork getNeuralNetwork();
-
int getActualTrainingEpochs();
+
int getInputSize();
+
int getMaxTrainingEpochs();
+
int getOutputSize();
-
- void learn(MLDataSet trainingSet);
- void learn(Set> trainingSet);
-
- void load(String fileName) throws IOException;
- void reset();
- void reset(int seed);
- void save(String fileName) throws IOException;
+
+ boolean load(InputStream input);
+
+ boolean save(OutputStream output);
+
void setMaxTrainingEpochs(int max);
- MLData compute(MLData input);
+ NNData compute(NNDataPair input);
+
+ //Due to Java type erasure, overloading a method
+ //simply named 'learn' which takes Lists would be problematic
+
+ void learnPatterns(List trainingSet);
+ void learnSequences(List> trainingSet);
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/Neuron.java b/src/net/woodyfolsom/msproj/ann/Neuron.java
similarity index 78%
rename from src/net/woodyfolsom/msproj/ann2/Neuron.java
rename to src/net/woodyfolsom/msproj/ann/Neuron.java
index 4e76aa5..c5ad0c4 100644
--- a/src/net/woodyfolsom/msproj/ann2/Neuron.java
+++ b/src/net/woodyfolsom/msproj/ann/Neuron.java
@@ -1,17 +1,17 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlTransient;
-import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
-import net.woodyfolsom.msproj.ann2.math.Sigmoid;
+import net.woodyfolsom.msproj.ann.math.ActivationFunction;
+import net.woodyfolsom.msproj.ann.math.Sigmoid;
public class Neuron {
private ActivationFunction activationFunction;
private int id;
private transient double input = 0.0;
- private transient double error = 0.0;
+ private transient double gradient = 0.0;
public Neuron() {
//no-arg constructor for JAXB
@@ -37,8 +37,8 @@ public class Neuron {
}
@XmlTransient
- public double getError() {
- return error;
+ public double getGradient() {
+ return gradient;
}
@XmlTransient
@@ -50,8 +50,8 @@ public class Neuron {
return activationFunction.calculate(input);
}
- public void setError(double value) {
- this.error = value;
+ public void setGradient(double value) {
+ this.gradient = value;
}
public void setInput(double input) {
@@ -92,7 +92,7 @@ public class Neuron {
@Override
public String toString() {
- return "Neuron #" + id +", input: " + input + ", error: " + error;
+ return "Neuron #" + id +", input: " + input + ", gradient: " + gradient;
}
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java b/src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java
new file mode 100644
index 0000000..9e6230a
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java
@@ -0,0 +1,5 @@
+package net.woodyfolsom.msproj.ann;
+
+public class ObjectiveFunction {
+
+}
diff --git a/src/net/woodyfolsom/msproj/ann/TTTFilter.java b/src/net/woodyfolsom/msproj/ann/TTTFilter.java
new file mode 100644
index 0000000..31f743c
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/TTTFilter.java
@@ -0,0 +1,34 @@
+package net.woodyfolsom.msproj.ann;
+
+/**
+ * Based on sample code from http://neuroph.sourceforge.net
+ *
+ * @author Woody
+ *
+ */
+public class TTTFilter extends AbstractNeuralNetFilter implements
+ NeuralNetFilter {
+
+ private static final int INPUT_SIZE = 9;
+ private static final int OUTPUT_SIZE = 1;
+
+ public TTTFilter() {
+ this(0.5,0.0, 1000);
+ }
+
+ public TTTFilter(double alpha, double lambda, int maxEpochs) {
+ super( new MultiLayerPerceptron(true, INPUT_SIZE, 5, OUTPUT_SIZE),
+ new TemporalDifference(0.5,0.0), maxEpochs, 0.05);
+ super.getNeuralNetwork().setName("XORFilter");
+ }
+
+ @Override
+ public int getInputSize() {
+ return INPUT_SIZE;
+ }
+
+ @Override
+ public int getOutputSize() {
+ return OUTPUT_SIZE;
+ }
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
new file mode 100644
index 0000000..036e77a
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
@@ -0,0 +1,187 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+import net.woodyfolsom.msproj.tictactoe.Action;
+import net.woodyfolsom.msproj.tictactoe.GameRecord;
+import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT;
+import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
+import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy;
+import net.woodyfolsom.msproj.tictactoe.Policy;
+import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
+import net.woodyfolsom.msproj.tictactoe.State;
+
+public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter
+
+ public static void main(String[] args) throws FileNotFoundException {
+ double alpha = 0.0;
+ double lambda = 0.9;
+ int maxGames = 15000;
+
+ new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
+ }
+
+ public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException {
+ ///
+ FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1);
+ neuralNetwork.setName("TicTacToe");
+ neuralNetwork.initWeights();
+ TrainingMethod trainer = new TemporalDifference(0.5,0.5);
+
+ System.out.println("Playing untrained games.");
+ for (int i = 0; i < 10; i++) {
+ System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
+ }
+
+ System.out.println("Learning from " + maxGames + " games of random self-play");
+
+ int gamesPlayed = 0;
+ List results = new ArrayList();
+ do {
+ GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer);
+ System.out.println("Winner: " + gameRecord.getResult());
+ gamesPlayed++;
+ results.add(gameRecord.getResult());
+ } while (gamesPlayed < maxGames);
+ ///
+
+ System.out.println("Learned network after " + maxGames + " training games.");
+
+ double[][] validationSet = new double[8][];
+
+ for (int i = 0; i < results.size(); i++) {
+ if (i % 10 == 0) {
+ System.out.println("" + (i+1) + ". " + results.get(i));
+ }
+ }
+ // empty board
+ validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // center
+ validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // top edge
+ validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // left edge
+ validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // corner
+ validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // win
+ validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
+ -1.0, 0.0 };
+ // loss
+ validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
+ 0.0, -1.0 };
+
+ // about to win
+ validationSet[7] = new double[] {
+ -1.0, 1.0, 1.0,
+ 1.0, -1.0, 1.0,
+ -1.0, -1.0, 0.0 };
+
+ String[] inputNames = new String[] { "00", "01", "02", "10", "11",
+ "12", "20", "21", "22" };
+ String[] outputNames = new String[] { "values" };
+
+ System.out.println("Output from eval set (learned network):");
+ testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
+
+ System.out.println("Playing optimal games.");
+ for (int i = 0; i < 10; i++) {
+ System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
+ }
+
+ /*
+ File output = new File("ttt.net");
+
+ FileOutputStream fos = new FileOutputStream(output);
+
+ neuralNetwork.save(fos);*/
+ }
+
+ private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
+ GameRecord gameRecord = new GameRecord();
+
+ Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
+
+ State state = gameRecord.getState();
+
+ do {
+ Action action;
+ State nextState;
+
+ action = neuralNetPolicy.getAction(gameRecord.getState());
+
+ nextState = gameRecord.apply(action);
+ //System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
+ //System.out.println("Next board state: " + nextState);
+ state = nextState;
+ } while (!state.isTerminal());
+
+ //finally, reinforce the actual reward
+
+ return gameRecord;
+ }
+
+ private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
+ GameRecord gameRecord = new GameRecord();
+
+ Policy randomPolicy = new RandomPolicy();
+ Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
+
+ //System.out.println("Playing epsilon-greedy game.");
+
+ State state = gameRecord.getState();
+ NNDataPair statePair;
+
+ Policy selectedPolicy;
+ trainer.zeroTraces(neuralNetwork);
+
+ do {
+ Action action;
+ State nextState;
+
+ if (Math.random() < epsilon) {
+ selectedPolicy = randomPolicy;
+ action = selectedPolicy.getAction(gameRecord.getState());
+ nextState = gameRecord.apply(action);
+ } else {
+ selectedPolicy = neuralNetPolicy;
+ action = selectedPolicy.getAction(gameRecord.getState());
+
+ nextState = gameRecord.apply(action);
+ statePair = NNDataSetFactory.createDataPair(state);
+ NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState);
+ trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal());
+ }
+ //System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
+
+ //System.out.println("Next board state: " + nextState);
+
+ state = nextState;
+ } while (!state.isTerminal());
+
+ //finally, reinforce the actual reward
+ statePair = NNDataSetFactory.createDataPair(state);
+ trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
+
+ return gameRecord;
+ }
+
+ private void testNetwork(FeedforwardNetwork neuralNetwork,
+ double[][] validationSet, String[] inputNames, String[] outputNames) {
+ for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
+ NNDataPair dp = new NNDataPair(new NNData(inputNames,
+ validationSet[valIndex]), new NNData(outputNames,
+ validationSet[valIndex]));
+ System.out.println(dp + " => " + neuralNetwork.compute(dp));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java
index 991608b..853fb9f 100644
--- a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java
+++ b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java
@@ -1,30 +1,133 @@
package net.woodyfolsom.msproj.ann;
-import org.encog.ml.data.MLDataSet;
-import org.encog.neural.networks.ContainsFlat;
-import org.encog.neural.networks.training.propagation.back.Backpropagation;
+import java.util.List;
-public class TemporalDifference extends Backpropagation {
+public class TemporalDifference extends TrainingMethod {
+ private final double alpha;
+ private final double gamma = 1.0;
private final double lambda;
-
- public TemporalDifference(ContainsFlat network, MLDataSet training,
- double theLearnRate, double theMomentum, double lambda) {
- super(network, training, theLearnRate, theMomentum);
+
+ public TemporalDifference(double alpha, double lambda) {
+ this.alpha = alpha;
this.lambda = lambda;
}
- public double getLamdba() {
- return lambda;
+ @Override
+ public void iteratePatterns(FeedforwardNetwork neuralNetwork,
+ List trainingSet) {
+ throw new UnsupportedOperationException();
}
@Override
- public double updateWeight(final double[] gradients,
- final double[] lastGradient, final int index) {
- double alpha = this.getLearningRate();
-
- //TODO fill in weight update for TD(lambda)
-
- return 0.0;
+ public double computePatternError(FeedforwardNetwork neuralNetwork,
+ List trainingSet) {
+ int numDataPairs = trainingSet.size();
+ int outputSize = neuralNetwork.getOutput().length;
+ int totalOutputSize = outputSize * numDataPairs;
+
+ double[] actuals = new double[totalOutputSize];
+ double[] ideals = new double[totalOutputSize];
+ for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
+ NNDataPair nnDataPair = trainingSet.get(dataPair);
+ double[] actual = neuralNetwork.compute(nnDataPair.getInput()
+ .getValues());
+ double[] ideal = nnDataPair.getIdeal().getValues();
+ int offset = dataPair * outputSize;
+
+ System.arraycopy(actual, 0, actuals, offset, outputSize);
+ System.arraycopy(ideal, 0, ideals, offset, outputSize);
+ }
+
+ double MSSE = errorFunction.compute(ideals, actuals);
+ return MSSE;
+ }
+
+ @Override
+ protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
+ Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
+ double[] idealValues = ideal.getValues();
+
+ for (int i = 0; i < idealValues.length; i++) {
+ double input = outputNeurons[i].getInput();
+ double derivative = outputNeurons[i].getActivationFunction()
+ .derivative(input);
+ outputNeurons[i].setGradient(outputNeurons[i].getGradient()
+ + derivative
+ * (idealValues[i] - outputNeurons[i].getOutput()));
+ }
+ // walking down the list of Neurons in reverse order, propagate the
+ // error
+ Neuron[] neurons = neuralNetwork.getNeurons();
+
+ for (int n = neurons.length - 1; n >= 0; n--) {
+
+ Neuron neuron = neurons[n];
+ double error = neuron.getGradient();
+
+ Connection[] connectionsFromN = neuralNetwork
+ .getConnectionsFrom(neuron.getId());
+ if (connectionsFromN.length > 0) {
+
+ double derivative = neuron.getActivationFunction().derivative(
+ neuron.getInput());
+ for (Connection connection : connectionsFromN) {
+ error += derivative
+ * connection.getWeight()
+ * neuralNetwork.getNeuron(connection.getDest())
+ .getGradient();
+ }
+ }
+ neuron.setGradient(error);
+ }
+ }
+
+ private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
+ for (Connection connection : neuralNetwork.getConnections()) {
+ Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
+ Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
+
+ double delta = alpha * srcNeuron.getOutput()
+ * destNeuron.getGradient() * predictionError + connection.getTrace() * lambda;
+
+ // TODO allow for momentum
+ // double lastDelta = connection.getLastDelta();
+ connection.addDelta(delta);
+ }
}
+ @Override
+ public void iterateSequences(FeedforwardNetwork neuralNetwork,
+ List> trainingSet) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double computeSequenceError(FeedforwardNetwork neuralNetwork,
+ List> trainingSet) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected void iteratePattern(FeedforwardNetwork neuralNetwork,
+ NNDataPair statePair, NNData nextReward) {
+ //System.out.println("Learningrate: " + alpha);
+
+ zeroGradients(neuralNetwork);
+
+ //System.out.println("Training with: " + statePair.getInput());
+
+ NNData ideal = nextReward;
+ NNData actual = neuralNetwork.compute(statePair);
+
+ //System.out.println("Updating weights. Ideal Output: " + ideal);
+ //System.out.println("Actual Output: " + actual);
+
+ // backpropagate the gradients w.r.t. output error
+ backPropagate(neuralNetwork, ideal);
+
+ double predictionError = statePair.getIdeal().getValues()[0] // reward_t
+ + actual.getValues()[0] - nextReward.getValues()[0];
+
+ updateWeights(neuralNetwork, predictionError);
+ }
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/TrainingMethod.java b/src/net/woodyfolsom/msproj/ann/TrainingMethod.java
new file mode 100644
index 0000000..ea6c051
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/TrainingMethod.java
@@ -0,0 +1,43 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.util.List;
+
+import net.woodyfolsom.msproj.ann.math.ErrorFunction;
+import net.woodyfolsom.msproj.ann.math.MSSE;
+
+public abstract class TrainingMethod {
+ protected final ErrorFunction errorFunction;
+
+ public TrainingMethod() {
+ this.errorFunction = MSSE.function;
+ }
+
+ protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork,
+ NNDataPair statePair, NNData nextReward);
+
+ protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork,
+ List trainingSet);
+
+ protected abstract double computePatternError(FeedforwardNetwork neuralNetwork,
+ List trainingSet);
+
+ protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork,
+ List> trainingSet);
+
+ protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output);
+
+ protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork,
+ List> trainingSet);
+
+ protected void zeroGradients(FeedforwardNetwork neuralNetwork) {
+ for (Neuron neuron : neuralNetwork.getNeurons()) {
+ neuron.setGradient(0.0);
+ }
+ }
+
+ protected void zeroTraces(FeedforwardNetwork neuralNetwork) {
+ for (Connection conn : neuralNetwork.getConnections()) {
+ conn.setTrace(0.0);
+ }
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/ann/WinFilter.java b/src/net/woodyfolsom/msproj/ann/WinFilter.java
deleted file mode 100644
index c042691..0000000
--- a/src/net/woodyfolsom/msproj/ann/WinFilter.java
+++ /dev/null
@@ -1,105 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import java.util.List;
-import java.util.Set;
-
-import org.encog.engine.network.activation.ActivationSigmoid;
-import org.encog.ml.data.MLDataPair;
-import org.encog.ml.data.MLDataSet;
-import org.encog.ml.data.basic.BasicMLDataSet;
-import org.encog.ml.train.MLTrain;
-import org.encog.neural.networks.BasicNetwork;
-import org.encog.neural.networks.layers.BasicLayer;
-import org.encog.neural.networks.training.propagation.back.Backpropagation;
-
-public class WinFilter extends AbstractNeuralNetFilter implements
- NeuralNetFilter {
-
- public WinFilter() {
- // create a neural network, without using a factory
- BasicNetwork network = new BasicNetwork();
- network.addLayer(new BasicLayer(null, false, 2));
- network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
- network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
- network.getStructure().finalizeStructure();
- network.reset();
-
- this.neuralNetwork = network;
- }
-
- @Override
- public void learn(MLDataSet trainingData) {
- throw new UnsupportedOperationException(
- "This filter learns a Set>, not an MLDataSet");
- }
-
- /**
- * Method is necessary because with temporal difference learning, some of
- * the MLDataPairs are related by being a sequence of moves within a
- * particular game.
- */
- @Override
- public void learn(Set> trainingSet) {
- MLDataSet mlDataset = new BasicMLDataSet();
-
- for (List gameRecord : trainingSet) {
- for (int t = 0; t < gameRecord.size() - 1; t++) {
- mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t)
- .getInput()));
- }
- mlDataset.add(gameRecord.get(gameRecord.size() - 1));
- }
-
- // train the neural network
- final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25);
- //final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8);
- actualTrainingEpochs = 0;
-
- do {
- if (actualTrainingEpochs > 0) {
- int gameStateIndex = 0;
- for (List gameRecord : trainingSet) {
- for (int t = 0; t < gameRecord.size() - 1; t++) {
- MLDataPair oldDataPair = mlDataset.get(gameStateIndex);
- this.neuralNetwork.compute(oldDataPair.getInput());
- gameStateIndex++;
- }
- gameStateIndex++;
- }
- }
- train.iteration();
- System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
- + train.getError());
- actualTrainingEpochs++;
- } while (train.getError() > 0.01
- && actualTrainingEpochs <= maxTrainingEpochs);
- }
-
- @Override
- public void reset() {
- neuralNetwork.reset();
- }
-
- @Override
- public void reset(int seed) {
- neuralNetwork.reset(seed);
- }
-
- @Override
- public BasicNetwork getNeuralNetwork() {
- // TODO Auto-generated method stub
- return null;
- }
-
- @Override
- public int getInputSize() {
- // TODO Auto-generated method stub
- return 0;
- }
-
- @Override
- public int getOutputSize() {
- // TODO Auto-generated method stub
- return 0;
- }
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/XORFilter.java b/src/net/woodyfolsom/msproj/ann/XORFilter.java
index 1021dc5..19e15d4 100644
--- a/src/net/woodyfolsom/msproj/ann/XORFilter.java
+++ b/src/net/woodyfolsom/msproj/ann/XORFilter.java
@@ -1,18 +1,5 @@
package net.woodyfolsom.msproj.ann;
-import java.util.List;
-import java.util.Set;
-
-import org.encog.engine.network.activation.ActivationSigmoid;
-import org.encog.ml.data.MLData;
-import org.encog.ml.data.MLDataPair;
-import org.encog.ml.data.MLDataSet;
-import org.encog.ml.data.basic.BasicMLData;
-import org.encog.ml.train.MLTrain;
-import org.encog.neural.networks.BasicNetwork;
-import org.encog.neural.networks.layers.BasicLayer;
-import org.encog.neural.networks.training.propagation.back.Backpropagation;
-
/**
* Based on sample code from http://neuroph.sourceforge.net
*
@@ -22,54 +9,30 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
public class XORFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter {
+ private static final int INPUT_SIZE = 2;
+ private static final int OUTPUT_SIZE = 1;
+
public XORFilter() {
- // create a neural network, without using a factory
- BasicNetwork network = new BasicNetwork();
- network.addLayer(new BasicLayer(null, false, 2));
- network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
- network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
- network.getStructure().finalizeStructure();
- network.reset();
-
- this.neuralNetwork = network;
+ this(0.8,0.7);
+ }
+
+ public XORFilter(double learningRate, double momentum) {
+ super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
+ new BackPropagation(learningRate, momentum), 1000, 0.01);
+ super.getNeuralNetwork().setName("XORFilter");
}
public double compute(double x, double y) {
- return compute(new BasicMLData(new double[]{x,y})).getData(0);
+ return getNeuralNetwork().compute(new double[]{x,y})[0];
}
@Override
public int getInputSize() {
- return 2;
+ return INPUT_SIZE;
}
@Override
public int getOutputSize() {
- // TODO Auto-generated method stub
- return 1;
- }
-
- @Override
- public void learn(MLDataSet trainingSet) {
-
- // train the neural network
- final MLTrain train = new Backpropagation(neuralNetwork, trainingSet,
- 0.7, 0.8);
-
- actualTrainingEpochs = 0;
-
- do {
- train.iteration();
- System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
- + train.getError());
- actualTrainingEpochs++;
- } while (train.getError() > 0.01
- && actualTrainingEpochs <= maxTrainingEpochs);
- }
-
- @Override
- public void learn(Set> trainingSet) {
- throw new UnsupportedOperationException(
- "This Filter learns an MLDataSet, not a Set>.");
+ return OUTPUT_SIZE;
}
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java b/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java
similarity index 91%
rename from src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java
rename to src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java
index c82a307..c941739 100644
--- a/src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java
+++ b/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2.math;
+package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlAttribute;
diff --git a/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java b/src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java
similarity index 64%
rename from src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java
rename to src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java
index 789da3b..956533b 100644
--- a/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java
+++ b/src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2.math;
+package net.woodyfolsom.msproj.ann.math;
public interface ErrorFunction {
double compute(double[] ideal, double[] actual);
diff --git a/src/net/woodyfolsom/msproj/ann2/math/Linear.java b/src/net/woodyfolsom/msproj/ann/math/Linear.java
similarity index 81%
rename from src/net/woodyfolsom/msproj/ann2/math/Linear.java
rename to src/net/woodyfolsom/msproj/ann/math/Linear.java
index 3e9ab4a..c5de8e8 100644
--- a/src/net/woodyfolsom/msproj/ann2/math/Linear.java
+++ b/src/net/woodyfolsom/msproj/ann/math/Linear.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2.math;
+package net.woodyfolsom.msproj.ann.math;
public class Linear extends ActivationFunction{
public static final Linear function = new Linear();
diff --git a/src/net/woodyfolsom/msproj/ann2/math/MSSE.java b/src/net/woodyfolsom/msproj/ann/math/MSSE.java
similarity index 88%
rename from src/net/woodyfolsom/msproj/ann2/math/MSSE.java
rename to src/net/woodyfolsom/msproj/ann/math/MSSE.java
index b4f85d1..960a5a9 100644
--- a/src/net/woodyfolsom/msproj/ann2/math/MSSE.java
+++ b/src/net/woodyfolsom/msproj/ann/math/MSSE.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2.math;
+package net.woodyfolsom.msproj.ann.math;
public class MSSE implements ErrorFunction{
public static final ErrorFunction function = new MSSE();
diff --git a/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java b/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java
similarity index 55%
rename from src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java
rename to src/net/woodyfolsom/msproj/ann/math/Sigmoid.java
index 530bd0a..de1e168 100644
--- a/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java
+++ b/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2.math;
+package net.woodyfolsom.msproj.ann.math;
public class Sigmoid extends ActivationFunction{
public static final Sigmoid function = new Sigmoid();
@@ -12,9 +12,9 @@ public class Sigmoid extends ActivationFunction{
}
public double derivative(double arg) {
- //lol wth?
- //double eX = Math.exp(arg);
- //return eX / (Math.pow((1+eX), 2));
- return arg - Math.pow(arg,2);
+ //lol wth? oh, the next derivative formula is a function of s(x), not x.
+ double eX = Math.exp(arg);
+ return eX / (Math.pow((1+eX), 2));
+ //return arg - Math.pow(arg,2);
}
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/math/Tanh.java b/src/net/woodyfolsom/msproj/ann/math/Tanh.java
similarity index 84%
rename from src/net/woodyfolsom/msproj/ann2/math/Tanh.java
rename to src/net/woodyfolsom/msproj/ann/math/Tanh.java
index 16108f4..2ab3ec4 100644
--- a/src/net/woodyfolsom/msproj/ann2/math/Tanh.java
+++ b/src/net/woodyfolsom/msproj/ann/math/Tanh.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2.math;
+package net.woodyfolsom.msproj.ann.math;
public class Tanh extends ActivationFunction{
public static final Tanh function = new Tanh();
diff --git a/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java
deleted file mode 100644
index d24bd06..0000000
--- a/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java
+++ /dev/null
@@ -1,84 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.List;
-
-public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
- private final FeedforwardNetwork neuralNetwork;
- private final TrainingMethod trainingMethod;
-
- private double maxError;
- private int actualTrainingEpochs = 0;
- private int maxTrainingEpochs;
-
- AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
- this.neuralNetwork = neuralNetwork;
- this.trainingMethod = trainingMethod;
- this.maxError = maxError;
- this.maxTrainingEpochs = maxTrainingEpochs;
- }
-
- @Override
- public NNData compute(NNDataPair input) {
- return this.neuralNetwork.compute(input);
- }
-
- public int getActualTrainingEpochs() {
- return actualTrainingEpochs;
- }
-
- @Override
- public int getInputSize() {
- return 2;
- }
-
- public int getMaxTrainingEpochs() {
- return maxTrainingEpochs;
- }
-
- protected FeedforwardNetwork getNeuralNetwork() {
- return neuralNetwork;
- }
-
- @Override
- public void learn(List trainingSet) {
- actualTrainingEpochs = 0;
- double error;
- neuralNetwork.initWeights();
-
- error = trainingMethod.computeError(neuralNetwork,trainingSet);
-
- if (error <= maxError) {
- System.out.println("Initial error: " + error);
- return;
- }
-
- do {
- trainingMethod.iterate(neuralNetwork,trainingSet);
- error = trainingMethod.computeError(neuralNetwork,trainingSet);
- System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
- + error);
- actualTrainingEpochs++;
- System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
- } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
- }
-
- @Override
- public boolean load(InputStream input) {
- return neuralNetwork.load(input);
- }
-
- @Override
- public boolean save(OutputStream output) {
- return neuralNetwork.save(output);
- }
-
- public void setMaxError(double maxError) {
- this.maxError = maxError;
- }
-
- public void setMaxTrainingEpochs(int max) {
- this.maxTrainingEpochs = max;
- }
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java
deleted file mode 100644
index 71a2dde..0000000
--- a/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java
+++ /dev/null
@@ -1,25 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.List;
-
-public interface NeuralNetFilter {
- int getActualTrainingEpochs();
-
- int getInputSize();
-
- int getMaxTrainingEpochs();
-
- int getOutputSize();
-
- boolean load(InputStream input);
-
- boolean save(OutputStream output);
-
- void setMaxTrainingEpochs(int max);
-
- NNData compute(NNDataPair input);
-
- void learn(List trainingSet);
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java b/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java
deleted file mode 100644
index 4eac691..0000000
--- a/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java
+++ /dev/null
@@ -1,5 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-public class ObjectiveFunction {
-
-}
diff --git a/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java
deleted file mode 100644
index b588029..0000000
--- a/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java
+++ /dev/null
@@ -1,19 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-import java.util.List;
-
-public class TemporalDifference implements TrainingMethod {
-
- @Override
- public void iterate(FeedforwardNetwork neuralNetwork,
- List trainingSet) {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public double computeError(FeedforwardNetwork neuralNetwork,
- List trainingSet) {
- throw new UnsupportedOperationException("Not implemented");
- }
-
-}
diff --git a/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java b/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java
deleted file mode 100644
index b109034..0000000
--- a/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java
+++ /dev/null
@@ -1,10 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-import java.util.List;
-
-public interface TrainingMethod {
-
- void iterate(FeedforwardNetwork neuralNetwork, List trainingSet);
- double computeError(FeedforwardNetwork neuralNetwork, List trainingSet);
-
-}
diff --git a/src/net/woodyfolsom/msproj/ann2/XORFilter.java b/src/net/woodyfolsom/msproj/ann2/XORFilter.java
deleted file mode 100644
index 7ffd6c4..0000000
--- a/src/net/woodyfolsom/msproj/ann2/XORFilter.java
+++ /dev/null
@@ -1,44 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-/**
- * Based on sample code from http://neuroph.sourceforge.net
- *
- * @author Woody
- *
- */
-public class XORFilter extends AbstractNeuralNetFilter implements
- NeuralNetFilter {
-
- private static final int INPUT_SIZE = 2;
- private static final int OUTPUT_SIZE = 1;
-
- public XORFilter() {
- this(0.8,0.7);
- }
-
- public XORFilter(double learningRate, double momentum) {
- super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
- new BackPropagation(learningRate, momentum), 1000, 0.01);
- super.getNeuralNetwork().setName("XORFilter");
-
- //TODO remove
- //getNeuralNetwork().setWeights(new double[] {
- // 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
- // -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
- // -0.993423, 0.164732, 0.752621}); //output
- }
-
- public double compute(double x, double y) {
- return getNeuralNetwork().compute(new double[]{x,y})[0];
- }
-
- @Override
- public int getInputSize() {
- return INPUT_SIZE;
- }
-
- @Override
- public int getOutputSize() {
- return OUTPUT_SIZE;
- }
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/tictactoe/Action.java b/src/net/woodyfolsom/msproj/tictactoe/Action.java
new file mode 100644
index 0000000..8ad5ae6
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/Action.java
@@ -0,0 +1,54 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class Action {
+ public static final Action NONE = new Action(PLAYER.NONE, -1, -1);
+
+ private Game.PLAYER player;
+ private int row;
+ private int column;
+
+ public static Action getInstance(PLAYER player, int row, int column) {
+ return new Action(player,row,column);
+ }
+
+ private Action(PLAYER player, int row, int column) {
+ this.player = player;
+ this.row = row;
+ this.column = column;
+ }
+
+ public Game.PLAYER getPlayer() {
+ return player;
+ }
+
+ public int getColumn() {
+ return column;
+ }
+
+ public int getRow() {
+ return row;
+ }
+
+ public boolean isNone() {
+ return this == Action.NONE;
+ }
+
+ public void setPlayer(Game.PLAYER player) {
+ this.player = player;
+ }
+
+ public void setRow(int row) {
+ this.row = row;
+ }
+
+ public void setColumn(int column) {
+ this.column = column;
+ }
+
+ @Override
+ public String toString() {
+ return player + "(" + row + ", " + column + ")";
+ }
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/tictactoe/Game.java b/src/net/woodyfolsom/msproj/tictactoe/Game.java
new file mode 100644
index 0000000..2635d57
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/Game.java
@@ -0,0 +1,5 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+public class Game {
+ public enum PLAYER {X,O,NONE}
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/GameRecord.java b/src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
new file mode 100644
index 0000000..23ffdc6
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
@@ -0,0 +1,63 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class GameRecord {
+ public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS}
+
+ private List actions = new ArrayList();
+ private List states = new ArrayList();
+
+ private RESULT result = RESULT.IN_PROGRESS;
+
+ public GameRecord() {
+ actions.add(Action.NONE);
+ states.add(new State());
+ }
+
+ public void addState(State state) {
+ states.add(state);
+ }
+
+ public State apply(Action action) {
+ State nextState = getState().apply(action);
+ if (nextState.isValid()) {
+ states.add(nextState);
+ actions.add(action);
+ }
+
+ if (nextState.isTerminal()) {
+ if (nextState.isWinner(PLAYER.X)) {
+ result = RESULT.X_WINS;
+ } else if (nextState.isWinner(PLAYER.O)) {
+ result = RESULT.O_WINS;
+ } else {
+ result = RESULT.TIE_GAME;
+ }
+ }
+ return nextState;
+ }
+
+ public int getNumStates() {
+ return states.size();
+ }
+
+ public RESULT getResult() {
+ return result;
+ }
+
+ public void setResult(RESULT result) {
+ this.result = result;
+ }
+
+ public State getState() {
+ return states.get(states.size()-1);
+ }
+
+ public State getState(int index) {
+ return states.get(index);
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java b/src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
new file mode 100644
index 0000000..888963e
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
@@ -0,0 +1,20 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class MoveGenerator {
+ public List getValidActions(State state) {
+ PLAYER playerToMove = state.getPlayerToMove();
+ List validActions = new ArrayList();
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 3; j++) {
+ if (state.isEmpty(i,j))
+ validActions.add(Action.getInstance(playerToMove, i, j));
+ }
+ }
+ return validActions;
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java b/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
new file mode 100644
index 0000000..efedfba
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
@@ -0,0 +1,81 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import net.woodyfolsom.msproj.ann.NNData;
+import net.woodyfolsom.msproj.ann.NNDataPair;
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class NNDataSetFactory {
+ public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"};
+ public static final String[] TTT_OUTPUT_FIELDS = {"value"};
+
+ public static List> createDataSet(List tttGames) {
+
+ List> nnDataSet = new ArrayList>();
+
+ for (GameRecord tttGame : tttGames) {
+ List gameData = createDataPairList(tttGame);
+
+
+ nnDataSet.add(gameData);
+ }
+
+ return nnDataSet;
+ }
+
+ public static List createDataPairList(GameRecord gameRecord) {
+ List gameData = new ArrayList();
+
+ for (int i = 0; i < gameRecord.getNumStates(); i++) {
+ gameData.add(createDataPair(gameRecord.getState(i)));
+ }
+
+ return gameData;
+ }
+
+ public static NNDataPair createDataPair(State tttState) {
+ double value;
+ if (tttState.isTerminal()) {
+ if (tttState.isWinner(PLAYER.X)) {
+ value = 1.0; // win for black
+ } else if (tttState.isWinner(PLAYER.O)) {
+ value = 0.0; // loss for black
+ //value = -1.0;
+ } else {
+ value = 0.5;
+ //value = 0.0; //tie
+ }
+ } else {
+ value = 0.0;
+ }
+
+ double[] inputValues = new double[9];
+ char[] boardCopy = tttState.getBoard();
+ inputValues[0] = getTicTacToeInput(boardCopy, 0, 0);
+ inputValues[1] = getTicTacToeInput(boardCopy, 0, 1);
+ inputValues[2] = getTicTacToeInput(boardCopy, 0, 2);
+ inputValues[3] = getTicTacToeInput(boardCopy, 1, 0);
+ inputValues[4] = getTicTacToeInput(boardCopy, 1, 1);
+ inputValues[5] = getTicTacToeInput(boardCopy, 1, 2);
+ inputValues[6] = getTicTacToeInput(boardCopy, 2, 0);
+ inputValues[7] = getTicTacToeInput(boardCopy, 2, 1);
+ inputValues[8] = getTicTacToeInput(boardCopy, 2, 2);
+
+ return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value}));
+ }
+
+ private static double getTicTacToeInput(char[] board, int row, int column) {
+ switch (board[row*3+column]) {
+ case 'X' :
+ return 1.0;
+ case 'O' :
+ return -1.0;
+ case '.' :
+ return 0.0;
+ default:
+ throw new RuntimeException("Invalid board symbol at " + row +", " + column);
+ }
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java b/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
new file mode 100644
index 0000000..4bb0fb8
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
@@ -0,0 +1,67 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
+import net.woodyfolsom.msproj.ann.NNDataPair;
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class NeuralNetPolicy extends Policy {
+ private FeedforwardNetwork neuralNet;
+ private MoveGenerator moveGenerator = new MoveGenerator();
+
+ public NeuralNetPolicy(FeedforwardNetwork neuralNet) {
+ super("NeuralNet-" + neuralNet.getName());
+ this.neuralNet = neuralNet;
+ }
+
+ @Override
+ public Action getAction(State state) {
+ List validMoves = moveGenerator.getValidActions(state);
+ Map scores = new HashMap();
+
+ for (Action action : validMoves) {
+ State nextState = state.apply(action);
+ NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
+ //estimated reward for X
+ scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
+ }
+
+ PLAYER playerToMove = state.getPlayerToMove();
+
+ if (playerToMove == PLAYER.X) {
+ return returnMaxAction(scores);
+ } else if (playerToMove == PLAYER.O) {
+ return returnMinAction(scores);
+ } else {
+ throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove);
+ }
+ //return validMoves.get((int)(Math.random() * validMoves.size()));
+ }
+
+ private Action returnMaxAction(Map scores) {
+ Action bestAction = null;
+ Double bestScore = Double.NEGATIVE_INFINITY;
+ for (Map.Entry entry : scores.entrySet()) {
+ if (entry.getValue() > bestScore) {
+ bestScore = entry.getValue();
+ bestAction = entry.getKey();
+ }
+ }
+ return bestAction;
+ }
+
+ private Action returnMinAction(Map scores) {
+ Action bestAction = null;
+ Double bestScore = Double.POSITIVE_INFINITY;
+ for (Map.Entry entry : scores.entrySet()) {
+ if (entry.getValue() < bestScore) {
+ bestScore = entry.getValue();
+ bestAction = entry.getKey();
+ }
+ }
+ return bestAction;
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/Policy.java b/src/net/woodyfolsom/msproj/tictactoe/Policy.java
new file mode 100644
index 0000000..c81ecb4
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/Policy.java
@@ -0,0 +1,15 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+public abstract class Policy {
+ private String name;
+
+ protected Policy(String name) {
+ this.name = name;
+ }
+
+ public abstract Action getAction(State state);
+
+ public String getName() {
+ return name;
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java b/src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
new file mode 100644
index 0000000..3b7bb22
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
@@ -0,0 +1,18 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.List;
+
+public class RandomPolicy extends Policy {
+ private MoveGenerator moveGenerator = new MoveGenerator();
+
+ public RandomPolicy() {
+ super("Random");
+ }
+
+ @Override
+ public Action getAction(State state) {
+ List validMoves = moveGenerator.getValidActions(state);
+ return validMoves.get((int)(Math.random() * validMoves.size()));
+ }
+
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/Referee.java b/src/net/woodyfolsom/msproj/tictactoe/Referee.java
new file mode 100644
index 0000000..ed1cf4e
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/Referee.java
@@ -0,0 +1,43 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class Referee {
+
+ public static void main(String[] args) {
+ new Referee().play(50);
+ }
+
+ public List play(int nGames) {
+ Policy policy = new RandomPolicy();
+
+ List tournament = new ArrayList();
+
+ for (int i = 0; i < nGames; i++) {
+ GameRecord gameRecord = new GameRecord();
+
+ System.out.println("Playing game #" +(i+1));
+
+ State state;
+ do {
+ Action action = policy.getAction(gameRecord.getState());
+ System.out.println("Action " + action + " selected by policy " + policy.getName());
+ state = gameRecord.apply(action);
+ System.out.println("Next board state:");
+ System.out.println(gameRecord.getState());
+ } while (!state.isTerminal());
+ System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult());
+ tournament.add(gameRecord);
+ }
+
+ System.out.println("Played " + tournament.size() + " random games.");
+ System.out.println("Results:");
+ for (int i = 0; i < tournament.size(); i++) {
+ GameRecord gameRecord = tournament.get(i);
+ System.out.println((i+1) + ". " + gameRecord.getResult());
+ }
+
+ return tournament;
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/tictactoe/State.java b/src/net/woodyfolsom/msproj/tictactoe/State.java
new file mode 100644
index 0000000..7d3ce6b
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/tictactoe/State.java
@@ -0,0 +1,116 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import java.util.Arrays;
+
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class State {
+ public static final State INVALID = new State();
+ public static char EMPTY_SQUARE = '.';
+
+ private char[] board;
+ private PLAYER playerToMove;
+
+ public State() {
+ playerToMove = Game.PLAYER.X;
+ board = new char[9];
+ Arrays.fill(board,'.');
+ }
+
+ private State(State that) {
+ this.board = Arrays.copyOf(that.board, that.board.length);
+ this.playerToMove = that.playerToMove;
+ }
+
+ public State apply(Action action) {
+ if (action.getPlayer() != playerToMove) {
+ System.out.println("It is not " + action.getPlayer() +"'s turn.");
+ return State.INVALID;
+ }
+ State nextState = new State(this);
+
+ int row = action.getRow();
+ int column = action.getColumn();
+ int dest = row * 3 + column;
+
+ if (board[dest] != EMPTY_SQUARE) {
+ System.out.println("Invalid move " + action + ", coordinate not empty.");
+ return State.INVALID;
+ }
+ switch (playerToMove) {
+ case X : nextState.board[dest] = 'X';
+ break;
+ case O : nextState.board[dest] = 'O';
+ break;
+ default:
+ throw new RuntimeException("Invalid playerToMove");
+ }
+
+ if (playerToMove == PLAYER.X) {
+ nextState.playerToMove = PLAYER.O;
+ } else {
+ nextState.playerToMove = PLAYER.X;
+ }
+ return nextState;
+ }
+
+ public char[] getBoard() {
+ return Arrays.copyOf(board, board.length);
+ }
+
+ public PLAYER getPlayerToMove() {
+ return playerToMove;
+ }
+
+ public boolean isEmpty(int row, int column) {
+ return board[row*3+column] == EMPTY_SQUARE;
+ }
+
+ public boolean isFull(char mark1, char mark2, char mark3) {
+ return mark1 != '.' && mark2 != '.' && mark3 != '.';
+ }
+
+ public boolean isWinner(PLAYER player) {
+ return isWin(player,board[0],board[1],board[2]) ||
+ isWin(player,board[3],board[4],board[5]) ||
+ isWin(player,board[6],board[7],board[8]) ||
+ isWin(player,board[0],board[3],board[6]) ||
+ isWin(player,board[1],board[4],board[7]) ||
+ isWin(player,board[2],board[5],board[8]) ||
+ isWin(player,board[0],board[4],board[8]) ||
+ isWin(player,board[2],board[4],board[6]);
+ }
+
+ public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) {
+ if (isFull(mark1,mark2,mark3)) {
+ switch (player) {
+ case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X';
+ case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O';
+ default :
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+
+ public boolean isTerminal() {
+ return isWinner(PLAYER.X) || isWinner(PLAYER.O) ||
+ (isFull(board[0],board[1], board[2]) &&
+ isFull(board[3],board[4], board[5]) &&
+ isFull(board[6],board[7], board[8]));
+ }
+
+ public boolean isValid() {
+ return this != INVALID;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n");
+ sb.append(""+board[0] + board[1] + board[2] + "\n");
+ sb.append(""+board[3] + board[4] + board[5] + "\n");
+ sb.append(""+board[6] + board[7] + board[8] + "\n");
+ return sb.toString();
+ }
+}
diff --git a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java b/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java
similarity index 89%
rename from test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java
rename to test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java
index 7dc8633..f6d840b 100644
--- a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java
+++ b/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java
@@ -1,4 +1,4 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -10,6 +10,12 @@ import java.io.IOException;
import javax.xml.bind.JAXBException;
+import net.woodyfolsom.msproj.ann.Connection;
+import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
+import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
+import net.woodyfolsom.msproj.ann.NNData;
+import net.woodyfolsom.msproj.ann.NNDataPair;
+
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
diff --git a/test/net/woodyfolsom/msproj/ann/TTTFilterTest.java b/test/net/woodyfolsom/msproj/ann/TTTFilterTest.java
new file mode 100644
index 0000000..60af997
--- /dev/null
+++ b/test/net/woodyfolsom/msproj/ann/TTTFilterTest.java
@@ -0,0 +1,100 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+
+import net.woodyfolsom.msproj.ann.NNData;
+import net.woodyfolsom.msproj.ann.NNDataPair;
+import net.woodyfolsom.msproj.ann.NeuralNetFilter;
+import net.woodyfolsom.msproj.ann.TTTFilter;
+import net.woodyfolsom.msproj.tictactoe.GameRecord;
+import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
+import net.woodyfolsom.msproj.tictactoe.Referee;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TTTFilterTest {
+ private static final String FILENAME = "tttPerceptron.net";
+
+ @AfterClass
+ public static void deleteNewNet() {
+ File file = new File(FILENAME);
+ if (file.exists()) {
+ file.delete();
+ }
+ }
+
+ @BeforeClass
+ public static void deleteSavedNet() {
+ File file = new File(FILENAME);
+ if (file.exists()) {
+ file.delete();
+ }
+ }
+
+ @Test
+ public void testLearn() throws IOException {
+ double alpha = 0.5;
+ double lambda = 0.0;
+ int maxEpochs = 1000;
+
+ NeuralNetFilter nnLearner = new TTTFilter(alpha, lambda, maxEpochs);
+
+ // Create trainingSet from a tournament of random games.
+ // Future iterations will use Epsilon-greedy play from a policy based on
+ // this network to generate additional datasets.
+ List tournament = new Referee().play(1);
+ List> trainingSet = NNDataSetFactory
+ .createDataSet(tournament);
+
+ System.out.println("Generated " + trainingSet.size()
+ + " datasets from random self-play.");
+ nnLearner.learnSequences(trainingSet);
+ System.out.println("Learned network after "
+ + nnLearner.getActualTrainingEpochs() + " training epochs.");
+
+ double[][] validationSet = new double[7][];
+
+ // empty board
+ validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // center
+ validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // top edge
+ validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // left edge
+ validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // corner
+ validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0 };
+ // win
+ validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
+ -1.0, 0.0 };
+ // loss
+ validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
+ 0.0, -1.0 };
+
+ String[] inputNames = new String[] { "00", "01", "02", "10", "11",
+ "12", "20", "21", "22" };
+ String[] outputNames = new String[] { "values" };
+
+ System.out.println("Output from eval set (learned network):");
+ testNetwork(nnLearner, validationSet, inputNames, outputNames);
+ }
+
+ private void testNetwork(NeuralNetFilter nnLearner,
+ double[][] validationSet, String[] inputNames, String[] outputNames) {
+ for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
+ NNDataPair dp = new NNDataPair(new NNData(inputNames,
+ validationSet[valIndex]), new NNData(outputNames,
+ validationSet[valIndex]));
+ System.out.println(dp + " => " + nnLearner.compute(dp));
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java
deleted file mode 100644
index bcb0e18..0000000
--- a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java
+++ /dev/null
@@ -1,64 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import java.io.File;
-import java.io.FileFilter;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-
-import net.woodyfolsom.msproj.GameRecord;
-import net.woodyfolsom.msproj.Referee;
-
-import org.antlr.runtime.RecognitionException;
-import org.encog.ml.data.MLData;
-import org.encog.ml.data.MLDataPair;
-import org.junit.Test;
-
-public class WinFilterTest {
-
- @Test
- public void testLearnSaveLoad() throws IOException, RecognitionException {
- File[] sgfFiles = new File("data/games/random_vs_random")
- .listFiles(new FileFilter() {
- @Override
- public boolean accept(File pathname) {
- return pathname.getName().endsWith(".sgf");
- }
- });
-
- Set> trainingData = new HashSet>();
-
- for (File file : sgfFiles) {
- FileInputStream fis = new FileInputStream(file);
- GameRecord gameRecord = Referee.replay(fis);
-
- List gameData = new ArrayList();
- for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
- gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
- }
-
- trainingData.add(gameData);
-
- fis.close();
- }
-
- WinFilter winFilter = new WinFilter();
-
- winFilter.learn(trainingData);
-
- for (List trainingSequence : trainingData) {
- for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
- if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
- continue;
- }
- MLData input = trainingSequence.get(stateIndex).getInput();
-
- System.out.println("Turn " + stateIndex + ": " + input + " => "
- + winFilter.compute(input));
- }
- }
- }
-}
diff --git a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java
index 8a39c15..0ac82d8 100644
--- a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java
+++ b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java
@@ -1,10 +1,19 @@
package net.woodyfolsom.msproj.ann;
-import java.io.File;
-import java.io.IOException;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import net.woodyfolsom.msproj.ann.NNData;
+import net.woodyfolsom.msproj.ann.NNDataPair;
+import net.woodyfolsom.msproj.ann.NeuralNetFilter;
+import net.woodyfolsom.msproj.ann.XORFilter;
-import org.encog.ml.data.MLDataSet;
-import org.encog.ml.data.basic.BasicMLDataSet;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -29,14 +38,55 @@ public class XORFilterTest {
}
@Test
- public void testLearnSaveLoad() throws IOException {
- NeuralNetFilter nnLearner = new XORFilter();
- System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
+ public void testLearn() throws IOException {
+ NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
+ for (int i = 0; i < size; i++) {
+ trainingInput[i * 4 + 0] = new double[] { 0, 0 };
+ trainingInput[i * 4 + 1] = new double[] { 0, 1 };
+ trainingInput[i * 4 + 2] = new double[] { 1, 0 };
+ trainingInput[i * 4 + 3] = new double[] { 1, 1 };
+ trainingOutput[i * 4 + 0] = new double[] { 0 };
+ trainingOutput[i * 4 + 1] = new double[] { 1 };
+ trainingOutput[i * 4 + 2] = new double[] { 1 };
+ trainingOutput[i * 4 + 3] = new double[] { 0 };
+ }
+
+ // create training data
+ List trainingSet = new ArrayList();
+ String[] inputNames = new String[] {"x","y"};
+ String[] outputNames = new String[] {"XOR"};
+ for (int i = 0; i < 4*size; i++) {
+ trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
+ }
+
+ nnLearner.setMaxTrainingEpochs(20000);
+ nnLearner.learnPatterns(trainingSet);
+ System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
+
+ double[][] validationSet = new double[4][2];
+
+ validationSet[0] = new double[] { 0, 0 };
+ validationSet[1] = new double[] { 0, 1 };
+ validationSet[2] = new double[] { 1, 0 };
+ validationSet[3] = new double[] { 1, 1 };
+
+ System.out.println("Output from eval set (learned network):");
+ testNetwork(nnLearner, validationSet, inputNames, outputNames);
+ }
+
+ @Test
+ public void testLearnSaveLoad() throws IOException {
+ NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
+
+ // create training set (logical XOR function)
+ int size = 2;
+ double[][] trainingInput = new double[4 * size][];
+ double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
@@ -49,10 +99,17 @@ public class XORFilterTest {
}
// create training data
- MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
+ List trainingSet = new ArrayList();
+ String[] inputNames = new String[] {"x","y"};
+ String[] outputNames = new String[] {"XOR"};
+ for (int i = 0; i < 4*size; i++) {
+ trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
+ }
+
+ nnLearner.setMaxTrainingEpochs(1);
+ nnLearner.learnPatterns(trainingSet);
+ System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
- nnLearner.learn(trainingSet);
-
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
@@ -61,18 +118,23 @@ public class XORFilterTest {
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
- testNetwork(nnLearner, validationSet);
-
- nnLearner.save(FILENAME);
- nnLearner.load(FILENAME);
+ testNetwork(nnLearner, validationSet, inputNames, outputNames);
+ FileOutputStream fos = new FileOutputStream(FILENAME);
+ assertTrue(nnLearner.save(fos));
+ fos.close();
+
+ FileInputStream fis = new FileInputStream(FILENAME);
+ assertTrue(nnLearner.load(fis));
+ fis.close();
+
System.out.println("Output from eval set (learned network, post-serialization):");
- testNetwork(nnLearner, validationSet);
+ testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
- private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
+ private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
- DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
+ NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
diff --git a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java b/test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java
similarity index 71%
rename from test/net/woodyfolsom/msproj/ann2/SigmoidTest.java
rename to test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java
index 6cbaffe..8a4722b 100644
--- a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java
+++ b/test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java
@@ -1,11 +1,11 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
-import net.woodyfolsom.msproj.ann2.math.Sigmoid;
-import net.woodyfolsom.msproj.ann2.math.Tanh;
+import net.woodyfolsom.msproj.ann.math.ActivationFunction;
+import net.woodyfolsom.msproj.ann.math.Sigmoid;
+import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test;
diff --git a/test/net/woodyfolsom/msproj/ann2/TanhTest.java b/test/net/woodyfolsom/msproj/ann/math/TanhTest.java
similarity index 75%
rename from test/net/woodyfolsom/msproj/ann2/TanhTest.java
rename to test/net/woodyfolsom/msproj/ann/math/TanhTest.java
index 8429fa1..942ea85 100644
--- a/test/net/woodyfolsom/msproj/ann2/TanhTest.java
+++ b/test/net/woodyfolsom/msproj/ann/math/TanhTest.java
@@ -1,10 +1,10 @@
-package net.woodyfolsom.msproj.ann2;
+package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
-import net.woodyfolsom.msproj.ann2.math.Tanh;
+import net.woodyfolsom.msproj.ann.math.ActivationFunction;
+import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test;
diff --git a/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java
deleted file mode 100644
index b9d8887..0000000
--- a/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java
+++ /dev/null
@@ -1,136 +0,0 @@
-package net.woodyfolsom.msproj.ann2;
-
-import static org.junit.Assert.assertTrue;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-
-public class XORFilterTest {
- private static final String FILENAME = "xorPerceptron.net";
-
- @AfterClass
- public static void deleteNewNet() {
- File file = new File(FILENAME);
- if (file.exists()) {
- file.delete();
- }
- }
-
- @BeforeClass
- public static void deleteSavedNet() {
- File file = new File(FILENAME);
- if (file.exists()) {
- file.delete();
- }
- }
-
- @Test
- public void testLearn() throws IOException {
- NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
-
- // create training set (logical XOR function)
- int size = 1;
- double[][] trainingInput = new double[4 * size][];
- double[][] trainingOutput = new double[4 * size][];
- for (int i = 0; i < size; i++) {
- trainingInput[i * 4 + 0] = new double[] { 0, 0 };
- trainingInput[i * 4 + 1] = new double[] { 0, 1 };
- trainingInput[i * 4 + 2] = new double[] { 1, 0 };
- trainingInput[i * 4 + 3] = new double[] { 1, 1 };
- trainingOutput[i * 4 + 0] = new double[] { 0 };
- trainingOutput[i * 4 + 1] = new double[] { 1 };
- trainingOutput[i * 4 + 2] = new double[] { 1 };
- trainingOutput[i * 4 + 3] = new double[] { 0 };
- }
-
- // create training data
- List trainingSet = new ArrayList();
- String[] inputNames = new String[] {"x","y"};
- String[] outputNames = new String[] {"XOR"};
- for (int i = 0; i < 4*size; i++) {
- trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
- }
-
- nnLearner.setMaxTrainingEpochs(20000);
- nnLearner.learn(trainingSet);
- System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
-
- double[][] validationSet = new double[4][2];
-
- validationSet[0] = new double[] { 0, 0 };
- validationSet[1] = new double[] { 0, 1 };
- validationSet[2] = new double[] { 1, 0 };
- validationSet[3] = new double[] { 1, 1 };
-
- System.out.println("Output from eval set (learned network):");
- testNetwork(nnLearner, validationSet, inputNames, outputNames);
- }
-
- @Test
- public void testLearnSaveLoad() throws IOException {
- NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
-
- // create training set (logical XOR function)
- int size = 2;
- double[][] trainingInput = new double[4 * size][];
- double[][] trainingOutput = new double[4 * size][];
- for (int i = 0; i < size; i++) {
- trainingInput[i * 4 + 0] = new double[] { 0, 0 };
- trainingInput[i * 4 + 1] = new double[] { 0, 1 };
- trainingInput[i * 4 + 2] = new double[] { 1, 0 };
- trainingInput[i * 4 + 3] = new double[] { 1, 1 };
- trainingOutput[i * 4 + 0] = new double[] { 0 };
- trainingOutput[i * 4 + 1] = new double[] { 1 };
- trainingOutput[i * 4 + 2] = new double[] { 1 };
- trainingOutput[i * 4 + 3] = new double[] { 0 };
- }
-
- // create training data
- List trainingSet = new ArrayList();
- String[] inputNames = new String[] {"x","y"};
- String[] outputNames = new String[] {"XOR"};
- for (int i = 0; i < 4*size; i++) {
- trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
- }
-
- nnLearner.setMaxTrainingEpochs(1);
- nnLearner.learn(trainingSet);
- System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
-
- double[][] validationSet = new double[4][2];
-
- validationSet[0] = new double[] { 0, 0 };
- validationSet[1] = new double[] { 0, 1 };
- validationSet[2] = new double[] { 1, 0 };
- validationSet[3] = new double[] { 1, 1 };
-
- System.out.println("Output from eval set (learned network, pre-serialization):");
- testNetwork(nnLearner, validationSet, inputNames, outputNames);
-
- FileOutputStream fos = new FileOutputStream(FILENAME);
- assertTrue(nnLearner.save(fos));
- fos.close();
-
- FileInputStream fis = new FileInputStream(FILENAME);
- assertTrue(nnLearner.load(fis));
- fis.close();
-
- System.out.println("Output from eval set (learned network, post-serialization):");
- testNetwork(nnLearner, validationSet, inputNames, outputNames);
- }
-
- private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
- for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
- NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
- System.out.println(dp + " => " + nnLearner.compute(dp));
- }
- }
-}
\ No newline at end of file
diff --git a/test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java b/test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
new file mode 100644
index 0000000..952040a
--- /dev/null
+++ b/test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
@@ -0,0 +1,73 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.junit.Test;
+
+import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
+
+public class GameRecordTest {
+
+ @Test
+ public void testGetResultXwins() {
+ GameRecord gameRecord = new GameRecord();
+ gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
+ State finalState = gameRecord.getState();
+ System.out.println("Final state:");
+ System.out.println(finalState);
+ assertTrue(finalState.isValid());
+ assertTrue(finalState.isTerminal());
+ assertTrue(finalState.isWinner(PLAYER.X));
+ assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
+ }
+
+ @Test
+ public void testGetResultOwins() {
+ GameRecord gameRecord = new GameRecord();
+ gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
+
+ State finalState = gameRecord.getState();
+ System.out.println("Final state:");
+ System.out.println(finalState);
+ assertTrue(finalState.isValid());
+ assertTrue(finalState.isTerminal());
+ assertTrue(finalState.isWinner(PLAYER.O));
+ assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
+ }
+
+ @Test
+ public void testGetResultTieGame() {
+ GameRecord gameRecord = new GameRecord();
+ gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
+
+ gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
+
+ gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
+ gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
+ gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
+
+ State finalState = gameRecord.getState();
+ System.out.println("Final state:");
+ System.out.println(finalState);
+ assertTrue(finalState.isValid());
+ assertTrue(finalState.isTerminal());
+ assertFalse(finalState.isWinner(PLAYER.X));
+ assertFalse(finalState.isWinner(PLAYER.O));
+ assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
+ }
+}
diff --git a/test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java b/test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
new file mode 100644
index 0000000..f2d0e87
--- /dev/null
+++ b/test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
@@ -0,0 +1,12 @@
+package net.woodyfolsom.msproj.tictactoe;
+
+import org.junit.Test;
+
+public class RefereeTest {
+
+ @Test
+ public void testPlay100Games() {
+ new Referee().play(100);
+ }
+
+}
diff --git a/ttt.net b/ttt.net
new file mode 100644
index 0000000..7f9550d
--- /dev/null
+++ b/ttt.net
@@ -0,0 +1,129 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+
+
+ 10
+ 11
+ 12
+ 13
+ 14
+
+
+ 15
+
+