From 790b5666a87064c2da368b64e54d78fb812f9acc Mon Sep 17 00:00:00 2001 From: Woody Folsom Date: Sat, 24 Nov 2012 22:20:41 -0500 Subject: [PATCH] Functional MLP for XOR toy problem. --- .../msproj/ann2/AbstractNeuralNetFilter.java | 84 +++++ .../msproj/ann2/ActivationFunction.java | 5 - .../msproj/ann2/BackPropagation.java | 120 ++++++++ .../woodyfolsom/msproj/ann2/Connection.java | 94 ++++++ .../msproj/ann2/FeedforwardNetwork.java | 291 ++++++++++++++++++ src/net/woodyfolsom/msproj/ann2/Layer.java | 47 +-- .../msproj/ann2/MultiLayerPerceptron.java | 138 ++++----- src/net/woodyfolsom/msproj/ann2/NNData.java | 9 + .../woodyfolsom/msproj/ann2/NNDataPair.java | 8 +- .../msproj/ann2/NeuralNetFilter.java | 25 ++ .../msproj/ann2/NeuralNetwork.java | 53 ---- src/net/woodyfolsom/msproj/ann2/Neuron.java | 48 +-- src/net/woodyfolsom/msproj/ann2/Tanh.java | 10 - .../msproj/ann2/TemporalDifference.java | 19 ++ .../msproj/ann2/TrainingMethod.java | 10 + .../woodyfolsom/msproj/ann2/XORFilter.java | 44 +++ .../ActivationFunction.java} | 27 +- .../msproj/ann2/math/ErrorFunction.java | 5 + .../woodyfolsom/msproj/ann2/math/Linear.java | 17 + .../woodyfolsom/msproj/ann2/math/MSSE.java | 22 ++ .../woodyfolsom/msproj/ann2/math/Sigmoid.java | 20 ++ .../woodyfolsom/msproj/ann2/math/Tanh.java | 21 ++ .../msproj/ann2/MultiLayerPerceptronTest.java | 40 ++- .../woodyfolsom/msproj/ann2/SigmoidTest.java | 13 +- .../net/woodyfolsom/msproj/ann2/TanhTest.java | 20 +- .../msproj/ann2/XORFilterTest.java | 136 ++++++++ 26 files changed, 1109 insertions(+), 217 deletions(-) create mode 100644 src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java delete mode 100644 src/net/woodyfolsom/msproj/ann2/ActivationFunction.java create mode 100644 src/net/woodyfolsom/msproj/ann2/BackPropagation.java create mode 100644 src/net/woodyfolsom/msproj/ann2/Connection.java create mode 100644 src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java create mode 100644 src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java delete mode 100644 src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java delete mode 100644 src/net/woodyfolsom/msproj/ann2/Tanh.java create mode 100644 src/net/woodyfolsom/msproj/ann2/TemporalDifference.java create mode 100644 src/net/woodyfolsom/msproj/ann2/TrainingMethod.java create mode 100644 src/net/woodyfolsom/msproj/ann2/XORFilter.java rename src/net/woodyfolsom/msproj/ann2/{Sigmoid.java => math/ActivationFunction.java} (56%) create mode 100644 src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java create mode 100644 src/net/woodyfolsom/msproj/ann2/math/Linear.java create mode 100644 src/net/woodyfolsom/msproj/ann2/math/MSSE.java create mode 100644 src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java create mode 100644 src/net/woodyfolsom/msproj/ann2/math/Tanh.java create mode 100644 test/net/woodyfolsom/msproj/ann2/XORFilterTest.java diff --git a/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java new file mode 100644 index 0000000..d24bd06 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java @@ -0,0 +1,84 @@ +package net.woodyfolsom.msproj.ann2; + +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { + private final FeedforwardNetwork neuralNetwork; + private final TrainingMethod trainingMethod; + + private double maxError; + private int actualTrainingEpochs = 0; + private int maxTrainingEpochs; + + AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) { + this.neuralNetwork = neuralNetwork; + this.trainingMethod = trainingMethod; + this.maxError = maxError; + this.maxTrainingEpochs = maxTrainingEpochs; + } + + @Override + public NNData compute(NNDataPair input) { + return this.neuralNetwork.compute(input); + } + + public int getActualTrainingEpochs() { + return actualTrainingEpochs; + } + + @Override + public int getInputSize() { + return 2; + } + + public int getMaxTrainingEpochs() { + return maxTrainingEpochs; + } + + protected FeedforwardNetwork getNeuralNetwork() { + return neuralNetwork; + } + + @Override + public void learn(List trainingSet) { + actualTrainingEpochs = 0; + double error; + neuralNetwork.initWeights(); + + error = trainingMethod.computeError(neuralNetwork,trainingSet); + + if (error <= maxError) { + System.out.println("Initial error: " + error); + return; + } + + do { + trainingMethod.iterate(neuralNetwork,trainingSet); + error = trainingMethod.computeError(neuralNetwork,trainingSet); + System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + + error); + actualTrainingEpochs++; + System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error); + } while (error > maxError && actualTrainingEpochs < maxTrainingEpochs); + } + + @Override + public boolean load(InputStream input) { + return neuralNetwork.load(input); + } + + @Override + public boolean save(OutputStream output) { + return neuralNetwork.save(output); + } + + public void setMaxError(double maxError) { + this.maxError = maxError; + } + + public void setMaxTrainingEpochs(int max) { + this.maxTrainingEpochs = max; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/ActivationFunction.java b/src/net/woodyfolsom/msproj/ann2/ActivationFunction.java deleted file mode 100644 index ffe4454..0000000 --- a/src/net/woodyfolsom/msproj/ann2/ActivationFunction.java +++ /dev/null @@ -1,5 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -public interface ActivationFunction { - double calculate(double arg); -} diff --git a/src/net/woodyfolsom/msproj/ann2/BackPropagation.java b/src/net/woodyfolsom/msproj/ann2/BackPropagation.java new file mode 100644 index 0000000..8b40e1e --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/BackPropagation.java @@ -0,0 +1,120 @@ +package net.woodyfolsom.msproj.ann2; + +import java.util.List; + +import net.woodyfolsom.msproj.ann2.math.ErrorFunction; +import net.woodyfolsom.msproj.ann2.math.MSSE; + +public class BackPropagation implements TrainingMethod { + private final ErrorFunction errorFunction; + private final double learningRate; + private final double momentum; + + public BackPropagation(double learningRate, double momentum) { + this.errorFunction = MSSE.function; + this.learningRate = learningRate; + this.momentum = momentum; + } + + @Override + public void iterate(FeedforwardNetwork neuralNetwork, + List trainingSet) { + System.out.println("Learningrate: " + learningRate); + System.out.println("Momentum: " + momentum); + + //zeroErrors(neuralNetwork); + + for (NNDataPair trainingPair : trainingSet) { + zeroErrors(neuralNetwork); + + System.out.println("Training with: " + trainingPair.getInput()); + + NNData ideal = trainingPair.getIdeal(); + NNData actual = neuralNetwork.compute(trainingPair); + + System.out.println("Updating weights. Ideal Output: " + ideal); + System.out.println("Actual Output: " + actual); + + updateErrors(neuralNetwork, ideal); + + updateWeights(neuralNetwork); + } + + //updateWeights(neuralNetwork); + } + + @Override + public double computeError(FeedforwardNetwork neuralNetwork, + List trainingSet) { + int numDataPairs = trainingSet.size(); + int outputSize = neuralNetwork.getOutput().length; + int totalOutputSize = outputSize * numDataPairs; + + double[] actuals = new double[totalOutputSize]; + double[] ideals = new double[totalOutputSize]; + for (int dataPair = 0; dataPair < numDataPairs; dataPair++) { + NNDataPair nnDataPair = trainingSet.get(dataPair); + double[] actual = neuralNetwork.compute(nnDataPair.getInput() + .getValues()); + double[] ideal = nnDataPair.getIdeal().getValues(); + int offset = dataPair * outputSize; + + System.arraycopy(actual, 0, actuals, offset, outputSize); + System.arraycopy(ideal, 0, ideals, offset, outputSize); + } + + double MSSE = errorFunction.compute(ideals, actuals); + return MSSE; + } + + private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) { + Neuron[] outputNeurons = neuralNetwork.getOutputNeurons(); + double[] idealValues = ideal.getValues(); + + for (int i = 0; i < idealValues.length; i++) { + double output = outputNeurons[i].getOutput(); + double derivative = outputNeurons[i].getActivationFunction() + .derivative(output); + outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output)); + } + // walking down the list of Neurons in reverse order, propagate the + // error + Neuron[] neurons = neuralNetwork.getNeurons(); + + for (int n = neurons.length - 1; n >= 0; n--) { + + Neuron neuron = neurons[n]; + double error = neuron.getError(); + + Connection[] connectionsFromN = neuralNetwork + .getConnectionsFrom(neuron.getId()); + if (connectionsFromN.length > 0) { + + double derivative = neuron.getActivationFunction().derivative( + neuron.getOutput()); + for (Connection connection : connectionsFromN) { + error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError(); + } + } + neuron.setError(error); + } + } + + private void updateWeights(FeedforwardNetwork neuralNetwork) { + for (Connection connection : neuralNetwork.getConnections()) { + Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc()); + Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest()); + double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError(); + //TODO allow for momentum + //double lastDelta = connection.getLastDelta(); + connection.addDelta(delta); + } + } + + private void zeroErrors(FeedforwardNetwork neuralNetwork) { + // Set output errors relative to ideals, all other errors to 0. + for (Neuron neuron : neuralNetwork.getNeurons()) { + neuron.setError(0.0); + } + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Connection.java b/src/net/woodyfolsom/msproj/ann2/Connection.java new file mode 100644 index 0000000..8b38b1b --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/Connection.java @@ -0,0 +1,94 @@ +package net.woodyfolsom.msproj.ann2; + +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlTransient; + +public class Connection { + private int src; + private int dest; + private double weight; + private transient double lastDelta = 0.0; + + public Connection() { + //no-arg constructor for JAXB + } + + public Connection(int src, int dest, double weight) { + this.src = src; + this.dest = dest; + this.weight = weight; + } + + public void addDelta(double delta) { + this.weight += delta; + this.lastDelta = delta; + } + + @XmlAttribute + public int getDest() { + return dest; + } + + @XmlTransient + public double getLastDelta() { + return lastDelta; + } + + @XmlAttribute + public int getSrc() { + return src; + } + + @XmlAttribute + public double getWeight() { + return weight; + } + + public void setDest(int dest) { + this.dest = dest; + } + + public void setSrc(int src) { + this.src = src; + } + + public void setWeight(double weight) { + this.weight = weight; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + dest; + result = prime * result + src; + long temp; + temp = Double.doubleToLongBits(weight); + result = prime * result + (int) (temp ^ (temp >>> 32)); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Connection other = (Connection) obj; + if (dest != other.dest) + return false; + if (src != other.src) + return false; + if (Double.doubleToLongBits(weight) != Double + .doubleToLongBits(other.weight)) + return false; + return true; + } + + @Override + public String toString() { + return "Connection(" + src + ", " + dest +"), weight: " + weight; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java b/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java new file mode 100644 index 0000000..29e1f10 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java @@ -0,0 +1,291 @@ +package net.woodyfolsom.msproj.ann2; + +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlElement; + +import net.woodyfolsom.msproj.ann2.math.ActivationFunction; +import net.woodyfolsom.msproj.ann2.math.Linear; +import net.woodyfolsom.msproj.ann2.math.Sigmoid; + +public abstract class FeedforwardNetwork { + private ActivationFunction activationFunction; + private boolean biased; + private List connections; + private List neurons; + private String name; + + private transient int biasNeuronId; + private transient Map> connectionsFrom; + private transient Map> connectionsTo; + + public FeedforwardNetwork() { + this(false); + } + + public FeedforwardNetwork(boolean biased) { + //No-arg constructor for JAXB + this.activationFunction = Sigmoid.function; + this.connections = new ArrayList(); + this.connectionsFrom = new HashMap>(); + this.connectionsTo = new HashMap>(); + this.neurons = new ArrayList(); + this.name = "UNDEFINED"; + this.biasNeuronId = -1; + setBiased(biased); + } + + public void addConnection(Connection connection) { + connections.add(connection); + + int src = connection.getSrc(); + int dest = connection.getDest(); + + if (!connectionsFrom.containsKey(src)) { + connectionsFrom.put(src, new ArrayList()); + } + + if (!connectionsTo.containsKey(dest)) { + connectionsTo.put(dest, new ArrayList()); + } + + connectionsFrom.get(src).add(connection); + connectionsTo.get(dest).add(connection); + } + + public NNData compute(NNDataPair nnDataPair) { + NNData actual = new NNData(nnDataPair.getIdeal().getFields(), + compute(nnDataPair.getInput().getValues())); + return actual; + } + + public double[] compute(double[] input) { + zeroInputs(); + setInput(input); + feedforward(); + return getOutput(); + } + + void createBiasConnection(int neuronId, double weight) { + if (!biased) { + throw new UnsupportedOperationException("Not a biased network"); + } + addConnection(new Connection(biasNeuronId, neuronId, weight)); + } + + /** + * Adds a new neuron with a unique id to this FeedforwardNetwork. + * @return + */ + Neuron createNeuron(boolean input) { + Neuron neuron; + if (input) { + neuron = new Neuron(Linear.function, neurons.size()); + } else { + neuron = new Neuron(activationFunction, neurons.size()); + } + neurons.add(neuron); + return neuron; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + FeedforwardNetwork other = (FeedforwardNetwork) obj; + if (activationFunction == null) { + if (other.activationFunction != null) + return false; + } else if (!activationFunction.equals(other.activationFunction)) + return false; + if (connections == null) { + if (other.connections != null) + return false; + } else if (!connections.equals(other.connections)) + return false; + if (name == null) { + if (other.name != null) + return false; + } else if (!name.equals(other.name)) + return false; + if (neurons == null) { + if (other.neurons != null) + return false; + } else if (!neurons.equals(other.neurons)) + return false; + return true; + } + + protected void feedforward() { + for (int i = 0; i < neurons.size(); i++) { + Neuron src = neurons.get(i); + for (Connection connection : getConnectionsFrom(src.getId())) { + Neuron dest = getNeuron(connection.getDest()); + dest.addInput(src.getOutput() * connection.getWeight()); + } + } + } + + @XmlElement(type = Sigmoid.class) + public ActivationFunction getActivationFunction() { + return activationFunction; + } + + protected abstract double[] getOutput(); + protected abstract Neuron[] getOutputNeurons(); + + @XmlAttribute + public String getName() { + return name; + } + + protected Neuron getNeuron(int id) { + return neurons.get(id); + } + + @XmlElement + protected Connection[] getConnections() { + return connections.toArray(new Connection[connections.size()]); + } + + protected Connection[] getConnectionsFrom(int neuronId) { + List connList = connectionsFrom.get(neuronId); + + if (connList == null) { + return new Connection[0]; + } else { + return connList.toArray(new Connection[connList.size()]); + } + } + + protected Connection[] getConnectionsTo(int neuronId) { + List connList = connectionsTo.get(neuronId); + + if (connList == null) { + return new Connection[0]; + } else { + return connList.toArray(new Connection[connList.size()]); + } + } + + @XmlAttribute + public boolean isBiased() { + return biased; + } + + @XmlElement + protected Neuron[] getNeurons() { + return neurons.toArray(new Neuron[neurons.size()]); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime + * result + + ((activationFunction == null) ? 0 : activationFunction + .hashCode()); + result = prime * result + + ((connections == null) ? 0 : connections.hashCode()); + result = prime * result + ((name == null) ? 0 : name.hashCode()); + result = prime * result + ((neurons == null) ? 0 : neurons.hashCode()); + return result; + } + + public void initWeights() { + for (Connection connection : connections) { + connection.setWeight(1.0-Math.random()*2.0); + } + } + + public abstract boolean load(InputStream is); + + public abstract boolean save(OutputStream os); + + public void setActivationFunction(ActivationFunction activationFunction) { + this.activationFunction = activationFunction; + } + + public void setBiased(boolean biased) { + + if (this.biased == biased) { + return; + } + + this.biased = biased; + + if (biased) { + Neuron biasNeuron = createNeuron(true); + biasNeuron.setInput(1.0); + biasNeuronId = biasNeuron.getId(); + } else { + //This is an inefficient but concise way to remove all connections involving the bias Neuron + //from the global + + //Remove all connections from biasId from this index + List connectionsFromBias = connectionsFrom.remove(biasNeuronId); + + //Remove all connections to all nodes from biasId from this index + for (Map.Entry> mapEntry : connectionsTo.entrySet()) { + mapEntry.getValue().removeAll(connectionsFromBias); + } + + //Finally, remove from the (serialized) list of non-indexed connections + connections.remove(connectionsFromBias); + + biasNeuronId = -1; + } + } + + protected void setConnections(Connection[] connections) { + this.connections.clear(); + this.connectionsFrom.clear(); + this.connectionsTo.clear(); + for (Connection connection : connections) { + addConnection(connection); + } + } + + protected abstract void setInput(double[] input); + + public void setName(String name) { + this.name = name; + } + + protected void setNeurons(Neuron[] neurons) { + this.neurons.clear(); + for (Neuron neuron : neurons) { + this.neurons.add(neuron); + } + } + + public void setWeights(double[] weights) { + if (weights.length != connections.size()) { + throw new IllegalArgumentException("# of weights must == # of connections"); + } + + for (int i = 0; i < connections.size(); i++) { + connections.get(i).setWeight(weights[i]); + } + } + + protected void zeroInputs() { + for (Neuron neuron : neurons) { + if (neuron.getId() != biasNeuronId){ + neuron.setInput(0.0); + } + } + } + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Layer.java b/src/net/woodyfolsom/msproj/ann2/Layer.java index 3bd0830..88cb178 100644 --- a/src/net/woodyfolsom/msproj/ann2/Layer.java +++ b/src/net/woodyfolsom/msproj/ann2/Layer.java @@ -2,29 +2,48 @@ package net.woodyfolsom.msproj.ann2; import java.util.Arrays; +import javax.xml.bind.annotation.XmlElement; + public class Layer { - private Neuron[] neurons; + private int[] neuronIds; public Layer() { - //default constructor for JAXB + neuronIds = new int[0]; } - public Layer(int numNeurons, int numWeights, ActivationFunction activationFunction) { - neurons = new Neuron[numNeurons]; - for (int neuronIndex = 0; neuronIndex < numNeurons; neuronIndex++) { - neurons[neuronIndex] = new Neuron(activationFunction, numWeights); - } + public Layer(int numNeurons) { + neuronIds = new int[numNeurons]; } public int size() { - return neurons.length; + return neuronIds.length; + } + + public int getNeuronId(int index) { + return neuronIds[index]; + } + + @XmlElement + public int[] getNeuronIds() { + int[] safeCopy = new int[neuronIds.length]; + System.arraycopy(neuronIds, 0, safeCopy, 0, neuronIds.length); + return safeCopy; + } + + public void setNeuronId(int index, int id) { + neuronIds[index] = id; + } + + public void setNeuronIds(int[] neuronIds) { + this.neuronIds = new int[neuronIds.length]; + System.arraycopy(neuronIds, 0, this.neuronIds, 0, neuronIds.length); } @Override public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + Arrays.hashCode(neurons); + result = prime * result + Arrays.hashCode(neuronIds); return result; } @@ -37,17 +56,9 @@ public class Layer { if (getClass() != obj.getClass()) return false; Layer other = (Layer) obj; - if (!Arrays.equals(neurons, other.neurons)) + if (!Arrays.equals(neuronIds, other.neuronIds)) return false; return true; } - public Neuron[] getNeurons() { - return neurons; - } - - public void setNeurons(Neuron[] neurons) { - this.neurons = neurons; - } - } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java b/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java index 93ae15f..23f3dca 100644 --- a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java +++ b/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java @@ -2,19 +2,16 @@ package net.woodyfolsom.msproj.ann2; import java.io.InputStream; import java.io.OutputStream; -import java.util.Arrays; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBException; import javax.xml.bind.Marshaller; import javax.xml.bind.Unmarshaller; -import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlRootElement; @XmlRootElement -public class MultiLayerPerceptron extends NeuralNetwork { - private ActivationFunction activationFunction; +public class MultiLayerPerceptron extends FeedforwardNetwork { private boolean biased; private Layer[] layers; @@ -23,17 +20,15 @@ public class MultiLayerPerceptron extends NeuralNetwork { } public MultiLayerPerceptron(boolean biased, int... layerSizes) { + super(biased); + int numLayers = layerSizes.length; - + if (numLayers < 2) { throw new IllegalArgumentException("# of layers must be >= 2"); } - - this.activationFunction = Sigmoid.function; - this.biased = biased; - this.layers = new Layer[numLayers]; - int numWeights; + this.layers = new Layer[numLayers]; for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) { int layerSize = layerSizes[layerIndex]; @@ -41,66 +36,71 @@ public class MultiLayerPerceptron extends NeuralNetwork { if (layerSize < 1) { throw new IllegalArgumentException("Layer size must be >= 1"); } - - if (layerIndex == 0) { - numWeights = 0; - if (biased) { - layerSize++; - } - } else { - numWeights = layers[layerIndex - 1].size(); - } - layers[layerIndex] = new Layer(layerSize, numWeights, - activationFunction); + Layer newLayer = createNewLayer(layerIndex, layerSize); + + if (layerIndex > 0) { + Layer prevLayer = layers[layerIndex - 1]; + for (int j = 0; j < newLayer.size(); j++) { + if (biased) { + createBiasConnection(newLayer.getNeuronId(j),0.0); + } + for (int i = 0; i < prevLayer.size(); i++) { + addConnection(new Connection(prevLayer.getNeuronId(i), + newLayer.getNeuronId(j), 0.0)); + } + } + } } } - @XmlElement(type=Sigmoid.class) - public ActivationFunction getActivationFunction() { - return activationFunction; + private Layer createNewLayer(int layerIndex, int layerSize) { + Layer layer = new Layer(layerSize); + layers[layerIndex] = layer; + for (int n = 0; n < layerSize; n++) { + Neuron neuron = createNeuron(layerIndex == 0); + layer.setNeuronId(n, neuron.getId()); + } + return layer; } - + @XmlElement public Layer[] getLayers() { return layers; } - + @Override protected double[] getOutput() { - // TODO Auto-generated method stub - return null; + Layer outputLayer = layers[layers.length - 1]; + double output[] = new double[outputLayer.size()]; + for (int n = 0; n < outputLayer.size(); n++) { + output[n] = getNeuron(outputLayer.getNeuronId(n)).getOutput(); + } + return output; + } + + @Override + public Neuron[] getOutputNeurons() { + Layer outputLayer = layers[layers.length - 1]; + Neuron[] outputNeurons = new Neuron[outputLayer.size()]; + for (int i = 0; i < outputLayer.size(); i++) { + outputNeurons[i] = getNeuron(outputLayer.getNeuronId(i)); + } + return outputNeurons; } - @Override - protected Neuron[] getNeurons() { - // TODO Auto-generated method stub - return null; - } - - @XmlAttribute - public boolean isBiased() { - return biased; - } - - public void setActivationFunction(ActivationFunction activationFunction) { - this.activationFunction = activationFunction; - } - @Override protected void setInput(double[] input) { - // TODO Auto-generated method stub - + Layer inputLayer = layers[0]; + for (int n = 0; n < inputLayer.size(); n++) { + getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]); + } } - public void setBiased(boolean biased) { - this.biased = biased; - } - public void setLayers(Layer[] layers) { this.layers = layers; } - + @Override public boolean load(InputStream is) { try { @@ -111,7 +111,9 @@ public class MultiLayerPerceptron extends NeuralNetwork { Unmarshaller u = jc.createUnmarshaller(); MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is); - this.activationFunction = mlp.activationFunction; + super.setActivationFunction(mlp.getActivationFunction()); + super.setConnections(mlp.getConnections()); + super.setNeurons(mlp.getNeurons()); this.biased = mlp.biased; this.layers = mlp.layers; @@ -138,38 +140,4 @@ public class MultiLayerPerceptron extends NeuralNetwork { return false; } } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime - * result - + ((activationFunction == null) ? 0 : activationFunction - .hashCode()); - result = prime * result + (biased ? 1231 : 1237); - result = prime * result + Arrays.hashCode(layers); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - MultiLayerPerceptron other = (MultiLayerPerceptron) obj; - if (activationFunction == null) { - if (other.activationFunction != null) - return false; - } else if (!activationFunction.equals(other.activationFunction)) - return false; - if (biased != other.biased) - return false; - if (!Arrays.equals(layers, other.layers)) - return false; - return true; - } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/NNData.java b/src/net/woodyfolsom/msproj/ann2/NNData.java index b182765..12f2150 100644 --- a/src/net/woodyfolsom/msproj/ann2/NNData.java +++ b/src/net/woodyfolsom/msproj/ann2/NNData.java @@ -9,6 +9,15 @@ public class NNData { this.values = values; } + public NNData(NNData that) { + this.fields = that.fields; + this.values = that.values; + } + + public String[] getFields() { + return fields; + } + public double[] getValues() { return values; } diff --git a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java b/src/net/woodyfolsom/msproj/ann2/NNDataPair.java index b6f76ea..53d501d 100644 --- a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java +++ b/src/net/woodyfolsom/msproj/ann2/NNDataPair.java @@ -1,16 +1,16 @@ package net.woodyfolsom.msproj.ann2; public class NNDataPair { - private final NNData actual; + private final NNData input; private final NNData ideal; public NNDataPair(NNData actual, NNData ideal) { - this.actual = actual; + this.input = actual; this.ideal = ideal; } - public NNData getActual() { - return actual; + public NNData getInput() { + return input; } public NNData getIdeal() { diff --git a/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java new file mode 100644 index 0000000..71a2dde --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java @@ -0,0 +1,25 @@ +package net.woodyfolsom.msproj.ann2; + +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +public interface NeuralNetFilter { + int getActualTrainingEpochs(); + + int getInputSize(); + + int getMaxTrainingEpochs(); + + int getOutputSize(); + + boolean load(InputStream input); + + boolean save(OutputStream output); + + void setMaxTrainingEpochs(int max); + + NNData compute(NNDataPair input); + + void learn(List trainingSet); +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java b/src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java deleted file mode 100644 index 9679d6f..0000000 --- a/src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java +++ /dev/null @@ -1,53 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -import java.io.InputStream; -import java.io.OutputStream; - -import javax.xml.bind.JAXBException; - -/** - * A NeuralNetwork is simply an ordered set of Neurons. - * - * Functions which rely on knowledge of input neurons, output neurons and layers - * are delegated to MultiLayerPerception. - * - * The primary function implemented in this abstract class is feedfoward. - * This function depends only on getNeurons() returning Neurons in feedforward order - * and the returned Neurons must have the correct number of weights for the NeuralNetwork - * configuration. - * - * @author Woody - * - */ -public abstract class NeuralNetwork { - public NeuralNetwork() { - } - - public double[] calculate(double[] input) { - zeroInputs(); - setInput(input); - feedforward(); - return getOutput(); - } - - protected void feedforward() { - Neuron[] neurons = getNeurons(); - } - - protected abstract double[] getOutput(); - - protected abstract Neuron[] getNeurons(); - - public abstract boolean load(InputStream is); - public abstract boolean save(OutputStream os); - - protected abstract void setInput(double[] input); - - protected void zeroInputs() { - for (Neuron neuron : getNeurons()) { - neuron.setInput(0.0); - } - } - - -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Neuron.java b/src/net/woodyfolsom/msproj/ann2/Neuron.java index f32ea0a..4e76aa5 100644 --- a/src/net/woodyfolsom/msproj/ann2/Neuron.java +++ b/src/net/woodyfolsom/msproj/ann2/Neuron.java @@ -1,24 +1,29 @@ package net.woodyfolsom.msproj.ann2; -import java.util.Arrays; - -import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlTransient; +import net.woodyfolsom.msproj.ann2.math.ActivationFunction; +import net.woodyfolsom.msproj.ann2.math.Sigmoid; + public class Neuron { private ActivationFunction activationFunction; - private double[] weights; - + private int id; private transient double input = 0.0; + private transient double error = 0.0; public Neuron() { //no-arg constructor for JAXB } - public Neuron(ActivationFunction activationFunction, int numWeights) { + public Neuron(ActivationFunction activationFunction, int id) { this.activationFunction = activationFunction; - this.weights = new double[numWeights]; + this.id = id; + } + + public void addInput(double value) { + input += value; } @XmlElement(type=Sigmoid.class) @@ -26,12 +31,15 @@ public class Neuron { return activationFunction; } - void afterUnmarshal(Unmarshaller aUnmarshaller, Object aParent) - { - if (weights == null) { - weights = new double[0]; - } - } + @XmlAttribute + public int getId() { + return id; + } + + @XmlTransient + public double getError() { + return error; + } @XmlTransient public double getInput() { @@ -42,9 +50,8 @@ public class Neuron { return activationFunction.calculate(input); } - @XmlElement - public double[] getWeights() { - return weights; + public void setError(double value) { + this.error = value; } public void setInput(double input) { @@ -59,7 +66,6 @@ public class Neuron { * result + ((activationFunction == null) ? 0 : activationFunction .hashCode()); - result = prime * result + Arrays.hashCode(weights); return result; } @@ -77,8 +83,6 @@ public class Neuron { return false; } else if (!activationFunction.equals(other.activationFunction)) return false; - if (!Arrays.equals(weights, other.weights)) - return false; return true; } @@ -86,7 +90,9 @@ public class Neuron { this.activationFunction = activationFunction; } - public void setWeights(double[] weights) { - this.weights = weights; + @Override + public String toString() { + return "Neuron #" + id +", input: " + input + ", error: " + error; } + } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Tanh.java b/src/net/woodyfolsom/msproj/ann2/Tanh.java deleted file mode 100644 index 7277d67..0000000 --- a/src/net/woodyfolsom/msproj/ann2/Tanh.java +++ /dev/null @@ -1,10 +0,0 @@ -package net.woodyfolsom.msproj.ann2; - -public class Tanh implements ActivationFunction{ - - @Override - public double calculate(double arg) { - return Math.tanh(arg); - } - -} diff --git a/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java new file mode 100644 index 0000000..b588029 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/TemporalDifference.java @@ -0,0 +1,19 @@ +package net.woodyfolsom.msproj.ann2; + +import java.util.List; + +public class TemporalDifference implements TrainingMethod { + + @Override + public void iterate(FeedforwardNetwork neuralNetwork, + List trainingSet) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public double computeError(FeedforwardNetwork neuralNetwork, + List trainingSet) { + throw new UnsupportedOperationException("Not implemented"); + } + +} diff --git a/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java b/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java new file mode 100644 index 0000000..b109034 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/TrainingMethod.java @@ -0,0 +1,10 @@ +package net.woodyfolsom.msproj.ann2; + +import java.util.List; + +public interface TrainingMethod { + + void iterate(FeedforwardNetwork neuralNetwork, List trainingSet); + double computeError(FeedforwardNetwork neuralNetwork, List trainingSet); + +} diff --git a/src/net/woodyfolsom/msproj/ann2/XORFilter.java b/src/net/woodyfolsom/msproj/ann2/XORFilter.java new file mode 100644 index 0000000..7ffd6c4 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/XORFilter.java @@ -0,0 +1,44 @@ +package net.woodyfolsom.msproj.ann2; + +/** + * Based on sample code from http://neuroph.sourceforge.net + * + * @author Woody + * + */ +public class XORFilter extends AbstractNeuralNetFilter implements + NeuralNetFilter { + + private static final int INPUT_SIZE = 2; + private static final int OUTPUT_SIZE = 1; + + public XORFilter() { + this(0.8,0.7); + } + + public XORFilter(double learningRate, double momentum) { + super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE), + new BackPropagation(learningRate, momentum), 1000, 0.01); + super.getNeuralNetwork().setName("XORFilter"); + + //TODO remove + //getNeuralNetwork().setWeights(new double[] { + // 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias + // -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias + // -0.993423, 0.164732, 0.752621}); //output + } + + public double compute(double x, double y) { + return getNeuralNetwork().compute(new double[]{x,y})[0]; + } + + @Override + public int getInputSize() { + return INPUT_SIZE; + } + + @Override + public int getOutputSize() { + return OUTPUT_SIZE; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Sigmoid.java b/src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java similarity index 56% rename from src/net/woodyfolsom/msproj/ann2/Sigmoid.java rename to src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java index 8629496..c82a307 100644 --- a/src/net/woodyfolsom/msproj/ann2/Sigmoid.java +++ b/src/net/woodyfolsom/msproj/ann2/math/ActivationFunction.java @@ -1,25 +1,29 @@ -package net.woodyfolsom.msproj.ann2; +package net.woodyfolsom.msproj.ann2.math; -public class Sigmoid implements ActivationFunction{ - public static final Sigmoid function = new Sigmoid(); +import javax.xml.bind.annotation.XmlAttribute; + +public abstract class ActivationFunction { private String name; - private Sigmoid() { - this.name = "Sigmoid"; + public abstract double calculate(double arg); + public abstract double derivative(double arg); + + public ActivationFunction() { + //no-arg constructor for JAXB } - public double calculate(double arg) { - return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg)); + public ActivationFunction(String name) { + this.name = name; } - + + @XmlAttribute public String getName() { return name; } - + public void setName(String name) { this.name = name; } - @Override public int hashCode() { final int prime = 31; @@ -27,7 +31,6 @@ public class Sigmoid implements ActivationFunction{ result = prime * result + ((name == null) ? 0 : name.hashCode()); return result; } - @Override public boolean equals(Object obj) { if (this == obj) @@ -36,7 +39,7 @@ public class Sigmoid implements ActivationFunction{ return false; if (getClass() != obj.getClass()) return false; - Sigmoid other = (Sigmoid) obj; + ActivationFunction other = (ActivationFunction) obj; if (name == null) { if (other.name != null) return false; diff --git a/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java b/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java new file mode 100644 index 0000000..789da3b --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java @@ -0,0 +1,5 @@ +package net.woodyfolsom.msproj.ann2.math; + +public interface ErrorFunction { + double compute(double[] ideal, double[] actual); +} diff --git a/src/net/woodyfolsom/msproj/ann2/math/Linear.java b/src/net/woodyfolsom/msproj/ann2/math/Linear.java new file mode 100644 index 0000000..3e9ab4a --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/math/Linear.java @@ -0,0 +1,17 @@ +package net.woodyfolsom.msproj.ann2.math; + +public class Linear extends ActivationFunction{ + public static final Linear function = new Linear(); + + private Linear() { + super("Linear"); + } + + public double calculate(double arg) { + return arg; + } + + public double derivative(double arg) { + return 0; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/math/MSSE.java b/src/net/woodyfolsom/msproj/ann2/math/MSSE.java new file mode 100644 index 0000000..b4f85d1 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/math/MSSE.java @@ -0,0 +1,22 @@ +package net.woodyfolsom.msproj.ann2.math; + +public class MSSE implements ErrorFunction{ + public static final ErrorFunction function = new MSSE(); + + public double compute(double[] ideal, double[] actual) { + int idealSize = ideal.length; + int actualSize = actual.length; + + if (idealSize != actualSize) { + throw new IllegalArgumentException("actualSize != idealSize"); + } + + double SSE = 0.0; + + for (int i = 0; i < idealSize; i++) { + SSE += Math.pow(ideal[i] - actual[i], 2); + } + + return SSE / idealSize; + } +} diff --git a/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java b/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java new file mode 100644 index 0000000..530bd0a --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java @@ -0,0 +1,20 @@ +package net.woodyfolsom.msproj.ann2.math; + +public class Sigmoid extends ActivationFunction{ + public static final Sigmoid function = new Sigmoid(); + + private Sigmoid() { + super("Sigmoid"); + } + + public double calculate(double arg) { + return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg)); + } + + public double derivative(double arg) { + //lol wth? + //double eX = Math.exp(arg); + //return eX / (Math.pow((1+eX), 2)); + return arg - Math.pow(arg,2); + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/math/Tanh.java b/src/net/woodyfolsom/msproj/ann2/math/Tanh.java new file mode 100644 index 0000000..16108f4 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/math/Tanh.java @@ -0,0 +1,21 @@ +package net.woodyfolsom.msproj.ann2.math; + +public class Tanh extends ActivationFunction{ + public static final Tanh function = new Tanh(); + + public Tanh() { + super("Tanh"); + } + + @Override + public double calculate(double arg) { + return Math.tanh(arg); + } + + @Override + public double derivative(double arg) { + double tanh = Math.tanh(arg); + return 1 - Math.pow(tanh, 2); + } + +} diff --git a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java b/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java index 344f026..7dc8633 100644 --- a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java +++ b/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java @@ -16,7 +16,8 @@ import org.junit.Test; public class MultiLayerPerceptronTest { static final File TEST_FILE = new File("data/test/mlp.net"); - + static final double EPS = 0.001; + @BeforeClass public static void setUp() { if (TEST_FILE.exists()) { @@ -49,14 +50,47 @@ public class MultiLayerPerceptronTest { @Test public void testPersistence() throws JAXBException, IOException { - NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1); + FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1); FileOutputStream fos = new FileOutputStream(TEST_FILE); assertTrue(mlp.save(fos)); fos.close(); FileInputStream fis = new FileInputStream(TEST_FILE); - NeuralNetwork mlp2 = new MultiLayerPerceptron(); + FeedforwardNetwork mlp2 = new MultiLayerPerceptron(); assertTrue(mlp2.load(fis)); assertEquals(mlp, mlp2); fis.close(); } + + @Test + public void testCompute() { + FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1); + NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0})); + NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5})); + NNData actualOutput = mlp.compute(actual); + assertEquals(expected.getIdeal(), actualOutput); + } + + @Test + public void testXORnetwork() { + FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1); + mlp.setWeights(new double[] { + 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias + -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias + -0.993423, 0.164732, 0.752621}); //output + + for (Connection connection : mlp.getConnections()) { + System.out.println(connection); + } + NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610})); + NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0})); + NNData actualOutput = mlp.compute(actual); + assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS); + } + /** + * +Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123 +Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932 +Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621 + + */ } diff --git a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java b/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java index a8e3350..6cbaffe 100644 --- a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java +++ b/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java @@ -3,16 +3,27 @@ package net.woodyfolsom.msproj.ann2; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import net.woodyfolsom.msproj.ann2.math.ActivationFunction; +import net.woodyfolsom.msproj.ann2.math.Sigmoid; +import net.woodyfolsom.msproj.ann2.math.Tanh; + import org.junit.Test; public class SigmoidTest { + static double EPS = 0.001; + @Test public void testCalculate() { - double EPS = 0.001; ActivationFunction sigmoid = Sigmoid.function; assertEquals(0.5,sigmoid.calculate(0.0),EPS); assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS); assertTrue(sigmoid.calculate(-9000.0) < EPS); } + + @Test + public void testDerivative() { + ActivationFunction sigmoid = new Tanh(); + assertEquals(1.0,sigmoid.derivative(0.0), EPS); + } } diff --git a/test/net/woodyfolsom/msproj/ann2/TanhTest.java b/test/net/woodyfolsom/msproj/ann2/TanhTest.java index abd3874..8429fa1 100644 --- a/test/net/woodyfolsom/msproj/ann2/TanhTest.java +++ b/test/net/woodyfolsom/msproj/ann2/TanhTest.java @@ -3,16 +3,26 @@ package net.woodyfolsom.msproj.ann2; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import net.woodyfolsom.msproj.ann2.math.ActivationFunction; +import net.woodyfolsom.msproj.ann2.math.Tanh; + import org.junit.Test; public class TanhTest { + static double EPS = 0.001; + @Test public void testCalculate() { - double EPS = 0.001; - ActivationFunction sigmoid = new Tanh(); - assertEquals(0.0,sigmoid.calculate(0.0),EPS); - assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS); - assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS); + ActivationFunction tanh = new Tanh(); + assertEquals(0.0,tanh.calculate(0.0),EPS); + assertTrue(tanh.calculate(100.0) > 0.5 - EPS); + assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS); + } + + @Test + public void testDerivative() { + ActivationFunction tanh = new Tanh(); + assertEquals(1.0,tanh.derivative(0.0), EPS); } } diff --git a/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java new file mode 100644 index 0000000..b9d8887 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann2/XORFilterTest.java @@ -0,0 +1,136 @@ +package net.woodyfolsom.msproj.ann2; + +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class XORFilterTest { + private static final String FILENAME = "xorPerceptron.net"; + + @AfterClass + public static void deleteNewNet() { + File file = new File(FILENAME); + if (file.exists()) { + file.delete(); + } + } + + @BeforeClass + public static void deleteSavedNet() { + File file = new File(FILENAME); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void testLearn() throws IOException { + NeuralNetFilter nnLearner = new XORFilter(0.05,0.0); + + // create training set (logical XOR function) + int size = 1; + double[][] trainingInput = new double[4 * size][]; + double[][] trainingOutput = new double[4 * size][]; + for (int i = 0; i < size; i++) { + trainingInput[i * 4 + 0] = new double[] { 0, 0 }; + trainingInput[i * 4 + 1] = new double[] { 0, 1 }; + trainingInput[i * 4 + 2] = new double[] { 1, 0 }; + trainingInput[i * 4 + 3] = new double[] { 1, 1 }; + trainingOutput[i * 4 + 0] = new double[] { 0 }; + trainingOutput[i * 4 + 1] = new double[] { 1 }; + trainingOutput[i * 4 + 2] = new double[] { 1 }; + trainingOutput[i * 4 + 3] = new double[] { 0 }; + } + + // create training data + List trainingSet = new ArrayList(); + String[] inputNames = new String[] {"x","y"}; + String[] outputNames = new String[] {"XOR"}; + for (int i = 0; i < 4*size; i++) { + trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i]))); + } + + nnLearner.setMaxTrainingEpochs(20000); + nnLearner.learn(trainingSet); + System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); + + double[][] validationSet = new double[4][2]; + + validationSet[0] = new double[] { 0, 0 }; + validationSet[1] = new double[] { 0, 1 }; + validationSet[2] = new double[] { 1, 0 }; + validationSet[3] = new double[] { 1, 1 }; + + System.out.println("Output from eval set (learned network):"); + testNetwork(nnLearner, validationSet, inputNames, outputNames); + } + + @Test + public void testLearnSaveLoad() throws IOException { + NeuralNetFilter nnLearner = new XORFilter(0.5,0.0); + + // create training set (logical XOR function) + int size = 2; + double[][] trainingInput = new double[4 * size][]; + double[][] trainingOutput = new double[4 * size][]; + for (int i = 0; i < size; i++) { + trainingInput[i * 4 + 0] = new double[] { 0, 0 }; + trainingInput[i * 4 + 1] = new double[] { 0, 1 }; + trainingInput[i * 4 + 2] = new double[] { 1, 0 }; + trainingInput[i * 4 + 3] = new double[] { 1, 1 }; + trainingOutput[i * 4 + 0] = new double[] { 0 }; + trainingOutput[i * 4 + 1] = new double[] { 1 }; + trainingOutput[i * 4 + 2] = new double[] { 1 }; + trainingOutput[i * 4 + 3] = new double[] { 0 }; + } + + // create training data + List trainingSet = new ArrayList(); + String[] inputNames = new String[] {"x","y"}; + String[] outputNames = new String[] {"XOR"}; + for (int i = 0; i < 4*size; i++) { + trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i]))); + } + + nnLearner.setMaxTrainingEpochs(1); + nnLearner.learn(trainingSet); + System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); + + double[][] validationSet = new double[4][2]; + + validationSet[0] = new double[] { 0, 0 }; + validationSet[1] = new double[] { 0, 1 }; + validationSet[2] = new double[] { 1, 0 }; + validationSet[3] = new double[] { 1, 1 }; + + System.out.println("Output from eval set (learned network, pre-serialization):"); + testNetwork(nnLearner, validationSet, inputNames, outputNames); + + FileOutputStream fos = new FileOutputStream(FILENAME); + assertTrue(nnLearner.save(fos)); + fos.close(); + + FileInputStream fis = new FileInputStream(FILENAME); + assertTrue(nnLearner.load(fis)); + fis.close(); + + System.out.println("Output from eval set (learned network, post-serialization):"); + testNetwork(nnLearner, validationSet, inputNames, outputNames); + } + + private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) { + for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { + NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex])); + System.out.println(dp + " => " + nnLearner.compute(dp)); + } + } +} \ No newline at end of file