Lots of neural network stuff.
This commit is contained in:
@@ -7,6 +7,5 @@
|
||||
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
|
||||
<classpathentry kind="lib" path="lib/kgsGtp.jar"/>
|
||||
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
|
||||
<classpathentry kind="lib" path="lib/encog-java-core.jar" sourcepath="lib/encog-java-core-sources.jar"/>
|
||||
<classpathentry kind="output" path="bin"/>
|
||||
</classpath>
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,21 +1,26 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
import org.encog.neural.networks.PersistBasicNetwork;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||
protected BasicNetwork neuralNetwork;
|
||||
protected int actualTrainingEpochs = 0;
|
||||
protected int maxTrainingEpochs = 1000;
|
||||
private final FeedforwardNetwork neuralNetwork;
|
||||
private final TrainingMethod trainingMethod;
|
||||
|
||||
private double maxError;
|
||||
private int actualTrainingEpochs = 0;
|
||||
private int maxTrainingEpochs;
|
||||
|
||||
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
|
||||
this.neuralNetwork = neuralNetwork;
|
||||
this.trainingMethod = trainingMethod;
|
||||
this.maxError = maxError;
|
||||
this.maxTrainingEpochs = maxTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLData compute(MLData input) {
|
||||
public NNData compute(NNDataPair input) {
|
||||
return this.neuralNetwork.compute(input);
|
||||
}
|
||||
|
||||
@@ -23,38 +28,83 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||
return actualTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
public int getMaxTrainingEpochs() {
|
||||
return maxTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BasicNetwork getNeuralNetwork() {
|
||||
protected FeedforwardNetwork getNeuralNetwork() {
|
||||
return neuralNetwork;
|
||||
}
|
||||
|
||||
public void load(String filename) throws IOException {
|
||||
FileInputStream fis = new FileInputStream(new File(filename));
|
||||
neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis);
|
||||
fis.close();
|
||||
@Override
|
||||
public void learnPatterns(List<NNDataPair> trainingSet) {
|
||||
actualTrainingEpochs = 0;
|
||||
double error;
|
||||
neuralNetwork.initWeights();
|
||||
|
||||
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
|
||||
|
||||
if (error <= maxError) {
|
||||
System.out.println("Initial error: " + error);
|
||||
return;
|
||||
}
|
||||
|
||||
do {
|
||||
trainingMethod.iteratePatterns(neuralNetwork,trainingSet);
|
||||
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ error);
|
||||
actualTrainingEpochs++;
|
||||
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
neuralNetwork.reset();
|
||||
public void learnSequences(List<List<NNDataPair>> trainingSet) {
|
||||
actualTrainingEpochs = 0;
|
||||
double error;
|
||||
neuralNetwork.initWeights();
|
||||
|
||||
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||
|
||||
if (error <= maxError) {
|
||||
System.out.println("Initial error: " + error);
|
||||
return;
|
||||
}
|
||||
|
||||
do {
|
||||
trainingMethod.iterateSequences(neuralNetwork,trainingSet);
|
||||
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||
if (Double.isNaN(error)) {
|
||||
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||
}
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ error);
|
||||
actualTrainingEpochs++;
|
||||
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset(int seed) {
|
||||
neuralNetwork.reset(seed);
|
||||
public boolean load(InputStream input) {
|
||||
return neuralNetwork.load(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean save(OutputStream output) {
|
||||
return neuralNetwork.save(output);
|
||||
}
|
||||
|
||||
public void save(String filename) throws IOException {
|
||||
FileOutputStream fos = new FileOutputStream(new File(filename));
|
||||
new PersistBasicNetwork().save(fos, getNeuralNetwork());
|
||||
fos.close();
|
||||
public void setMaxError(double maxError) {
|
||||
this.maxError = maxError;
|
||||
}
|
||||
|
||||
public void setMaxTrainingEpochs(int max) {
|
||||
this.maxTrainingEpochs = max;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ErrorFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.MSSE;
|
||||
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.MSSE;
|
||||
|
||||
public class BackPropagation implements TrainingMethod {
|
||||
public class BackPropagation extends TrainingMethod {
|
||||
private final ErrorFunction errorFunction;
|
||||
private final double learningRate;
|
||||
private final double momentum;
|
||||
@@ -17,15 +17,13 @@ public class BackPropagation implements TrainingMethod {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterate(FeedforwardNetwork neuralNetwork,
|
||||
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
System.out.println("Learningrate: " + learningRate);
|
||||
System.out.println("Momentum: " + momentum);
|
||||
|
||||
//zeroErrors(neuralNetwork);
|
||||
|
||||
for (NNDataPair trainingPair : trainingSet) {
|
||||
zeroErrors(neuralNetwork);
|
||||
zeroGradients(neuralNetwork);
|
||||
|
||||
System.out.println("Training with: " + trainingPair.getInput());
|
||||
|
||||
@@ -35,16 +33,15 @@ public class BackPropagation implements TrainingMethod {
|
||||
System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||
System.out.println("Actual Output: " + actual);
|
||||
|
||||
updateErrors(neuralNetwork, ideal);
|
||||
//backpropagate the gradients w.r.t. output error
|
||||
backPropagate(neuralNetwork, ideal);
|
||||
|
||||
updateWeights(neuralNetwork);
|
||||
}
|
||||
|
||||
//updateWeights(neuralNetwork);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeError(FeedforwardNetwork neuralNetwork,
|
||||
public double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
int numDataPairs = trainingSet.size();
|
||||
int outputSize = neuralNetwork.getOutput().length;
|
||||
@@ -67,15 +64,17 @@ public class BackPropagation implements TrainingMethod {
|
||||
return MSSE;
|
||||
}
|
||||
|
||||
private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||
@Override
|
||||
protected
|
||||
void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||
double[] idealValues = ideal.getValues();
|
||||
|
||||
for (int i = 0; i < idealValues.length; i++) {
|
||||
double output = outputNeurons[i].getOutput();
|
||||
double input = outputNeurons[i].getInput();
|
||||
double derivative = outputNeurons[i].getActivationFunction()
|
||||
.derivative(output);
|
||||
outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output));
|
||||
.derivative(input);
|
||||
outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput()));
|
||||
}
|
||||
// walking down the list of Neurons in reverse order, propagate the
|
||||
// error
|
||||
@@ -84,19 +83,19 @@ public class BackPropagation implements TrainingMethod {
|
||||
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||
|
||||
Neuron neuron = neurons[n];
|
||||
double error = neuron.getError();
|
||||
double error = neuron.getGradient();
|
||||
|
||||
Connection[] connectionsFromN = neuralNetwork
|
||||
.getConnectionsFrom(neuron.getId());
|
||||
if (connectionsFromN.length > 0) {
|
||||
|
||||
double derivative = neuron.getActivationFunction().derivative(
|
||||
neuron.getOutput());
|
||||
neuron.getInput());
|
||||
for (Connection connection : connectionsFromN) {
|
||||
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError();
|
||||
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient();
|
||||
}
|
||||
}
|
||||
neuron.setError(error);
|
||||
neuron.setGradient(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,17 +103,30 @@ public class BackPropagation implements TrainingMethod {
|
||||
for (Connection connection : neuralNetwork.getConnections()) {
|
||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError();
|
||||
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient();
|
||||
//TODO allow for momentum
|
||||
//double lastDelta = connection.getLastDelta();
|
||||
connection.addDelta(delta);
|
||||
}
|
||||
}
|
||||
|
||||
private void zeroErrors(FeedforwardNetwork neuralNetwork) {
|
||||
// Set output errors relative to ideals, all other errors to 0.
|
||||
for (Neuron neuron : neuralNetwork.getNeurons()) {
|
||||
neuron.setError(0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||
NNDataPair statePair, NNData nextReward) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
@@ -8,6 +8,7 @@ public class Connection {
|
||||
private int dest;
|
||||
private double weight;
|
||||
private transient double lastDelta = 0.0;
|
||||
private transient double trace = 0.0;
|
||||
|
||||
public Connection() {
|
||||
//no-arg constructor for JAXB
|
||||
@@ -20,6 +21,7 @@ public class Connection {
|
||||
}
|
||||
|
||||
public void addDelta(double delta) {
|
||||
this.trace = delta;
|
||||
this.weight += delta;
|
||||
this.lastDelta = delta;
|
||||
}
|
||||
@@ -39,6 +41,10 @@ public class Connection {
|
||||
return src;
|
||||
}
|
||||
|
||||
public double getTrace() {
|
||||
return trace;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public double getWeight() {
|
||||
return weight;
|
||||
@@ -52,6 +58,11 @@ public class Connection {
|
||||
this.src = src;
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public void setTrace(double trace) {
|
||||
this.trace = trace;
|
||||
}
|
||||
|
||||
public void setWeight(double weight) {
|
||||
this.weight = weight;
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import org.encog.ml.data.basic.BasicMLData;
|
||||
|
||||
public class DoublePair extends BasicMLData {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public DoublePair(double x, double y) {
|
||||
super(new double[] { x, y });
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
@@ -9,10 +9,11 @@ import java.util.Map;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.Linear;
|
||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Linear;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
|
||||
public abstract class FeedforwardNetwork {
|
||||
private ActivationFunction activationFunction;
|
||||
@@ -83,12 +84,12 @@ public abstract class FeedforwardNetwork {
|
||||
* Adds a new neuron with a unique id to this FeedforwardNetwork.
|
||||
* @return
|
||||
*/
|
||||
Neuron createNeuron(boolean input) {
|
||||
Neuron createNeuron(boolean input, ActivationFunction afunc) {
|
||||
Neuron neuron;
|
||||
if (input) {
|
||||
neuron = new Neuron(Linear.function, neurons.size());
|
||||
} else {
|
||||
neuron = new Neuron(activationFunction, neurons.size());
|
||||
neuron = new Neuron(afunc, neurons.size());
|
||||
}
|
||||
neurons.add(neuron);
|
||||
return neuron;
|
||||
@@ -153,6 +154,10 @@ public abstract class FeedforwardNetwork {
|
||||
return neurons.get(id);
|
||||
}
|
||||
|
||||
public Connection getConnection(int index) {
|
||||
return connections.get(index);
|
||||
}
|
||||
|
||||
@XmlElement
|
||||
protected Connection[] getConnections() {
|
||||
return connections.toArray(new Connection[connections.size()]);
|
||||
@@ -178,6 +183,22 @@ public abstract class FeedforwardNetwork {
|
||||
}
|
||||
}
|
||||
|
||||
public double[] getGradients() {
|
||||
double[] gradients = new double[neurons.size()];
|
||||
for (int n = 0; n < gradients.length; n++) {
|
||||
gradients[n] = neurons.get(n).getGradient();
|
||||
}
|
||||
return gradients;
|
||||
}
|
||||
|
||||
public double[] getWeights() {
|
||||
double[] weights = new double[connections.size()];
|
||||
for (int i = 0; i < connections.size(); i++) {
|
||||
weights[i] = connections.get(i).getWeight();
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public boolean isBiased() {
|
||||
return biased;
|
||||
@@ -226,7 +247,7 @@ public abstract class FeedforwardNetwork {
|
||||
this.biased = biased;
|
||||
|
||||
if (biased) {
|
||||
Neuron biasNeuron = createNeuron(true);
|
||||
Neuron biasNeuron = createNeuron(true, activationFunction);
|
||||
biasNeuron.setInput(1.0);
|
||||
biasNeuronId = biasNeuron.getId();
|
||||
} else {
|
||||
@@ -270,6 +291,7 @@ public abstract class FeedforwardNetwork {
|
||||
}
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public void setWeights(double[] weights) {
|
||||
if (weights.length != connections.size()) {
|
||||
throw new IllegalArgumentException("# of weights must == # of connections");
|
||||
@@ -1,117 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import net.woodyfolsom.msproj.GameResult;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.basic.BasicMLData;
|
||||
import org.encog.ml.data.basic.BasicMLDataPair;
|
||||
import org.encog.util.kmeans.Centroid;
|
||||
|
||||
public class GameStateMLDataPair implements MLDataPair {
|
||||
private BasicMLDataPair mlDataPairDelegate;
|
||||
private GameState gameState;
|
||||
|
||||
public GameStateMLDataPair(GameState gameState) {
|
||||
this.gameState = gameState;
|
||||
mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal()));
|
||||
}
|
||||
|
||||
public GameStateMLDataPair(GameStateMLDataPair that) {
|
||||
this.gameState = new GameState(that.gameState);
|
||||
mlDataPairDelegate = new BasicMLDataPair(
|
||||
that.mlDataPairDelegate.getInput(),
|
||||
that.mlDataPairDelegate.getIdeal());
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLDataPair clone() {
|
||||
return new GameStateMLDataPair(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Centroid<MLDataPair> createCentroid() {
|
||||
return mlDataPairDelegate.createCentroid();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a vector of normalized scores from GameState.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private double[] createInput() {
|
||||
|
||||
GameResult result = gameState.getResult();
|
||||
|
||||
double maxScore = gameState.getGameConfig().getSize()
|
||||
* gameState.getGameConfig().getSize();
|
||||
|
||||
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
|
||||
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
|
||||
|
||||
return new double[] { blackScore, whiteScore };
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a vector of values indicating strength of black/white win output
|
||||
* from network.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private double[] createIdeal() {
|
||||
GameResult result = gameState.getResult();
|
||||
|
||||
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
|
||||
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
|
||||
|
||||
return new double[] { blackWinner, whiteWinner };
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLData getIdeal() {
|
||||
return mlDataPairDelegate.getIdeal();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] getIdealArray() {
|
||||
return mlDataPairDelegate.getIdealArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLData getInput() {
|
||||
return mlDataPairDelegate.getInput();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] getInputArray() {
|
||||
return mlDataPairDelegate.getInputArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getSignificance() {
|
||||
return mlDataPairDelegate.getSignificance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupervised() {
|
||||
return mlDataPairDelegate.isSupervised();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIdealArray(double[] arg0) {
|
||||
mlDataPairDelegate.setIdealArray(arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInputArray(double[] arg0) {
|
||||
mlDataPairDelegate.setInputArray(arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSignificance(double arg0) {
|
||||
mlDataPairDelegate.setSignificance(arg0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
@@ -10,6 +10,10 @@ import javax.xml.bind.Unmarshaller;
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
import javax.xml.bind.annotation.XmlRootElement;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
@XmlRootElement
|
||||
public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||
private boolean biased;
|
||||
@@ -37,7 +41,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||
throw new IllegalArgumentException("Layer size must be >= 1");
|
||||
}
|
||||
|
||||
Layer newLayer = createNewLayer(layerIndex, layerSize);
|
||||
|
||||
Layer newLayer;
|
||||
if (layerIndex == numLayers - 1) {
|
||||
newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function);
|
||||
} else {
|
||||
newLayer = createNewLayer(layerIndex, layerSize, Tanh.function);
|
||||
}
|
||||
|
||||
if (layerIndex > 0) {
|
||||
Layer prevLayer = layers[layerIndex - 1];
|
||||
@@ -54,11 +64,11 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||
}
|
||||
}
|
||||
|
||||
private Layer createNewLayer(int layerIndex, int layerSize) {
|
||||
private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) {
|
||||
Layer layer = new Layer(layerSize);
|
||||
layers[layerIndex] = layer;
|
||||
for (int n = 0; n < layerSize; n++) {
|
||||
Neuron neuron = createNeuron(layerIndex == 0);
|
||||
Neuron neuron = createNeuron(layerIndex == 0, afunc);
|
||||
layer.setNeuronId(n, neuron.getId());
|
||||
}
|
||||
return layer;
|
||||
@@ -93,8 +103,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||
protected void setInput(double[] input) {
|
||||
Layer inputLayer = layers[0];
|
||||
for (int n = 0; n < inputLayer.size(); n++) {
|
||||
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
|
||||
try {
|
||||
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
|
||||
} catch (NullPointerException npe) {
|
||||
npe.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public void setLayers(Layer[] layers) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class NNData {
|
||||
private final double[] values;
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class NNDataPair {
|
||||
private final NNData input;
|
||||
@@ -1,30 +1,29 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
|
||||
public interface NeuralNetFilter {
|
||||
BasicNetwork getNeuralNetwork();
|
||||
|
||||
int getActualTrainingEpochs();
|
||||
|
||||
int getInputSize();
|
||||
|
||||
int getMaxTrainingEpochs();
|
||||
|
||||
int getOutputSize();
|
||||
|
||||
void learn(MLDataSet trainingSet);
|
||||
void learn(Set<List<MLDataPair>> trainingSet);
|
||||
|
||||
void load(String fileName) throws IOException;
|
||||
void reset();
|
||||
void reset(int seed);
|
||||
void save(String fileName) throws IOException;
|
||||
|
||||
boolean load(InputStream input);
|
||||
|
||||
boolean save(OutputStream output);
|
||||
|
||||
void setMaxTrainingEpochs(int max);
|
||||
|
||||
MLData compute(MLData input);
|
||||
NNData compute(NNDataPair input);
|
||||
|
||||
//Due to Java type erasure, overloading a method
|
||||
//simply named 'learn' which takes Lists would be problematic
|
||||
|
||||
void learnPatterns(List<NNDataPair> trainingSet);
|
||||
void learnSequences(List<List<NNDataPair>> trainingSet);
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
|
||||
public class Neuron {
|
||||
private ActivationFunction activationFunction;
|
||||
private int id;
|
||||
private transient double input = 0.0;
|
||||
private transient double error = 0.0;
|
||||
private transient double gradient = 0.0;
|
||||
|
||||
public Neuron() {
|
||||
//no-arg constructor for JAXB
|
||||
@@ -37,8 +37,8 @@ public class Neuron {
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public double getError() {
|
||||
return error;
|
||||
public double getGradient() {
|
||||
return gradient;
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
@@ -50,8 +50,8 @@ public class Neuron {
|
||||
return activationFunction.calculate(input);
|
||||
}
|
||||
|
||||
public void setError(double value) {
|
||||
this.error = value;
|
||||
public void setGradient(double value) {
|
||||
this.gradient = value;
|
||||
}
|
||||
|
||||
public void setInput(double input) {
|
||||
@@ -92,7 +92,7 @@ public class Neuron {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Neuron #" + id +", input: " + input + ", error: " + error;
|
||||
return "Neuron #" + id +", input: " + input + ", gradient: " + gradient;
|
||||
}
|
||||
|
||||
}
|
||||
5
src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java
Normal file
5
src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java
Normal file
@@ -0,0 +1,5 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class ObjectiveFunction {
|
||||
|
||||
}
|
||||
34
src/net/woodyfolsom/msproj/ann/TTTFilter.java
Normal file
34
src/net/woodyfolsom/msproj/ann/TTTFilter.java
Normal file
@@ -0,0 +1,34 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
/**
|
||||
* Based on sample code from http://neuroph.sourceforge.net
|
||||
*
|
||||
* @author Woody
|
||||
*
|
||||
*/
|
||||
public class TTTFilter extends AbstractNeuralNetFilter implements
|
||||
NeuralNetFilter {
|
||||
|
||||
private static final int INPUT_SIZE = 9;
|
||||
private static final int OUTPUT_SIZE = 1;
|
||||
|
||||
public TTTFilter() {
|
||||
this(0.5,0.0, 1000);
|
||||
}
|
||||
|
||||
public TTTFilter(double alpha, double lambda, int maxEpochs) {
|
||||
super( new MultiLayerPerceptron(true, INPUT_SIZE, 5, OUTPUT_SIZE),
|
||||
new TemporalDifference(0.5,0.0), maxEpochs, 0.05);
|
||||
super.getNeuralNetwork().setName("XORFilter");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return INPUT_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputSize() {
|
||||
return OUTPUT_SIZE;
|
||||
}
|
||||
}
|
||||
187
src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
Normal file
187
src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
Normal file
@@ -0,0 +1,187 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Action;
|
||||
import net.woodyfolsom.msproj.tictactoe.GameRecord;
|
||||
import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT;
|
||||
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||
import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy;
|
||||
import net.woodyfolsom.msproj.tictactoe.Policy;
|
||||
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
|
||||
import net.woodyfolsom.msproj.tictactoe.State;
|
||||
|
||||
public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter
|
||||
|
||||
public static void main(String[] args) throws FileNotFoundException {
|
||||
double alpha = 0.0;
|
||||
double lambda = 0.9;
|
||||
int maxGames = 15000;
|
||||
|
||||
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||
}
|
||||
|
||||
public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException {
|
||||
///
|
||||
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1);
|
||||
neuralNetwork.setName("TicTacToe");
|
||||
neuralNetwork.initWeights();
|
||||
TrainingMethod trainer = new TemporalDifference(0.5,0.5);
|
||||
|
||||
System.out.println("Playing untrained games.");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
|
||||
}
|
||||
|
||||
System.out.println("Learning from " + maxGames + " games of random self-play");
|
||||
|
||||
int gamesPlayed = 0;
|
||||
List<RESULT> results = new ArrayList<RESULT>();
|
||||
do {
|
||||
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer);
|
||||
System.out.println("Winner: " + gameRecord.getResult());
|
||||
gamesPlayed++;
|
||||
results.add(gameRecord.getResult());
|
||||
} while (gamesPlayed < maxGames);
|
||||
///
|
||||
|
||||
System.out.println("Learned network after " + maxGames + " training games.");
|
||||
|
||||
double[][] validationSet = new double[8][];
|
||||
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
if (i % 10 == 0) {
|
||||
System.out.println("" + (i+1) + ". " + results.get(i));
|
||||
}
|
||||
}
|
||||
// empty board
|
||||
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// center
|
||||
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// top edge
|
||||
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// left edge
|
||||
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// corner
|
||||
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// win
|
||||
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
|
||||
-1.0, 0.0 };
|
||||
// loss
|
||||
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
|
||||
0.0, -1.0 };
|
||||
|
||||
// about to win
|
||||
validationSet[7] = new double[] {
|
||||
-1.0, 1.0, 1.0,
|
||||
1.0, -1.0, 1.0,
|
||||
-1.0, -1.0, 0.0 };
|
||||
|
||||
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
||||
"12", "20", "21", "22" };
|
||||
String[] outputNames = new String[] { "values" };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||
|
||||
System.out.println("Playing optimal games.");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
|
||||
}
|
||||
|
||||
/*
|
||||
File output = new File("ttt.net");
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(output);
|
||||
|
||||
neuralNetwork.save(fos);*/
|
||||
}
|
||||
|
||||
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||
|
||||
State state = gameRecord.getState();
|
||||
|
||||
do {
|
||||
Action action;
|
||||
State nextState;
|
||||
|
||||
action = neuralNetPolicy.getAction(gameRecord.getState());
|
||||
|
||||
nextState = gameRecord.apply(action);
|
||||
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
|
||||
//System.out.println("Next board state: " + nextState);
|
||||
state = nextState;
|
||||
} while (!state.isTerminal());
|
||||
|
||||
//finally, reinforce the actual reward
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
Policy randomPolicy = new RandomPolicy();
|
||||
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||
|
||||
//System.out.println("Playing epsilon-greedy game.");
|
||||
|
||||
State state = gameRecord.getState();
|
||||
NNDataPair statePair;
|
||||
|
||||
Policy selectedPolicy;
|
||||
trainer.zeroTraces(neuralNetwork);
|
||||
|
||||
do {
|
||||
Action action;
|
||||
State nextState;
|
||||
|
||||
if (Math.random() < epsilon) {
|
||||
selectedPolicy = randomPolicy;
|
||||
action = selectedPolicy.getAction(gameRecord.getState());
|
||||
nextState = gameRecord.apply(action);
|
||||
} else {
|
||||
selectedPolicy = neuralNetPolicy;
|
||||
action = selectedPolicy.getAction(gameRecord.getState());
|
||||
|
||||
nextState = gameRecord.apply(action);
|
||||
statePair = NNDataSetFactory.createDataPair(state);
|
||||
NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState);
|
||||
trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal());
|
||||
}
|
||||
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
|
||||
|
||||
//System.out.println("Next board state: " + nextState);
|
||||
|
||||
state = nextState;
|
||||
} while (!state.isTerminal());
|
||||
|
||||
//finally, reinforce the actual reward
|
||||
statePair = NNDataSetFactory.createDataPair(state);
|
||||
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private void testNetwork(FeedforwardNetwork neuralNetwork,
|
||||
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||
validationSet[valIndex]), new NNData(outputNames,
|
||||
validationSet[valIndex]));
|
||||
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,30 +1,133 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.neural.networks.ContainsFlat;
|
||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
import java.util.List;
|
||||
|
||||
public class TemporalDifference extends Backpropagation {
|
||||
public class TemporalDifference extends TrainingMethod {
|
||||
private final double alpha;
|
||||
private final double gamma = 1.0;
|
||||
private final double lambda;
|
||||
|
||||
public TemporalDifference(ContainsFlat network, MLDataSet training,
|
||||
double theLearnRate, double theMomentum, double lambda) {
|
||||
super(network, training, theLearnRate, theMomentum);
|
||||
|
||||
public TemporalDifference(double alpha, double lambda) {
|
||||
this.alpha = alpha;
|
||||
this.lambda = lambda;
|
||||
}
|
||||
|
||||
public double getLamdba() {
|
||||
return lambda;
|
||||
@Override
|
||||
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double updateWeight(final double[] gradients,
|
||||
final double[] lastGradient, final int index) {
|
||||
double alpha = this.getLearningRate();
|
||||
|
||||
//TODO fill in weight update for TD(lambda)
|
||||
|
||||
return 0.0;
|
||||
public double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
int numDataPairs = trainingSet.size();
|
||||
int outputSize = neuralNetwork.getOutput().length;
|
||||
int totalOutputSize = outputSize * numDataPairs;
|
||||
|
||||
double[] actuals = new double[totalOutputSize];
|
||||
double[] ideals = new double[totalOutputSize];
|
||||
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
|
||||
NNDataPair nnDataPair = trainingSet.get(dataPair);
|
||||
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
|
||||
.getValues());
|
||||
double[] ideal = nnDataPair.getIdeal().getValues();
|
||||
int offset = dataPair * outputSize;
|
||||
|
||||
System.arraycopy(actual, 0, actuals, offset, outputSize);
|
||||
System.arraycopy(ideal, 0, ideals, offset, outputSize);
|
||||
}
|
||||
|
||||
double MSSE = errorFunction.compute(ideals, actuals);
|
||||
return MSSE;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||
double[] idealValues = ideal.getValues();
|
||||
|
||||
for (int i = 0; i < idealValues.length; i++) {
|
||||
double input = outputNeurons[i].getInput();
|
||||
double derivative = outputNeurons[i].getActivationFunction()
|
||||
.derivative(input);
|
||||
outputNeurons[i].setGradient(outputNeurons[i].getGradient()
|
||||
+ derivative
|
||||
* (idealValues[i] - outputNeurons[i].getOutput()));
|
||||
}
|
||||
// walking down the list of Neurons in reverse order, propagate the
|
||||
// error
|
||||
Neuron[] neurons = neuralNetwork.getNeurons();
|
||||
|
||||
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||
|
||||
Neuron neuron = neurons[n];
|
||||
double error = neuron.getGradient();
|
||||
|
||||
Connection[] connectionsFromN = neuralNetwork
|
||||
.getConnectionsFrom(neuron.getId());
|
||||
if (connectionsFromN.length > 0) {
|
||||
|
||||
double derivative = neuron.getActivationFunction().derivative(
|
||||
neuron.getInput());
|
||||
for (Connection connection : connectionsFromN) {
|
||||
error += derivative
|
||||
* connection.getWeight()
|
||||
* neuralNetwork.getNeuron(connection.getDest())
|
||||
.getGradient();
|
||||
}
|
||||
}
|
||||
neuron.setGradient(error);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
|
||||
for (Connection connection : neuralNetwork.getConnections()) {
|
||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||
|
||||
double delta = alpha * srcNeuron.getOutput()
|
||||
* destNeuron.getGradient() * predictionError + connection.getTrace() * lambda;
|
||||
|
||||
// TODO allow for momentum
|
||||
// double lastDelta = connection.getLastDelta();
|
||||
connection.addDelta(delta);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||
NNDataPair statePair, NNData nextReward) {
|
||||
//System.out.println("Learningrate: " + alpha);
|
||||
|
||||
zeroGradients(neuralNetwork);
|
||||
|
||||
//System.out.println("Training with: " + statePair.getInput());
|
||||
|
||||
NNData ideal = nextReward;
|
||||
NNData actual = neuralNetwork.compute(statePair);
|
||||
|
||||
//System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||
//System.out.println("Actual Output: " + actual);
|
||||
|
||||
// backpropagate the gradients w.r.t. output error
|
||||
backPropagate(neuralNetwork, ideal);
|
||||
|
||||
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
|
||||
+ actual.getValues()[0] - nextReward.getValues()[0];
|
||||
|
||||
updateWeights(neuralNetwork, predictionError);
|
||||
}
|
||||
}
|
||||
43
src/net/woodyfolsom/msproj/ann/TrainingMethod.java
Normal file
43
src/net/woodyfolsom/msproj/ann/TrainingMethod.java
Normal file
@@ -0,0 +1,43 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.MSSE;
|
||||
|
||||
public abstract class TrainingMethod {
|
||||
protected final ErrorFunction errorFunction;
|
||||
|
||||
public TrainingMethod() {
|
||||
this.errorFunction = MSSE.function;
|
||||
}
|
||||
|
||||
protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||
NNDataPair statePair, NNData nextReward);
|
||||
|
||||
protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet);
|
||||
|
||||
protected abstract double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet);
|
||||
|
||||
protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet);
|
||||
|
||||
protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output);
|
||||
|
||||
protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet);
|
||||
|
||||
protected void zeroGradients(FeedforwardNetwork neuralNetwork) {
|
||||
for (Neuron neuron : neuralNetwork.getNeurons()) {
|
||||
neuron.setGradient(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
protected void zeroTraces(FeedforwardNetwork neuralNetwork) {
|
||||
for (Connection conn : neuralNetwork.getConnections()) {
|
||||
conn.setTrace(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||
import org.encog.ml.train.MLTrain;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
import org.encog.neural.networks.layers.BasicLayer;
|
||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
|
||||
public class WinFilter extends AbstractNeuralNetFilter implements
|
||||
NeuralNetFilter {
|
||||
|
||||
public WinFilter() {
|
||||
// create a neural network, without using a factory
|
||||
BasicNetwork network = new BasicNetwork();
|
||||
network.addLayer(new BasicLayer(null, false, 2));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
|
||||
network.getStructure().finalizeStructure();
|
||||
network.reset();
|
||||
|
||||
this.neuralNetwork = network;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(MLDataSet trainingData) {
|
||||
throw new UnsupportedOperationException(
|
||||
"This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
|
||||
}
|
||||
|
||||
/**
|
||||
* Method is necessary because with temporal difference learning, some of
|
||||
* the MLDataPairs are related by being a sequence of moves within a
|
||||
* particular game.
|
||||
*/
|
||||
@Override
|
||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||
MLDataSet mlDataset = new BasicMLDataSet();
|
||||
|
||||
for (List<MLDataPair> gameRecord : trainingSet) {
|
||||
for (int t = 0; t < gameRecord.size() - 1; t++) {
|
||||
mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t)
|
||||
.getInput()));
|
||||
}
|
||||
mlDataset.add(gameRecord.get(gameRecord.size() - 1));
|
||||
}
|
||||
|
||||
// train the neural network
|
||||
final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25);
|
||||
//final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8);
|
||||
actualTrainingEpochs = 0;
|
||||
|
||||
do {
|
||||
if (actualTrainingEpochs > 0) {
|
||||
int gameStateIndex = 0;
|
||||
for (List<MLDataPair> gameRecord : trainingSet) {
|
||||
for (int t = 0; t < gameRecord.size() - 1; t++) {
|
||||
MLDataPair oldDataPair = mlDataset.get(gameStateIndex);
|
||||
this.neuralNetwork.compute(oldDataPair.getInput());
|
||||
gameStateIndex++;
|
||||
}
|
||||
gameStateIndex++;
|
||||
}
|
||||
}
|
||||
train.iteration();
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ train.getError());
|
||||
actualTrainingEpochs++;
|
||||
} while (train.getError() > 0.01
|
||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
neuralNetwork.reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset(int seed) {
|
||||
neuralNetwork.reset(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BasicNetwork getNeuralNetwork() {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
// TODO Auto-generated method stub
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputSize() {
|
||||
// TODO Auto-generated method stub
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,5 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLData;
|
||||
import org.encog.ml.train.MLTrain;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
import org.encog.neural.networks.layers.BasicLayer;
|
||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
|
||||
/**
|
||||
* Based on sample code from http://neuroph.sourceforge.net
|
||||
*
|
||||
@@ -22,54 +9,30 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||
NeuralNetFilter {
|
||||
|
||||
private static final int INPUT_SIZE = 2;
|
||||
private static final int OUTPUT_SIZE = 1;
|
||||
|
||||
public XORFilter() {
|
||||
// create a neural network, without using a factory
|
||||
BasicNetwork network = new BasicNetwork();
|
||||
network.addLayer(new BasicLayer(null, false, 2));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
|
||||
network.getStructure().finalizeStructure();
|
||||
network.reset();
|
||||
|
||||
this.neuralNetwork = network;
|
||||
this(0.8,0.7);
|
||||
}
|
||||
|
||||
public XORFilter(double learningRate, double momentum) {
|
||||
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
|
||||
new BackPropagation(learningRate, momentum), 1000, 0.01);
|
||||
super.getNeuralNetwork().setName("XORFilter");
|
||||
}
|
||||
|
||||
public double compute(double x, double y) {
|
||||
return compute(new BasicMLData(new double[]{x,y})).getData(0);
|
||||
return getNeuralNetwork().compute(new double[]{x,y})[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return 2;
|
||||
return INPUT_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputSize() {
|
||||
// TODO Auto-generated method stub
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(MLDataSet trainingSet) {
|
||||
|
||||
// train the neural network
|
||||
final MLTrain train = new Backpropagation(neuralNetwork, trainingSet,
|
||||
0.7, 0.8);
|
||||
|
||||
actualTrainingEpochs = 0;
|
||||
|
||||
do {
|
||||
train.iteration();
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ train.getError());
|
||||
actualTrainingEpochs++;
|
||||
} while (train.getError() > 0.01
|
||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException(
|
||||
"This Filter learns an MLDataSet, not a Set<List<MLData>>.");
|
||||
return OUTPUT_SIZE;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2.math;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2.math;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public interface ErrorFunction {
|
||||
double compute(double[] ideal, double[] actual);
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2.math;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public class Linear extends ActivationFunction{
|
||||
public static final Linear function = new Linear();
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2.math;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public class MSSE implements ErrorFunction{
|
||||
public static final ErrorFunction function = new MSSE();
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2.math;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public class Sigmoid extends ActivationFunction{
|
||||
public static final Sigmoid function = new Sigmoid();
|
||||
@@ -12,9 +12,9 @@ public class Sigmoid extends ActivationFunction{
|
||||
}
|
||||
|
||||
public double derivative(double arg) {
|
||||
//lol wth?
|
||||
//double eX = Math.exp(arg);
|
||||
//return eX / (Math.pow((1+eX), 2));
|
||||
return arg - Math.pow(arg,2);
|
||||
//lol wth? oh, the next derivative formula is a function of s(x), not x.
|
||||
double eX = Math.exp(arg);
|
||||
return eX / (Math.pow((1+eX), 2));
|
||||
//return arg - Math.pow(arg,2);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2.math;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public class Tanh extends ActivationFunction{
|
||||
public static final Tanh function = new Tanh();
|
||||
@@ -1,84 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||
private final FeedforwardNetwork neuralNetwork;
|
||||
private final TrainingMethod trainingMethod;
|
||||
|
||||
private double maxError;
|
||||
private int actualTrainingEpochs = 0;
|
||||
private int maxTrainingEpochs;
|
||||
|
||||
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
|
||||
this.neuralNetwork = neuralNetwork;
|
||||
this.trainingMethod = trainingMethod;
|
||||
this.maxError = maxError;
|
||||
this.maxTrainingEpochs = maxTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NNData compute(NNDataPair input) {
|
||||
return this.neuralNetwork.compute(input);
|
||||
}
|
||||
|
||||
public int getActualTrainingEpochs() {
|
||||
return actualTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
public int getMaxTrainingEpochs() {
|
||||
return maxTrainingEpochs;
|
||||
}
|
||||
|
||||
protected FeedforwardNetwork getNeuralNetwork() {
|
||||
return neuralNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(List<NNDataPair> trainingSet) {
|
||||
actualTrainingEpochs = 0;
|
||||
double error;
|
||||
neuralNetwork.initWeights();
|
||||
|
||||
error = trainingMethod.computeError(neuralNetwork,trainingSet);
|
||||
|
||||
if (error <= maxError) {
|
||||
System.out.println("Initial error: " + error);
|
||||
return;
|
||||
}
|
||||
|
||||
do {
|
||||
trainingMethod.iterate(neuralNetwork,trainingSet);
|
||||
error = trainingMethod.computeError(neuralNetwork,trainingSet);
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ error);
|
||||
actualTrainingEpochs++;
|
||||
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean load(InputStream input) {
|
||||
return neuralNetwork.load(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean save(OutputStream output) {
|
||||
return neuralNetwork.save(output);
|
||||
}
|
||||
|
||||
public void setMaxError(double maxError) {
|
||||
this.maxError = maxError;
|
||||
}
|
||||
|
||||
public void setMaxTrainingEpochs(int max) {
|
||||
this.maxTrainingEpochs = max;
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.List;
|
||||
|
||||
public interface NeuralNetFilter {
|
||||
int getActualTrainingEpochs();
|
||||
|
||||
int getInputSize();
|
||||
|
||||
int getMaxTrainingEpochs();
|
||||
|
||||
int getOutputSize();
|
||||
|
||||
boolean load(InputStream input);
|
||||
|
||||
boolean save(OutputStream output);
|
||||
|
||||
void setMaxTrainingEpochs(int max);
|
||||
|
||||
NNData compute(NNDataPair input);
|
||||
|
||||
void learn(List<NNDataPair> trainingSet);
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
public class ObjectiveFunction {
|
||||
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TemporalDifference implements TrainingMethod {
|
||||
|
||||
@Override
|
||||
public void iterate(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
throw new UnsupportedOperationException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
throw new UnsupportedOperationException("Not implemented");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface TrainingMethod {
|
||||
|
||||
void iterate(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
|
||||
double computeError(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
|
||||
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
/**
|
||||
* Based on sample code from http://neuroph.sourceforge.net
|
||||
*
|
||||
* @author Woody
|
||||
*
|
||||
*/
|
||||
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||
NeuralNetFilter {
|
||||
|
||||
private static final int INPUT_SIZE = 2;
|
||||
private static final int OUTPUT_SIZE = 1;
|
||||
|
||||
public XORFilter() {
|
||||
this(0.8,0.7);
|
||||
}
|
||||
|
||||
public XORFilter(double learningRate, double momentum) {
|
||||
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
|
||||
new BackPropagation(learningRate, momentum), 1000, 0.01);
|
||||
super.getNeuralNetwork().setName("XORFilter");
|
||||
|
||||
//TODO remove
|
||||
//getNeuralNetwork().setWeights(new double[] {
|
||||
// 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
|
||||
// -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
|
||||
// -0.993423, 0.164732, 0.752621}); //output
|
||||
}
|
||||
|
||||
public double compute(double x, double y) {
|
||||
return getNeuralNetwork().compute(new double[]{x,y})[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return INPUT_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputSize() {
|
||||
return OUTPUT_SIZE;
|
||||
}
|
||||
}
|
||||
54
src/net/woodyfolsom/msproj/tictactoe/Action.java
Normal file
54
src/net/woodyfolsom/msproj/tictactoe/Action.java
Normal file
@@ -0,0 +1,54 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class Action {
|
||||
public static final Action NONE = new Action(PLAYER.NONE, -1, -1);
|
||||
|
||||
private Game.PLAYER player;
|
||||
private int row;
|
||||
private int column;
|
||||
|
||||
public static Action getInstance(PLAYER player, int row, int column) {
|
||||
return new Action(player,row,column);
|
||||
}
|
||||
|
||||
private Action(PLAYER player, int row, int column) {
|
||||
this.player = player;
|
||||
this.row = row;
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
public Game.PLAYER getPlayer() {
|
||||
return player;
|
||||
}
|
||||
|
||||
public int getColumn() {
|
||||
return column;
|
||||
}
|
||||
|
||||
public int getRow() {
|
||||
return row;
|
||||
}
|
||||
|
||||
public boolean isNone() {
|
||||
return this == Action.NONE;
|
||||
}
|
||||
|
||||
public void setPlayer(Game.PLAYER player) {
|
||||
this.player = player;
|
||||
}
|
||||
|
||||
public void setRow(int row) {
|
||||
this.row = row;
|
||||
}
|
||||
|
||||
public void setColumn(int column) {
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return player + "(" + row + ", " + column + ")";
|
||||
}
|
||||
}
|
||||
5
src/net/woodyfolsom/msproj/tictactoe/Game.java
Normal file
5
src/net/woodyfolsom/msproj/tictactoe/Game.java
Normal file
@@ -0,0 +1,5 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
public class Game {
|
||||
public enum PLAYER {X,O,NONE}
|
||||
}
|
||||
63
src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
Normal file
63
src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
Normal file
@@ -0,0 +1,63 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class GameRecord {
|
||||
public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS}
|
||||
|
||||
private List<Action> actions = new ArrayList<Action>();
|
||||
private List<State> states = new ArrayList<State>();
|
||||
|
||||
private RESULT result = RESULT.IN_PROGRESS;
|
||||
|
||||
public GameRecord() {
|
||||
actions.add(Action.NONE);
|
||||
states.add(new State());
|
||||
}
|
||||
|
||||
public void addState(State state) {
|
||||
states.add(state);
|
||||
}
|
||||
|
||||
public State apply(Action action) {
|
||||
State nextState = getState().apply(action);
|
||||
if (nextState.isValid()) {
|
||||
states.add(nextState);
|
||||
actions.add(action);
|
||||
}
|
||||
|
||||
if (nextState.isTerminal()) {
|
||||
if (nextState.isWinner(PLAYER.X)) {
|
||||
result = RESULT.X_WINS;
|
||||
} else if (nextState.isWinner(PLAYER.O)) {
|
||||
result = RESULT.O_WINS;
|
||||
} else {
|
||||
result = RESULT.TIE_GAME;
|
||||
}
|
||||
}
|
||||
return nextState;
|
||||
}
|
||||
|
||||
public int getNumStates() {
|
||||
return states.size();
|
||||
}
|
||||
|
||||
public RESULT getResult() {
|
||||
return result;
|
||||
}
|
||||
|
||||
public void setResult(RESULT result) {
|
||||
this.result = result;
|
||||
}
|
||||
|
||||
public State getState() {
|
||||
return states.get(states.size()-1);
|
||||
}
|
||||
|
||||
public State getState(int index) {
|
||||
return states.get(index);
|
||||
}
|
||||
}
|
||||
20
src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
Normal file
20
src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
Normal file
@@ -0,0 +1,20 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class MoveGenerator {
|
||||
public List<Action> getValidActions(State state) {
|
||||
PLAYER playerToMove = state.getPlayerToMove();
|
||||
List<Action> validActions = new ArrayList<Action>();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
if (state.isEmpty(i,j))
|
||||
validActions.add(Action.getInstance(playerToMove, i, j));
|
||||
}
|
||||
}
|
||||
return validActions;
|
||||
}
|
||||
}
|
||||
81
src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
Normal file
81
src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
Normal file
@@ -0,0 +1,81 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class NNDataSetFactory {
|
||||
public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"};
|
||||
public static final String[] TTT_OUTPUT_FIELDS = {"value"};
|
||||
|
||||
public static List<List<NNDataPair>> createDataSet(List<GameRecord> tttGames) {
|
||||
|
||||
List<List<NNDataPair>> nnDataSet = new ArrayList<List<NNDataPair>>();
|
||||
|
||||
for (GameRecord tttGame : tttGames) {
|
||||
List<NNDataPair> gameData = createDataPairList(tttGame);
|
||||
|
||||
|
||||
nnDataSet.add(gameData);
|
||||
}
|
||||
|
||||
return nnDataSet;
|
||||
}
|
||||
|
||||
public static List<NNDataPair> createDataPairList(GameRecord gameRecord) {
|
||||
List<NNDataPair> gameData = new ArrayList<NNDataPair>();
|
||||
|
||||
for (int i = 0; i < gameRecord.getNumStates(); i++) {
|
||||
gameData.add(createDataPair(gameRecord.getState(i)));
|
||||
}
|
||||
|
||||
return gameData;
|
||||
}
|
||||
|
||||
public static NNDataPair createDataPair(State tttState) {
|
||||
double value;
|
||||
if (tttState.isTerminal()) {
|
||||
if (tttState.isWinner(PLAYER.X)) {
|
||||
value = 1.0; // win for black
|
||||
} else if (tttState.isWinner(PLAYER.O)) {
|
||||
value = 0.0; // loss for black
|
||||
//value = -1.0;
|
||||
} else {
|
||||
value = 0.5;
|
||||
//value = 0.0; //tie
|
||||
}
|
||||
} else {
|
||||
value = 0.0;
|
||||
}
|
||||
|
||||
double[] inputValues = new double[9];
|
||||
char[] boardCopy = tttState.getBoard();
|
||||
inputValues[0] = getTicTacToeInput(boardCopy, 0, 0);
|
||||
inputValues[1] = getTicTacToeInput(boardCopy, 0, 1);
|
||||
inputValues[2] = getTicTacToeInput(boardCopy, 0, 2);
|
||||
inputValues[3] = getTicTacToeInput(boardCopy, 1, 0);
|
||||
inputValues[4] = getTicTacToeInput(boardCopy, 1, 1);
|
||||
inputValues[5] = getTicTacToeInput(boardCopy, 1, 2);
|
||||
inputValues[6] = getTicTacToeInput(boardCopy, 2, 0);
|
||||
inputValues[7] = getTicTacToeInput(boardCopy, 2, 1);
|
||||
inputValues[8] = getTicTacToeInput(boardCopy, 2, 2);
|
||||
|
||||
return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value}));
|
||||
}
|
||||
|
||||
private static double getTicTacToeInput(char[] board, int row, int column) {
|
||||
switch (board[row*3+column]) {
|
||||
case 'X' :
|
||||
return 1.0;
|
||||
case 'O' :
|
||||
return -1.0;
|
||||
case '.' :
|
||||
return 0.0;
|
||||
default:
|
||||
throw new RuntimeException("Invalid board symbol at " + row +", " + column);
|
||||
}
|
||||
}
|
||||
}
|
||||
67
src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
Normal file
67
src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
Normal file
@@ -0,0 +1,67 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class NeuralNetPolicy extends Policy {
|
||||
private FeedforwardNetwork neuralNet;
|
||||
private MoveGenerator moveGenerator = new MoveGenerator();
|
||||
|
||||
public NeuralNetPolicy(FeedforwardNetwork neuralNet) {
|
||||
super("NeuralNet-" + neuralNet.getName());
|
||||
this.neuralNet = neuralNet;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(State state) {
|
||||
List<Action> validMoves = moveGenerator.getValidActions(state);
|
||||
Map<Action, Double> scores = new HashMap<Action, Double>();
|
||||
|
||||
for (Action action : validMoves) {
|
||||
State nextState = state.apply(action);
|
||||
NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
|
||||
//estimated reward for X
|
||||
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
|
||||
}
|
||||
|
||||
PLAYER playerToMove = state.getPlayerToMove();
|
||||
|
||||
if (playerToMove == PLAYER.X) {
|
||||
return returnMaxAction(scores);
|
||||
} else if (playerToMove == PLAYER.O) {
|
||||
return returnMinAction(scores);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove);
|
||||
}
|
||||
//return validMoves.get((int)(Math.random() * validMoves.size()));
|
||||
}
|
||||
|
||||
private Action returnMaxAction(Map<Action,Double> scores) {
|
||||
Action bestAction = null;
|
||||
Double bestScore = Double.NEGATIVE_INFINITY;
|
||||
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
|
||||
if (entry.getValue() > bestScore) {
|
||||
bestScore = entry.getValue();
|
||||
bestAction = entry.getKey();
|
||||
}
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
|
||||
private Action returnMinAction(Map<Action,Double> scores) {
|
||||
Action bestAction = null;
|
||||
Double bestScore = Double.POSITIVE_INFINITY;
|
||||
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
|
||||
if (entry.getValue() < bestScore) {
|
||||
bestScore = entry.getValue();
|
||||
bestAction = entry.getKey();
|
||||
}
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
}
|
||||
15
src/net/woodyfolsom/msproj/tictactoe/Policy.java
Normal file
15
src/net/woodyfolsom/msproj/tictactoe/Policy.java
Normal file
@@ -0,0 +1,15 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
public abstract class Policy {
|
||||
private String name;
|
||||
|
||||
protected Policy(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public abstract Action getAction(State state);
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
18
src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
Normal file
18
src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class RandomPolicy extends Policy {
|
||||
private MoveGenerator moveGenerator = new MoveGenerator();
|
||||
|
||||
public RandomPolicy() {
|
||||
super("Random");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(State state) {
|
||||
List<Action> validMoves = moveGenerator.getValidActions(state);
|
||||
return validMoves.get((int)(Math.random() * validMoves.size()));
|
||||
}
|
||||
|
||||
}
|
||||
43
src/net/woodyfolsom/msproj/tictactoe/Referee.java
Normal file
43
src/net/woodyfolsom/msproj/tictactoe/Referee.java
Normal file
@@ -0,0 +1,43 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class Referee {
|
||||
|
||||
public static void main(String[] args) {
|
||||
new Referee().play(50);
|
||||
}
|
||||
|
||||
public List<GameRecord> play(int nGames) {
|
||||
Policy policy = new RandomPolicy();
|
||||
|
||||
List<GameRecord> tournament = new ArrayList<GameRecord>();
|
||||
|
||||
for (int i = 0; i < nGames; i++) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
System.out.println("Playing game #" +(i+1));
|
||||
|
||||
State state;
|
||||
do {
|
||||
Action action = policy.getAction(gameRecord.getState());
|
||||
System.out.println("Action " + action + " selected by policy " + policy.getName());
|
||||
state = gameRecord.apply(action);
|
||||
System.out.println("Next board state:");
|
||||
System.out.println(gameRecord.getState());
|
||||
} while (!state.isTerminal());
|
||||
System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult());
|
||||
tournament.add(gameRecord);
|
||||
}
|
||||
|
||||
System.out.println("Played " + tournament.size() + " random games.");
|
||||
System.out.println("Results:");
|
||||
for (int i = 0; i < tournament.size(); i++) {
|
||||
GameRecord gameRecord = tournament.get(i);
|
||||
System.out.println((i+1) + ". " + gameRecord.getResult());
|
||||
}
|
||||
|
||||
return tournament;
|
||||
}
|
||||
}
|
||||
116
src/net/woodyfolsom/msproj/tictactoe/State.java
Normal file
116
src/net/woodyfolsom/msproj/tictactoe/State.java
Normal file
@@ -0,0 +1,116 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class State {
|
||||
public static final State INVALID = new State();
|
||||
public static char EMPTY_SQUARE = '.';
|
||||
|
||||
private char[] board;
|
||||
private PLAYER playerToMove;
|
||||
|
||||
public State() {
|
||||
playerToMove = Game.PLAYER.X;
|
||||
board = new char[9];
|
||||
Arrays.fill(board,'.');
|
||||
}
|
||||
|
||||
private State(State that) {
|
||||
this.board = Arrays.copyOf(that.board, that.board.length);
|
||||
this.playerToMove = that.playerToMove;
|
||||
}
|
||||
|
||||
public State apply(Action action) {
|
||||
if (action.getPlayer() != playerToMove) {
|
||||
System.out.println("It is not " + action.getPlayer() +"'s turn.");
|
||||
return State.INVALID;
|
||||
}
|
||||
State nextState = new State(this);
|
||||
|
||||
int row = action.getRow();
|
||||
int column = action.getColumn();
|
||||
int dest = row * 3 + column;
|
||||
|
||||
if (board[dest] != EMPTY_SQUARE) {
|
||||
System.out.println("Invalid move " + action + ", coordinate not empty.");
|
||||
return State.INVALID;
|
||||
}
|
||||
switch (playerToMove) {
|
||||
case X : nextState.board[dest] = 'X';
|
||||
break;
|
||||
case O : nextState.board[dest] = 'O';
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Invalid playerToMove");
|
||||
}
|
||||
|
||||
if (playerToMove == PLAYER.X) {
|
||||
nextState.playerToMove = PLAYER.O;
|
||||
} else {
|
||||
nextState.playerToMove = PLAYER.X;
|
||||
}
|
||||
return nextState;
|
||||
}
|
||||
|
||||
public char[] getBoard() {
|
||||
return Arrays.copyOf(board, board.length);
|
||||
}
|
||||
|
||||
public PLAYER getPlayerToMove() {
|
||||
return playerToMove;
|
||||
}
|
||||
|
||||
public boolean isEmpty(int row, int column) {
|
||||
return board[row*3+column] == EMPTY_SQUARE;
|
||||
}
|
||||
|
||||
public boolean isFull(char mark1, char mark2, char mark3) {
|
||||
return mark1 != '.' && mark2 != '.' && mark3 != '.';
|
||||
}
|
||||
|
||||
public boolean isWinner(PLAYER player) {
|
||||
return isWin(player,board[0],board[1],board[2]) ||
|
||||
isWin(player,board[3],board[4],board[5]) ||
|
||||
isWin(player,board[6],board[7],board[8]) ||
|
||||
isWin(player,board[0],board[3],board[6]) ||
|
||||
isWin(player,board[1],board[4],board[7]) ||
|
||||
isWin(player,board[2],board[5],board[8]) ||
|
||||
isWin(player,board[0],board[4],board[8]) ||
|
||||
isWin(player,board[2],board[4],board[6]);
|
||||
}
|
||||
|
||||
public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) {
|
||||
if (isFull(mark1,mark2,mark3)) {
|
||||
switch (player) {
|
||||
case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X';
|
||||
case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O';
|
||||
default :
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isTerminal() {
|
||||
return isWinner(PLAYER.X) || isWinner(PLAYER.O) ||
|
||||
(isFull(board[0],board[1], board[2]) &&
|
||||
isFull(board[3],board[4], board[5]) &&
|
||||
isFull(board[6],board[7], board[8]));
|
||||
}
|
||||
|
||||
public boolean isValid() {
|
||||
return this != INVALID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n");
|
||||
sb.append(""+board[0] + board[1] + board[2] + "\n");
|
||||
sb.append(""+board[3] + board[4] + board[5] + "\n");
|
||||
sb.append(""+board[6] + board[7] + board[8] + "\n");
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
@@ -10,6 +10,12 @@ import java.io.IOException;
|
||||
|
||||
import javax.xml.bind.JAXBException;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.Connection;
|
||||
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||
import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
100
test/net/woodyfolsom/msproj/ann/TTTFilterTest.java
Normal file
100
test/net/woodyfolsom/msproj/ann/TTTFilterTest.java
Normal file
@@ -0,0 +1,100 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
|
||||
import net.woodyfolsom.msproj.ann.TTTFilter;
|
||||
import net.woodyfolsom.msproj.tictactoe.GameRecord;
|
||||
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||
import net.woodyfolsom.msproj.tictactoe.Referee;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TTTFilterTest {
|
||||
private static final String FILENAME = "tttPerceptron.net";
|
||||
|
||||
@AfterClass
|
||||
public static void deleteNewNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void deleteSavedNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearn() throws IOException {
|
||||
double alpha = 0.5;
|
||||
double lambda = 0.0;
|
||||
int maxEpochs = 1000;
|
||||
|
||||
NeuralNetFilter nnLearner = new TTTFilter(alpha, lambda, maxEpochs);
|
||||
|
||||
// Create trainingSet from a tournament of random games.
|
||||
// Future iterations will use Epsilon-greedy play from a policy based on
|
||||
// this network to generate additional datasets.
|
||||
List<GameRecord> tournament = new Referee().play(1);
|
||||
List<List<NNDataPair>> trainingSet = NNDataSetFactory
|
||||
.createDataSet(tournament);
|
||||
|
||||
System.out.println("Generated " + trainingSet.size()
|
||||
+ " datasets from random self-play.");
|
||||
nnLearner.learnSequences(trainingSet);
|
||||
System.out.println("Learned network after "
|
||||
+ nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[7][];
|
||||
|
||||
// empty board
|
||||
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// center
|
||||
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// top edge
|
||||
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// left edge
|
||||
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// corner
|
||||
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// win
|
||||
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
|
||||
-1.0, 0.0 };
|
||||
// loss
|
||||
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
|
||||
0.0, -1.0 };
|
||||
|
||||
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
||||
"12", "20", "21", "22" };
|
||||
String[] outputNames = new String[] { "values" };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetFilter nnLearner,
|
||||
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||
validationSet[valIndex]), new NNData(outputNames,
|
||||
validationSet[valIndex]));
|
||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileFilter;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import net.woodyfolsom.msproj.GameRecord;
|
||||
import net.woodyfolsom.msproj.Referee;
|
||||
|
||||
import org.antlr.runtime.RecognitionException;
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.junit.Test;
|
||||
|
||||
public class WinFilterTest {
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException, RecognitionException {
|
||||
File[] sgfFiles = new File("data/games/random_vs_random")
|
||||
.listFiles(new FileFilter() {
|
||||
@Override
|
||||
public boolean accept(File pathname) {
|
||||
return pathname.getName().endsWith(".sgf");
|
||||
}
|
||||
});
|
||||
|
||||
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
|
||||
|
||||
for (File file : sgfFiles) {
|
||||
FileInputStream fis = new FileInputStream(file);
|
||||
GameRecord gameRecord = Referee.replay(fis);
|
||||
|
||||
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
|
||||
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
|
||||
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
|
||||
}
|
||||
|
||||
trainingData.add(gameData);
|
||||
|
||||
fis.close();
|
||||
}
|
||||
|
||||
WinFilter winFilter = new WinFilter();
|
||||
|
||||
winFilter.learn(trainingData);
|
||||
|
||||
for (List<MLDataPair> trainingSequence : trainingData) {
|
||||
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
||||
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
||||
continue;
|
||||
}
|
||||
MLData input = trainingSequence.get(stateIndex).getInput();
|
||||
|
||||
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
||||
+ winFilter.compute(input));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,19 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
|
||||
import net.woodyfolsom.msproj.ann.XORFilter;
|
||||
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
@@ -29,14 +38,55 @@ public class XORFilterTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter();
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
public void testLearn() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 1;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(20000);
|
||||
nnLearner.learnPatterns(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 2;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
@@ -49,10 +99,17 @@ public class XORFilterTest {
|
||||
}
|
||||
|
||||
// create training data
|
||||
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(1);
|
||||
nnLearner.learnPatterns(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
nnLearner.learn(trainingSet);
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
@@ -61,18 +118,23 @@ public class XORFilterTest {
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||
testNetwork(nnLearner, validationSet);
|
||||
|
||||
nnLearner.save(FILENAME);
|
||||
nnLearner.load(FILENAME);
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(FILENAME);
|
||||
assertTrue(nnLearner.save(fos));
|
||||
fos.close();
|
||||
|
||||
FileInputStream fis = new FileInputStream(FILENAME);
|
||||
assertTrue(nnLearner.load(fis));
|
||||
fis.close();
|
||||
|
||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||
testNetwork(nnLearner, validationSet);
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class XORFilterTest {
|
||||
private static final String FILENAME = "xorPerceptron.net";
|
||||
|
||||
@AfterClass
|
||||
public static void deleteNewNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void deleteSavedNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearn() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 1;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(20000);
|
||||
nnLearner.learn(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 2;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(1);
|
||||
nnLearner.learn(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(FILENAME);
|
||||
assertTrue(nnLearner.save(fos));
|
||||
fos.close();
|
||||
|
||||
FileInputStream fis = new FileInputStream(FILENAME);
|
||||
assertTrue(nnLearner.load(fis));
|
||||
fis.close();
|
||||
|
||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
73
test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
Normal file
73
test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
Normal file
@@ -0,0 +1,73 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class GameRecordTest {
|
||||
|
||||
@Test
|
||||
public void testGetResultXwins() {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
|
||||
State finalState = gameRecord.getState();
|
||||
System.out.println("Final state:");
|
||||
System.out.println(finalState);
|
||||
assertTrue(finalState.isValid());
|
||||
assertTrue(finalState.isTerminal());
|
||||
assertTrue(finalState.isWinner(PLAYER.X));
|
||||
assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetResultOwins() {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
|
||||
|
||||
State finalState = gameRecord.getState();
|
||||
System.out.println("Final state:");
|
||||
System.out.println(finalState);
|
||||
assertTrue(finalState.isValid());
|
||||
assertTrue(finalState.isTerminal());
|
||||
assertTrue(finalState.isWinner(PLAYER.O));
|
||||
assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetResultTieGame() {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
|
||||
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
|
||||
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
|
||||
|
||||
State finalState = gameRecord.getState();
|
||||
System.out.println("Final state:");
|
||||
System.out.println(finalState);
|
||||
assertTrue(finalState.isValid());
|
||||
assertTrue(finalState.isTerminal());
|
||||
assertFalse(finalState.isWinner(PLAYER.X));
|
||||
assertFalse(finalState.isWinner(PLAYER.O));
|
||||
assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
|
||||
}
|
||||
}
|
||||
12
test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
Normal file
12
test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
Normal file
@@ -0,0 +1,12 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class RefereeTest {
|
||||
|
||||
@Test
|
||||
public void testPlay100Games() {
|
||||
new Referee().play(100);
|
||||
}
|
||||
|
||||
}
|
||||
129
ttt.net
Normal file
129
ttt.net
Normal file
@@ -0,0 +1,129 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<multiLayerPerceptron biased="true" name="TicTacToe">
|
||||
<activationFunction name="Sigmoid"/>
|
||||
<connections dest="10" src="0" weight="0.5827629317852295"/>
|
||||
<connections dest="10" src="1" weight="0.49198735902918994"/>
|
||||
<connections dest="10" src="2" weight="-0.3019566272377494"/>
|
||||
<connections dest="10" src="3" weight="0.42204442000472525"/>
|
||||
<connections dest="10" src="4" weight="-0.26015075178733194"/>
|
||||
<connections dest="10" src="5" weight="-0.001558299861060293"/>
|
||||
<connections dest="10" src="6" weight="0.07987916348233416"/>
|
||||
<connections dest="10" src="7" weight="0.07258122647153753"/>
|
||||
<connections dest="10" src="8" weight="-0.691045501522254"/>
|
||||
<connections dest="10" src="9" weight="0.7118463494749109"/>
|
||||
<connections dest="11" src="0" weight="-1.8387878977128804"/>
|
||||
<connections dest="11" src="1" weight="0.07066242812415906"/>
|
||||
<connections dest="11" src="2" weight="-0.2141385079779094"/>
|
||||
<connections dest="11" src="3" weight="0.02318115051417748"/>
|
||||
<connections dest="11" src="4" weight="-0.4940158494633454"/>
|
||||
<connections dest="11" src="5" weight="0.24951794707397953"/>
|
||||
<connections dest="11" src="6" weight="-0.3422002057868113"/>
|
||||
<connections dest="11" src="7" weight="-0.34896333718320666"/>
|
||||
<connections dest="11" src="8" weight="0.18236809262087086"/>
|
||||
<connections dest="11" src="9" weight="-0.39168932467050466"/>
|
||||
<connections dest="12" src="0" weight="1.5206290139263101"/>
|
||||
<connections dest="12" src="1" weight="-0.4806468102477885"/>
|
||||
<connections dest="12" src="2" weight="0.21439697155823853"/>
|
||||
<connections dest="12" src="3" weight="0.1226010537695569"/>
|
||||
<connections dest="12" src="4" weight="-0.2957055657777683"/>
|
||||
<connections dest="12" src="5" weight="0.6130228290778311"/>
|
||||
<connections dest="12" src="6" weight="0.36875530286236485"/>
|
||||
<connections dest="12" src="7" weight="-0.5171899914088294"/>
|
||||
<connections dest="12" src="8" weight="0.10837708801339006"/>
|
||||
<connections dest="12" src="9" weight="-0.7053746937035315"/>
|
||||
<connections dest="13" src="0" weight="0.002913660858364482"/>
|
||||
<connections dest="13" src="1" weight="-0.7651207747987173"/>
|
||||
<connections dest="13" src="2" weight="0.9715970070491731"/>
|
||||
<connections dest="13" src="3" weight="-0.9956453258174628"/>
|
||||
<connections dest="13" src="4" weight="-0.9408358352747842"/>
|
||||
<connections dest="13" src="5" weight="-1.008966493202113"/>
|
||||
<connections dest="13" src="6" weight="-0.672355054680489"/>
|
||||
<connections dest="13" src="7" weight="-0.3367206164565582"/>
|
||||
<connections dest="13" src="8" weight="0.7588693137687637"/>
|
||||
<connections dest="13" src="9" weight="-0.7196453490945308"/>
|
||||
<connections dest="14" src="0" weight="-1.9439726796836931"/>
|
||||
<connections dest="14" src="1" weight="-0.2894027034518325"/>
|
||||
<connections dest="14" src="2" weight="0.2110335238178935"/>
|
||||
<connections dest="14" src="3" weight="-0.009846640898758158"/>
|
||||
<connections dest="14" src="4" weight="0.1568088381509006"/>
|
||||
<connections dest="14" src="5" weight="-0.18073468038735682"/>
|
||||
<connections dest="14" src="6" weight="0.3823096688264287"/>
|
||||
<connections dest="14" src="7" weight="-0.21319807548539116"/>
|
||||
<connections dest="14" src="8" weight="-0.3736851760400955"/>
|
||||
<connections dest="14" src="9" weight="-0.10659568761110778"/>
|
||||
<connections dest="15" src="0" weight="-3.5802003342217197"/>
|
||||
<connections dest="15" src="10" weight="-0.520010988494904"/>
|
||||
<connections dest="15" src="11" weight="2.0607479402794953"/>
|
||||
<connections dest="15" src="12" weight="-1.3810086619100004"/>
|
||||
<connections dest="15" src="13" weight="-0.024645797466295187"/>
|
||||
<connections dest="15" src="14" weight="2.4372644169618125"/>
|
||||
<neurons id="0">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="1">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="2">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="3">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="4">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="5">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="6">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="7">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="8">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="9">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="10">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="11">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="12">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="13">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="14">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="15">
|
||||
<activationFunction name="Sigmoid"/>
|
||||
</neurons>
|
||||
<layers>
|
||||
<neuronIds>1</neuronIds>
|
||||
<neuronIds>2</neuronIds>
|
||||
<neuronIds>3</neuronIds>
|
||||
<neuronIds>4</neuronIds>
|
||||
<neuronIds>5</neuronIds>
|
||||
<neuronIds>6</neuronIds>
|
||||
<neuronIds>7</neuronIds>
|
||||
<neuronIds>8</neuronIds>
|
||||
<neuronIds>9</neuronIds>
|
||||
</layers>
|
||||
<layers>
|
||||
<neuronIds>10</neuronIds>
|
||||
<neuronIds>11</neuronIds>
|
||||
<neuronIds>12</neuronIds>
|
||||
<neuronIds>13</neuronIds>
|
||||
<neuronIds>14</neuronIds>
|
||||
</layers>
|
||||
<layers>
|
||||
<neuronIds>15</neuronIds>
|
||||
</layers>
|
||||
</multiLayerPerceptron>
|
||||
Reference in New Issue
Block a user