Lots of neural network stuff.
This commit is contained in:
@@ -7,6 +7,5 @@
|
|||||||
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
|
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/kgsGtp.jar"/>
|
<classpathentry kind="lib" path="lib/kgsGtp.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
|
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/encog-java-core.jar" sourcepath="lib/encog-java-core-sources.jar"/>
|
|
||||||
<classpathentry kind="output" path="bin"/>
|
<classpathentry kind="output" path="bin"/>
|
||||||
</classpath>
|
</classpath>
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,21 +1,26 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.InputStream;
|
||||||
import java.io.FileInputStream;
|
import java.io.OutputStream;
|
||||||
import java.io.FileOutputStream;
|
import java.util.List;
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
import org.encog.ml.data.MLData;
|
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
|
||||||
import org.encog.neural.networks.PersistBasicNetwork;
|
|
||||||
|
|
||||||
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||||
protected BasicNetwork neuralNetwork;
|
private final FeedforwardNetwork neuralNetwork;
|
||||||
protected int actualTrainingEpochs = 0;
|
private final TrainingMethod trainingMethod;
|
||||||
protected int maxTrainingEpochs = 1000;
|
|
||||||
|
private double maxError;
|
||||||
|
private int actualTrainingEpochs = 0;
|
||||||
|
private int maxTrainingEpochs;
|
||||||
|
|
||||||
|
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
|
||||||
|
this.neuralNetwork = neuralNetwork;
|
||||||
|
this.trainingMethod = trainingMethod;
|
||||||
|
this.maxError = maxError;
|
||||||
|
this.maxTrainingEpochs = maxTrainingEpochs;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MLData compute(MLData input) {
|
public NNData compute(NNDataPair input) {
|
||||||
return this.neuralNetwork.compute(input);
|
return this.neuralNetwork.compute(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,38 +28,83 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
|||||||
return actualTrainingEpochs;
|
return actualTrainingEpochs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputSize() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
public int getMaxTrainingEpochs() {
|
public int getMaxTrainingEpochs() {
|
||||||
return maxTrainingEpochs;
|
return maxTrainingEpochs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
protected FeedforwardNetwork getNeuralNetwork() {
|
||||||
public BasicNetwork getNeuralNetwork() {
|
|
||||||
return neuralNetwork;
|
return neuralNetwork;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void load(String filename) throws IOException {
|
@Override
|
||||||
FileInputStream fis = new FileInputStream(new File(filename));
|
public void learnPatterns(List<NNDataPair> trainingSet) {
|
||||||
neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis);
|
actualTrainingEpochs = 0;
|
||||||
fis.close();
|
double error;
|
||||||
|
neuralNetwork.initWeights();
|
||||||
|
|
||||||
|
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
|
||||||
|
|
||||||
|
if (error <= maxError) {
|
||||||
|
System.out.println("Initial error: " + error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
do {
|
||||||
|
trainingMethod.iteratePatterns(neuralNetwork,trainingSet);
|
||||||
|
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
|
||||||
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
|
+ error);
|
||||||
|
actualTrainingEpochs++;
|
||||||
|
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||||
|
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void learnSequences(List<List<NNDataPair>> trainingSet) {
|
||||||
neuralNetwork.reset();
|
actualTrainingEpochs = 0;
|
||||||
|
double error;
|
||||||
|
neuralNetwork.initWeights();
|
||||||
|
|
||||||
|
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||||
|
|
||||||
|
if (error <= maxError) {
|
||||||
|
System.out.println("Initial error: " + error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
do {
|
||||||
|
trainingMethod.iterateSequences(neuralNetwork,trainingSet);
|
||||||
|
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||||
|
if (Double.isNaN(error)) {
|
||||||
|
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||||
|
}
|
||||||
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
|
+ error);
|
||||||
|
actualTrainingEpochs++;
|
||||||
|
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||||
|
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset(int seed) {
|
public boolean load(InputStream input) {
|
||||||
neuralNetwork.reset(seed);
|
return neuralNetwork.load(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean save(OutputStream output) {
|
||||||
|
return neuralNetwork.save(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void save(String filename) throws IOException {
|
public void setMaxError(double maxError) {
|
||||||
FileOutputStream fos = new FileOutputStream(new File(filename));
|
this.maxError = maxError;
|
||||||
new PersistBasicNetwork().save(fos, getNeuralNetwork());
|
|
||||||
fos.close();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setMaxTrainingEpochs(int max) {
|
public void setMaxTrainingEpochs(int max) {
|
||||||
this.maxTrainingEpochs = max;
|
this.maxTrainingEpochs = max;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.ann2.math.ErrorFunction;
|
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
|
||||||
import net.woodyfolsom.msproj.ann2.math.MSSE;
|
import net.woodyfolsom.msproj.ann.math.MSSE;
|
||||||
|
|
||||||
public class BackPropagation implements TrainingMethod {
|
public class BackPropagation extends TrainingMethod {
|
||||||
private final ErrorFunction errorFunction;
|
private final ErrorFunction errorFunction;
|
||||||
private final double learningRate;
|
private final double learningRate;
|
||||||
private final double momentum;
|
private final double momentum;
|
||||||
@@ -17,15 +17,13 @@ public class BackPropagation implements TrainingMethod {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void iterate(FeedforwardNetwork neuralNetwork,
|
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||||
List<NNDataPair> trainingSet) {
|
List<NNDataPair> trainingSet) {
|
||||||
System.out.println("Learningrate: " + learningRate);
|
System.out.println("Learningrate: " + learningRate);
|
||||||
System.out.println("Momentum: " + momentum);
|
System.out.println("Momentum: " + momentum);
|
||||||
|
|
||||||
//zeroErrors(neuralNetwork);
|
|
||||||
|
|
||||||
for (NNDataPair trainingPair : trainingSet) {
|
for (NNDataPair trainingPair : trainingSet) {
|
||||||
zeroErrors(neuralNetwork);
|
zeroGradients(neuralNetwork);
|
||||||
|
|
||||||
System.out.println("Training with: " + trainingPair.getInput());
|
System.out.println("Training with: " + trainingPair.getInput());
|
||||||
|
|
||||||
@@ -35,16 +33,15 @@ public class BackPropagation implements TrainingMethod {
|
|||||||
System.out.println("Updating weights. Ideal Output: " + ideal);
|
System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||||
System.out.println("Actual Output: " + actual);
|
System.out.println("Actual Output: " + actual);
|
||||||
|
|
||||||
updateErrors(neuralNetwork, ideal);
|
//backpropagate the gradients w.r.t. output error
|
||||||
|
backPropagate(neuralNetwork, ideal);
|
||||||
|
|
||||||
updateWeights(neuralNetwork);
|
updateWeights(neuralNetwork);
|
||||||
}
|
}
|
||||||
|
|
||||||
//updateWeights(neuralNetwork);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double computeError(FeedforwardNetwork neuralNetwork,
|
public double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||||
List<NNDataPair> trainingSet) {
|
List<NNDataPair> trainingSet) {
|
||||||
int numDataPairs = trainingSet.size();
|
int numDataPairs = trainingSet.size();
|
||||||
int outputSize = neuralNetwork.getOutput().length;
|
int outputSize = neuralNetwork.getOutput().length;
|
||||||
@@ -67,15 +64,17 @@ public class BackPropagation implements TrainingMethod {
|
|||||||
return MSSE;
|
return MSSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
@Override
|
||||||
|
protected
|
||||||
|
void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||||
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||||
double[] idealValues = ideal.getValues();
|
double[] idealValues = ideal.getValues();
|
||||||
|
|
||||||
for (int i = 0; i < idealValues.length; i++) {
|
for (int i = 0; i < idealValues.length; i++) {
|
||||||
double output = outputNeurons[i].getOutput();
|
double input = outputNeurons[i].getInput();
|
||||||
double derivative = outputNeurons[i].getActivationFunction()
|
double derivative = outputNeurons[i].getActivationFunction()
|
||||||
.derivative(output);
|
.derivative(input);
|
||||||
outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output));
|
outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput()));
|
||||||
}
|
}
|
||||||
// walking down the list of Neurons in reverse order, propagate the
|
// walking down the list of Neurons in reverse order, propagate the
|
||||||
// error
|
// error
|
||||||
@@ -84,19 +83,19 @@ public class BackPropagation implements TrainingMethod {
|
|||||||
for (int n = neurons.length - 1; n >= 0; n--) {
|
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||||
|
|
||||||
Neuron neuron = neurons[n];
|
Neuron neuron = neurons[n];
|
||||||
double error = neuron.getError();
|
double error = neuron.getGradient();
|
||||||
|
|
||||||
Connection[] connectionsFromN = neuralNetwork
|
Connection[] connectionsFromN = neuralNetwork
|
||||||
.getConnectionsFrom(neuron.getId());
|
.getConnectionsFrom(neuron.getId());
|
||||||
if (connectionsFromN.length > 0) {
|
if (connectionsFromN.length > 0) {
|
||||||
|
|
||||||
double derivative = neuron.getActivationFunction().derivative(
|
double derivative = neuron.getActivationFunction().derivative(
|
||||||
neuron.getOutput());
|
neuron.getInput());
|
||||||
for (Connection connection : connectionsFromN) {
|
for (Connection connection : connectionsFromN) {
|
||||||
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError();
|
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
neuron.setError(error);
|
neuron.setGradient(error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,17 +103,30 @@ public class BackPropagation implements TrainingMethod {
|
|||||||
for (Connection connection : neuralNetwork.getConnections()) {
|
for (Connection connection : neuralNetwork.getConnections()) {
|
||||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||||
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError();
|
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient();
|
||||||
//TODO allow for momentum
|
//TODO allow for momentum
|
||||||
//double lastDelta = connection.getLastDelta();
|
//double lastDelta = connection.getLastDelta();
|
||||||
connection.addDelta(delta);
|
connection.addDelta(delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void zeroErrors(FeedforwardNetwork neuralNetwork) {
|
@Override
|
||||||
// Set output errors relative to ideals, all other errors to 0.
|
public void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||||
for (Neuron neuron : neuralNetwork.getNeurons()) {
|
List<List<NNDataPair>> trainingSet) {
|
||||||
neuron.setError(0.0);
|
throw new UnsupportedOperationException();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<List<NNDataPair>> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||||
|
NNDataPair statePair, NNData nextReward) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import javax.xml.bind.annotation.XmlAttribute;
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
import javax.xml.bind.annotation.XmlTransient;
|
import javax.xml.bind.annotation.XmlTransient;
|
||||||
@@ -8,6 +8,7 @@ public class Connection {
|
|||||||
private int dest;
|
private int dest;
|
||||||
private double weight;
|
private double weight;
|
||||||
private transient double lastDelta = 0.0;
|
private transient double lastDelta = 0.0;
|
||||||
|
private transient double trace = 0.0;
|
||||||
|
|
||||||
public Connection() {
|
public Connection() {
|
||||||
//no-arg constructor for JAXB
|
//no-arg constructor for JAXB
|
||||||
@@ -20,6 +21,7 @@ public class Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void addDelta(double delta) {
|
public void addDelta(double delta) {
|
||||||
|
this.trace = delta;
|
||||||
this.weight += delta;
|
this.weight += delta;
|
||||||
this.lastDelta = delta;
|
this.lastDelta = delta;
|
||||||
}
|
}
|
||||||
@@ -39,6 +41,10 @@ public class Connection {
|
|||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double getTrace() {
|
||||||
|
return trace;
|
||||||
|
}
|
||||||
|
|
||||||
@XmlAttribute
|
@XmlAttribute
|
||||||
public double getWeight() {
|
public double getWeight() {
|
||||||
return weight;
|
return weight;
|
||||||
@@ -52,6 +58,11 @@ public class Connection {
|
|||||||
this.src = src;
|
this.src = src;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@XmlTransient
|
||||||
|
public void setTrace(double trace) {
|
||||||
|
this.trace = trace;
|
||||||
|
}
|
||||||
|
|
||||||
public void setWeight(double weight) {
|
public void setWeight(double weight) {
|
||||||
this.weight = weight;
|
this.weight = weight;
|
||||||
}
|
}
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import org.encog.ml.data.basic.BasicMLData;
|
|
||||||
|
|
||||||
public class DoublePair extends BasicMLData {
|
|
||||||
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
|
|
||||||
public DoublePair(double x, double y) {
|
|
||||||
super(new double[] { x, y });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
@@ -9,10 +9,11 @@ import java.util.Map;
|
|||||||
|
|
||||||
import javax.xml.bind.annotation.XmlAttribute;
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
import javax.xml.bind.annotation.XmlElement;
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
|
import javax.xml.bind.annotation.XmlTransient;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||||
import net.woodyfolsom.msproj.ann2.math.Linear;
|
import net.woodyfolsom.msproj.ann.math.Linear;
|
||||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||||
|
|
||||||
public abstract class FeedforwardNetwork {
|
public abstract class FeedforwardNetwork {
|
||||||
private ActivationFunction activationFunction;
|
private ActivationFunction activationFunction;
|
||||||
@@ -83,12 +84,12 @@ public abstract class FeedforwardNetwork {
|
|||||||
* Adds a new neuron with a unique id to this FeedforwardNetwork.
|
* Adds a new neuron with a unique id to this FeedforwardNetwork.
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Neuron createNeuron(boolean input) {
|
Neuron createNeuron(boolean input, ActivationFunction afunc) {
|
||||||
Neuron neuron;
|
Neuron neuron;
|
||||||
if (input) {
|
if (input) {
|
||||||
neuron = new Neuron(Linear.function, neurons.size());
|
neuron = new Neuron(Linear.function, neurons.size());
|
||||||
} else {
|
} else {
|
||||||
neuron = new Neuron(activationFunction, neurons.size());
|
neuron = new Neuron(afunc, neurons.size());
|
||||||
}
|
}
|
||||||
neurons.add(neuron);
|
neurons.add(neuron);
|
||||||
return neuron;
|
return neuron;
|
||||||
@@ -153,6 +154,10 @@ public abstract class FeedforwardNetwork {
|
|||||||
return neurons.get(id);
|
return neurons.get(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Connection getConnection(int index) {
|
||||||
|
return connections.get(index);
|
||||||
|
}
|
||||||
|
|
||||||
@XmlElement
|
@XmlElement
|
||||||
protected Connection[] getConnections() {
|
protected Connection[] getConnections() {
|
||||||
return connections.toArray(new Connection[connections.size()]);
|
return connections.toArray(new Connection[connections.size()]);
|
||||||
@@ -178,6 +183,22 @@ public abstract class FeedforwardNetwork {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double[] getGradients() {
|
||||||
|
double[] gradients = new double[neurons.size()];
|
||||||
|
for (int n = 0; n < gradients.length; n++) {
|
||||||
|
gradients[n] = neurons.get(n).getGradient();
|
||||||
|
}
|
||||||
|
return gradients;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getWeights() {
|
||||||
|
double[] weights = new double[connections.size()];
|
||||||
|
for (int i = 0; i < connections.size(); i++) {
|
||||||
|
weights[i] = connections.get(i).getWeight();
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
@XmlAttribute
|
@XmlAttribute
|
||||||
public boolean isBiased() {
|
public boolean isBiased() {
|
||||||
return biased;
|
return biased;
|
||||||
@@ -226,7 +247,7 @@ public abstract class FeedforwardNetwork {
|
|||||||
this.biased = biased;
|
this.biased = biased;
|
||||||
|
|
||||||
if (biased) {
|
if (biased) {
|
||||||
Neuron biasNeuron = createNeuron(true);
|
Neuron biasNeuron = createNeuron(true, activationFunction);
|
||||||
biasNeuron.setInput(1.0);
|
biasNeuron.setInput(1.0);
|
||||||
biasNeuronId = biasNeuron.getId();
|
biasNeuronId = biasNeuron.getId();
|
||||||
} else {
|
} else {
|
||||||
@@ -270,6 +291,7 @@ public abstract class FeedforwardNetwork {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@XmlTransient
|
||||||
public void setWeights(double[] weights) {
|
public void setWeights(double[] weights) {
|
||||||
if (weights.length != connections.size()) {
|
if (weights.length != connections.size()) {
|
||||||
throw new IllegalArgumentException("# of weights must == # of connections");
|
throw new IllegalArgumentException("# of weights must == # of connections");
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.GameResult;
|
|
||||||
import net.woodyfolsom.msproj.GameState;
|
|
||||||
import net.woodyfolsom.msproj.Player;
|
|
||||||
|
|
||||||
import org.encog.ml.data.MLData;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.encog.ml.data.basic.BasicMLData;
|
|
||||||
import org.encog.ml.data.basic.BasicMLDataPair;
|
|
||||||
import org.encog.util.kmeans.Centroid;
|
|
||||||
|
|
||||||
public class GameStateMLDataPair implements MLDataPair {
|
|
||||||
private BasicMLDataPair mlDataPairDelegate;
|
|
||||||
private GameState gameState;
|
|
||||||
|
|
||||||
public GameStateMLDataPair(GameState gameState) {
|
|
||||||
this.gameState = gameState;
|
|
||||||
mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal()));
|
|
||||||
}
|
|
||||||
|
|
||||||
public GameStateMLDataPair(GameStateMLDataPair that) {
|
|
||||||
this.gameState = new GameState(that.gameState);
|
|
||||||
mlDataPairDelegate = new BasicMLDataPair(
|
|
||||||
that.mlDataPairDelegate.getInput(),
|
|
||||||
that.mlDataPairDelegate.getIdeal());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MLDataPair clone() {
|
|
||||||
return new GameStateMLDataPair(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Centroid<MLDataPair> createCentroid() {
|
|
||||||
return mlDataPairDelegate.createCentroid();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a vector of normalized scores from GameState.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private double[] createInput() {
|
|
||||||
|
|
||||||
GameResult result = gameState.getResult();
|
|
||||||
|
|
||||||
double maxScore = gameState.getGameConfig().getSize()
|
|
||||||
* gameState.getGameConfig().getSize();
|
|
||||||
|
|
||||||
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
|
|
||||||
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
|
|
||||||
|
|
||||||
return new double[] { blackScore, whiteScore };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a vector of values indicating strength of black/white win output
|
|
||||||
* from network.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private double[] createIdeal() {
|
|
||||||
GameResult result = gameState.getResult();
|
|
||||||
|
|
||||||
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
|
|
||||||
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
|
|
||||||
|
|
||||||
return new double[] { blackWinner, whiteWinner };
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MLData getIdeal() {
|
|
||||||
return mlDataPairDelegate.getIdeal();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] getIdealArray() {
|
|
||||||
return mlDataPairDelegate.getIdealArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MLData getInput() {
|
|
||||||
return mlDataPairDelegate.getInput();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] getInputArray() {
|
|
||||||
return mlDataPairDelegate.getInputArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getSignificance() {
|
|
||||||
return mlDataPairDelegate.getSignificance();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isSupervised() {
|
|
||||||
return mlDataPairDelegate.isSupervised();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIdealArray(double[] arg0) {
|
|
||||||
mlDataPairDelegate.setIdealArray(arg0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setInputArray(double[] arg0) {
|
|
||||||
mlDataPairDelegate.setInputArray(arg0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setSignificance(double arg0) {
|
|
||||||
mlDataPairDelegate.setSignificance(arg0);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
@@ -10,6 +10,10 @@ import javax.xml.bind.Unmarshaller;
|
|||||||
import javax.xml.bind.annotation.XmlElement;
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
import javax.xml.bind.annotation.XmlRootElement;
|
import javax.xml.bind.annotation.XmlRootElement;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||||
|
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||||
|
|
||||||
@XmlRootElement
|
@XmlRootElement
|
||||||
public class MultiLayerPerceptron extends FeedforwardNetwork {
|
public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||||
private boolean biased;
|
private boolean biased;
|
||||||
@@ -37,7 +41,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
|
|||||||
throw new IllegalArgumentException("Layer size must be >= 1");
|
throw new IllegalArgumentException("Layer size must be >= 1");
|
||||||
}
|
}
|
||||||
|
|
||||||
Layer newLayer = createNewLayer(layerIndex, layerSize);
|
|
||||||
|
Layer newLayer;
|
||||||
|
if (layerIndex == numLayers - 1) {
|
||||||
|
newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function);
|
||||||
|
} else {
|
||||||
|
newLayer = createNewLayer(layerIndex, layerSize, Tanh.function);
|
||||||
|
}
|
||||||
|
|
||||||
if (layerIndex > 0) {
|
if (layerIndex > 0) {
|
||||||
Layer prevLayer = layers[layerIndex - 1];
|
Layer prevLayer = layers[layerIndex - 1];
|
||||||
@@ -54,11 +64,11 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Layer createNewLayer(int layerIndex, int layerSize) {
|
private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) {
|
||||||
Layer layer = new Layer(layerSize);
|
Layer layer = new Layer(layerSize);
|
||||||
layers[layerIndex] = layer;
|
layers[layerIndex] = layer;
|
||||||
for (int n = 0; n < layerSize; n++) {
|
for (int n = 0; n < layerSize; n++) {
|
||||||
Neuron neuron = createNeuron(layerIndex == 0);
|
Neuron neuron = createNeuron(layerIndex == 0, afunc);
|
||||||
layer.setNeuronId(n, neuron.getId());
|
layer.setNeuronId(n, neuron.getId());
|
||||||
}
|
}
|
||||||
return layer;
|
return layer;
|
||||||
@@ -93,8 +103,13 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
|
|||||||
protected void setInput(double[] input) {
|
protected void setInput(double[] input) {
|
||||||
Layer inputLayer = layers[0];
|
Layer inputLayer = layers[0];
|
||||||
for (int n = 0; n < inputLayer.size(); n++) {
|
for (int n = 0; n < inputLayer.size(); n++) {
|
||||||
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
|
try {
|
||||||
|
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
|
||||||
|
} catch (NullPointerException npe) {
|
||||||
|
npe.printStackTrace();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setLayers(Layer[] layers) {
|
public void setLayers(Layer[] layers) {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
public class NNData {
|
public class NNData {
|
||||||
private final double[] values;
|
private final double[] values;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
public class NNDataPair {
|
public class NNDataPair {
|
||||||
private final NNData input;
|
private final NNData input;
|
||||||
@@ -1,30 +1,29 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import org.encog.ml.data.MLData;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.encog.ml.data.MLDataSet;
|
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
|
||||||
|
|
||||||
public interface NeuralNetFilter {
|
public interface NeuralNetFilter {
|
||||||
BasicNetwork getNeuralNetwork();
|
|
||||||
|
|
||||||
int getActualTrainingEpochs();
|
int getActualTrainingEpochs();
|
||||||
|
|
||||||
int getInputSize();
|
int getInputSize();
|
||||||
|
|
||||||
int getMaxTrainingEpochs();
|
int getMaxTrainingEpochs();
|
||||||
|
|
||||||
int getOutputSize();
|
int getOutputSize();
|
||||||
|
|
||||||
void learn(MLDataSet trainingSet);
|
boolean load(InputStream input);
|
||||||
void learn(Set<List<MLDataPair>> trainingSet);
|
|
||||||
|
boolean save(OutputStream output);
|
||||||
void load(String fileName) throws IOException;
|
|
||||||
void reset();
|
|
||||||
void reset(int seed);
|
|
||||||
void save(String fileName) throws IOException;
|
|
||||||
void setMaxTrainingEpochs(int max);
|
void setMaxTrainingEpochs(int max);
|
||||||
|
|
||||||
MLData compute(MLData input);
|
NNData compute(NNDataPair input);
|
||||||
|
|
||||||
|
//Due to Java type erasure, overloading a method
|
||||||
|
//simply named 'learn' which takes Lists would be problematic
|
||||||
|
|
||||||
|
void learnPatterns(List<NNDataPair> trainingSet);
|
||||||
|
void learnSequences(List<List<NNDataPair>> trainingSet);
|
||||||
}
|
}
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import javax.xml.bind.annotation.XmlAttribute;
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
import javax.xml.bind.annotation.XmlElement;
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
import javax.xml.bind.annotation.XmlTransient;
|
import javax.xml.bind.annotation.XmlTransient;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||||
|
|
||||||
public class Neuron {
|
public class Neuron {
|
||||||
private ActivationFunction activationFunction;
|
private ActivationFunction activationFunction;
|
||||||
private int id;
|
private int id;
|
||||||
private transient double input = 0.0;
|
private transient double input = 0.0;
|
||||||
private transient double error = 0.0;
|
private transient double gradient = 0.0;
|
||||||
|
|
||||||
public Neuron() {
|
public Neuron() {
|
||||||
//no-arg constructor for JAXB
|
//no-arg constructor for JAXB
|
||||||
@@ -37,8 +37,8 @@ public class Neuron {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@XmlTransient
|
@XmlTransient
|
||||||
public double getError() {
|
public double getGradient() {
|
||||||
return error;
|
return gradient;
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlTransient
|
@XmlTransient
|
||||||
@@ -50,8 +50,8 @@ public class Neuron {
|
|||||||
return activationFunction.calculate(input);
|
return activationFunction.calculate(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setError(double value) {
|
public void setGradient(double value) {
|
||||||
this.error = value;
|
this.gradient = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setInput(double input) {
|
public void setInput(double input) {
|
||||||
@@ -92,7 +92,7 @@ public class Neuron {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "Neuron #" + id +", input: " + input + ", error: " + error;
|
return "Neuron #" + id +", input: " + input + ", gradient: " + gradient;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
5
src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java
Normal file
5
src/net/woodyfolsom/msproj/ann/ObjectiveFunction.java
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
public class ObjectiveFunction {
|
||||||
|
|
||||||
|
}
|
||||||
34
src/net/woodyfolsom/msproj/ann/TTTFilter.java
Normal file
34
src/net/woodyfolsom/msproj/ann/TTTFilter.java
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Based on sample code from http://neuroph.sourceforge.net
|
||||||
|
*
|
||||||
|
* @author Woody
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class TTTFilter extends AbstractNeuralNetFilter implements
|
||||||
|
NeuralNetFilter {
|
||||||
|
|
||||||
|
private static final int INPUT_SIZE = 9;
|
||||||
|
private static final int OUTPUT_SIZE = 1;
|
||||||
|
|
||||||
|
public TTTFilter() {
|
||||||
|
this(0.5,0.0, 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TTTFilter(double alpha, double lambda, int maxEpochs) {
|
||||||
|
super( new MultiLayerPerceptron(true, INPUT_SIZE, 5, OUTPUT_SIZE),
|
||||||
|
new TemporalDifference(0.5,0.0), maxEpochs, 0.05);
|
||||||
|
super.getNeuralNetwork().setName("XORFilter");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputSize() {
|
||||||
|
return INPUT_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOutputSize() {
|
||||||
|
return OUTPUT_SIZE;
|
||||||
|
}
|
||||||
|
}
|
||||||
187
src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
Normal file
187
src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Action;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.GameRecord;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Policy;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.State;
|
||||||
|
|
||||||
|
public class TTTFilterTrainer { //implements epsilon-greedy trainer? online version of NeuralNetFilter
|
||||||
|
|
||||||
|
public static void main(String[] args) throws FileNotFoundException {
|
||||||
|
double alpha = 0.0;
|
||||||
|
double lambda = 0.9;
|
||||||
|
int maxGames = 15000;
|
||||||
|
|
||||||
|
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void trainNetwork(double alpha, double lambda, int maxGames) throws FileNotFoundException {
|
||||||
|
///
|
||||||
|
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9,5,1);
|
||||||
|
neuralNetwork.setName("TicTacToe");
|
||||||
|
neuralNetwork.initWeights();
|
||||||
|
TrainingMethod trainer = new TemporalDifference(0.5,0.5);
|
||||||
|
|
||||||
|
System.out.println("Playing untrained games.");
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Learning from " + maxGames + " games of random self-play");
|
||||||
|
|
||||||
|
int gamesPlayed = 0;
|
||||||
|
List<RESULT> results = new ArrayList<RESULT>();
|
||||||
|
do {
|
||||||
|
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork, trainer);
|
||||||
|
System.out.println("Winner: " + gameRecord.getResult());
|
||||||
|
gamesPlayed++;
|
||||||
|
results.add(gameRecord.getResult());
|
||||||
|
} while (gamesPlayed < maxGames);
|
||||||
|
///
|
||||||
|
|
||||||
|
System.out.println("Learned network after " + maxGames + " training games.");
|
||||||
|
|
||||||
|
double[][] validationSet = new double[8][];
|
||||||
|
|
||||||
|
for (int i = 0; i < results.size(); i++) {
|
||||||
|
if (i % 10 == 0) {
|
||||||
|
System.out.println("" + (i+1) + ". " + results.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// empty board
|
||||||
|
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// center
|
||||||
|
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// top edge
|
||||||
|
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// left edge
|
||||||
|
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// corner
|
||||||
|
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// win
|
||||||
|
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
|
||||||
|
-1.0, 0.0 };
|
||||||
|
// loss
|
||||||
|
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
|
||||||
|
0.0, -1.0 };
|
||||||
|
|
||||||
|
// about to win
|
||||||
|
validationSet[7] = new double[] {
|
||||||
|
-1.0, 1.0, 1.0,
|
||||||
|
1.0, -1.0, 1.0,
|
||||||
|
-1.0, -1.0, 0.0 };
|
||||||
|
|
||||||
|
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
||||||
|
"12", "20", "21", "22" };
|
||||||
|
String[] outputNames = new String[] { "values" };
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network):");
|
||||||
|
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||||
|
|
||||||
|
System.out.println("Playing optimal games.");
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
System.out.println("" + (i+1) + ". " + playOptimal(neuralNetwork).getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
File output = new File("ttt.net");
|
||||||
|
|
||||||
|
FileOutputStream fos = new FileOutputStream(output);
|
||||||
|
|
||||||
|
neuralNetwork.save(fos);*/
|
||||||
|
}
|
||||||
|
|
||||||
|
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
|
||||||
|
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||||
|
|
||||||
|
State state = gameRecord.getState();
|
||||||
|
|
||||||
|
do {
|
||||||
|
Action action;
|
||||||
|
State nextState;
|
||||||
|
|
||||||
|
action = neuralNetPolicy.getAction(gameRecord.getState());
|
||||||
|
|
||||||
|
nextState = gameRecord.apply(action);
|
||||||
|
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
|
||||||
|
//System.out.println("Next board state: " + nextState);
|
||||||
|
state = nextState;
|
||||||
|
} while (!state.isTerminal());
|
||||||
|
|
||||||
|
//finally, reinforce the actual reward
|
||||||
|
|
||||||
|
return gameRecord;
|
||||||
|
}
|
||||||
|
|
||||||
|
private GameRecord playEpsilonGreedy(double epsilon, FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
|
||||||
|
Policy randomPolicy = new RandomPolicy();
|
||||||
|
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||||
|
|
||||||
|
//System.out.println("Playing epsilon-greedy game.");
|
||||||
|
|
||||||
|
State state = gameRecord.getState();
|
||||||
|
NNDataPair statePair;
|
||||||
|
|
||||||
|
Policy selectedPolicy;
|
||||||
|
trainer.zeroTraces(neuralNetwork);
|
||||||
|
|
||||||
|
do {
|
||||||
|
Action action;
|
||||||
|
State nextState;
|
||||||
|
|
||||||
|
if (Math.random() < epsilon) {
|
||||||
|
selectedPolicy = randomPolicy;
|
||||||
|
action = selectedPolicy.getAction(gameRecord.getState());
|
||||||
|
nextState = gameRecord.apply(action);
|
||||||
|
} else {
|
||||||
|
selectedPolicy = neuralNetPolicy;
|
||||||
|
action = selectedPolicy.getAction(gameRecord.getState());
|
||||||
|
|
||||||
|
nextState = gameRecord.apply(action);
|
||||||
|
statePair = NNDataSetFactory.createDataPair(state);
|
||||||
|
NNDataPair nextStatePair = NNDataSetFactory.createDataPair(nextState);
|
||||||
|
trainer.iteratePattern(neuralNetwork, statePair, nextStatePair.getIdeal());
|
||||||
|
}
|
||||||
|
//System.out.println("Action " + action + " selected by policy " + selectedPolicy.getName());
|
||||||
|
|
||||||
|
//System.out.println("Next board state: " + nextState);
|
||||||
|
|
||||||
|
state = nextState;
|
||||||
|
} while (!state.isTerminal());
|
||||||
|
|
||||||
|
//finally, reinforce the actual reward
|
||||||
|
statePair = NNDataSetFactory.createDataPair(state);
|
||||||
|
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
|
||||||
|
|
||||||
|
return gameRecord;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testNetwork(FeedforwardNetwork neuralNetwork,
|
||||||
|
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||||
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
|
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||||
|
validationSet[valIndex]), new NNData(outputNames,
|
||||||
|
validationSet[valIndex]));
|
||||||
|
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,30 +1,133 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import org.encog.ml.data.MLDataSet;
|
import java.util.List;
|
||||||
import org.encog.neural.networks.ContainsFlat;
|
|
||||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
|
||||||
|
|
||||||
public class TemporalDifference extends Backpropagation {
|
public class TemporalDifference extends TrainingMethod {
|
||||||
|
private final double alpha;
|
||||||
|
private final double gamma = 1.0;
|
||||||
private final double lambda;
|
private final double lambda;
|
||||||
|
|
||||||
public TemporalDifference(ContainsFlat network, MLDataSet training,
|
public TemporalDifference(double alpha, double lambda) {
|
||||||
double theLearnRate, double theMomentum, double lambda) {
|
this.alpha = alpha;
|
||||||
super(network, training, theLearnRate, theMomentum);
|
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getLamdba() {
|
@Override
|
||||||
return lambda;
|
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double updateWeight(final double[] gradients,
|
public double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||||
final double[] lastGradient, final int index) {
|
List<NNDataPair> trainingSet) {
|
||||||
double alpha = this.getLearningRate();
|
int numDataPairs = trainingSet.size();
|
||||||
|
int outputSize = neuralNetwork.getOutput().length;
|
||||||
//TODO fill in weight update for TD(lambda)
|
int totalOutputSize = outputSize * numDataPairs;
|
||||||
|
|
||||||
return 0.0;
|
double[] actuals = new double[totalOutputSize];
|
||||||
|
double[] ideals = new double[totalOutputSize];
|
||||||
|
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
|
||||||
|
NNDataPair nnDataPair = trainingSet.get(dataPair);
|
||||||
|
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
|
||||||
|
.getValues());
|
||||||
|
double[] ideal = nnDataPair.getIdeal().getValues();
|
||||||
|
int offset = dataPair * outputSize;
|
||||||
|
|
||||||
|
System.arraycopy(actual, 0, actuals, offset, outputSize);
|
||||||
|
System.arraycopy(ideal, 0, ideals, offset, outputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
double MSSE = errorFunction.compute(ideals, actuals);
|
||||||
|
return MSSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||||
|
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||||
|
double[] idealValues = ideal.getValues();
|
||||||
|
|
||||||
|
for (int i = 0; i < idealValues.length; i++) {
|
||||||
|
double input = outputNeurons[i].getInput();
|
||||||
|
double derivative = outputNeurons[i].getActivationFunction()
|
||||||
|
.derivative(input);
|
||||||
|
outputNeurons[i].setGradient(outputNeurons[i].getGradient()
|
||||||
|
+ derivative
|
||||||
|
* (idealValues[i] - outputNeurons[i].getOutput()));
|
||||||
|
}
|
||||||
|
// walking down the list of Neurons in reverse order, propagate the
|
||||||
|
// error
|
||||||
|
Neuron[] neurons = neuralNetwork.getNeurons();
|
||||||
|
|
||||||
|
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||||
|
|
||||||
|
Neuron neuron = neurons[n];
|
||||||
|
double error = neuron.getGradient();
|
||||||
|
|
||||||
|
Connection[] connectionsFromN = neuralNetwork
|
||||||
|
.getConnectionsFrom(neuron.getId());
|
||||||
|
if (connectionsFromN.length > 0) {
|
||||||
|
|
||||||
|
double derivative = neuron.getActivationFunction().derivative(
|
||||||
|
neuron.getInput());
|
||||||
|
for (Connection connection : connectionsFromN) {
|
||||||
|
error += derivative
|
||||||
|
* connection.getWeight()
|
||||||
|
* neuralNetwork.getNeuron(connection.getDest())
|
||||||
|
.getGradient();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
neuron.setGradient(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
|
||||||
|
for (Connection connection : neuralNetwork.getConnections()) {
|
||||||
|
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
|
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||||
|
|
||||||
|
double delta = alpha * srcNeuron.getOutput()
|
||||||
|
* destNeuron.getGradient() * predictionError + connection.getTrace() * lambda;
|
||||||
|
|
||||||
|
// TODO allow for momentum
|
||||||
|
// double lastDelta = connection.getLastDelta();
|
||||||
|
connection.addDelta(delta);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<List<NNDataPair>> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<List<NNDataPair>> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||||
|
NNDataPair statePair, NNData nextReward) {
|
||||||
|
//System.out.println("Learningrate: " + alpha);
|
||||||
|
|
||||||
|
zeroGradients(neuralNetwork);
|
||||||
|
|
||||||
|
//System.out.println("Training with: " + statePair.getInput());
|
||||||
|
|
||||||
|
NNData ideal = nextReward;
|
||||||
|
NNData actual = neuralNetwork.compute(statePair);
|
||||||
|
|
||||||
|
//System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||||
|
//System.out.println("Actual Output: " + actual);
|
||||||
|
|
||||||
|
// backpropagate the gradients w.r.t. output error
|
||||||
|
backPropagate(neuralNetwork, ideal);
|
||||||
|
|
||||||
|
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
|
||||||
|
+ actual.getValues()[0] - nextReward.getValues()[0];
|
||||||
|
|
||||||
|
updateWeights(neuralNetwork, predictionError);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
43
src/net/woodyfolsom/msproj/ann/TrainingMethod.java
Normal file
43
src/net/woodyfolsom/msproj/ann/TrainingMethod.java
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann.math.MSSE;
|
||||||
|
|
||||||
|
public abstract class TrainingMethod {
|
||||||
|
protected final ErrorFunction errorFunction;
|
||||||
|
|
||||||
|
public TrainingMethod() {
|
||||||
|
this.errorFunction = MSSE.function;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||||
|
NNDataPair statePair, NNData nextReward);
|
||||||
|
|
||||||
|
protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet);
|
||||||
|
|
||||||
|
protected abstract double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet);
|
||||||
|
|
||||||
|
protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<List<NNDataPair>> trainingSet);
|
||||||
|
|
||||||
|
protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output);
|
||||||
|
|
||||||
|
protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<List<NNDataPair>> trainingSet);
|
||||||
|
|
||||||
|
protected void zeroGradients(FeedforwardNetwork neuralNetwork) {
|
||||||
|
for (Neuron neuron : neuralNetwork.getNeurons()) {
|
||||||
|
neuron.setGradient(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void zeroTraces(FeedforwardNetwork neuralNetwork) {
|
||||||
|
for (Connection conn : neuralNetwork.getConnections()) {
|
||||||
|
conn.setTrace(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.encog.ml.data.MLDataSet;
|
|
||||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
|
||||||
import org.encog.ml.train.MLTrain;
|
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
|
||||||
import org.encog.neural.networks.layers.BasicLayer;
|
|
||||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
|
||||||
|
|
||||||
public class WinFilter extends AbstractNeuralNetFilter implements
|
|
||||||
NeuralNetFilter {
|
|
||||||
|
|
||||||
public WinFilter() {
|
|
||||||
// create a neural network, without using a factory
|
|
||||||
BasicNetwork network = new BasicNetwork();
|
|
||||||
network.addLayer(new BasicLayer(null, false, 2));
|
|
||||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
|
|
||||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
|
|
||||||
network.getStructure().finalizeStructure();
|
|
||||||
network.reset();
|
|
||||||
|
|
||||||
this.neuralNetwork = network;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(MLDataSet trainingData) {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Method is necessary because with temporal difference learning, some of
|
|
||||||
* the MLDataPairs are related by being a sequence of moves within a
|
|
||||||
* particular game.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
|
||||||
MLDataSet mlDataset = new BasicMLDataSet();
|
|
||||||
|
|
||||||
for (List<MLDataPair> gameRecord : trainingSet) {
|
|
||||||
for (int t = 0; t < gameRecord.size() - 1; t++) {
|
|
||||||
mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t)
|
|
||||||
.getInput()));
|
|
||||||
}
|
|
||||||
mlDataset.add(gameRecord.get(gameRecord.size() - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
// train the neural network
|
|
||||||
final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25);
|
|
||||||
//final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8);
|
|
||||||
actualTrainingEpochs = 0;
|
|
||||||
|
|
||||||
do {
|
|
||||||
if (actualTrainingEpochs > 0) {
|
|
||||||
int gameStateIndex = 0;
|
|
||||||
for (List<MLDataPair> gameRecord : trainingSet) {
|
|
||||||
for (int t = 0; t < gameRecord.size() - 1; t++) {
|
|
||||||
MLDataPair oldDataPair = mlDataset.get(gameStateIndex);
|
|
||||||
this.neuralNetwork.compute(oldDataPair.getInput());
|
|
||||||
gameStateIndex++;
|
|
||||||
}
|
|
||||||
gameStateIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
train.iteration();
|
|
||||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
|
||||||
+ train.getError());
|
|
||||||
actualTrainingEpochs++;
|
|
||||||
} while (train.getError() > 0.01
|
|
||||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
neuralNetwork.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset(int seed) {
|
|
||||||
neuralNetwork.reset(seed);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public BasicNetwork getNeuralNetwork() {
|
|
||||||
// TODO Auto-generated method stub
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getInputSize() {
|
|
||||||
// TODO Auto-generated method stub
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getOutputSize() {
|
|
||||||
// TODO Auto-generated method stub
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +1,5 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
|
||||||
import org.encog.ml.data.MLData;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.encog.ml.data.MLDataSet;
|
|
||||||
import org.encog.ml.data.basic.BasicMLData;
|
|
||||||
import org.encog.ml.train.MLTrain;
|
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
|
||||||
import org.encog.neural.networks.layers.BasicLayer;
|
|
||||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Based on sample code from http://neuroph.sourceforge.net
|
* Based on sample code from http://neuroph.sourceforge.net
|
||||||
*
|
*
|
||||||
@@ -22,54 +9,30 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
|||||||
public class XORFilter extends AbstractNeuralNetFilter implements
|
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||||
NeuralNetFilter {
|
NeuralNetFilter {
|
||||||
|
|
||||||
|
private static final int INPUT_SIZE = 2;
|
||||||
|
private static final int OUTPUT_SIZE = 1;
|
||||||
|
|
||||||
public XORFilter() {
|
public XORFilter() {
|
||||||
// create a neural network, without using a factory
|
this(0.8,0.7);
|
||||||
BasicNetwork network = new BasicNetwork();
|
}
|
||||||
network.addLayer(new BasicLayer(null, false, 2));
|
|
||||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
|
public XORFilter(double learningRate, double momentum) {
|
||||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
|
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
|
||||||
network.getStructure().finalizeStructure();
|
new BackPropagation(learningRate, momentum), 1000, 0.01);
|
||||||
network.reset();
|
super.getNeuralNetwork().setName("XORFilter");
|
||||||
|
|
||||||
this.neuralNetwork = network;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public double compute(double x, double y) {
|
public double compute(double x, double y) {
|
||||||
return compute(new BasicMLData(new double[]{x,y})).getData(0);
|
return getNeuralNetwork().compute(new double[]{x,y})[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getInputSize() {
|
public int getInputSize() {
|
||||||
return 2;
|
return INPUT_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getOutputSize() {
|
public int getOutputSize() {
|
||||||
// TODO Auto-generated method stub
|
return OUTPUT_SIZE;
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(MLDataSet trainingSet) {
|
|
||||||
|
|
||||||
// train the neural network
|
|
||||||
final MLTrain train = new Backpropagation(neuralNetwork, trainingSet,
|
|
||||||
0.7, 0.8);
|
|
||||||
|
|
||||||
actualTrainingEpochs = 0;
|
|
||||||
|
|
||||||
do {
|
|
||||||
train.iteration();
|
|
||||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
|
||||||
+ train.getError());
|
|
||||||
actualTrainingEpochs++;
|
|
||||||
} while (train.getError() > 0.01
|
|
||||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"This Filter learns an MLDataSet, not a Set<List<MLData>>.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2.math;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
import javax.xml.bind.annotation.XmlAttribute;
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2.math;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
public interface ErrorFunction {
|
public interface ErrorFunction {
|
||||||
double compute(double[] ideal, double[] actual);
|
double compute(double[] ideal, double[] actual);
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2.math;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
public class Linear extends ActivationFunction{
|
public class Linear extends ActivationFunction{
|
||||||
public static final Linear function = new Linear();
|
public static final Linear function = new Linear();
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2.math;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
public class MSSE implements ErrorFunction{
|
public class MSSE implements ErrorFunction{
|
||||||
public static final ErrorFunction function = new MSSE();
|
public static final ErrorFunction function = new MSSE();
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2.math;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
public class Sigmoid extends ActivationFunction{
|
public class Sigmoid extends ActivationFunction{
|
||||||
public static final Sigmoid function = new Sigmoid();
|
public static final Sigmoid function = new Sigmoid();
|
||||||
@@ -12,9 +12,9 @@ public class Sigmoid extends ActivationFunction{
|
|||||||
}
|
}
|
||||||
|
|
||||||
public double derivative(double arg) {
|
public double derivative(double arg) {
|
||||||
//lol wth?
|
//lol wth? oh, the next derivative formula is a function of s(x), not x.
|
||||||
//double eX = Math.exp(arg);
|
double eX = Math.exp(arg);
|
||||||
//return eX / (Math.pow((1+eX), 2));
|
return eX / (Math.pow((1+eX), 2));
|
||||||
return arg - Math.pow(arg,2);
|
//return arg - Math.pow(arg,2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2.math;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
public class Tanh extends ActivationFunction{
|
public class Tanh extends ActivationFunction{
|
||||||
public static final Tanh function = new Tanh();
|
public static final Tanh function = new Tanh();
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
|
||||||
private final FeedforwardNetwork neuralNetwork;
|
|
||||||
private final TrainingMethod trainingMethod;
|
|
||||||
|
|
||||||
private double maxError;
|
|
||||||
private int actualTrainingEpochs = 0;
|
|
||||||
private int maxTrainingEpochs;
|
|
||||||
|
|
||||||
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
|
|
||||||
this.neuralNetwork = neuralNetwork;
|
|
||||||
this.trainingMethod = trainingMethod;
|
|
||||||
this.maxError = maxError;
|
|
||||||
this.maxTrainingEpochs = maxTrainingEpochs;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NNData compute(NNDataPair input) {
|
|
||||||
return this.neuralNetwork.compute(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getActualTrainingEpochs() {
|
|
||||||
return actualTrainingEpochs;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getInputSize() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMaxTrainingEpochs() {
|
|
||||||
return maxTrainingEpochs;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected FeedforwardNetwork getNeuralNetwork() {
|
|
||||||
return neuralNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(List<NNDataPair> trainingSet) {
|
|
||||||
actualTrainingEpochs = 0;
|
|
||||||
double error;
|
|
||||||
neuralNetwork.initWeights();
|
|
||||||
|
|
||||||
error = trainingMethod.computeError(neuralNetwork,trainingSet);
|
|
||||||
|
|
||||||
if (error <= maxError) {
|
|
||||||
System.out.println("Initial error: " + error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
do {
|
|
||||||
trainingMethod.iterate(neuralNetwork,trainingSet);
|
|
||||||
error = trainingMethod.computeError(neuralNetwork,trainingSet);
|
|
||||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
|
||||||
+ error);
|
|
||||||
actualTrainingEpochs++;
|
|
||||||
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
|
||||||
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean load(InputStream input) {
|
|
||||||
return neuralNetwork.load(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean save(OutputStream output) {
|
|
||||||
return neuralNetwork.save(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMaxError(double maxError) {
|
|
||||||
this.maxError = maxError;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMaxTrainingEpochs(int max) {
|
|
||||||
this.maxTrainingEpochs = max;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface NeuralNetFilter {
|
|
||||||
int getActualTrainingEpochs();
|
|
||||||
|
|
||||||
int getInputSize();
|
|
||||||
|
|
||||||
int getMaxTrainingEpochs();
|
|
||||||
|
|
||||||
int getOutputSize();
|
|
||||||
|
|
||||||
boolean load(InputStream input);
|
|
||||||
|
|
||||||
boolean save(OutputStream output);
|
|
||||||
|
|
||||||
void setMaxTrainingEpochs(int max);
|
|
||||||
|
|
||||||
NNData compute(NNDataPair input);
|
|
||||||
|
|
||||||
void learn(List<NNDataPair> trainingSet);
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
public class ObjectiveFunction {
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class TemporalDifference implements TrainingMethod {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void iterate(FeedforwardNetwork neuralNetwork,
|
|
||||||
List<NNDataPair> trainingSet) {
|
|
||||||
throw new UnsupportedOperationException("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double computeError(FeedforwardNetwork neuralNetwork,
|
|
||||||
List<NNDataPair> trainingSet) {
|
|
||||||
throw new UnsupportedOperationException("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface TrainingMethod {
|
|
||||||
|
|
||||||
void iterate(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
|
|
||||||
double computeError(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Based on sample code from http://neuroph.sourceforge.net
|
|
||||||
*
|
|
||||||
* @author Woody
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class XORFilter extends AbstractNeuralNetFilter implements
|
|
||||||
NeuralNetFilter {
|
|
||||||
|
|
||||||
private static final int INPUT_SIZE = 2;
|
|
||||||
private static final int OUTPUT_SIZE = 1;
|
|
||||||
|
|
||||||
public XORFilter() {
|
|
||||||
this(0.8,0.7);
|
|
||||||
}
|
|
||||||
|
|
||||||
public XORFilter(double learningRate, double momentum) {
|
|
||||||
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
|
|
||||||
new BackPropagation(learningRate, momentum), 1000, 0.01);
|
|
||||||
super.getNeuralNetwork().setName("XORFilter");
|
|
||||||
|
|
||||||
//TODO remove
|
|
||||||
//getNeuralNetwork().setWeights(new double[] {
|
|
||||||
// 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
|
|
||||||
// -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
|
|
||||||
// -0.993423, 0.164732, 0.752621}); //output
|
|
||||||
}
|
|
||||||
|
|
||||||
public double compute(double x, double y) {
|
|
||||||
return getNeuralNetwork().compute(new double[]{x,y})[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getInputSize() {
|
|
||||||
return INPUT_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getOutputSize() {
|
|
||||||
return OUTPUT_SIZE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
54
src/net/woodyfolsom/msproj/tictactoe/Action.java
Normal file
54
src/net/woodyfolsom/msproj/tictactoe/Action.java
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class Action {
|
||||||
|
public static final Action NONE = new Action(PLAYER.NONE, -1, -1);
|
||||||
|
|
||||||
|
private Game.PLAYER player;
|
||||||
|
private int row;
|
||||||
|
private int column;
|
||||||
|
|
||||||
|
public static Action getInstance(PLAYER player, int row, int column) {
|
||||||
|
return new Action(player,row,column);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Action(PLAYER player, int row, int column) {
|
||||||
|
this.player = player;
|
||||||
|
this.row = row;
|
||||||
|
this.column = column;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Game.PLAYER getPlayer() {
|
||||||
|
return player;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getColumn() {
|
||||||
|
return column;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getRow() {
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isNone() {
|
||||||
|
return this == Action.NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPlayer(Game.PLAYER player) {
|
||||||
|
this.player = player;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRow(int row) {
|
||||||
|
this.row = row;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setColumn(int column) {
|
||||||
|
this.column = column;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return player + "(" + row + ", " + column + ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
5
src/net/woodyfolsom/msproj/tictactoe/Game.java
Normal file
5
src/net/woodyfolsom/msproj/tictactoe/Game.java
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
public class Game {
|
||||||
|
public enum PLAYER {X,O,NONE}
|
||||||
|
}
|
||||||
63
src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
Normal file
63
src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class GameRecord {
|
||||||
|
public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS}
|
||||||
|
|
||||||
|
private List<Action> actions = new ArrayList<Action>();
|
||||||
|
private List<State> states = new ArrayList<State>();
|
||||||
|
|
||||||
|
private RESULT result = RESULT.IN_PROGRESS;
|
||||||
|
|
||||||
|
public GameRecord() {
|
||||||
|
actions.add(Action.NONE);
|
||||||
|
states.add(new State());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addState(State state) {
|
||||||
|
states.add(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
public State apply(Action action) {
|
||||||
|
State nextState = getState().apply(action);
|
||||||
|
if (nextState.isValid()) {
|
||||||
|
states.add(nextState);
|
||||||
|
actions.add(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nextState.isTerminal()) {
|
||||||
|
if (nextState.isWinner(PLAYER.X)) {
|
||||||
|
result = RESULT.X_WINS;
|
||||||
|
} else if (nextState.isWinner(PLAYER.O)) {
|
||||||
|
result = RESULT.O_WINS;
|
||||||
|
} else {
|
||||||
|
result = RESULT.TIE_GAME;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nextState;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumStates() {
|
||||||
|
return states.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public RESULT getResult() {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setResult(RESULT result) {
|
||||||
|
this.result = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public State getState() {
|
||||||
|
return states.get(states.size()-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public State getState(int index) {
|
||||||
|
return states.get(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
20
src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
Normal file
20
src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class MoveGenerator {
|
||||||
|
public List<Action> getValidActions(State state) {
|
||||||
|
PLAYER playerToMove = state.getPlayerToMove();
|
||||||
|
List<Action> validActions = new ArrayList<Action>();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
for (int j = 0; j < 3; j++) {
|
||||||
|
if (state.isEmpty(i,j))
|
||||||
|
validActions.add(Action.getInstance(playerToMove, i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validActions;
|
||||||
|
}
|
||||||
|
}
|
||||||
81
src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
Normal file
81
src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.NNData;
|
||||||
|
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class NNDataSetFactory {
|
||||||
|
public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"};
|
||||||
|
public static final String[] TTT_OUTPUT_FIELDS = {"value"};
|
||||||
|
|
||||||
|
public static List<List<NNDataPair>> createDataSet(List<GameRecord> tttGames) {
|
||||||
|
|
||||||
|
List<List<NNDataPair>> nnDataSet = new ArrayList<List<NNDataPair>>();
|
||||||
|
|
||||||
|
for (GameRecord tttGame : tttGames) {
|
||||||
|
List<NNDataPair> gameData = createDataPairList(tttGame);
|
||||||
|
|
||||||
|
|
||||||
|
nnDataSet.add(gameData);
|
||||||
|
}
|
||||||
|
|
||||||
|
return nnDataSet;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<NNDataPair> createDataPairList(GameRecord gameRecord) {
|
||||||
|
List<NNDataPair> gameData = new ArrayList<NNDataPair>();
|
||||||
|
|
||||||
|
for (int i = 0; i < gameRecord.getNumStates(); i++) {
|
||||||
|
gameData.add(createDataPair(gameRecord.getState(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return gameData;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static NNDataPair createDataPair(State tttState) {
|
||||||
|
double value;
|
||||||
|
if (tttState.isTerminal()) {
|
||||||
|
if (tttState.isWinner(PLAYER.X)) {
|
||||||
|
value = 1.0; // win for black
|
||||||
|
} else if (tttState.isWinner(PLAYER.O)) {
|
||||||
|
value = 0.0; // loss for black
|
||||||
|
//value = -1.0;
|
||||||
|
} else {
|
||||||
|
value = 0.5;
|
||||||
|
//value = 0.0; //tie
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] inputValues = new double[9];
|
||||||
|
char[] boardCopy = tttState.getBoard();
|
||||||
|
inputValues[0] = getTicTacToeInput(boardCopy, 0, 0);
|
||||||
|
inputValues[1] = getTicTacToeInput(boardCopy, 0, 1);
|
||||||
|
inputValues[2] = getTicTacToeInput(boardCopy, 0, 2);
|
||||||
|
inputValues[3] = getTicTacToeInput(boardCopy, 1, 0);
|
||||||
|
inputValues[4] = getTicTacToeInput(boardCopy, 1, 1);
|
||||||
|
inputValues[5] = getTicTacToeInput(boardCopy, 1, 2);
|
||||||
|
inputValues[6] = getTicTacToeInput(boardCopy, 2, 0);
|
||||||
|
inputValues[7] = getTicTacToeInput(boardCopy, 2, 1);
|
||||||
|
inputValues[8] = getTicTacToeInput(boardCopy, 2, 2);
|
||||||
|
|
||||||
|
return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double getTicTacToeInput(char[] board, int row, int column) {
|
||||||
|
switch (board[row*3+column]) {
|
||||||
|
case 'X' :
|
||||||
|
return 1.0;
|
||||||
|
case 'O' :
|
||||||
|
return -1.0;
|
||||||
|
case '.' :
|
||||||
|
return 0.0;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Invalid board symbol at " + row +", " + column);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
67
src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
Normal file
67
src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||||
|
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class NeuralNetPolicy extends Policy {
|
||||||
|
private FeedforwardNetwork neuralNet;
|
||||||
|
private MoveGenerator moveGenerator = new MoveGenerator();
|
||||||
|
|
||||||
|
public NeuralNetPolicy(FeedforwardNetwork neuralNet) {
|
||||||
|
super("NeuralNet-" + neuralNet.getName());
|
||||||
|
this.neuralNet = neuralNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Action getAction(State state) {
|
||||||
|
List<Action> validMoves = moveGenerator.getValidActions(state);
|
||||||
|
Map<Action, Double> scores = new HashMap<Action, Double>();
|
||||||
|
|
||||||
|
for (Action action : validMoves) {
|
||||||
|
State nextState = state.apply(action);
|
||||||
|
NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
|
||||||
|
//estimated reward for X
|
||||||
|
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PLAYER playerToMove = state.getPlayerToMove();
|
||||||
|
|
||||||
|
if (playerToMove == PLAYER.X) {
|
||||||
|
return returnMaxAction(scores);
|
||||||
|
} else if (playerToMove == PLAYER.O) {
|
||||||
|
return returnMinAction(scores);
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove);
|
||||||
|
}
|
||||||
|
//return validMoves.get((int)(Math.random() * validMoves.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Action returnMaxAction(Map<Action,Double> scores) {
|
||||||
|
Action bestAction = null;
|
||||||
|
Double bestScore = Double.NEGATIVE_INFINITY;
|
||||||
|
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
|
||||||
|
if (entry.getValue() > bestScore) {
|
||||||
|
bestScore = entry.getValue();
|
||||||
|
bestAction = entry.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestAction;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Action returnMinAction(Map<Action,Double> scores) {
|
||||||
|
Action bestAction = null;
|
||||||
|
Double bestScore = Double.POSITIVE_INFINITY;
|
||||||
|
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
|
||||||
|
if (entry.getValue() < bestScore) {
|
||||||
|
bestScore = entry.getValue();
|
||||||
|
bestAction = entry.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestAction;
|
||||||
|
}
|
||||||
|
}
|
||||||
15
src/net/woodyfolsom/msproj/tictactoe/Policy.java
Normal file
15
src/net/woodyfolsom/msproj/tictactoe/Policy.java
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
public abstract class Policy {
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
protected Policy(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract Action getAction(State state);
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
}
|
||||||
18
src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
Normal file
18
src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RandomPolicy extends Policy {
|
||||||
|
private MoveGenerator moveGenerator = new MoveGenerator();
|
||||||
|
|
||||||
|
public RandomPolicy() {
|
||||||
|
super("Random");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Action getAction(State state) {
|
||||||
|
List<Action> validMoves = moveGenerator.getValidActions(state);
|
||||||
|
return validMoves.get((int)(Math.random() * validMoves.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
43
src/net/woodyfolsom/msproj/tictactoe/Referee.java
Normal file
43
src/net/woodyfolsom/msproj/tictactoe/Referee.java
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class Referee {
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
new Referee().play(50);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<GameRecord> play(int nGames) {
|
||||||
|
Policy policy = new RandomPolicy();
|
||||||
|
|
||||||
|
List<GameRecord> tournament = new ArrayList<GameRecord>();
|
||||||
|
|
||||||
|
for (int i = 0; i < nGames; i++) {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
|
||||||
|
System.out.println("Playing game #" +(i+1));
|
||||||
|
|
||||||
|
State state;
|
||||||
|
do {
|
||||||
|
Action action = policy.getAction(gameRecord.getState());
|
||||||
|
System.out.println("Action " + action + " selected by policy " + policy.getName());
|
||||||
|
state = gameRecord.apply(action);
|
||||||
|
System.out.println("Next board state:");
|
||||||
|
System.out.println(gameRecord.getState());
|
||||||
|
} while (!state.isTerminal());
|
||||||
|
System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult());
|
||||||
|
tournament.add(gameRecord);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Played " + tournament.size() + " random games.");
|
||||||
|
System.out.println("Results:");
|
||||||
|
for (int i = 0; i < tournament.size(); i++) {
|
||||||
|
GameRecord gameRecord = tournament.get(i);
|
||||||
|
System.out.println((i+1) + ". " + gameRecord.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
return tournament;
|
||||||
|
}
|
||||||
|
}
|
||||||
116
src/net/woodyfolsom/msproj/tictactoe/State.java
Normal file
116
src/net/woodyfolsom/msproj/tictactoe/State.java
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class State {
|
||||||
|
public static final State INVALID = new State();
|
||||||
|
public static char EMPTY_SQUARE = '.';
|
||||||
|
|
||||||
|
private char[] board;
|
||||||
|
private PLAYER playerToMove;
|
||||||
|
|
||||||
|
public State() {
|
||||||
|
playerToMove = Game.PLAYER.X;
|
||||||
|
board = new char[9];
|
||||||
|
Arrays.fill(board,'.');
|
||||||
|
}
|
||||||
|
|
||||||
|
private State(State that) {
|
||||||
|
this.board = Arrays.copyOf(that.board, that.board.length);
|
||||||
|
this.playerToMove = that.playerToMove;
|
||||||
|
}
|
||||||
|
|
||||||
|
public State apply(Action action) {
|
||||||
|
if (action.getPlayer() != playerToMove) {
|
||||||
|
System.out.println("It is not " + action.getPlayer() +"'s turn.");
|
||||||
|
return State.INVALID;
|
||||||
|
}
|
||||||
|
State nextState = new State(this);
|
||||||
|
|
||||||
|
int row = action.getRow();
|
||||||
|
int column = action.getColumn();
|
||||||
|
int dest = row * 3 + column;
|
||||||
|
|
||||||
|
if (board[dest] != EMPTY_SQUARE) {
|
||||||
|
System.out.println("Invalid move " + action + ", coordinate not empty.");
|
||||||
|
return State.INVALID;
|
||||||
|
}
|
||||||
|
switch (playerToMove) {
|
||||||
|
case X : nextState.board[dest] = 'X';
|
||||||
|
break;
|
||||||
|
case O : nextState.board[dest] = 'O';
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Invalid playerToMove");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (playerToMove == PLAYER.X) {
|
||||||
|
nextState.playerToMove = PLAYER.O;
|
||||||
|
} else {
|
||||||
|
nextState.playerToMove = PLAYER.X;
|
||||||
|
}
|
||||||
|
return nextState;
|
||||||
|
}
|
||||||
|
|
||||||
|
public char[] getBoard() {
|
||||||
|
return Arrays.copyOf(board, board.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PLAYER getPlayerToMove() {
|
||||||
|
return playerToMove;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isEmpty(int row, int column) {
|
||||||
|
return board[row*3+column] == EMPTY_SQUARE;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isFull(char mark1, char mark2, char mark3) {
|
||||||
|
return mark1 != '.' && mark2 != '.' && mark3 != '.';
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isWinner(PLAYER player) {
|
||||||
|
return isWin(player,board[0],board[1],board[2]) ||
|
||||||
|
isWin(player,board[3],board[4],board[5]) ||
|
||||||
|
isWin(player,board[6],board[7],board[8]) ||
|
||||||
|
isWin(player,board[0],board[3],board[6]) ||
|
||||||
|
isWin(player,board[1],board[4],board[7]) ||
|
||||||
|
isWin(player,board[2],board[5],board[8]) ||
|
||||||
|
isWin(player,board[0],board[4],board[8]) ||
|
||||||
|
isWin(player,board[2],board[4],board[6]);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) {
|
||||||
|
if (isFull(mark1,mark2,mark3)) {
|
||||||
|
switch (player) {
|
||||||
|
case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X';
|
||||||
|
case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O';
|
||||||
|
default :
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isTerminal() {
|
||||||
|
return isWinner(PLAYER.X) || isWinner(PLAYER.O) ||
|
||||||
|
(isFull(board[0],board[1], board[2]) &&
|
||||||
|
isFull(board[3],board[4], board[5]) &&
|
||||||
|
isFull(board[6],board[7], board[8]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isValid() {
|
||||||
|
return this != INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n");
|
||||||
|
sb.append(""+board[0] + board[1] + board[2] + "\n");
|
||||||
|
sb.append(""+board[3] + board[4] + board[5] + "\n");
|
||||||
|
sb.append(""+board[6] + board[7] + board[8] + "\n");
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
@@ -10,6 +10,12 @@ import java.io.IOException;
|
|||||||
|
|
||||||
import javax.xml.bind.JAXBException;
|
import javax.xml.bind.JAXBException;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.Connection;
|
||||||
|
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||||
|
import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
|
||||||
|
import net.woodyfolsom.msproj.ann.NNData;
|
||||||
|
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||||
|
|
||||||
import org.junit.AfterClass;
|
import org.junit.AfterClass;
|
||||||
import org.junit.BeforeClass;
|
import org.junit.BeforeClass;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
100
test/net/woodyfolsom/msproj/ann/TTTFilterTest.java
Normal file
100
test/net/woodyfolsom/msproj/ann/TTTFilterTest.java
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.NNData;
|
||||||
|
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||||
|
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
|
||||||
|
import net.woodyfolsom.msproj.ann.TTTFilter;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.GameRecord;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Referee;
|
||||||
|
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class TTTFilterTest {
|
||||||
|
private static final String FILENAME = "tttPerceptron.net";
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void deleteNewNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void deleteSavedNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearn() throws IOException {
|
||||||
|
double alpha = 0.5;
|
||||||
|
double lambda = 0.0;
|
||||||
|
int maxEpochs = 1000;
|
||||||
|
|
||||||
|
NeuralNetFilter nnLearner = new TTTFilter(alpha, lambda, maxEpochs);
|
||||||
|
|
||||||
|
// Create trainingSet from a tournament of random games.
|
||||||
|
// Future iterations will use Epsilon-greedy play from a policy based on
|
||||||
|
// this network to generate additional datasets.
|
||||||
|
List<GameRecord> tournament = new Referee().play(1);
|
||||||
|
List<List<NNDataPair>> trainingSet = NNDataSetFactory
|
||||||
|
.createDataSet(tournament);
|
||||||
|
|
||||||
|
System.out.println("Generated " + trainingSet.size()
|
||||||
|
+ " datasets from random self-play.");
|
||||||
|
nnLearner.learnSequences(trainingSet);
|
||||||
|
System.out.println("Learned network after "
|
||||||
|
+ nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||||
|
|
||||||
|
double[][] validationSet = new double[7][];
|
||||||
|
|
||||||
|
// empty board
|
||||||
|
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// center
|
||||||
|
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// top edge
|
||||||
|
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// left edge
|
||||||
|
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// corner
|
||||||
|
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0 };
|
||||||
|
// win
|
||||||
|
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
|
||||||
|
-1.0, 0.0 };
|
||||||
|
// loss
|
||||||
|
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
|
||||||
|
0.0, -1.0 };
|
||||||
|
|
||||||
|
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
||||||
|
"12", "20", "21", "22" };
|
||||||
|
String[] outputNames = new String[] { "values" };
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network):");
|
||||||
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testNetwork(NeuralNetFilter nnLearner,
|
||||||
|
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||||
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
|
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||||
|
validationSet[valIndex]), new NNData(outputNames,
|
||||||
|
validationSet[valIndex]));
|
||||||
|
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileFilter;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.GameRecord;
|
|
||||||
import net.woodyfolsom.msproj.Referee;
|
|
||||||
|
|
||||||
import org.antlr.runtime.RecognitionException;
|
|
||||||
import org.encog.ml.data.MLData;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
public class WinFilterTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLearnSaveLoad() throws IOException, RecognitionException {
|
|
||||||
File[] sgfFiles = new File("data/games/random_vs_random")
|
|
||||||
.listFiles(new FileFilter() {
|
|
||||||
@Override
|
|
||||||
public boolean accept(File pathname) {
|
|
||||||
return pathname.getName().endsWith(".sgf");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
|
|
||||||
|
|
||||||
for (File file : sgfFiles) {
|
|
||||||
FileInputStream fis = new FileInputStream(file);
|
|
||||||
GameRecord gameRecord = Referee.replay(fis);
|
|
||||||
|
|
||||||
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
|
|
||||||
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
|
|
||||||
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
trainingData.add(gameData);
|
|
||||||
|
|
||||||
fis.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
WinFilter winFilter = new WinFilter();
|
|
||||||
|
|
||||||
winFilter.learn(trainingData);
|
|
||||||
|
|
||||||
for (List<MLDataPair> trainingSequence : trainingData) {
|
|
||||||
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
|
||||||
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
MLData input = trainingSequence.get(stateIndex).getInput();
|
|
||||||
|
|
||||||
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
|
||||||
+ winFilter.compute(input));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,10 +1,19 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.File;
|
import static org.junit.Assert.assertTrue;
|
||||||
import java.io.IOException;
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann.NNData;
|
||||||
|
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||||
|
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
|
||||||
|
import net.woodyfolsom.msproj.ann.XORFilter;
|
||||||
|
|
||||||
import org.encog.ml.data.MLDataSet;
|
|
||||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
|
||||||
import org.junit.AfterClass;
|
import org.junit.AfterClass;
|
||||||
import org.junit.BeforeClass;
|
import org.junit.BeforeClass;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
@@ -29,14 +38,55 @@ public class XORFilterTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLearnSaveLoad() throws IOException {
|
public void testLearn() throws IOException {
|
||||||
NeuralNetFilter nnLearner = new XORFilter();
|
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
|
||||||
|
|
||||||
// create training set (logical XOR function)
|
// create training set (logical XOR function)
|
||||||
int size = 1;
|
int size = 1;
|
||||||
double[][] trainingInput = new double[4 * size][];
|
double[][] trainingInput = new double[4 * size][];
|
||||||
double[][] trainingOutput = new double[4 * size][];
|
double[][] trainingOutput = new double[4 * size][];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||||
|
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||||
|
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||||
|
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||||
|
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||||
|
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||||
|
}
|
||||||
|
|
||||||
|
// create training data
|
||||||
|
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||||
|
String[] inputNames = new String[] {"x","y"};
|
||||||
|
String[] outputNames = new String[] {"XOR"};
|
||||||
|
for (int i = 0; i < 4*size; i++) {
|
||||||
|
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||||
|
}
|
||||||
|
|
||||||
|
nnLearner.setMaxTrainingEpochs(20000);
|
||||||
|
nnLearner.learnPatterns(trainingSet);
|
||||||
|
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||||
|
|
||||||
|
double[][] validationSet = new double[4][2];
|
||||||
|
|
||||||
|
validationSet[0] = new double[] { 0, 0 };
|
||||||
|
validationSet[1] = new double[] { 0, 1 };
|
||||||
|
validationSet[2] = new double[] { 1, 0 };
|
||||||
|
validationSet[3] = new double[] { 1, 1 };
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network):");
|
||||||
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearnSaveLoad() throws IOException {
|
||||||
|
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||||
|
|
||||||
|
// create training set (logical XOR function)
|
||||||
|
int size = 2;
|
||||||
|
double[][] trainingInput = new double[4 * size][];
|
||||||
|
double[][] trainingOutput = new double[4 * size][];
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||||
@@ -49,10 +99,17 @@ public class XORFilterTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create training data
|
// create training data
|
||||||
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
|
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||||
|
String[] inputNames = new String[] {"x","y"};
|
||||||
|
String[] outputNames = new String[] {"XOR"};
|
||||||
|
for (int i = 0; i < 4*size; i++) {
|
||||||
|
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||||
|
}
|
||||||
|
|
||||||
|
nnLearner.setMaxTrainingEpochs(1);
|
||||||
|
nnLearner.learnPatterns(trainingSet);
|
||||||
|
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||||
|
|
||||||
nnLearner.learn(trainingSet);
|
|
||||||
|
|
||||||
double[][] validationSet = new double[4][2];
|
double[][] validationSet = new double[4][2];
|
||||||
|
|
||||||
validationSet[0] = new double[] { 0, 0 };
|
validationSet[0] = new double[] { 0, 0 };
|
||||||
@@ -61,18 +118,23 @@ public class XORFilterTest {
|
|||||||
validationSet[3] = new double[] { 1, 1 };
|
validationSet[3] = new double[] { 1, 1 };
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||||
testNetwork(nnLearner, validationSet);
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
|
|
||||||
nnLearner.save(FILENAME);
|
|
||||||
nnLearner.load(FILENAME);
|
|
||||||
|
|
||||||
|
FileOutputStream fos = new FileOutputStream(FILENAME);
|
||||||
|
assertTrue(nnLearner.save(fos));
|
||||||
|
fos.close();
|
||||||
|
|
||||||
|
FileInputStream fis = new FileInputStream(FILENAME);
|
||||||
|
assertTrue(nnLearner.load(fis));
|
||||||
|
fis.close();
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||||
testNetwork(nnLearner, validationSet);
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
||||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||||
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann.math;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||||
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
public class XORFilterTest {
|
|
||||||
private static final String FILENAME = "xorPerceptron.net";
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void deleteNewNet() {
|
|
||||||
File file = new File(FILENAME);
|
|
||||||
if (file.exists()) {
|
|
||||||
file.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void deleteSavedNet() {
|
|
||||||
File file = new File(FILENAME);
|
|
||||||
if (file.exists()) {
|
|
||||||
file.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLearn() throws IOException {
|
|
||||||
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
|
|
||||||
|
|
||||||
// create training set (logical XOR function)
|
|
||||||
int size = 1;
|
|
||||||
double[][] trainingInput = new double[4 * size][];
|
|
||||||
double[][] trainingOutput = new double[4 * size][];
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
|
||||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
|
||||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
|
||||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
|
||||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
|
||||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
|
||||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
|
||||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
|
||||||
}
|
|
||||||
|
|
||||||
// create training data
|
|
||||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
|
||||||
String[] inputNames = new String[] {"x","y"};
|
|
||||||
String[] outputNames = new String[] {"XOR"};
|
|
||||||
for (int i = 0; i < 4*size; i++) {
|
|
||||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
|
||||||
}
|
|
||||||
|
|
||||||
nnLearner.setMaxTrainingEpochs(20000);
|
|
||||||
nnLearner.learn(trainingSet);
|
|
||||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
|
||||||
|
|
||||||
double[][] validationSet = new double[4][2];
|
|
||||||
|
|
||||||
validationSet[0] = new double[] { 0, 0 };
|
|
||||||
validationSet[1] = new double[] { 0, 1 };
|
|
||||||
validationSet[2] = new double[] { 1, 0 };
|
|
||||||
validationSet[3] = new double[] { 1, 1 };
|
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network):");
|
|
||||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLearnSaveLoad() throws IOException {
|
|
||||||
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
|
||||||
|
|
||||||
// create training set (logical XOR function)
|
|
||||||
int size = 2;
|
|
||||||
double[][] trainingInput = new double[4 * size][];
|
|
||||||
double[][] trainingOutput = new double[4 * size][];
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
|
||||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
|
||||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
|
||||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
|
||||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
|
||||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
|
||||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
|
||||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
|
||||||
}
|
|
||||||
|
|
||||||
// create training data
|
|
||||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
|
||||||
String[] inputNames = new String[] {"x","y"};
|
|
||||||
String[] outputNames = new String[] {"XOR"};
|
|
||||||
for (int i = 0; i < 4*size; i++) {
|
|
||||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
|
||||||
}
|
|
||||||
|
|
||||||
nnLearner.setMaxTrainingEpochs(1);
|
|
||||||
nnLearner.learn(trainingSet);
|
|
||||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
|
||||||
|
|
||||||
double[][] validationSet = new double[4][2];
|
|
||||||
|
|
||||||
validationSet[0] = new double[] { 0, 0 };
|
|
||||||
validationSet[1] = new double[] { 0, 1 };
|
|
||||||
validationSet[2] = new double[] { 1, 0 };
|
|
||||||
validationSet[3] = new double[] { 1, 1 };
|
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
|
||||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
|
||||||
|
|
||||||
FileOutputStream fos = new FileOutputStream(FILENAME);
|
|
||||||
assertTrue(nnLearner.save(fos));
|
|
||||||
fos.close();
|
|
||||||
|
|
||||||
FileInputStream fis = new FileInputStream(FILENAME);
|
|
||||||
assertTrue(nnLearner.load(fis));
|
|
||||||
fis.close();
|
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
|
||||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
|
||||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
|
||||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
|
||||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
73
test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
Normal file
73
test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||||
|
|
||||||
|
public class GameRecordTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetResultXwins() {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
|
||||||
|
State finalState = gameRecord.getState();
|
||||||
|
System.out.println("Final state:");
|
||||||
|
System.out.println(finalState);
|
||||||
|
assertTrue(finalState.isValid());
|
||||||
|
assertTrue(finalState.isTerminal());
|
||||||
|
assertTrue(finalState.isWinner(PLAYER.X));
|
||||||
|
assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetResultOwins() {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
|
||||||
|
|
||||||
|
State finalState = gameRecord.getState();
|
||||||
|
System.out.println("Final state:");
|
||||||
|
System.out.println(finalState);
|
||||||
|
assertTrue(finalState.isValid());
|
||||||
|
assertTrue(finalState.isTerminal());
|
||||||
|
assertTrue(finalState.isWinner(PLAYER.O));
|
||||||
|
assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetResultTieGame() {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
|
||||||
|
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
|
||||||
|
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
|
||||||
|
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
|
||||||
|
|
||||||
|
State finalState = gameRecord.getState();
|
||||||
|
System.out.println("Final state:");
|
||||||
|
System.out.println(finalState);
|
||||||
|
assertTrue(finalState.isValid());
|
||||||
|
assertTrue(finalState.isTerminal());
|
||||||
|
assertFalse(finalState.isWinner(PLAYER.X));
|
||||||
|
assertFalse(finalState.isWinner(PLAYER.O));
|
||||||
|
assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
|
||||||
|
}
|
||||||
|
}
|
||||||
12
test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
Normal file
12
test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package net.woodyfolsom.msproj.tictactoe;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class RefereeTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPlay100Games() {
|
||||||
|
new Referee().play(100);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
129
ttt.net
Normal file
129
ttt.net
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||||
|
<multiLayerPerceptron biased="true" name="TicTacToe">
|
||||||
|
<activationFunction name="Sigmoid"/>
|
||||||
|
<connections dest="10" src="0" weight="0.5827629317852295"/>
|
||||||
|
<connections dest="10" src="1" weight="0.49198735902918994"/>
|
||||||
|
<connections dest="10" src="2" weight="-0.3019566272377494"/>
|
||||||
|
<connections dest="10" src="3" weight="0.42204442000472525"/>
|
||||||
|
<connections dest="10" src="4" weight="-0.26015075178733194"/>
|
||||||
|
<connections dest="10" src="5" weight="-0.001558299861060293"/>
|
||||||
|
<connections dest="10" src="6" weight="0.07987916348233416"/>
|
||||||
|
<connections dest="10" src="7" weight="0.07258122647153753"/>
|
||||||
|
<connections dest="10" src="8" weight="-0.691045501522254"/>
|
||||||
|
<connections dest="10" src="9" weight="0.7118463494749109"/>
|
||||||
|
<connections dest="11" src="0" weight="-1.8387878977128804"/>
|
||||||
|
<connections dest="11" src="1" weight="0.07066242812415906"/>
|
||||||
|
<connections dest="11" src="2" weight="-0.2141385079779094"/>
|
||||||
|
<connections dest="11" src="3" weight="0.02318115051417748"/>
|
||||||
|
<connections dest="11" src="4" weight="-0.4940158494633454"/>
|
||||||
|
<connections dest="11" src="5" weight="0.24951794707397953"/>
|
||||||
|
<connections dest="11" src="6" weight="-0.3422002057868113"/>
|
||||||
|
<connections dest="11" src="7" weight="-0.34896333718320666"/>
|
||||||
|
<connections dest="11" src="8" weight="0.18236809262087086"/>
|
||||||
|
<connections dest="11" src="9" weight="-0.39168932467050466"/>
|
||||||
|
<connections dest="12" src="0" weight="1.5206290139263101"/>
|
||||||
|
<connections dest="12" src="1" weight="-0.4806468102477885"/>
|
||||||
|
<connections dest="12" src="2" weight="0.21439697155823853"/>
|
||||||
|
<connections dest="12" src="3" weight="0.1226010537695569"/>
|
||||||
|
<connections dest="12" src="4" weight="-0.2957055657777683"/>
|
||||||
|
<connections dest="12" src="5" weight="0.6130228290778311"/>
|
||||||
|
<connections dest="12" src="6" weight="0.36875530286236485"/>
|
||||||
|
<connections dest="12" src="7" weight="-0.5171899914088294"/>
|
||||||
|
<connections dest="12" src="8" weight="0.10837708801339006"/>
|
||||||
|
<connections dest="12" src="9" weight="-0.7053746937035315"/>
|
||||||
|
<connections dest="13" src="0" weight="0.002913660858364482"/>
|
||||||
|
<connections dest="13" src="1" weight="-0.7651207747987173"/>
|
||||||
|
<connections dest="13" src="2" weight="0.9715970070491731"/>
|
||||||
|
<connections dest="13" src="3" weight="-0.9956453258174628"/>
|
||||||
|
<connections dest="13" src="4" weight="-0.9408358352747842"/>
|
||||||
|
<connections dest="13" src="5" weight="-1.008966493202113"/>
|
||||||
|
<connections dest="13" src="6" weight="-0.672355054680489"/>
|
||||||
|
<connections dest="13" src="7" weight="-0.3367206164565582"/>
|
||||||
|
<connections dest="13" src="8" weight="0.7588693137687637"/>
|
||||||
|
<connections dest="13" src="9" weight="-0.7196453490945308"/>
|
||||||
|
<connections dest="14" src="0" weight="-1.9439726796836931"/>
|
||||||
|
<connections dest="14" src="1" weight="-0.2894027034518325"/>
|
||||||
|
<connections dest="14" src="2" weight="0.2110335238178935"/>
|
||||||
|
<connections dest="14" src="3" weight="-0.009846640898758158"/>
|
||||||
|
<connections dest="14" src="4" weight="0.1568088381509006"/>
|
||||||
|
<connections dest="14" src="5" weight="-0.18073468038735682"/>
|
||||||
|
<connections dest="14" src="6" weight="0.3823096688264287"/>
|
||||||
|
<connections dest="14" src="7" weight="-0.21319807548539116"/>
|
||||||
|
<connections dest="14" src="8" weight="-0.3736851760400955"/>
|
||||||
|
<connections dest="14" src="9" weight="-0.10659568761110778"/>
|
||||||
|
<connections dest="15" src="0" weight="-3.5802003342217197"/>
|
||||||
|
<connections dest="15" src="10" weight="-0.520010988494904"/>
|
||||||
|
<connections dest="15" src="11" weight="2.0607479402794953"/>
|
||||||
|
<connections dest="15" src="12" weight="-1.3810086619100004"/>
|
||||||
|
<connections dest="15" src="13" weight="-0.024645797466295187"/>
|
||||||
|
<connections dest="15" src="14" weight="2.4372644169618125"/>
|
||||||
|
<neurons id="0">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="1">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="2">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="3">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="4">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="5">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="6">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="7">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="8">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="9">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="10">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="11">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="12">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="13">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="14">
|
||||||
|
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||||
|
</neurons>
|
||||||
|
<neurons id="15">
|
||||||
|
<activationFunction name="Sigmoid"/>
|
||||||
|
</neurons>
|
||||||
|
<layers>
|
||||||
|
<neuronIds>1</neuronIds>
|
||||||
|
<neuronIds>2</neuronIds>
|
||||||
|
<neuronIds>3</neuronIds>
|
||||||
|
<neuronIds>4</neuronIds>
|
||||||
|
<neuronIds>5</neuronIds>
|
||||||
|
<neuronIds>6</neuronIds>
|
||||||
|
<neuronIds>7</neuronIds>
|
||||||
|
<neuronIds>8</neuronIds>
|
||||||
|
<neuronIds>9</neuronIds>
|
||||||
|
</layers>
|
||||||
|
<layers>
|
||||||
|
<neuronIds>10</neuronIds>
|
||||||
|
<neuronIds>11</neuronIds>
|
||||||
|
<neuronIds>12</neuronIds>
|
||||||
|
<neuronIds>13</neuronIds>
|
||||||
|
<neuronIds>14</neuronIds>
|
||||||
|
</layers>
|
||||||
|
<layers>
|
||||||
|
<neuronIds>15</neuronIds>
|
||||||
|
</layers>
|
||||||
|
</multiLayerPerceptron>
|
||||||
Reference in New Issue
Block a user