diff --git a/.classpath b/.classpath index 6f91090..7addda2 100644 --- a/.classpath +++ b/.classpath @@ -7,6 +7,5 @@ - diff --git a/lib/encog-java-core-javadoc.jar b/lib/encog-java-core-javadoc.jar deleted file mode 100644 index 2cf8d93..0000000 Binary files a/lib/encog-java-core-javadoc.jar and /dev/null differ diff --git a/lib/encog-java-core-sources.jar b/lib/encog-java-core-sources.jar deleted file mode 100644 index 0cb6b50..0000000 Binary files a/lib/encog-java-core-sources.jar and /dev/null differ diff --git a/lib/encog-java-core.jar b/lib/encog-java-core.jar deleted file mode 100644 index f846d91..0000000 Binary files a/lib/encog-java-core.jar and /dev/null differ diff --git a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java index 865d6e9..8ecde7b 100644 --- a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java +++ b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java @@ -1,21 +1,26 @@ package net.woodyfolsom.msproj.ann; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; - -import org.encog.ml.data.MLData; -import org.encog.neural.networks.BasicNetwork; -import org.encog.neural.networks.PersistBasicNetwork; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { - protected BasicNetwork neuralNetwork; - protected int actualTrainingEpochs = 0; - protected int maxTrainingEpochs = 1000; + private final FeedforwardNetwork neuralNetwork; + private final TrainingMethod trainingMethod; + + private double maxError; + private int actualTrainingEpochs = 0; + private int maxTrainingEpochs; + + AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) { + this.neuralNetwork = neuralNetwork; + this.trainingMethod = trainingMethod; + this.maxError = maxError; + this.maxTrainingEpochs = maxTrainingEpochs; + } @Override - public MLData compute(MLData input) { + public NNData compute(NNDataPair input) { return this.neuralNetwork.compute(input); } @@ -23,38 +28,83 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { return actualTrainingEpochs; } + @Override + public int getInputSize() { + return 2; + } + public int getMaxTrainingEpochs() { return maxTrainingEpochs; } - @Override - public BasicNetwork getNeuralNetwork() { + protected FeedforwardNetwork getNeuralNetwork() { return neuralNetwork; } - public void load(String filename) throws IOException { - FileInputStream fis = new FileInputStream(new File(filename)); - neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis); - fis.close(); + @Override + public void learnPatterns(List trainingSet) { + actualTrainingEpochs = 0; + double error; + neuralNetwork.initWeights(); + + error = trainingMethod.computePatternError(neuralNetwork,trainingSet); + + if (error <= maxError) { + System.out.println("Initial error: " + error); + return; + } + + do { + trainingMethod.iteratePatterns(neuralNetwork,trainingSet); + error = trainingMethod.computePatternError(neuralNetwork,trainingSet); + System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + + error); + actualTrainingEpochs++; + System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error); + } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs); } @Override - public void reset() { - neuralNetwork.reset(); + public void learnSequences(List> trainingSet) { + actualTrainingEpochs = 0; + double error; + neuralNetwork.initWeights(); + + error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet); + + if (error <= maxError) { + System.out.println("Initial error: " + error); + return; + } + + do { + trainingMethod.iterateSequences(neuralNetwork,trainingSet); + error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet); + if (Double.isNaN(error)) { + error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet); + } + System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + + error); + actualTrainingEpochs++; + System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error); + } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs); } @Override - public void reset(int seed) { - neuralNetwork.reset(seed); + public boolean load(InputStream input) { + return neuralNetwork.load(input); + } + + @Override + public boolean save(OutputStream output) { + return neuralNetwork.save(output); } - public void save(String filename) throws IOException { - FileOutputStream fos = new FileOutputStream(new File(filename)); - new PersistBasicNetwork().save(fos, getNeuralNetwork()); - fos.close(); + public void setMaxError(double maxError) { + this.maxError = maxError; } public void setMaxTrainingEpochs(int max) { this.maxTrainingEpochs = max; } -} +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/BackPropagation.java b/src/net/woodyfolsom/msproj/ann/BackPropagation.java similarity index 65% rename from src/net/woodyfolsom/msproj/ann2/BackPropagation.java rename to src/net/woodyfolsom/msproj/ann/BackPropagation.java index 8b40e1e..6be81cd 100644 --- a/src/net/woodyfolsom/msproj/ann2/BackPropagation.java +++ b/src/net/woodyfolsom/msproj/ann/BackPropagation.java @@ -1,11 +1,11 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import java.util.List; -import net.woodyfolsom.msproj.ann2.math.ErrorFunction; -import net.woodyfolsom.msproj.ann2.math.MSSE; +import net.woodyfolsom.msproj.ann.math.ErrorFunction; +import net.woodyfolsom.msproj.ann.math.MSSE; -public class BackPropagation implements TrainingMethod { +public class BackPropagation extends TrainingMethod { private final ErrorFunction errorFunction; private final double learningRate; private final double momentum; @@ -17,15 +17,13 @@ public class BackPropagation implements TrainingMethod { } @Override - public void iterate(FeedforwardNetwork neuralNetwork, + public void iteratePatterns(FeedforwardNetwork neuralNetwork, List trainingSet) { System.out.println("Learningrate: " + learningRate); System.out.println("Momentum: " + momentum); - //zeroErrors(neuralNetwork); - for (NNDataPair trainingPair : trainingSet) { - zeroErrors(neuralNetwork); + zeroGradients(neuralNetwork); System.out.println("Training with: " + trainingPair.getInput()); @@ -35,16 +33,15 @@ public class BackPropagation implements TrainingMethod { System.out.println("Updating weights. Ideal Output: " + ideal); System.out.println("Actual Output: " + actual); - updateErrors(neuralNetwork, ideal); + //backpropagate the gradients w.r.t. output error + backPropagate(neuralNetwork, ideal); updateWeights(neuralNetwork); } - - //updateWeights(neuralNetwork); } @Override - public double computeError(FeedforwardNetwork neuralNetwork, + public double computePatternError(FeedforwardNetwork neuralNetwork, List trainingSet) { int numDataPairs = trainingSet.size(); int outputSize = neuralNetwork.getOutput().length; @@ -67,15 +64,17 @@ public class BackPropagation implements TrainingMethod { return MSSE; } - private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) { + @Override + protected + void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) { Neuron[] outputNeurons = neuralNetwork.getOutputNeurons(); double[] idealValues = ideal.getValues(); for (int i = 0; i < idealValues.length; i++) { - double output = outputNeurons[i].getOutput(); + double input = outputNeurons[i].getInput(); double derivative = outputNeurons[i].getActivationFunction() - .derivative(output); - outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output)); + .derivative(input); + outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput())); } // walking down the list of Neurons in reverse order, propagate the // error @@ -84,19 +83,19 @@ public class BackPropagation implements TrainingMethod { for (int n = neurons.length - 1; n >= 0; n--) { Neuron neuron = neurons[n]; - double error = neuron.getError(); + double error = neuron.getGradient(); Connection[] connectionsFromN = neuralNetwork .getConnectionsFrom(neuron.getId()); if (connectionsFromN.length > 0) { double derivative = neuron.getActivationFunction().derivative( - neuron.getOutput()); + neuron.getInput()); for (Connection connection : connectionsFromN) { - error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError(); + error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient(); } } - neuron.setError(error); + neuron.setGradient(error); } } @@ -104,17 +103,30 @@ public class BackPropagation implements TrainingMethod { for (Connection connection : neuralNetwork.getConnections()) { Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); - double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError(); + double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient(); //TODO allow for momentum //double lastDelta = connection.getLastDelta(); connection.addDelta(delta); } } - - private void zeroErrors(FeedforwardNetwork neuralNetwork) { - // Set output errors relative to ideals, all other errors to 0. - for (Neuron neuron : neuralNetwork.getNeurons()) { - neuron.setError(0.0); - } + + @Override + public void iterateSequences(FeedforwardNetwork neuralNetwork, + List> trainingSet) { + throw new UnsupportedOperationException(); } + + @Override + public double computeSequenceError(FeedforwardNetwork neuralNetwork, + List> trainingSet) { + throw new UnsupportedOperationException(); + } + + @Override + protected void iteratePattern(FeedforwardNetwork neuralNetwork, + NNDataPair statePair, NNData nextReward) { + throw new UnsupportedOperationException(); + } + + } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Connection.java b/src/net/woodyfolsom/msproj/ann/Connection.java similarity index 84% rename from src/net/woodyfolsom/msproj/ann2/Connection.java rename to src/net/woodyfolsom/msproj/ann/Connection.java index 8b38b1b..f2d2e5a 100644 --- a/src/net/woodyfolsom/msproj/ann2/Connection.java +++ b/src/net/woodyfolsom/msproj/ann/Connection.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlTransient; @@ -8,6 +8,7 @@ public class Connection { private int dest; private double weight; private transient double lastDelta = 0.0; + private transient double trace = 0.0; public Connection() { //no-arg constructor for JAXB @@ -20,6 +21,7 @@ public class Connection { } public void addDelta(double delta) { + this.trace = delta; this.weight += delta; this.lastDelta = delta; } @@ -39,6 +41,10 @@ public class Connection { return src; } + public double getTrace() { + return trace; + } + @XmlAttribute public double getWeight() { return weight; @@ -52,6 +58,11 @@ public class Connection { this.src = src; } + @XmlTransient + public void setTrace(double trace) { + this.trace = trace; + } + public void setWeight(double weight) { this.weight = weight; } diff --git a/src/net/woodyfolsom/msproj/ann/DoublePair.java b/src/net/woodyfolsom/msproj/ann/DoublePair.java deleted file mode 100644 index cd0a00f..0000000 --- a/src/net/woodyfolsom/msproj/ann/DoublePair.java +++ /dev/null @@ -1,12 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import org.encog.ml.data.basic.BasicMLData; - -public class DoublePair extends BasicMLData { - - private static final long serialVersionUID = 1L; - - public DoublePair(double x, double y) { - super(new double[] { x, y }); - } -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java b/src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java similarity index 85% rename from src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java rename to src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java index 29e1f10..c89eef9 100644 --- a/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java +++ b/src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import java.io.InputStream; import java.io.OutputStream; @@ -9,10 +9,11 @@ import java.util.Map; import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlTransient; -import net.woodyfolsom.msproj.ann2.math.ActivationFunction; -import net.woodyfolsom.msproj.ann2.math.Linear; -import net.woodyfolsom.msproj.ann2.math.Sigmoid; +import net.woodyfolsom.msproj.ann.math.ActivationFunction; +import net.woodyfolsom.msproj.ann.math.Linear; +import net.woodyfolsom.msproj.ann.math.Sigmoid; public abstract class FeedforwardNetwork { private ActivationFunction activationFunction; @@ -83,12 +84,12 @@ public abstract class FeedforwardNetwork { * Adds a new neuron with a unique id to this FeedforwardNetwork. * @return */ - Neuron createNeuron(boolean input) { + Neuron createNeuron(boolean input, ActivationFunction afunc) { Neuron neuron; if (input) { neuron = new Neuron(Linear.function, neurons.size()); } else { - neuron = new Neuron(activationFunction, neurons.size()); + neuron = new Neuron(afunc, neurons.size()); } neurons.add(neuron); return neuron; @@ -153,6 +154,10 @@ public abstract class FeedforwardNetwork { return neurons.get(id); } + public Connection getConnection(int index) { + return connections.get(index); + } + @XmlElement protected Connection[] getConnections() { return connections.toArray(new Connection[connections.size()]); @@ -178,6 +183,22 @@ public abstract class FeedforwardNetwork { } } + public double[] getGradients() { + double[] gradients = new double[neurons.size()]; + for (int n = 0; n < gradients.length; n++) { + gradients[n] = neurons.get(n).getGradient(); + } + return gradients; + } + + public double[] getWeights() { + double[] weights = new double[connections.size()]; + for (int i = 0; i < connections.size(); i++) { + weights[i] = connections.get(i).getWeight(); + } + return weights; + } + @XmlAttribute public boolean isBiased() { return biased; @@ -226,7 +247,7 @@ public abstract class FeedforwardNetwork { this.biased = biased; if (biased) { - Neuron biasNeuron = createNeuron(true); + Neuron biasNeuron = createNeuron(true, activationFunction); biasNeuron.setInput(1.0); biasNeuronId = biasNeuron.getId(); } else { @@ -270,6 +291,7 @@ public abstract class FeedforwardNetwork { } } + @XmlTransient public void setWeights(double[] weights) { if (weights.length != connections.size()) { throw new IllegalArgumentException("# of weights must == # of connections"); diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java deleted file mode 100644 index d7c1316..0000000 --- a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java +++ /dev/null @@ -1,117 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import net.woodyfolsom.msproj.GameResult; -import net.woodyfolsom.msproj.GameState; -import net.woodyfolsom.msproj.Player; - -import org.encog.ml.data.MLData; -import org.encog.ml.data.MLDataPair; -import org.encog.ml.data.basic.BasicMLData; -import org.encog.ml.data.basic.BasicMLDataPair; -import org.encog.util.kmeans.Centroid; - -public class GameStateMLDataPair implements MLDataPair { - private BasicMLDataPair mlDataPairDelegate; - private GameState gameState; - - public GameStateMLDataPair(GameState gameState) { - this.gameState = gameState; - mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal())); - } - - public GameStateMLDataPair(GameStateMLDataPair that) { - this.gameState = new GameState(that.gameState); - mlDataPairDelegate = new BasicMLDataPair( - that.mlDataPairDelegate.getInput(), - that.mlDataPairDelegate.getIdeal()); - } - - @Override - public MLDataPair clone() { - return new GameStateMLDataPair(this); - } - - @Override - public Centroid createCentroid() { - return mlDataPairDelegate.createCentroid(); - } - - /** - * Creates a vector of normalized scores from GameState. - * - * @return - */ - private double[] createInput() { - - GameResult result = gameState.getResult(); - - double maxScore = gameState.getGameConfig().getSize() - * gameState.getGameConfig().getSize(); - - double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore); - double blackScore = Math.min(1.0, result.getBlackScore() / maxScore); - - return new double[] { blackScore, whiteScore }; - } - - /** - * Creates a vector of values indicating strength of black/white win output - * from network. - * - * @return - */ - private double[] createIdeal() { - GameResult result = gameState.getResult(); - - double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0; - double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0; - - return new double[] { blackWinner, whiteWinner }; - } - - @Override - public MLData getIdeal() { - return mlDataPairDelegate.getIdeal(); - } - - @Override - public double[] getIdealArray() { - return mlDataPairDelegate.getIdealArray(); - } - - @Override - public MLData getInput() { - return mlDataPairDelegate.getInput(); - } - - @Override - public double[] getInputArray() { - return mlDataPairDelegate.getInputArray(); - } - - @Override - public double getSignificance() { - return mlDataPairDelegate.getSignificance(); - } - - @Override - public boolean isSupervised() { - return mlDataPairDelegate.isSupervised(); - } - - @Override - public void setIdealArray(double[] arg0) { - mlDataPairDelegate.setIdealArray(arg0); - } - - @Override - public void setInputArray(double[] arg0) { - mlDataPairDelegate.setInputArray(arg0); - } - - @Override - public void setSignificance(double arg0) { - mlDataPairDelegate.setSignificance(arg0); - } - -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Layer.java b/src/net/woodyfolsom/msproj/ann/Layer.java similarity index 92% rename from src/net/woodyfolsom/msproj/ann2/Layer.java rename to src/net/woodyfolsom/msproj/ann/Layer.java index 88cb178..dc2866f 100644 --- a/src/net/woodyfolsom/msproj/ann2/Layer.java +++ b/src/net/woodyfolsom/msproj/ann/Layer.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import java.util.Arrays; diff --git a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java b/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java similarity index 79% rename from src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java rename to src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java index 23f3dca..6f9c8f4 100644 --- a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java +++ b/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import java.io.InputStream; import java.io.OutputStream; @@ -10,6 +10,10 @@ import javax.xml.bind.Unmarshaller; import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlRootElement; +import net.woodyfolsom.msproj.ann.math.ActivationFunction; +import net.woodyfolsom.msproj.ann.math.Sigmoid; +import net.woodyfolsom.msproj.ann.math.Tanh; + @XmlRootElement public class MultiLayerPerceptron extends FeedforwardNetwork { private boolean biased; @@ -37,7 +41,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork { throw new IllegalArgumentException("Layer size must be >= 1"); } - Layer newLayer = createNewLayer(layerIndex, layerSize); + + Layer newLayer; + if (layerIndex == numLayers - 1) { + newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function); + } else { + newLayer = createNewLayer(layerIndex, layerSize, Tanh.function); + } if (layerIndex > 0) { Layer prevLayer = layers[layerIndex - 1]; @@ -54,11 +64,11 @@ public class MultiLayerPerceptron extends FeedforwardNetwork { } } - private Layer createNewLayer(int layerIndex, int layerSize) { + private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) { Layer layer = new Layer(layerSize); layers[layerIndex] = layer; for (int n = 0; n < layerSize; n++) { - Neuron neuron = createNeuron(layerIndex == 0); + Neuron neuron = createNeuron(layerIndex == 0, afunc); layer.setNeuronId(n, neuron.getId()); } return layer; @@ -93,8 +103,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork { protected void setInput(double[] input) { Layer inputLayer = layers[0]; for (int n = 0; n < inputLayer.size(); n++) { - getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]); + try { + getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]); + } catch (NullPointerException npe) { + npe.printStackTrace(); + } } + } public void setLayers(Layer[] layers) { diff --git a/src/net/woodyfolsom/msproj/ann2/NNData.java b/src/net/woodyfolsom/msproj/ann/NNData.java similarity index 89% rename from src/net/woodyfolsom/msproj/ann2/NNData.java rename to src/net/woodyfolsom/msproj/ann/NNData.java index 12f2150..f5200f7 100644 --- a/src/net/woodyfolsom/msproj/ann2/NNData.java +++ b/src/net/woodyfolsom/msproj/ann/NNData.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; public class NNData { private final double[] values; diff --git a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java b/src/net/woodyfolsom/msproj/ann/NNDataPair.java similarity index 83% rename from src/net/woodyfolsom/msproj/ann2/NNDataPair.java rename to src/net/woodyfolsom/msproj/ann/NNDataPair.java index 53d501d..2b2921a 100644 --- a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java +++ b/src/net/woodyfolsom/msproj/ann/NNDataPair.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; public class NNDataPair { private final NNData input; diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java index 73ce084..08a5c76 100644 --- a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java +++ b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java @@ -1,30 +1,29 @@ package net.woodyfolsom.msproj.ann; -import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; -import java.util.Set; - -import org.encog.ml.data.MLData; -import org.encog.ml.data.MLDataPair; -import org.encog.ml.data.MLDataSet; -import org.encog.neural.networks.BasicNetwork; public interface NeuralNetFilter { - BasicNetwork getNeuralNetwork(); - int getActualTrainingEpochs(); + int getInputSize(); + int getMaxTrainingEpochs(); + int getOutputSize(); - - void learn(MLDataSet trainingSet); - void learn(Set> trainingSet); - - void load(String fileName) throws IOException; - void reset(); - void reset(int seed); - void save(String fileName) throws IOException; + + boolean load(InputStream input); + + boolean save(OutputStream output); + void setMaxTrainingEpochs(int max); - MLData compute(MLData input); + NNData compute(NNDataPair input); + + //Due to Java type erasure, overloading a method + //simply named 'learn' which takes Lists would be problematic + + void learnPatterns(List trainingSet); + void learnSequences(List> trainingSet); } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Neuron.java b/src/net/woodyfolsom/msproj/ann/Neuron.java similarity index 78% rename from src/net/woodyfolsom/msproj/ann2/Neuron.java rename to src/net/woodyfolsom/msproj/ann/Neuron.java index 4e76aa5..c5ad0c4 100644 --- a/src/net/woodyfolsom/msproj/ann2/Neuron.java +++ b/src/net/woodyfolsom/msproj/ann/Neuron.java @@ -1,17 +1,17 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlTransient; -import net.woodyfolsom.msproj.ann2.math.ActivationFunction; -import net.woodyfolsom.msproj.ann2.math.Sigmoid; +import net.woodyfolsom.msproj.ann.math.ActivationFunction; +import net.woodyfolsom.msproj.ann.math.Sigmoid; public class Neuron { private ActivationFunction activationFunction; private int id; private transient double input = 0.0; - private transient double error = 0.0; + private transient double gradient = 0.0; public Neuron() { //no-arg constructor for JAXB @@ -37,8 +37,8 @@ public class Neuron { } @XmlTransient - public double getError() { - return error; + public double getGradient() { + return gradient; } @XmlTransient @@ -50,8 +50,8 @@ public class Neuron { return activationFunction.calculate(input); } - public void setError(double value) { - this.error = value; + public void setGradient(double value) { + this.gradient = value; } public void setInput(double input) { @@ -92,7 +92,7 @@ public class Neuron { @Override public String toString() { - return "Neuron #" + id +", input: " + input + ", error: " + error; + return "Neuron #" + id +", input: " + input + ", gradient: " + gradient; } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java b/src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java new file mode 100644 index 0000000..9e6230a --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java @@ -0,0 +1,5 @@ +package net.woodyfolsom.msproj.ann; + +public class ObjectiveFunction { + +} diff --git a/src/net/woodyfolsom/msproj/ann/TTTFilter.java b/src/net/woodyfolsom/msproj/ann/TTTFilter.java new file mode 100644 index 0000000..31f743c --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/TTTFilter.java @@ -0,0 +1,34 @@ +package net.woodyfolsom.msproj.ann; + +/** + * Based on sample code from http://neuroph.sourceforge.net + * + * @author Woody + * + */ +public class TTTFilter extends AbstractNeuralNetFilter implements + NeuralNetFilter { + + private static final int INPUT_SIZE = 9; + private static final int OUTPUT_SIZE = 1; + + public TTTFilter() { + this(0.5,0.0, 1000); + } + + public TTTFilter(double alpha, double lambda, int maxEpochs) { + super( new MultiLayerPerceptron(true, INPUT_SIZE, 5, OUTPUT_SIZE), + new TemporalDifference(0.5,0.0), maxEpochs, 0.05); + super.getNeuralNetwork().setName("XORFilter"); + } + + @Override + public int getInputSize() { + return INPUT_SIZE; + } + + @Override + public int getOutputSize() { + return OUTPUT_SIZE; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java new file mode 100644 index 0000000..036e77a --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java @@ -0,0 +1,187 @@ +package net.woodyfolsom.msproj.ann; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.util.ArrayList; +import java.util.List; + +import net.woodyfolsom.msproj.tictactoe.Action; +import net.woodyfolsom.msproj.tictactoe.GameRecord; +import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT; +import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory; +import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy; +import net.woodyfolsom.msproj.tictactoe.Policy; +import net.woodyfolsom.msproj.tictactoe.RandomPolicy; +import net.woodyfolsom.msproj.tictactoe.State; + +public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter + + public static void main(String[] args) throws FileNotFoundException { + double alpha = 0.0; + double lambda = 0.9; + int maxGames = 15000; + + new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames); + } + + public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException { + /// + FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1); + neuralNetwork.setName("TicTacToe"); + neuralNetwork.initWeights(); + TrainingMethod trainer = new TemporalDifference(0.5,0.5); + + System.out.println("Playing untrained games."); + for (int i = 0; i < 10; i++) { + System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult()); + } + + System.out.println("Learning from " + maxGames + " games of random self-play"); + + int gamesPlayed = 0; + List results = new ArrayList(); + do { + GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer); + System.out.println("Winner: " + gameRecord.getResult()); + gamesPlayed++; + results.add(gameRecord.getResult()); + } while (gamesPlayed < maxGames); + /// + + System.out.println("Learned network after " + maxGames + " training games."); + + double[][] validationSet = new double[8][]; + + for (int i = 0; i < results.size(); i++) { + if (i % 10 == 0) { + System.out.println("" + (i+1) + ". " + results.get(i)); + } + } + // empty board + validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // center + validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0 }; + // top edge + validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // left edge + validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // corner + validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // win + validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0, + -1.0, 0.0 }; + // loss + validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0, + 0.0, -1.0 }; + + // about to win + validationSet[7] = new double[] { + -1.0, 1.0, 1.0, + 1.0, -1.0, 1.0, + -1.0, -1.0, 0.0 }; + + String[] inputNames = new String[] { "00", "01", "02", "10", "11", + "12", "20", "21", "22" }; + String[] outputNames = new String[] { "values" }; + + System.out.println("Output from eval set (learned network):"); + testNetwork(neuralNetwork, validationSet, inputNames, outputNames); + + System.out.println("Playing optimal games."); + for (int i = 0; i < 10; i++) { + System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult()); + } + + /* + File output = new File("ttt.net"); + + FileOutputStream fos = new FileOutputStream(output); + + neuralNetwork.save(fos);*/ + } + + private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) { + GameRecord gameRecord = new GameRecord(); + + Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); + + State state = gameRecord.getState(); + + do { + Action action; + State nextState; + + action = neuralNetPolicy.getAction(gameRecord.getState()); + + nextState = gameRecord.apply(action); + //System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName()); + //System.out.println("Next board state: " + nextState); + state = nextState; + } while (!state.isTerminal()); + + //finally, reinforce the actual reward + + return gameRecord; + } + + private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) { + GameRecord gameRecord = new GameRecord(); + + Policy randomPolicy = new RandomPolicy(); + Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork); + + //System.out.println("Playing epsilon-greedy game."); + + State state = gameRecord.getState(); + NNDataPair statePair; + + Policy selectedPolicy; + trainer.zeroTraces(neuralNetwork); + + do { + Action action; + State nextState; + + if (Math.random() < epsilon) { + selectedPolicy = randomPolicy; + action = selectedPolicy.getAction(gameRecord.getState()); + nextState = gameRecord.apply(action); + } else { + selectedPolicy = neuralNetPolicy; + action = selectedPolicy.getAction(gameRecord.getState()); + + nextState = gameRecord.apply(action); + statePair = NNDataSetFactory.createDataPair(state); + NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState); + trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal()); + } + //System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName()); + + //System.out.println("Next board state: " + nextState); + + state = nextState; + } while (!state.isTerminal()); + + //finally, reinforce the actual reward + statePair = NNDataSetFactory.createDataPair(state); + trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal()); + + return gameRecord; + } + + private void testNetwork(FeedforwardNetwork neuralNetwork, + double[][] validationSet, String[] inputNames, String[] outputNames) { + for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { + NNDataPair dp = new NNDataPair(new NNData(inputNames, + validationSet[valIndex]), new NNData(outputNames, + validationSet[valIndex])); + System.out.println(dp + " => " + neuralNetwork.compute(dp)); + } + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java index 991608b..853fb9f 100644 --- a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java +++ b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java @@ -1,30 +1,133 @@ package net.woodyfolsom.msproj.ann; -import org.encog.ml.data.MLDataSet; -import org.encog.neural.networks.ContainsFlat; -import org.encog.neural.networks.training.propagation.back.Backpropagation; +import java.util.List; -public class TemporalDifference extends Backpropagation { +public class TemporalDifference extends TrainingMethod { + private final double alpha; + private final double gamma = 1.0; private final double lambda; - - public TemporalDifference(ContainsFlat network, MLDataSet training, - double theLearnRate, double theMomentum, double lambda) { - super(network, training, theLearnRate, theMomentum); + + public TemporalDifference(double alpha, double lambda) { + this.alpha = alpha; this.lambda = lambda; } - public double getLamdba() { - return lambda; + @Override + public void iteratePatterns(FeedforwardNetwork neuralNetwork, + List trainingSet) { + throw new UnsupportedOperationException(); } @Override - public double updateWeight(final double[] gradients, - final double[] lastGradient, final int index) { - double alpha = this.getLearningRate(); - - //TODO fill in weight update for TD(lambda) - - return 0.0; + public double computePatternError(FeedforwardNetwork neuralNetwork, + List trainingSet) { + int numDataPairs = trainingSet.size(); + int outputSize = neuralNetwork.getOutput().length; + int totalOutputSize = outputSize * numDataPairs; + + double[] actuals = new double[totalOutputSize]; + double[] ideals = new double[totalOutputSize]; + for (int dataPair = 0; dataPair < numDataPairs; dataPair++) { + NNDataPair nnDataPair = trainingSet.get(dataPair); + double[] actual = neuralNetwork.compute(nnDataPair.getInput() + .getValues()); + double[] ideal = nnDataPair.getIdeal().getValues(); + int offset = dataPair * outputSize; + + System.arraycopy(actual, 0, actuals, offset, outputSize); + System.arraycopy(ideal, 0, ideals, offset, outputSize); + } + + double MSSE = errorFunction.compute(ideals, actuals); + return MSSE; + } + + @Override + protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) { + Neuron[] outputNeurons = neuralNetwork.getOutputNeurons(); + double[] idealValues = ideal.getValues(); + + for (int i = 0; i < idealValues.length; i++) { + double input = outputNeurons[i].getInput(); + double derivative = outputNeurons[i].getActivationFunction() + .derivative(input); + outputNeurons[i].setGradient(outputNeurons[i].getGradient() + + derivative + * (idealValues[i] - outputNeurons[i].getOutput())); + } + // walking down the list of Neurons in reverse order, propagate the + // error + Neuron[] neurons = neuralNetwork.getNeurons(); + + for (int n = neurons.length - 1; n >= 0; n--) { + + Neuron neuron = neurons[n]; + double error = neuron.getGradient(); + + Connection[] connectionsFromN = neuralNetwork + .getConnectionsFrom(neuron.getId()); + if (connectionsFromN.length > 0) { + + double derivative = neuron.getActivationFunction().derivative( + neuron.getInput()); + for (Connection connection : connectionsFromN) { + error += derivative + * connection.getWeight() + * neuralNetwork.getNeuron(connection.getDest()) + .getGradient(); + } + } + neuron.setGradient(error); + } + } + + private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) { + for (Connection connection : neuralNetwork.getConnections()) { + Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); + Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); + + double delta = alpha * srcNeuron.getOutput() + * destNeuron.getGradient() * predictionError + connection.getTrace() * lambda; + + // TODO allow for momentum + // double lastDelta = connection.getLastDelta(); + connection.addDelta(delta); + } } + @Override + public void iterateSequences(FeedforwardNetwork neuralNetwork, + List> trainingSet) { + throw new UnsupportedOperationException(); + } + + @Override + public double computeSequenceError(FeedforwardNetwork neuralNetwork, + List> trainingSet) { + throw new UnsupportedOperationException(); + } + + @Override + protected void iteratePattern(FeedforwardNetwork neuralNetwork, + NNDataPair statePair, NNData nextReward) { + //System.out.println("Learningrate: " + alpha); + + zeroGradients(neuralNetwork); + + //System.out.println("Training with: " + statePair.getInput()); + + NNData ideal = nextReward; + NNData actual = neuralNetwork.compute(statePair); + + //System.out.println("Updating weights. Ideal Output: " + ideal); + //System.out.println("Actual Output: " + actual); + + // backpropagate the gradients w.r.t. output error + backPropagate(neuralNetwork, ideal); + + double predictionError = statePair.getIdeal().getValues()[0] // reward_t + + actual.getValues()[0] - nextReward.getValues()[0]; + + updateWeights(neuralNetwork, predictionError); + } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/TrainingMethod.java b/src/net/woodyfolsom/msproj/ann/TrainingMethod.java new file mode 100644 index 0000000..ea6c051 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/TrainingMethod.java @@ -0,0 +1,43 @@ +package net.woodyfolsom.msproj.ann; + +import java.util.List; + +import net.woodyfolsom.msproj.ann.math.ErrorFunction; +import net.woodyfolsom.msproj.ann.math.MSSE; + +public abstract class TrainingMethod { + protected final ErrorFunction errorFunction; + + public TrainingMethod() { + this.errorFunction = MSSE.function; + } + + protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork, + NNDataPair statePair, NNData nextReward); + + protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork, + List trainingSet); + + protected abstract double computePatternError(FeedforwardNetwork neuralNetwork, + List trainingSet); + + protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork, + List> trainingSet); + + protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output); + + protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork, + List> trainingSet); + + protected void zeroGradients(FeedforwardNetwork neuralNetwork) { + for (Neuron neuron : neuralNetwork.getNeurons()) { + neuron.setGradient(0.0); + } + } + + protected void zeroTraces(FeedforwardNetwork neuralNetwork) { + for (Connection conn : neuralNetwork.getConnections()) { + conn.setTrace(0.0); + } + } +} diff --git a/src/net/woodyfolsom/msproj/ann/WinFilter.java b/src/net/woodyfolsom/msproj/ann/WinFilter.java deleted file mode 100644 index c042691..0000000 --- a/src/net/woodyfolsom/msproj/ann/WinFilter.java +++ /dev/null @@ -1,105 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import java.util.List; -import java.util.Set; - -import org.encog.engine.network.activation.ActivationSigmoid; -import org.encog.ml.data.MLDataPair; -import org.encog.ml.data.MLDataSet; -import org.encog.ml.data.basic.BasicMLDataSet; -import org.encog.ml.train.MLTrain; -import org.encog.neural.networks.BasicNetwork; -import org.encog.neural.networks.layers.BasicLayer; -import org.encog.neural.networks.training.propagation.back.Backpropagation; - -public class WinFilter extends AbstractNeuralNetFilter implements - NeuralNetFilter { - - public WinFilter() { - // create a neural network, without using a factory - BasicNetwork network = new BasicNetwork(); - network.addLayer(new BasicLayer(null, false, 2)); - network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4)); - network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2)); - network.getStructure().finalizeStructure(); - network.reset(); - - this.neuralNetwork = network; - } - - @Override - public void learn(MLDataSet trainingData) { - throw new UnsupportedOperationException( - "This filter learns a Set>, not an MLDataSet"); - } - - /** - * Method is necessary because with temporal difference learning, some of - * the MLDataPairs are related by being a sequence of moves within a - * particular game. - */ - @Override - public void learn(Set> trainingSet) { - MLDataSet mlDataset = new BasicMLDataSet(); - - for (List gameRecord : trainingSet) { - for (int t = 0; t < gameRecord.size() - 1; t++) { - mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t) - .getInput())); - } - mlDataset.add(gameRecord.get(gameRecord.size() - 1)); - } - - // train the neural network - final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25); - //final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8); - actualTrainingEpochs = 0; - - do { - if (actualTrainingEpochs > 0) { - int gameStateIndex = 0; - for (List gameRecord : trainingSet) { - for (int t = 0; t < gameRecord.size() - 1; t++) { - MLDataPair oldDataPair = mlDataset.get(gameStateIndex); - this.neuralNetwork.compute(oldDataPair.getInput()); - gameStateIndex++; - } - gameStateIndex++; - } - } - train.iteration(); - System.out.println("Epoch #" + actualTrainingEpochs + " Error:" - + train.getError()); - actualTrainingEpochs++; - } while (train.getError() > 0.01 - && actualTrainingEpochs <= maxTrainingEpochs); - } - - @Override - public void reset() { - neuralNetwork.reset(); - } - - @Override - public void reset(int seed) { - neuralNetwork.reset(seed); - } - - @Override - public BasicNetwork getNeuralNetwork() { - // TODO Auto-generated method stub - return null; - } - - @Override - public int getInputSize() { - // TODO Auto-generated method stub - return 0; - } - - @Override - public int getOutputSize() { - // TODO Auto-generated method stub - return 0; - } -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/XORFilter.java b/src/net/woodyfolsom/msproj/ann/XORFilter.java index 1021dc5..19e15d4 100644 --- a/src/net/woodyfolsom/msproj/ann/XORFilter.java +++ b/src/net/woodyfolsom/msproj/ann/XORFilter.java @@ -1,18 +1,5 @@ package net.woodyfolsom.msproj.ann; -import java.util.List; -import java.util.Set; - -import org.encog.engine.network.activation.ActivationSigmoid; -import org.encog.ml.data.MLData; -import org.encog.ml.data.MLDataPair; -import org.encog.ml.data.MLDataSet; -import org.encog.ml.data.basic.BasicMLData; -import org.encog.ml.train.MLTrain; -import org.encog.neural.networks.BasicNetwork; -import org.encog.neural.networks.layers.BasicLayer; -import org.encog.neural.networks.training.propagation.back.Backpropagation; - /** * Based on sample code from http://neuroph.sourceforge.net * @@ -22,54 +9,30 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation; public class XORFilter extends AbstractNeuralNetFilter implements NeuralNetFilter { + private static final int INPUT_SIZE = 2; + private static final int OUTPUT_SIZE = 1; + public XORFilter() { - // create a neural network, without using a factory - BasicNetwork network = new BasicNetwork(); - network.addLayer(new BasicLayer(null, false, 2)); - network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3)); - network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1)); - network.getStructure().finalizeStructure(); - network.reset(); - - this.neuralNetwork = network; + this(0.8,0.7); + } + + public XORFilter(double learningRate, double momentum) { + super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE), + new BackPropagation(learningRate, momentum), 1000, 0.01); + super.getNeuralNetwork().setName("XORFilter"); } public double compute(double x, double y) { - return compute(new BasicMLData(new double[]{x,y})).getData(0); + return getNeuralNetwork().compute(new double[]{x,y})[0]; } @Override public int getInputSize() { - return 2; + return INPUT_SIZE; } @Override public int getOutputSize() { - // TODO Auto-generated method stub - return 1; - } - - @Override - public void learn(MLDataSet trainingSet) { - - // train the neural network - final MLTrain train = new Backpropagation(neuralNetwork, trainingSet, - 0.7, 0.8); - - actualTrainingEpochs = 0; - - do { - train.iteration(); - System.out.println("Epoch #" + actualTrainingEpochs + " Error:" - + train.getError()); - actualTrainingEpochs++; - } while (train.getError() > 0.01 - && actualTrainingEpochs <= maxTrainingEpochs); - } - - @Override - public void learn(Set> trainingSet) { - throw new UnsupportedOperationException( - "This Filter learns an MLDataSet, not a Set>."); + return OUTPUT_SIZE; } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java b/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java similarity index 91% rename from src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java rename to src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java index c82a307..c941739 100644 --- a/src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java +++ b/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2.math; +package net.woodyfolsom.msproj.ann.math; import javax.xml.bind.annotation.XmlAttribute; diff --git a/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java b/src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java similarity index 64% rename from src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java rename to src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java index 789da3b..956533b 100644 --- a/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java +++ b/src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2.math; +package net.woodyfolsom.msproj.ann.math; public interface ErrorFunction { double compute(double[] ideal, double[] actual); diff --git a/src/net/woodyfolsom/msproj/ann2/math/Linear.java b/src/net/woodyfolsom/msproj/ann/math/Linear.java similarity index 81% rename from src/net/woodyfolsom/msproj/ann2/math/Linear.java rename to src/net/woodyfolsom/msproj/ann/math/Linear.java index 3e9ab4a..c5de8e8 100644 --- a/src/net/woodyfolsom/msproj/ann2/math/Linear.java +++ b/src/net/woodyfolsom/msproj/ann/math/Linear.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2.math; +package net.woodyfolsom.msproj.ann.math; public class Linear extends ActivationFunction{ public static final Linear function = new Linear(); diff --git a/src/net/woodyfolsom/msproj/ann2/math/MSSE.java b/src/net/woodyfolsom/msproj/ann/math/MSSE.java similarity index 88% rename from src/net/woodyfolsom/msproj/ann2/math/MSSE.java rename to src/net/woodyfolsom/msproj/ann/math/MSSE.java index b4f85d1..960a5a9 100644 --- a/src/net/woodyfolsom/msproj/ann2/math/MSSE.java +++ b/src/net/woodyfolsom/msproj/ann/math/MSSE.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2.math; +package net.woodyfolsom.msproj.ann.math; public class MSSE implements ErrorFunction{ public static final ErrorFunction function = new MSSE(); diff --git a/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java b/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java similarity index 55% rename from src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java rename to src/net/woodyfolsom/msproj/ann/math/Sigmoid.java index 530bd0a..de1e168 100644 --- a/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java +++ b/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2.math; +package net.woodyfolsom.msproj.ann.math; public class Sigmoid extends ActivationFunction{ public static final Sigmoid function = new Sigmoid(); @@ -12,9 +12,9 @@ public class Sigmoid extends ActivationFunction{ } public double derivative(double arg) { - //lol wth? - //double eX = Math.exp(arg); - //return eX / (Math.pow((1+eX), 2)); - return arg - Math.pow(arg,2); + //lol wth? oh, the next derivative formula is a function of s(x), not x. + double eX = Math.exp(arg); + return eX / (Math.pow((1+eX), 2)); + //return arg - Math.pow(arg,2); } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/math/Tanh.java b/src/net/woodyfolsom/msproj/ann/math/Tanh.java similarity index 84% rename from src/net/woodyfolsom/msproj/ann2/math/Tanh.java rename to src/net/woodyfolsom/msproj/ann/math/Tanh.java index 16108f4..2ab3ec4 100644 --- a/src/net/woodyfolsom/msproj/ann2/math/Tanh.java +++ b/src/net/woodyfolsom/msproj/ann/math/Tanh.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2.math; +package net.woodyfolsom.msproj.ann.math; public class Tanh extends ActivationFunction{ public static final Tanh function = new Tanh(); diff --git a/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java deleted file mode 100644 index d24bd06..0000000 --- a/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java +++ /dev/null @@ -1,84 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -import java.io.InputStream; -import java.io.OutputStream; -import java.util.List; - -public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { - private final FeedforwardNetwork neuralNetwork; - private final TrainingMethod trainingMethod; - - private double maxError; - private int actualTrainingEpochs = 0; - private int maxTrainingEpochs; - - AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) { - this.neuralNetwork = neuralNetwork; - this.trainingMethod = trainingMethod; - this.maxError = maxError; - this.maxTrainingEpochs = maxTrainingEpochs; - } - - @Override - public NNData compute(NNDataPair input) { - return this.neuralNetwork.compute(input); - } - - public int getActualTrainingEpochs() { - return actualTrainingEpochs; - } - - @Override - public int getInputSize() { - return 2; - } - - public int getMaxTrainingEpochs() { - return maxTrainingEpochs; - } - - protected FeedforwardNetwork getNeuralNetwork() { - return neuralNetwork; - } - - @Override - public void learn(List trainingSet) { - actualTrainingEpochs = 0; - double error; - neuralNetwork.initWeights(); - - error = trainingMethod.computeError(neuralNetwork,trainingSet); - - if (error <= maxError) { - System.out.println("Initial error: " + error); - return; - } - - do { - trainingMethod.iterate(neuralNetwork,trainingSet); - error = trainingMethod.computeError(neuralNetwork,trainingSet); - System.out.println("Epoch #" + actualTrainingEpochs + " Error:" - + error); - actualTrainingEpochs++; - System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error); - } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs); - } - - @Override - public boolean load(InputStream input) { - return neuralNetwork.load(input); - } - - @Override - public boolean save(OutputStream output) { - return neuralNetwork.save(output); - } - - public void setMaxError(double maxError) { - this.maxError = maxError; - } - - public void setMaxTrainingEpochs(int max) { - this.maxTrainingEpochs = max; - } -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java deleted file mode 100644 index 71a2dde..0000000 --- a/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java +++ /dev/null @@ -1,25 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -import java.io.InputStream; -import java.io.OutputStream; -import java.util.List; - -public interface NeuralNetFilter { - int getActualTrainingEpochs(); - - int getInputSize(); - - int getMaxTrainingEpochs(); - - int getOutputSize(); - - boolean load(InputStream input); - - boolean save(OutputStream output); - - void setMaxTrainingEpochs(int max); - - NNData compute(NNDataPair input); - - void learn(List trainingSet); -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java b/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java deleted file mode 100644 index 4eac691..0000000 --- a/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java +++ /dev/null @@ -1,5 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -public class ObjectiveFunction { - -} diff --git a/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java deleted file mode 100644 index b588029..0000000 --- a/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java +++ /dev/null @@ -1,19 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -import java.util.List; - -public class TemporalDifference implements TrainingMethod { - - @Override - public void iterate(FeedforwardNetwork neuralNetwork, - List trainingSet) { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public double computeError(FeedforwardNetwork neuralNetwork, - List trainingSet) { - throw new UnsupportedOperationException("Not implemented"); - } - -} diff --git a/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java b/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java deleted file mode 100644 index b109034..0000000 --- a/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java +++ /dev/null @@ -1,10 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -import java.util.List; - -public interface TrainingMethod { - - void iterate(FeedforwardNetwork neuralNetwork, List trainingSet); - double computeError(FeedforwardNetwork neuralNetwork, List trainingSet); - -} diff --git a/src/net/woodyfolsom/msproj/ann2/XORFilter.java b/src/net/woodyfolsom/msproj/ann2/XORFilter.java deleted file mode 100644 index 7ffd6c4..0000000 --- a/src/net/woodyfolsom/msproj/ann2/XORFilter.java +++ /dev/null @@ -1,44 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -/** - * Based on sample code from http://neuroph.sourceforge.net - * - * @author Woody - * - */ -public class XORFilter extends AbstractNeuralNetFilter implements - NeuralNetFilter { - - private static final int INPUT_SIZE = 2; - private static final int OUTPUT_SIZE = 1; - - public XORFilter() { - this(0.8,0.7); - } - - public XORFilter(double learningRate, double momentum) { - super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE), - new BackPropagation(learningRate, momentum), 1000, 0.01); - super.getNeuralNetwork().setName("XORFilter"); - - //TODO remove - //getNeuralNetwork().setWeights(new double[] { - // 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias - // -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias - // -0.993423, 0.164732, 0.752621}); //output - } - - public double compute(double x, double y) { - return getNeuralNetwork().compute(new double[]{x,y})[0]; - } - - @Override - public int getInputSize() { - return INPUT_SIZE; - } - - @Override - public int getOutputSize() { - return OUTPUT_SIZE; - } -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/tictactoe/Action.java b/src/net/woodyfolsom/msproj/tictactoe/Action.java new file mode 100644 index 0000000..8ad5ae6 --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/Action.java @@ -0,0 +1,54 @@ +package net.woodyfolsom.msproj.tictactoe; + +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class Action { + public static final Action NONE = new Action(PLAYER.NONE, -1, -1); + + private Game.PLAYER player; + private int row; + private int column; + + public static Action getInstance(PLAYER player, int row, int column) { + return new Action(player,row,column); + } + + private Action(PLAYER player, int row, int column) { + this.player = player; + this.row = row; + this.column = column; + } + + public Game.PLAYER getPlayer() { + return player; + } + + public int getColumn() { + return column; + } + + public int getRow() { + return row; + } + + public boolean isNone() { + return this == Action.NONE; + } + + public void setPlayer(Game.PLAYER player) { + this.player = player; + } + + public void setRow(int row) { + this.row = row; + } + + public void setColumn(int column) { + this.column = column; + } + + @Override + public String toString() { + return player + "(" + row + ", " + column + ")"; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/tictactoe/Game.java b/src/net/woodyfolsom/msproj/tictactoe/Game.java new file mode 100644 index 0000000..2635d57 --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/Game.java @@ -0,0 +1,5 @@ +package net.woodyfolsom.msproj.tictactoe; + +public class Game { + public enum PLAYER {X,O,NONE} +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/GameRecord.java b/src/net/woodyfolsom/msproj/tictactoe/GameRecord.java new file mode 100644 index 0000000..23ffdc6 --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/GameRecord.java @@ -0,0 +1,63 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.ArrayList; +import java.util.List; + +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class GameRecord { + public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS} + + private List actions = new ArrayList(); + private List states = new ArrayList(); + + private RESULT result = RESULT.IN_PROGRESS; + + public GameRecord() { + actions.add(Action.NONE); + states.add(new State()); + } + + public void addState(State state) { + states.add(state); + } + + public State apply(Action action) { + State nextState = getState().apply(action); + if (nextState.isValid()) { + states.add(nextState); + actions.add(action); + } + + if (nextState.isTerminal()) { + if (nextState.isWinner(PLAYER.X)) { + result = RESULT.X_WINS; + } else if (nextState.isWinner(PLAYER.O)) { + result = RESULT.O_WINS; + } else { + result = RESULT.TIE_GAME; + } + } + return nextState; + } + + public int getNumStates() { + return states.size(); + } + + public RESULT getResult() { + return result; + } + + public void setResult(RESULT result) { + this.result = result; + } + + public State getState() { + return states.get(states.size()-1); + } + + public State getState(int index) { + return states.get(index); + } +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java b/src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java new file mode 100644 index 0000000..888963e --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java @@ -0,0 +1,20 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.ArrayList; +import java.util.List; + +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class MoveGenerator { + public List getValidActions(State state) { + PLAYER playerToMove = state.getPlayerToMove(); + List validActions = new ArrayList(); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + if (state.isEmpty(i,j)) + validActions.add(Action.getInstance(playerToMove, i, j)); + } + } + return validActions; + } +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java b/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java new file mode 100644 index 0000000..efedfba --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java @@ -0,0 +1,81 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.ArrayList; +import java.util.List; + +import net.woodyfolsom.msproj.ann.NNData; +import net.woodyfolsom.msproj.ann.NNDataPair; +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class NNDataSetFactory { + public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"}; + public static final String[] TTT_OUTPUT_FIELDS = {"value"}; + + public static List> createDataSet(List tttGames) { + + List> nnDataSet = new ArrayList>(); + + for (GameRecord tttGame : tttGames) { + List gameData = createDataPairList(tttGame); + + + nnDataSet.add(gameData); + } + + return nnDataSet; + } + + public static List createDataPairList(GameRecord gameRecord) { + List gameData = new ArrayList(); + + for (int i = 0; i < gameRecord.getNumStates(); i++) { + gameData.add(createDataPair(gameRecord.getState(i))); + } + + return gameData; + } + + public static NNDataPair createDataPair(State tttState) { + double value; + if (tttState.isTerminal()) { + if (tttState.isWinner(PLAYER.X)) { + value = 1.0; // win for black + } else if (tttState.isWinner(PLAYER.O)) { + value = 0.0; // loss for black + //value = -1.0; + } else { + value = 0.5; + //value = 0.0; //tie + } + } else { + value = 0.0; + } + + double[] inputValues = new double[9]; + char[] boardCopy = tttState.getBoard(); + inputValues[0] = getTicTacToeInput(boardCopy, 0, 0); + inputValues[1] = getTicTacToeInput(boardCopy, 0, 1); + inputValues[2] = getTicTacToeInput(boardCopy, 0, 2); + inputValues[3] = getTicTacToeInput(boardCopy, 1, 0); + inputValues[4] = getTicTacToeInput(boardCopy, 1, 1); + inputValues[5] = getTicTacToeInput(boardCopy, 1, 2); + inputValues[6] = getTicTacToeInput(boardCopy, 2, 0); + inputValues[7] = getTicTacToeInput(boardCopy, 2, 1); + inputValues[8] = getTicTacToeInput(boardCopy, 2, 2); + + return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value})); + } + + private static double getTicTacToeInput(char[] board, int row, int column) { + switch (board[row*3+column]) { + case 'X' : + return 1.0; + case 'O' : + return -1.0; + case '.' : + return 0.0; + default: + throw new RuntimeException("Invalid board symbol at " + row +", " + column); + } + } +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java b/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java new file mode 100644 index 0000000..4bb0fb8 --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java @@ -0,0 +1,67 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import net.woodyfolsom.msproj.ann.FeedforwardNetwork; +import net.woodyfolsom.msproj.ann.NNDataPair; +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class NeuralNetPolicy extends Policy { + private FeedforwardNetwork neuralNet; + private MoveGenerator moveGenerator = new MoveGenerator(); + + public NeuralNetPolicy(FeedforwardNetwork neuralNet) { + super("NeuralNet-" + neuralNet.getName()); + this.neuralNet = neuralNet; + } + + @Override + public Action getAction(State state) { + List validMoves = moveGenerator.getValidActions(state); + Map scores = new HashMap(); + + for (Action action : validMoves) { + State nextState = state.apply(action); + NNDataPair dataPair = NNDataSetFactory.createDataPair(state); + //estimated reward for X + scores.put(action, neuralNet.compute(dataPair).getValues()[0]); + } + + PLAYER playerToMove = state.getPlayerToMove(); + + if (playerToMove == PLAYER.X) { + return returnMaxAction(scores); + } else if (playerToMove == PLAYER.O) { + return returnMinAction(scores); + } else { + throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove); + } + //return validMoves.get((int)(Math.random() * validMoves.size())); + } + + private Action returnMaxAction(Map scores) { + Action bestAction = null; + Double bestScore = Double.NEGATIVE_INFINITY; + for (Map.Entry entry : scores.entrySet()) { + if (entry.getValue() > bestScore) { + bestScore = entry.getValue(); + bestAction = entry.getKey(); + } + } + return bestAction; + } + + private Action returnMinAction(Map scores) { + Action bestAction = null; + Double bestScore = Double.POSITIVE_INFINITY; + for (Map.Entry entry : scores.entrySet()) { + if (entry.getValue() < bestScore) { + bestScore = entry.getValue(); + bestAction = entry.getKey(); + } + } + return bestAction; + } +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/Policy.java b/src/net/woodyfolsom/msproj/tictactoe/Policy.java new file mode 100644 index 0000000..c81ecb4 --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/Policy.java @@ -0,0 +1,15 @@ +package net.woodyfolsom.msproj.tictactoe; + +public abstract class Policy { + private String name; + + protected Policy(String name) { + this.name = name; + } + + public abstract Action getAction(State state); + + public String getName() { + return name; + } +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java b/src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java new file mode 100644 index 0000000..3b7bb22 --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java @@ -0,0 +1,18 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.List; + +public class RandomPolicy extends Policy { + private MoveGenerator moveGenerator = new MoveGenerator(); + + public RandomPolicy() { + super("Random"); + } + + @Override + public Action getAction(State state) { + List validMoves = moveGenerator.getValidActions(state); + return validMoves.get((int)(Math.random() * validMoves.size())); + } + +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/Referee.java b/src/net/woodyfolsom/msproj/tictactoe/Referee.java new file mode 100644 index 0000000..ed1cf4e --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/Referee.java @@ -0,0 +1,43 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.ArrayList; +import java.util.List; + +public class Referee { + + public static void main(String[] args) { + new Referee().play(50); + } + + public List play(int nGames) { + Policy policy = new RandomPolicy(); + + List tournament = new ArrayList(); + + for (int i = 0; i < nGames; i++) { + GameRecord gameRecord = new GameRecord(); + + System.out.println("Playing game #" +(i+1)); + + State state; + do { + Action action = policy.getAction(gameRecord.getState()); + System.out.println("Action " + action + " selected by policy " + policy.getName()); + state = gameRecord.apply(action); + System.out.println("Next board state:"); + System.out.println(gameRecord.getState()); + } while (!state.isTerminal()); + System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult()); + tournament.add(gameRecord); + } + + System.out.println("Played " + tournament.size() + " random games."); + System.out.println("Results:"); + for (int i = 0; i < tournament.size(); i++) { + GameRecord gameRecord = tournament.get(i); + System.out.println((i+1) + ". " + gameRecord.getResult()); + } + + return tournament; + } +} diff --git a/src/net/woodyfolsom/msproj/tictactoe/State.java b/src/net/woodyfolsom/msproj/tictactoe/State.java new file mode 100644 index 0000000..7d3ce6b --- /dev/null +++ b/src/net/woodyfolsom/msproj/tictactoe/State.java @@ -0,0 +1,116 @@ +package net.woodyfolsom.msproj.tictactoe; + +import java.util.Arrays; + +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class State { + public static final State INVALID = new State(); + public static char EMPTY_SQUARE = '.'; + + private char[] board; + private PLAYER playerToMove; + + public State() { + playerToMove = Game.PLAYER.X; + board = new char[9]; + Arrays.fill(board,'.'); + } + + private State(State that) { + this.board = Arrays.copyOf(that.board, that.board.length); + this.playerToMove = that.playerToMove; + } + + public State apply(Action action) { + if (action.getPlayer() != playerToMove) { + System.out.println("It is not " + action.getPlayer() +"'s turn."); + return State.INVALID; + } + State nextState = new State(this); + + int row = action.getRow(); + int column = action.getColumn(); + int dest = row * 3 + column; + + if (board[dest] != EMPTY_SQUARE) { + System.out.println("Invalid move " + action + ", coordinate not empty."); + return State.INVALID; + } + switch (playerToMove) { + case X : nextState.board[dest] = 'X'; + break; + case O : nextState.board[dest] = 'O'; + break; + default: + throw new RuntimeException("Invalid playerToMove"); + } + + if (playerToMove == PLAYER.X) { + nextState.playerToMove = PLAYER.O; + } else { + nextState.playerToMove = PLAYER.X; + } + return nextState; + } + + public char[] getBoard() { + return Arrays.copyOf(board, board.length); + } + + public PLAYER getPlayerToMove() { + return playerToMove; + } + + public boolean isEmpty(int row, int column) { + return board[row*3+column] == EMPTY_SQUARE; + } + + public boolean isFull(char mark1, char mark2, char mark3) { + return mark1 != '.' && mark2 != '.' && mark3 != '.'; + } + + public boolean isWinner(PLAYER player) { + return isWin(player,board[0],board[1],board[2]) || + isWin(player,board[3],board[4],board[5]) || + isWin(player,board[6],board[7],board[8]) || + isWin(player,board[0],board[3],board[6]) || + isWin(player,board[1],board[4],board[7]) || + isWin(player,board[2],board[5],board[8]) || + isWin(player,board[0],board[4],board[8]) || + isWin(player,board[2],board[4],board[6]); + } + + public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) { + if (isFull(mark1,mark2,mark3)) { + switch (player) { + case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X'; + case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O'; + default : + return false; + } + } else { + return false; + } + } + + public boolean isTerminal() { + return isWinner(PLAYER.X) || isWinner(PLAYER.O) || + (isFull(board[0],board[1], board[2]) && + isFull(board[3],board[4], board[5]) && + isFull(board[6],board[7], board[8])); + } + + public boolean isValid() { + return this != INVALID; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n"); + sb.append(""+board[0] + board[1] + board[2] + "\n"); + sb.append(""+board[3] + board[4] + board[5] + "\n"); + sb.append(""+board[6] + board[7] + board[8] + "\n"); + return sb.toString(); + } +} diff --git a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java b/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java similarity index 89% rename from test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java rename to test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java index 7dc8633..f6d840b 100644 --- a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java +++ b/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java @@ -1,4 +1,4 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -10,6 +10,12 @@ import java.io.IOException; import javax.xml.bind.JAXBException; +import net.woodyfolsom.msproj.ann.Connection; +import net.woodyfolsom.msproj.ann.FeedforwardNetwork; +import net.woodyfolsom.msproj.ann.MultiLayerPerceptron; +import net.woodyfolsom.msproj.ann.NNData; +import net.woodyfolsom.msproj.ann.NNDataPair; + import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; diff --git a/test/net/woodyfolsom/msproj/ann/TTTFilterTest.java b/test/net/woodyfolsom/msproj/ann/TTTFilterTest.java new file mode 100644 index 0000000..60af997 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann/TTTFilterTest.java @@ -0,0 +1,100 @@ +package net.woodyfolsom.msproj.ann; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +import net.woodyfolsom.msproj.ann.NNData; +import net.woodyfolsom.msproj.ann.NNDataPair; +import net.woodyfolsom.msproj.ann.NeuralNetFilter; +import net.woodyfolsom.msproj.ann.TTTFilter; +import net.woodyfolsom.msproj.tictactoe.GameRecord; +import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory; +import net.woodyfolsom.msproj.tictactoe.Referee; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TTTFilterTest { + private static final String FILENAME = "tttPerceptron.net"; + + @AfterClass + public static void deleteNewNet() { + File file = new File(FILENAME); + if (file.exists()) { + file.delete(); + } + } + + @BeforeClass + public static void deleteSavedNet() { + File file = new File(FILENAME); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void testLearn() throws IOException { + double alpha = 0.5; + double lambda = 0.0; + int maxEpochs = 1000; + + NeuralNetFilter nnLearner = new TTTFilter(alpha, lambda, maxEpochs); + + // Create trainingSet from a tournament of random games. + // Future iterations will use Epsilon-greedy play from a policy based on + // this network to generate additional datasets. + List tournament = new Referee().play(1); + List> trainingSet = NNDataSetFactory + .createDataSet(tournament); + + System.out.println("Generated " + trainingSet.size() + + " datasets from random self-play."); + nnLearner.learnSequences(trainingSet); + System.out.println("Learned network after " + + nnLearner.getActualTrainingEpochs() + " training epochs."); + + double[][] validationSet = new double[7][]; + + // empty board + validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // center + validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0 }; + // top edge + validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // left edge + validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // corner + validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0 }; + // win + validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0, + -1.0, 0.0 }; + // loss + validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0, + 0.0, -1.0 }; + + String[] inputNames = new String[] { "00", "01", "02", "10", "11", + "12", "20", "21", "22" }; + String[] outputNames = new String[] { "values" }; + + System.out.println("Output from eval set (learned network):"); + testNetwork(nnLearner, validationSet, inputNames, outputNames); + } + + private void testNetwork(NeuralNetFilter nnLearner, + double[][] validationSet, String[] inputNames, String[] outputNames) { + for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { + NNDataPair dp = new NNDataPair(new NNData(inputNames, + validationSet[valIndex]), new NNData(outputNames, + validationSet[valIndex])); + System.out.println(dp + " => " + nnLearner.compute(dp)); + } + } +} \ No newline at end of file diff --git a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java deleted file mode 100644 index bcb0e18..0000000 --- a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java +++ /dev/null @@ -1,64 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import java.io.File; -import java.io.FileFilter; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import net.woodyfolsom.msproj.GameRecord; -import net.woodyfolsom.msproj.Referee; - -import org.antlr.runtime.RecognitionException; -import org.encog.ml.data.MLData; -import org.encog.ml.data.MLDataPair; -import org.junit.Test; - -public class WinFilterTest { - - @Test - public void testLearnSaveLoad() throws IOException, RecognitionException { - File[] sgfFiles = new File("data/games/random_vs_random") - .listFiles(new FileFilter() { - @Override - public boolean accept(File pathname) { - return pathname.getName().endsWith(".sgf"); - } - }); - - Set> trainingData = new HashSet>(); - - for (File file : sgfFiles) { - FileInputStream fis = new FileInputStream(file); - GameRecord gameRecord = Referee.replay(fis); - - List gameData = new ArrayList(); - for (int i = 0; i <= gameRecord.getNumTurns(); i++) { - gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i))); - } - - trainingData.add(gameData); - - fis.close(); - } - - WinFilter winFilter = new WinFilter(); - - winFilter.learn(trainingData); - - for (List trainingSequence : trainingData) { - for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) { - if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) { - continue; - } - MLData input = trainingSequence.get(stateIndex).getInput(); - - System.out.println("Turn " + stateIndex + ": " + input + " => " - + winFilter.compute(input)); - } - } - } -} diff --git a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java index 8a39c15..0ac82d8 100644 --- a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java +++ b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java @@ -1,10 +1,19 @@ package net.woodyfolsom.msproj.ann; -import java.io.File; -import java.io.IOException; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import net.woodyfolsom.msproj.ann.NNData; +import net.woodyfolsom.msproj.ann.NNDataPair; +import net.woodyfolsom.msproj.ann.NeuralNetFilter; +import net.woodyfolsom.msproj.ann.XORFilter; -import org.encog.ml.data.MLDataSet; -import org.encog.ml.data.basic.BasicMLDataSet; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -29,14 +38,55 @@ public class XORFilterTest { } @Test - public void testLearnSaveLoad() throws IOException { - NeuralNetFilter nnLearner = new XORFilter(); - System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); + public void testLearn() throws IOException { + NeuralNetFilter nnLearner = new XORFilter(0.5,0.0); // create training set (logical XOR function) int size = 1; double[][] trainingInput = new double[4 * size][]; double[][] trainingOutput = new double[4 * size][]; + for (int i = 0; i < size; i++) { + trainingInput[i * 4 + 0] = new double[] { 0, 0 }; + trainingInput[i * 4 + 1] = new double[] { 0, 1 }; + trainingInput[i * 4 + 2] = new double[] { 1, 0 }; + trainingInput[i * 4 + 3] = new double[] { 1, 1 }; + trainingOutput[i * 4 + 0] = new double[] { 0 }; + trainingOutput[i * 4 + 1] = new double[] { 1 }; + trainingOutput[i * 4 + 2] = new double[] { 1 }; + trainingOutput[i * 4 + 3] = new double[] { 0 }; + } + + // create training data + List trainingSet = new ArrayList(); + String[] inputNames = new String[] {"x","y"}; + String[] outputNames = new String[] {"XOR"}; + for (int i = 0; i < 4*size; i++) { + trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i]))); + } + + nnLearner.setMaxTrainingEpochs(20000); + nnLearner.learnPatterns(trainingSet); + System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); + + double[][] validationSet = new double[4][2]; + + validationSet[0] = new double[] { 0, 0 }; + validationSet[1] = new double[] { 0, 1 }; + validationSet[2] = new double[] { 1, 0 }; + validationSet[3] = new double[] { 1, 1 }; + + System.out.println("Output from eval set (learned network):"); + testNetwork(nnLearner, validationSet, inputNames, outputNames); + } + + @Test + public void testLearnSaveLoad() throws IOException { + NeuralNetFilter nnLearner = new XORFilter(0.5,0.0); + + // create training set (logical XOR function) + int size = 2; + double[][] trainingInput = new double[4 * size][]; + double[][] trainingOutput = new double[4 * size][]; for (int i = 0; i < size; i++) { trainingInput[i * 4 + 0] = new double[] { 0, 0 }; trainingInput[i * 4 + 1] = new double[] { 0, 1 }; @@ -49,10 +99,17 @@ public class XORFilterTest { } // create training data - MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput); + List trainingSet = new ArrayList(); + String[] inputNames = new String[] {"x","y"}; + String[] outputNames = new String[] {"XOR"}; + for (int i = 0; i < 4*size; i++) { + trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i]))); + } + + nnLearner.setMaxTrainingEpochs(1); + nnLearner.learnPatterns(trainingSet); + System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); - nnLearner.learn(trainingSet); - double[][] validationSet = new double[4][2]; validationSet[0] = new double[] { 0, 0 }; @@ -61,18 +118,23 @@ public class XORFilterTest { validationSet[3] = new double[] { 1, 1 }; System.out.println("Output from eval set (learned network, pre-serialization):"); - testNetwork(nnLearner, validationSet); - - nnLearner.save(FILENAME); - nnLearner.load(FILENAME); + testNetwork(nnLearner, validationSet, inputNames, outputNames); + FileOutputStream fos = new FileOutputStream(FILENAME); + assertTrue(nnLearner.save(fos)); + fos.close(); + + FileInputStream fis = new FileInputStream(FILENAME); + assertTrue(nnLearner.load(fis)); + fis.close(); + System.out.println("Output from eval set (learned network, post-serialization):"); - testNetwork(nnLearner, validationSet); + testNetwork(nnLearner, validationSet, inputNames, outputNames); } - private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) { + private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { - DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]); + NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex])); System.out.println(dp + " => " + nnLearner.compute(dp)); } } diff --git a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java b/test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java similarity index 71% rename from test/net/woodyfolsom/msproj/ann2/SigmoidTest.java rename to test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java index 6cbaffe..8a4722b 100644 --- a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java +++ b/test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java @@ -1,11 +1,11 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann.math; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import net.woodyfolsom.msproj.ann2.math.ActivationFunction; -import net.woodyfolsom.msproj.ann2.math.Sigmoid; -import net.woodyfolsom.msproj.ann2.math.Tanh; +import net.woodyfolsom.msproj.ann.math.ActivationFunction; +import net.woodyfolsom.msproj.ann.math.Sigmoid; +import net.woodyfolsom.msproj.ann.math.Tanh; import org.junit.Test; diff --git a/test/net/woodyfolsom/msproj/ann2/TanhTest.java b/test/net/woodyfolsom/msproj/ann/math/TanhTest.java similarity index 75% rename from test/net/woodyfolsom/msproj/ann2/TanhTest.java rename to test/net/woodyfolsom/msproj/ann/math/TanhTest.java index 8429fa1..942ea85 100644 --- a/test/net/woodyfolsom/msproj/ann2/TanhTest.java +++ b/test/net/woodyfolsom/msproj/ann/math/TanhTest.java @@ -1,10 +1,10 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann.math; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import net.woodyfolsom.msproj.ann2.math.ActivationFunction; -import net.woodyfolsom.msproj.ann2.math.Tanh; +import net.woodyfolsom.msproj.ann.math.ActivationFunction; +import net.woodyfolsom.msproj.ann.math.Tanh; import org.junit.Test; diff --git a/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java deleted file mode 100644 index b9d8887..0000000 --- a/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java +++ /dev/null @@ -1,136 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -import static org.junit.Assert.assertTrue; - -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -public class XORFilterTest { - private static final String FILENAME = "xorPerceptron.net"; - - @AfterClass - public static void deleteNewNet() { - File file = new File(FILENAME); - if (file.exists()) { - file.delete(); - } - } - - @BeforeClass - public static void deleteSavedNet() { - File file = new File(FILENAME); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void testLearn() throws IOException { - NeuralNetFilter nnLearner = new XORFilter(0.05,0.0); - - // create training set (logical XOR function) - int size = 1; - double[][] trainingInput = new double[4 * size][]; - double[][] trainingOutput = new double[4 * size][]; - for (int i = 0; i < size; i++) { - trainingInput[i * 4 + 0] = new double[] { 0, 0 }; - trainingInput[i * 4 + 1] = new double[] { 0, 1 }; - trainingInput[i * 4 + 2] = new double[] { 1, 0 }; - trainingInput[i * 4 + 3] = new double[] { 1, 1 }; - trainingOutput[i * 4 + 0] = new double[] { 0 }; - trainingOutput[i * 4 + 1] = new double[] { 1 }; - trainingOutput[i * 4 + 2] = new double[] { 1 }; - trainingOutput[i * 4 + 3] = new double[] { 0 }; - } - - // create training data - List trainingSet = new ArrayList(); - String[] inputNames = new String[] {"x","y"}; - String[] outputNames = new String[] {"XOR"}; - for (int i = 0; i < 4*size; i++) { - trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i]))); - } - - nnLearner.setMaxTrainingEpochs(20000); - nnLearner.learn(trainingSet); - System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); - - double[][] validationSet = new double[4][2]; - - validationSet[0] = new double[] { 0, 0 }; - validationSet[1] = new double[] { 0, 1 }; - validationSet[2] = new double[] { 1, 0 }; - validationSet[3] = new double[] { 1, 1 }; - - System.out.println("Output from eval set (learned network):"); - testNetwork(nnLearner, validationSet, inputNames, outputNames); - } - - @Test - public void testLearnSaveLoad() throws IOException { - NeuralNetFilter nnLearner = new XORFilter(0.5,0.0); - - // create training set (logical XOR function) - int size = 2; - double[][] trainingInput = new double[4 * size][]; - double[][] trainingOutput = new double[4 * size][]; - for (int i = 0; i < size; i++) { - trainingInput[i * 4 + 0] = new double[] { 0, 0 }; - trainingInput[i * 4 + 1] = new double[] { 0, 1 }; - trainingInput[i * 4 + 2] = new double[] { 1, 0 }; - trainingInput[i * 4 + 3] = new double[] { 1, 1 }; - trainingOutput[i * 4 + 0] = new double[] { 0 }; - trainingOutput[i * 4 + 1] = new double[] { 1 }; - trainingOutput[i * 4 + 2] = new double[] { 1 }; - trainingOutput[i * 4 + 3] = new double[] { 0 }; - } - - // create training data - List trainingSet = new ArrayList(); - String[] inputNames = new String[] {"x","y"}; - String[] outputNames = new String[] {"XOR"}; - for (int i = 0; i < 4*size; i++) { - trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i]))); - } - - nnLearner.setMaxTrainingEpochs(1); - nnLearner.learn(trainingSet); - System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); - - double[][] validationSet = new double[4][2]; - - validationSet[0] = new double[] { 0, 0 }; - validationSet[1] = new double[] { 0, 1 }; - validationSet[2] = new double[] { 1, 0 }; - validationSet[3] = new double[] { 1, 1 }; - - System.out.println("Output from eval set (learned network, pre-serialization):"); - testNetwork(nnLearner, validationSet, inputNames, outputNames); - - FileOutputStream fos = new FileOutputStream(FILENAME); - assertTrue(nnLearner.save(fos)); - fos.close(); - - FileInputStream fis = new FileInputStream(FILENAME); - assertTrue(nnLearner.load(fis)); - fis.close(); - - System.out.println("Output from eval set (learned network, post-serialization):"); - testNetwork(nnLearner, validationSet, inputNames, outputNames); - } - - private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) { - for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { - NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex])); - System.out.println(dp + " => " + nnLearner.compute(dp)); - } - } -} \ No newline at end of file diff --git a/test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java b/test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java new file mode 100644 index 0000000..952040a --- /dev/null +++ b/test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java @@ -0,0 +1,73 @@ +package net.woodyfolsom.msproj.tictactoe; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +import net.woodyfolsom.msproj.tictactoe.Game.PLAYER; + +public class GameRecordTest { + + @Test + public void testGetResultXwins() { + GameRecord gameRecord = new GameRecord(); + gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0)); + gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0)); + gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1)); + gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1)); + gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2)); + State finalState = gameRecord.getState(); + System.out.println("Final state:"); + System.out.println(finalState); + assertTrue(finalState.isValid()); + assertTrue(finalState.isTerminal()); + assertTrue(finalState.isWinner(PLAYER.X)); + assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult()); + } + + @Test + public void testGetResultOwins() { + GameRecord gameRecord = new GameRecord(); + gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0)); + gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2)); + gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1)); + gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1)); + gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0)); + gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0)); + + State finalState = gameRecord.getState(); + System.out.println("Final state:"); + System.out.println(finalState); + assertTrue(finalState.isValid()); + assertTrue(finalState.isTerminal()); + assertTrue(finalState.isWinner(PLAYER.O)); + assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult()); + } + + @Test + public void testGetResultTieGame() { + GameRecord gameRecord = new GameRecord(); + gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0)); + gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2)); + gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1)); + + gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0)); + gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2)); + gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1)); + + gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0)); + gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2)); + gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1)); + + State finalState = gameRecord.getState(); + System.out.println("Final state:"); + System.out.println(finalState); + assertTrue(finalState.isValid()); + assertTrue(finalState.isTerminal()); + assertFalse(finalState.isWinner(PLAYER.X)); + assertFalse(finalState.isWinner(PLAYER.O)); + assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult()); + } +} diff --git a/test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java b/test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java new file mode 100644 index 0000000..f2d0e87 --- /dev/null +++ b/test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java @@ -0,0 +1,12 @@ +package net.woodyfolsom.msproj.tictactoe; + +import org.junit.Test; + +public class RefereeTest { + + @Test + public void testPlay100Games() { + new Referee().play(100); + } + +} diff --git a/ttt.net b/ttt.net new file mode 100644 index 0000000..7f9550d --- /dev/null +++ b/ttt.net @@ -0,0 +1,129 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + + + 10 + 11 + 12 + 13 + 14 + + + 15 + +