Functional MLP for XOR toy problem.

This commit is contained in:
2012-11-24 22:20:41 -05:00
parent 874847f41b
commit 790b5666a8
26 changed files with 1109 additions and 217 deletions

View File

@@ -0,0 +1,84 @@
package net.woodyfolsom.msproj.ann2;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
private final FeedforwardNetwork neuralNetwork;
private final TrainingMethod trainingMethod;
private double maxError;
private int actualTrainingEpochs = 0;
private int maxTrainingEpochs;
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
this.neuralNetwork = neuralNetwork;
this.trainingMethod = trainingMethod;
this.maxError = maxError;
this.maxTrainingEpochs = maxTrainingEpochs;
}
@Override
public NNData compute(NNDataPair input) {
return this.neuralNetwork.compute(input);
}
public int getActualTrainingEpochs() {
return actualTrainingEpochs;
}
@Override
public int getInputSize() {
return 2;
}
public int getMaxTrainingEpochs() {
return maxTrainingEpochs;
}
protected FeedforwardNetwork getNeuralNetwork() {
return neuralNetwork;
}
@Override
public void learn(List<NNDataPair> trainingSet) {
actualTrainingEpochs = 0;
double error;
neuralNetwork.initWeights();
error = trainingMethod.computeError(neuralNetwork,trainingSet);
if (error <= maxError) {
System.out.println("Initial error: " + error);
return;
}
do {
trainingMethod.iterate(neuralNetwork,trainingSet);
error = trainingMethod.computeError(neuralNetwork,trainingSet);
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ error);
actualTrainingEpochs++;
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
}
@Override
public boolean load(InputStream input) {
return neuralNetwork.load(input);
}
@Override
public boolean save(OutputStream output) {
return neuralNetwork.save(output);
}
public void setMaxError(double maxError) {
this.maxError = maxError;
}
public void setMaxTrainingEpochs(int max) {
this.maxTrainingEpochs = max;
}
}

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann2;
public interface ActivationFunction {
double calculate(double arg);
}

View File

@@ -0,0 +1,120 @@
package net.woodyfolsom.msproj.ann2;
import java.util.List;
import net.woodyfolsom.msproj.ann2.math.ErrorFunction;
import net.woodyfolsom.msproj.ann2.math.MSSE;
public class BackPropagation implements TrainingMethod {
private final ErrorFunction errorFunction;
private final double learningRate;
private final double momentum;
public BackPropagation(double learningRate, double momentum) {
this.errorFunction = MSSE.function;
this.learningRate = learningRate;
this.momentum = momentum;
}
@Override
public void iterate(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
System.out.println("Learningrate: " + learningRate);
System.out.println("Momentum: " + momentum);
//zeroErrors(neuralNetwork);
for (NNDataPair trainingPair : trainingSet) {
zeroErrors(neuralNetwork);
System.out.println("Training with: " + trainingPair.getInput());
NNData ideal = trainingPair.getIdeal();
NNData actual = neuralNetwork.compute(trainingPair);
System.out.println("Updating weights. Ideal Output: " + ideal);
System.out.println("Actual Output: " + actual);
updateErrors(neuralNetwork, ideal);
updateWeights(neuralNetwork);
}
//updateWeights(neuralNetwork);
}
@Override
public double computeError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
int numDataPairs = trainingSet.size();
int outputSize = neuralNetwork.getOutput().length;
int totalOutputSize = outputSize * numDataPairs;
double[] actuals = new double[totalOutputSize];
double[] ideals = new double[totalOutputSize];
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
NNDataPair nnDataPair = trainingSet.get(dataPair);
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
.getValues());
double[] ideal = nnDataPair.getIdeal().getValues();
int offset = dataPair * outputSize;
System.arraycopy(actual, 0, actuals, offset, outputSize);
System.arraycopy(ideal, 0, ideals, offset, outputSize);
}
double MSSE = errorFunction.compute(ideals, actuals);
return MSSE;
}
private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) {
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
double[] idealValues = ideal.getValues();
for (int i = 0; i < idealValues.length; i++) {
double output = outputNeurons[i].getOutput();
double derivative = outputNeurons[i].getActivationFunction()
.derivative(output);
outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output));
}
// walking down the list of Neurons in reverse order, propagate the
// error
Neuron[] neurons = neuralNetwork.getNeurons();
for (int n = neurons.length - 1; n >= 0; n--) {
Neuron neuron = neurons[n];
double error = neuron.getError();
Connection[] connectionsFromN = neuralNetwork
.getConnectionsFrom(neuron.getId());
if (connectionsFromN.length > 0) {
double derivative = neuron.getActivationFunction().derivative(
neuron.getOutput());
for (Connection connection : connectionsFromN) {
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError();
}
}
neuron.setError(error);
}
}
private void updateWeights(FeedforwardNetwork neuralNetwork) {
for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError();
//TODO allow for momentum
//double lastDelta = connection.getLastDelta();
connection.addDelta(delta);
}
}
private void zeroErrors(FeedforwardNetwork neuralNetwork) {
// Set output errors relative to ideals, all other errors to 0.
for (Neuron neuron : neuralNetwork.getNeurons()) {
neuron.setError(0.0);
}
}
}

View File

@@ -0,0 +1,94 @@
package net.woodyfolsom.msproj.ann2;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlTransient;
public class Connection {
private int src;
private int dest;
private double weight;
private transient double lastDelta = 0.0;
public Connection() {
//no-arg constructor for JAXB
}
public Connection(int src, int dest, double weight) {
this.src = src;
this.dest = dest;
this.weight = weight;
}
public void addDelta(double delta) {
this.weight += delta;
this.lastDelta = delta;
}
@XmlAttribute
public int getDest() {
return dest;
}
@XmlTransient
public double getLastDelta() {
return lastDelta;
}
@XmlAttribute
public int getSrc() {
return src;
}
@XmlAttribute
public double getWeight() {
return weight;
}
public void setDest(int dest) {
this.dest = dest;
}
public void setSrc(int src) {
this.src = src;
}
public void setWeight(double weight) {
this.weight = weight;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + dest;
result = prime * result + src;
long temp;
temp = Double.doubleToLongBits(weight);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Connection other = (Connection) obj;
if (dest != other.dest)
return false;
if (src != other.src)
return false;
if (Double.doubleToLongBits(weight) != Double
.doubleToLongBits(other.weight))
return false;
return true;
}
@Override
public String toString() {
return "Connection(" + src + ", " + dest +"), weight: " + weight;
}
}

View File

@@ -0,0 +1,291 @@
package net.woodyfolsom.msproj.ann2;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Linear;
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
public abstract class FeedforwardNetwork {
private ActivationFunction activationFunction;
private boolean biased;
private List<Connection> connections;
private List<Neuron> neurons;
private String name;
private transient int biasNeuronId;
private transient Map<Integer, List<Connection>> connectionsFrom;
private transient Map<Integer, List<Connection>> connectionsTo;
public FeedforwardNetwork() {
this(false);
}
public FeedforwardNetwork(boolean biased) {
//No-arg constructor for JAXB
this.activationFunction = Sigmoid.function;
this.connections = new ArrayList<Connection>();
this.connectionsFrom = new HashMap<Integer,List<Connection>>();
this.connectionsTo = new HashMap<Integer,List<Connection>>();
this.neurons = new ArrayList<Neuron>();
this.name = "UNDEFINED";
this.biasNeuronId = -1;
setBiased(biased);
}
public void addConnection(Connection connection) {
connections.add(connection);
int src = connection.getSrc();
int dest = connection.getDest();
if (!connectionsFrom.containsKey(src)) {
connectionsFrom.put(src, new ArrayList<Connection>());
}
if (!connectionsTo.containsKey(dest)) {
connectionsTo.put(dest, new ArrayList<Connection>());
}
connectionsFrom.get(src).add(connection);
connectionsTo.get(dest).add(connection);
}
public NNData compute(NNDataPair nnDataPair) {
NNData actual = new NNData(nnDataPair.getIdeal().getFields(),
compute(nnDataPair.getInput().getValues()));
return actual;
}
public double[] compute(double[] input) {
zeroInputs();
setInput(input);
feedforward();
return getOutput();
}
void createBiasConnection(int neuronId, double weight) {
if (!biased) {
throw new UnsupportedOperationException("Not a biased network");
}
addConnection(new Connection(biasNeuronId, neuronId, weight));
}
/**
* Adds a new neuron with a unique id to this FeedforwardNetwork.
* @return
*/
Neuron createNeuron(boolean input) {
Neuron neuron;
if (input) {
neuron = new Neuron(Linear.function, neurons.size());
} else {
neuron = new Neuron(activationFunction, neurons.size());
}
neurons.add(neuron);
return neuron;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
FeedforwardNetwork other = (FeedforwardNetwork) obj;
if (activationFunction == null) {
if (other.activationFunction != null)
return false;
} else if (!activationFunction.equals(other.activationFunction))
return false;
if (connections == null) {
if (other.connections != null)
return false;
} else if (!connections.equals(other.connections))
return false;
if (name == null) {
if (other.name != null)
return false;
} else if (!name.equals(other.name))
return false;
if (neurons == null) {
if (other.neurons != null)
return false;
} else if (!neurons.equals(other.neurons))
return false;
return true;
}
protected void feedforward() {
for (int i = 0; i < neurons.size(); i++) {
Neuron src = neurons.get(i);
for (Connection connection : getConnectionsFrom(src.getId())) {
Neuron dest = getNeuron(connection.getDest());
dest.addInput(src.getOutput() * connection.getWeight());
}
}
}
@XmlElement(type = Sigmoid.class)
public ActivationFunction getActivationFunction() {
return activationFunction;
}
protected abstract double[] getOutput();
protected abstract Neuron[] getOutputNeurons();
@XmlAttribute
public String getName() {
return name;
}
protected Neuron getNeuron(int id) {
return neurons.get(id);
}
@XmlElement
protected Connection[] getConnections() {
return connections.toArray(new Connection[connections.size()]);
}
protected Connection[] getConnectionsFrom(int neuronId) {
List<Connection> connList = connectionsFrom.get(neuronId);
if (connList == null) {
return new Connection[0];
} else {
return connList.toArray(new Connection[connList.size()]);
}
}
protected Connection[] getConnectionsTo(int neuronId) {
List<Connection> connList = connectionsTo.get(neuronId);
if (connList == null) {
return new Connection[0];
} else {
return connList.toArray(new Connection[connList.size()]);
}
}
@XmlAttribute
public boolean isBiased() {
return biased;
}
@XmlElement
protected Neuron[] getNeurons() {
return neurons.toArray(new Neuron[neurons.size()]);
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime
* result
+ ((activationFunction == null) ? 0 : activationFunction
.hashCode());
result = prime * result
+ ((connections == null) ? 0 : connections.hashCode());
result = prime * result + ((name == null) ? 0 : name.hashCode());
result = prime * result + ((neurons == null) ? 0 : neurons.hashCode());
return result;
}
public void initWeights() {
for (Connection connection : connections) {
connection.setWeight(1.0-Math.random()*2.0);
}
}
public abstract boolean load(InputStream is);
public abstract boolean save(OutputStream os);
public void setActivationFunction(ActivationFunction activationFunction) {
this.activationFunction = activationFunction;
}
public void setBiased(boolean biased) {
if (this.biased == biased) {
return;
}
this.biased = biased;
if (biased) {
Neuron biasNeuron = createNeuron(true);
biasNeuron.setInput(1.0);
biasNeuronId = biasNeuron.getId();
} else {
//This is an inefficient but concise way to remove all connections involving the bias Neuron
//from the global
//Remove all connections from biasId from this index
List<Connection> connectionsFromBias = connectionsFrom.remove(biasNeuronId);
//Remove all connections to all nodes from biasId from this index
for (Map.Entry<Integer,List<Connection>> mapEntry : connectionsTo.entrySet()) {
mapEntry.getValue().removeAll(connectionsFromBias);
}
//Finally, remove from the (serialized) list of non-indexed connections
connections.remove(connectionsFromBias);
biasNeuronId = -1;
}
}
protected void setConnections(Connection[] connections) {
this.connections.clear();
this.connectionsFrom.clear();
this.connectionsTo.clear();
for (Connection connection : connections) {
addConnection(connection);
}
}
protected abstract void setInput(double[] input);
public void setName(String name) {
this.name = name;
}
protected void setNeurons(Neuron[] neurons) {
this.neurons.clear();
for (Neuron neuron : neurons) {
this.neurons.add(neuron);
}
}
public void setWeights(double[] weights) {
if (weights.length != connections.size()) {
throw new IllegalArgumentException("# of weights must == # of connections");
}
for (int i = 0; i < connections.size(); i++) {
connections.get(i).setWeight(weights[i]);
}
}
protected void zeroInputs() {
for (Neuron neuron : neurons) {
if (neuron.getId() != biasNeuronId){
neuron.setInput(0.0);
}
}
}
}

View File

@@ -2,29 +2,48 @@ package net.woodyfolsom.msproj.ann2;
import java.util.Arrays; import java.util.Arrays;
import javax.xml.bind.annotation.XmlElement;
public class Layer { public class Layer {
private Neuron[] neurons; private int[] neuronIds;
public Layer() { public Layer() {
//default constructor for JAXB neuronIds = new int[0];
} }
public Layer(int numNeurons, int numWeights, ActivationFunction activationFunction) { public Layer(int numNeurons) {
neurons = new Neuron[numNeurons]; neuronIds = new int[numNeurons];
for (int neuronIndex = 0; neuronIndex < numNeurons; neuronIndex++) {
neurons[neuronIndex] = new Neuron(activationFunction, numWeights);
}
} }
public int size() { public int size() {
return neurons.length; return neuronIds.length;
}
public int getNeuronId(int index) {
return neuronIds[index];
}
@XmlElement
public int[] getNeuronIds() {
int[] safeCopy = new int[neuronIds.length];
System.arraycopy(neuronIds, 0, safeCopy, 0, neuronIds.length);
return safeCopy;
}
public void setNeuronId(int index, int id) {
neuronIds[index] = id;
}
public void setNeuronIds(int[] neuronIds) {
this.neuronIds = new int[neuronIds.length];
System.arraycopy(neuronIds, 0, this.neuronIds, 0, neuronIds.length);
} }
@Override @Override
public int hashCode() { public int hashCode() {
final int prime = 31; final int prime = 31;
int result = 1; int result = 1;
result = prime * result + Arrays.hashCode(neurons); result = prime * result + Arrays.hashCode(neuronIds);
return result; return result;
} }
@@ -37,17 +56,9 @@ public class Layer {
if (getClass() != obj.getClass()) if (getClass() != obj.getClass())
return false; return false;
Layer other = (Layer) obj; Layer other = (Layer) obj;
if (!Arrays.equals(neurons, other.neurons)) if (!Arrays.equals(neuronIds, other.neuronIds))
return false; return false;
return true; return true;
} }
public Neuron[] getNeurons() {
return neurons;
}
public void setNeurons(Neuron[] neurons) {
this.neurons = neurons;
}
} }

View File

@@ -2,19 +2,16 @@ package net.woodyfolsom.msproj.ann2;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.Arrays;
import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException; import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller; import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller; import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlRootElement; import javax.xml.bind.annotation.XmlRootElement;
@XmlRootElement @XmlRootElement
public class MultiLayerPerceptron extends NeuralNetwork { public class MultiLayerPerceptron extends FeedforwardNetwork {
private ActivationFunction activationFunction;
private boolean biased; private boolean biased;
private Layer[] layers; private Layer[] layers;
@@ -23,18 +20,16 @@ public class MultiLayerPerceptron extends NeuralNetwork {
} }
public MultiLayerPerceptron(boolean biased, int... layerSizes) { public MultiLayerPerceptron(boolean biased, int... layerSizes) {
super(biased);
int numLayers = layerSizes.length; int numLayers = layerSizes.length;
if (numLayers < 2) { if (numLayers < 2) {
throw new IllegalArgumentException("# of layers must be >= 2"); throw new IllegalArgumentException("# of layers must be >= 2");
} }
this.activationFunction = Sigmoid.function;
this.biased = biased;
this.layers = new Layer[numLayers]; this.layers = new Layer[numLayers];
int numWeights;
for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) { for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) {
int layerSize = layerSizes[layerIndex]; int layerSize = layerSizes[layerIndex];
@@ -42,23 +37,31 @@ public class MultiLayerPerceptron extends NeuralNetwork {
throw new IllegalArgumentException("Layer size must be >= 1"); throw new IllegalArgumentException("Layer size must be >= 1");
} }
if (layerIndex == 0) { Layer newLayer = createNewLayer(layerIndex, layerSize);
numWeights = 0;
if (layerIndex > 0) {
Layer prevLayer = layers[layerIndex - 1];
for (int j = 0; j < newLayer.size(); j++) {
if (biased) { if (biased) {
layerSize++; createBiasConnection(newLayer.getNeuronId(j),0.0);
}
for (int i = 0; i < prevLayer.size(); i++) {
addConnection(new Connection(prevLayer.getNeuronId(i),
newLayer.getNeuronId(j), 0.0));
}
} }
} else {
numWeights = layers[layerIndex - 1].size();
} }
layers[layerIndex] = new Layer(layerSize, numWeights,
activationFunction);
} }
} }
@XmlElement(type=Sigmoid.class) private Layer createNewLayer(int layerIndex, int layerSize) {
public ActivationFunction getActivationFunction() { Layer layer = new Layer(layerSize);
return activationFunction; layers[layerIndex] = layer;
for (int n = 0; n < layerSize; n++) {
Neuron neuron = createNeuron(layerIndex == 0);
layer.setNeuronId(n, neuron.getId());
}
return layer;
} }
@XmlElement @XmlElement
@@ -68,33 +71,30 @@ public class MultiLayerPerceptron extends NeuralNetwork {
@Override @Override
protected double[] getOutput() { protected double[] getOutput() {
// TODO Auto-generated method stub Layer outputLayer = layers[layers.length - 1];
return null; double output[] = new double[outputLayer.size()];
for (int n = 0; n < outputLayer.size(); n++) {
output[n] = getNeuron(outputLayer.getNeuronId(n)).getOutput();
}
return output;
} }
@Override @Override
protected Neuron[] getNeurons() { public Neuron[] getOutputNeurons() {
// TODO Auto-generated method stub Layer outputLayer = layers[layers.length - 1];
return null; Neuron[] outputNeurons = new Neuron[outputLayer.size()];
for (int i = 0; i < outputLayer.size(); i++) {
outputNeurons[i] = getNeuron(outputLayer.getNeuronId(i));
} }
return outputNeurons;
@XmlAttribute
public boolean isBiased() {
return biased;
}
public void setActivationFunction(ActivationFunction activationFunction) {
this.activationFunction = activationFunction;
} }
@Override @Override
protected void setInput(double[] input) { protected void setInput(double[] input) {
// TODO Auto-generated method stub Layer inputLayer = layers[0];
for (int n = 0; n < inputLayer.size(); n++) {
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
} }
public void setBiased(boolean biased) {
this.biased = biased;
} }
public void setLayers(Layer[] layers) { public void setLayers(Layer[] layers) {
@@ -111,7 +111,9 @@ public class MultiLayerPerceptron extends NeuralNetwork {
Unmarshaller u = jc.createUnmarshaller(); Unmarshaller u = jc.createUnmarshaller();
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is); MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);
this.activationFunction = mlp.activationFunction; super.setActivationFunction(mlp.getActivationFunction());
super.setConnections(mlp.getConnections());
super.setNeurons(mlp.getNeurons());
this.biased = mlp.biased; this.biased = mlp.biased;
this.layers = mlp.layers; this.layers = mlp.layers;
@@ -138,38 +140,4 @@ public class MultiLayerPerceptron extends NeuralNetwork {
return false; return false;
} }
} }
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime
* result
+ ((activationFunction == null) ? 0 : activationFunction
.hashCode());
result = prime * result + (biased ? 1231 : 1237);
result = prime * result + Arrays.hashCode(layers);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
MultiLayerPerceptron other = (MultiLayerPerceptron) obj;
if (activationFunction == null) {
if (other.activationFunction != null)
return false;
} else if (!activationFunction.equals(other.activationFunction))
return false;
if (biased != other.biased)
return false;
if (!Arrays.equals(layers, other.layers))
return false;
return true;
}
} }

View File

@@ -9,6 +9,15 @@ public class NNData {
this.values = values; this.values = values;
} }
public NNData(NNData that) {
this.fields = that.fields;
this.values = that.values;
}
public String[] getFields() {
return fields;
}
public double[] getValues() { public double[] getValues() {
return values; return values;
} }

View File

@@ -1,16 +1,16 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann2;
public class NNDataPair { public class NNDataPair {
private final NNData actual; private final NNData input;
private final NNData ideal; private final NNData ideal;
public NNDataPair(NNData actual, NNData ideal) { public NNDataPair(NNData actual, NNData ideal) {
this.actual = actual; this.input = actual;
this.ideal = ideal; this.ideal = ideal;
} }
public NNData getActual() { public NNData getInput() {
return actual; return input;
} }
public NNData getIdeal() { public NNData getIdeal() {

View File

@@ -0,0 +1,25 @@
package net.woodyfolsom.msproj.ann2;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
public interface NeuralNetFilter {
int getActualTrainingEpochs();
int getInputSize();
int getMaxTrainingEpochs();
int getOutputSize();
boolean load(InputStream input);
boolean save(OutputStream output);
void setMaxTrainingEpochs(int max);
NNData compute(NNDataPair input);
void learn(List<NNDataPair> trainingSet);
}

View File

@@ -1,53 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import java.io.InputStream;
import java.io.OutputStream;
import javax.xml.bind.JAXBException;
/**
* A NeuralNetwork is simply an ordered set of Neurons.
*
* Functions which rely on knowledge of input neurons, output neurons and layers
* are delegated to MultiLayerPerception.
*
* The primary function implemented in this abstract class is feedfoward.
* This function depends only on getNeurons() returning Neurons in feedforward order
* and the returned Neurons must have the correct number of weights for the NeuralNetwork
* configuration.
*
* @author Woody
*
*/
public abstract class NeuralNetwork {
public NeuralNetwork() {
}
public double[] calculate(double[] input) {
zeroInputs();
setInput(input);
feedforward();
return getOutput();
}
protected void feedforward() {
Neuron[] neurons = getNeurons();
}
protected abstract double[] getOutput();
protected abstract Neuron[] getNeurons();
public abstract boolean load(InputStream is);
public abstract boolean save(OutputStream os);
protected abstract void setInput(double[] input);
protected void zeroInputs() {
for (Neuron neuron : getNeurons()) {
neuron.setInput(0.0);
}
}
}

View File

@@ -1,24 +1,29 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann2;
import java.util.Arrays; import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlTransient; import javax.xml.bind.annotation.XmlTransient;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
public class Neuron { public class Neuron {
private ActivationFunction activationFunction; private ActivationFunction activationFunction;
private double[] weights; private int id;
private transient double input = 0.0; private transient double input = 0.0;
private transient double error = 0.0;
public Neuron() { public Neuron() {
//no-arg constructor for JAXB //no-arg constructor for JAXB
} }
public Neuron(ActivationFunction activationFunction, int numWeights) { public Neuron(ActivationFunction activationFunction, int id) {
this.activationFunction = activationFunction; this.activationFunction = activationFunction;
this.weights = new double[numWeights]; this.id = id;
}
public void addInput(double value) {
input += value;
} }
@XmlElement(type=Sigmoid.class) @XmlElement(type=Sigmoid.class)
@@ -26,11 +31,14 @@ public class Neuron {
return activationFunction; return activationFunction;
} }
void afterUnmarshal(Unmarshaller aUnmarshaller, Object aParent) @XmlAttribute
{ public int getId() {
if (weights == null) { return id;
weights = new double[0];
} }
@XmlTransient
public double getError() {
return error;
} }
@XmlTransient @XmlTransient
@@ -42,9 +50,8 @@ public class Neuron {
return activationFunction.calculate(input); return activationFunction.calculate(input);
} }
@XmlElement public void setError(double value) {
public double[] getWeights() { this.error = value;
return weights;
} }
public void setInput(double input) { public void setInput(double input) {
@@ -59,7 +66,6 @@ public class Neuron {
* result * result
+ ((activationFunction == null) ? 0 : activationFunction + ((activationFunction == null) ? 0 : activationFunction
.hashCode()); .hashCode());
result = prime * result + Arrays.hashCode(weights);
return result; return result;
} }
@@ -77,8 +83,6 @@ public class Neuron {
return false; return false;
} else if (!activationFunction.equals(other.activationFunction)) } else if (!activationFunction.equals(other.activationFunction))
return false; return false;
if (!Arrays.equals(weights, other.weights))
return false;
return true; return true;
} }
@@ -86,7 +90,9 @@ public class Neuron {
this.activationFunction = activationFunction; this.activationFunction = activationFunction;
} }
public void setWeights(double[] weights) { @Override
this.weights = weights; public String toString() {
return "Neuron #" + id +", input: " + input + ", error: " + error;
} }
} }

View File

@@ -1,10 +0,0 @@
package net.woodyfolsom.msproj.ann2;
public class Tanh implements ActivationFunction{
@Override
public double calculate(double arg) {
return Math.tanh(arg);
}
}

View File

@@ -0,0 +1,19 @@
package net.woodyfolsom.msproj.ann2;
import java.util.List;
public class TemporalDifference implements TrainingMethod {
@Override
public void iterate(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public double computeError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
throw new UnsupportedOperationException("Not implemented");
}
}

View File

@@ -0,0 +1,10 @@
package net.woodyfolsom.msproj.ann2;
import java.util.List;
public interface TrainingMethod {
void iterate(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
double computeError(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
}

View File

@@ -0,0 +1,44 @@
package net.woodyfolsom.msproj.ann2;
/**
* Based on sample code from http://neuroph.sourceforge.net
*
* @author Woody
*
*/
public class XORFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter {
private static final int INPUT_SIZE = 2;
private static final int OUTPUT_SIZE = 1;
public XORFilter() {
this(0.8,0.7);
}
public XORFilter(double learningRate, double momentum) {
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
new BackPropagation(learningRate, momentum), 1000, 0.01);
super.getNeuralNetwork().setName("XORFilter");
//TODO remove
//getNeuralNetwork().setWeights(new double[] {
// 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
// -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
// -0.993423, 0.164732, 0.752621}); //output
}
public double compute(double x, double y) {
return getNeuralNetwork().compute(new double[]{x,y})[0];
}
@Override
public int getInputSize() {
return INPUT_SIZE;
}
@Override
public int getOutputSize() {
return OUTPUT_SIZE;
}
}

View File

@@ -1,17 +1,22 @@
package net.woodyfolsom.msproj.ann2; package net.woodyfolsom.msproj.ann2.math;
public class Sigmoid implements ActivationFunction{ import javax.xml.bind.annotation.XmlAttribute;
public static final Sigmoid function = new Sigmoid();
public abstract class ActivationFunction {
private String name; private String name;
private Sigmoid() { public abstract double calculate(double arg);
this.name = "Sigmoid"; public abstract double derivative(double arg);
public ActivationFunction() {
//no-arg constructor for JAXB
} }
public double calculate(double arg) { public ActivationFunction(String name) {
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg)); this.name = name;
} }
@XmlAttribute
public String getName() { public String getName() {
return name; return name;
} }
@@ -19,7 +24,6 @@ public class Sigmoid implements ActivationFunction{
public void setName(String name) { public void setName(String name) {
this.name = name; this.name = name;
} }
@Override @Override
public int hashCode() { public int hashCode() {
final int prime = 31; final int prime = 31;
@@ -27,7 +31,6 @@ public class Sigmoid implements ActivationFunction{
result = prime * result + ((name == null) ? 0 : name.hashCode()); result = prime * result + ((name == null) ? 0 : name.hashCode());
return result; return result;
} }
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
if (this == obj) if (this == obj)
@@ -36,7 +39,7 @@ public class Sigmoid implements ActivationFunction{
return false; return false;
if (getClass() != obj.getClass()) if (getClass() != obj.getClass())
return false; return false;
Sigmoid other = (Sigmoid) obj; ActivationFunction other = (ActivationFunction) obj;
if (name == null) { if (name == null) {
if (other.name != null) if (other.name != null)
return false; return false;

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.ann2.math;
public interface ErrorFunction {
double compute(double[] ideal, double[] actual);
}

View File

@@ -0,0 +1,17 @@
package net.woodyfolsom.msproj.ann2.math;
public class Linear extends ActivationFunction{
public static final Linear function = new Linear();
private Linear() {
super("Linear");
}
public double calculate(double arg) {
return arg;
}
public double derivative(double arg) {
return 0;
}
}

View File

@@ -0,0 +1,22 @@
package net.woodyfolsom.msproj.ann2.math;
public class MSSE implements ErrorFunction{
public static final ErrorFunction function = new MSSE();
public double compute(double[] ideal, double[] actual) {
int idealSize = ideal.length;
int actualSize = actual.length;
if (idealSize != actualSize) {
throw new IllegalArgumentException("actualSize != idealSize");
}
double SSE = 0.0;
for (int i = 0; i < idealSize; i++) {
SSE += Math.pow(ideal[i] - actual[i], 2);
}
return SSE / idealSize;
}
}

View File

@@ -0,0 +1,20 @@
package net.woodyfolsom.msproj.ann2.math;
public class Sigmoid extends ActivationFunction{
public static final Sigmoid function = new Sigmoid();
private Sigmoid() {
super("Sigmoid");
}
public double calculate(double arg) {
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg));
}
public double derivative(double arg) {
//lol wth?
//double eX = Math.exp(arg);
//return eX / (Math.pow((1+eX), 2));
return arg - Math.pow(arg,2);
}
}

View File

@@ -0,0 +1,21 @@
package net.woodyfolsom.msproj.ann2.math;
public class Tanh extends ActivationFunction{
public static final Tanh function = new Tanh();
public Tanh() {
super("Tanh");
}
@Override
public double calculate(double arg) {
return Math.tanh(arg);
}
@Override
public double derivative(double arg) {
double tanh = Math.tanh(arg);
return 1 - Math.pow(tanh, 2);
}
}

View File

@@ -16,6 +16,7 @@ import org.junit.Test;
public class MultiLayerPerceptronTest { public class MultiLayerPerceptronTest {
static final File TEST_FILE = new File("data/test/mlp.net"); static final File TEST_FILE = new File("data/test/mlp.net");
static final double EPS = 0.001;
@BeforeClass @BeforeClass
public static void setUp() { public static void setUp() {
@@ -49,14 +50,47 @@ public class MultiLayerPerceptronTest {
@Test @Test
public void testPersistence() throws JAXBException, IOException { public void testPersistence() throws JAXBException, IOException {
NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1); FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
FileOutputStream fos = new FileOutputStream(TEST_FILE); FileOutputStream fos = new FileOutputStream(TEST_FILE);
assertTrue(mlp.save(fos)); assertTrue(mlp.save(fos));
fos.close(); fos.close();
FileInputStream fis = new FileInputStream(TEST_FILE); FileInputStream fis = new FileInputStream(TEST_FILE);
NeuralNetwork mlp2 = new MultiLayerPerceptron(); FeedforwardNetwork mlp2 = new MultiLayerPerceptron();
assertTrue(mlp2.load(fis)); assertTrue(mlp2.load(fis));
assertEquals(mlp, mlp2); assertEquals(mlp, mlp2);
fis.close(); fis.close();
} }
@Test
public void testCompute() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal(), actualOutput);
}
@Test
public void testXORnetwork() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
mlp.setWeights(new double[] {
0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
-0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
-0.993423, 0.164732, 0.752621}); //output
for (Connection connection : mlp.getConnections()) {
System.out.println(connection);
}
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
}
/**
*
Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123
Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932
Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621
*/
} }

View File

@@ -3,16 +3,27 @@ package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
import net.woodyfolsom.msproj.ann2.math.Tanh;
import org.junit.Test; import org.junit.Test;
public class SigmoidTest { public class SigmoidTest {
static double EPS = 0.001;
@Test @Test
public void testCalculate() { public void testCalculate() {
double EPS = 0.001;
ActivationFunction sigmoid = Sigmoid.function; ActivationFunction sigmoid = Sigmoid.function;
assertEquals(0.5,sigmoid.calculate(0.0),EPS); assertEquals(0.5,sigmoid.calculate(0.0),EPS);
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS); assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
assertTrue(sigmoid.calculate(-9000.0) < EPS); assertTrue(sigmoid.calculate(-9000.0) < EPS);
} }
@Test
public void testDerivative() {
ActivationFunction sigmoid = new Tanh();
assertEquals(1.0,sigmoid.derivative(0.0), EPS);
}
} }

View File

@@ -3,16 +3,26 @@ package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Tanh;
import org.junit.Test; import org.junit.Test;
public class TanhTest { public class TanhTest {
static double EPS = 0.001;
@Test @Test
public void testCalculate() { public void testCalculate() {
double EPS = 0.001;
ActivationFunction sigmoid = new Tanh(); ActivationFunction tanh = new Tanh();
assertEquals(0.0,sigmoid.calculate(0.0),EPS); assertEquals(0.0,tanh.calculate(0.0),EPS);
assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS); assertTrue(tanh.calculate(100.0) > 0.5 - EPS);
assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS); assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS);
}
@Test
public void testDerivative() {
ActivationFunction tanh = new Tanh();
assertEquals(1.0,tanh.derivative(0.0), EPS);
} }
} }

View File

@@ -0,0 +1,136 @@
package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class XORFilterTest {
private static final String FILENAME = "xorPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 2;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(1);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
FileOutputStream fos = new FileOutputStream(FILENAME);
assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
}