Functional MLP for XOR toy problem.
This commit is contained in:
84
src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java
Normal file
84
src/net/woodyfolsom/msproj/ann2/AbstractNeuralNetFilter.java
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||||
|
private final FeedforwardNetwork neuralNetwork;
|
||||||
|
private final TrainingMethod trainingMethod;
|
||||||
|
|
||||||
|
private double maxError;
|
||||||
|
private int actualTrainingEpochs = 0;
|
||||||
|
private int maxTrainingEpochs;
|
||||||
|
|
||||||
|
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
|
||||||
|
this.neuralNetwork = neuralNetwork;
|
||||||
|
this.trainingMethod = trainingMethod;
|
||||||
|
this.maxError = maxError;
|
||||||
|
this.maxTrainingEpochs = maxTrainingEpochs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NNData compute(NNDataPair input) {
|
||||||
|
return this.neuralNetwork.compute(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getActualTrainingEpochs() {
|
||||||
|
return actualTrainingEpochs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputSize() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMaxTrainingEpochs() {
|
||||||
|
return maxTrainingEpochs;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected FeedforwardNetwork getNeuralNetwork() {
|
||||||
|
return neuralNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(List<NNDataPair> trainingSet) {
|
||||||
|
actualTrainingEpochs = 0;
|
||||||
|
double error;
|
||||||
|
neuralNetwork.initWeights();
|
||||||
|
|
||||||
|
error = trainingMethod.computeError(neuralNetwork,trainingSet);
|
||||||
|
|
||||||
|
if (error <= maxError) {
|
||||||
|
System.out.println("Initial error: " + error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
do {
|
||||||
|
trainingMethod.iterate(neuralNetwork,trainingSet);
|
||||||
|
error = trainingMethod.computeError(neuralNetwork,trainingSet);
|
||||||
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
|
+ error);
|
||||||
|
actualTrainingEpochs++;
|
||||||
|
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||||
|
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean load(InputStream input) {
|
||||||
|
return neuralNetwork.load(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean save(OutputStream output) {
|
||||||
|
return neuralNetwork.save(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxError(double maxError) {
|
||||||
|
this.maxError = maxError;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxTrainingEpochs(int max) {
|
||||||
|
this.maxTrainingEpochs = max;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
public interface ActivationFunction {
|
|
||||||
double calculate(double arg);
|
|
||||||
}
|
|
||||||
120
src/net/woodyfolsom/msproj/ann2/BackPropagation.java
Normal file
120
src/net/woodyfolsom/msproj/ann2/BackPropagation.java
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.ErrorFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.MSSE;
|
||||||
|
|
||||||
|
public class BackPropagation implements TrainingMethod {
|
||||||
|
private final ErrorFunction errorFunction;
|
||||||
|
private final double learningRate;
|
||||||
|
private final double momentum;
|
||||||
|
|
||||||
|
public BackPropagation(double learningRate, double momentum) {
|
||||||
|
this.errorFunction = MSSE.function;
|
||||||
|
this.learningRate = learningRate;
|
||||||
|
this.momentum = momentum;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void iterate(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet) {
|
||||||
|
System.out.println("Learningrate: " + learningRate);
|
||||||
|
System.out.println("Momentum: " + momentum);
|
||||||
|
|
||||||
|
//zeroErrors(neuralNetwork);
|
||||||
|
|
||||||
|
for (NNDataPair trainingPair : trainingSet) {
|
||||||
|
zeroErrors(neuralNetwork);
|
||||||
|
|
||||||
|
System.out.println("Training with: " + trainingPair.getInput());
|
||||||
|
|
||||||
|
NNData ideal = trainingPair.getIdeal();
|
||||||
|
NNData actual = neuralNetwork.compute(trainingPair);
|
||||||
|
|
||||||
|
System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||||
|
System.out.println("Actual Output: " + actual);
|
||||||
|
|
||||||
|
updateErrors(neuralNetwork, ideal);
|
||||||
|
|
||||||
|
updateWeights(neuralNetwork);
|
||||||
|
}
|
||||||
|
|
||||||
|
//updateWeights(neuralNetwork);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeError(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet) {
|
||||||
|
int numDataPairs = trainingSet.size();
|
||||||
|
int outputSize = neuralNetwork.getOutput().length;
|
||||||
|
int totalOutputSize = outputSize * numDataPairs;
|
||||||
|
|
||||||
|
double[] actuals = new double[totalOutputSize];
|
||||||
|
double[] ideals = new double[totalOutputSize];
|
||||||
|
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
|
||||||
|
NNDataPair nnDataPair = trainingSet.get(dataPair);
|
||||||
|
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
|
||||||
|
.getValues());
|
||||||
|
double[] ideal = nnDataPair.getIdeal().getValues();
|
||||||
|
int offset = dataPair * outputSize;
|
||||||
|
|
||||||
|
System.arraycopy(actual, 0, actuals, offset, outputSize);
|
||||||
|
System.arraycopy(ideal, 0, ideals, offset, outputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
double MSSE = errorFunction.compute(ideals, actuals);
|
||||||
|
return MSSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateErrors(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||||
|
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||||
|
double[] idealValues = ideal.getValues();
|
||||||
|
|
||||||
|
for (int i = 0; i < idealValues.length; i++) {
|
||||||
|
double output = outputNeurons[i].getOutput();
|
||||||
|
double derivative = outputNeurons[i].getActivationFunction()
|
||||||
|
.derivative(output);
|
||||||
|
outputNeurons[i].setError(outputNeurons[i].getError() + derivative * (idealValues[i] - output));
|
||||||
|
}
|
||||||
|
// walking down the list of Neurons in reverse order, propagate the
|
||||||
|
// error
|
||||||
|
Neuron[] neurons = neuralNetwork.getNeurons();
|
||||||
|
|
||||||
|
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||||
|
|
||||||
|
Neuron neuron = neurons[n];
|
||||||
|
double error = neuron.getError();
|
||||||
|
|
||||||
|
Connection[] connectionsFromN = neuralNetwork
|
||||||
|
.getConnectionsFrom(neuron.getId());
|
||||||
|
if (connectionsFromN.length > 0) {
|
||||||
|
|
||||||
|
double derivative = neuron.getActivationFunction().derivative(
|
||||||
|
neuron.getOutput());
|
||||||
|
for (Connection connection : connectionsFromN) {
|
||||||
|
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
neuron.setError(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateWeights(FeedforwardNetwork neuralNetwork) {
|
||||||
|
for (Connection connection : neuralNetwork.getConnections()) {
|
||||||
|
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
|
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||||
|
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getError();
|
||||||
|
//TODO allow for momentum
|
||||||
|
//double lastDelta = connection.getLastDelta();
|
||||||
|
connection.addDelta(delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void zeroErrors(FeedforwardNetwork neuralNetwork) {
|
||||||
|
// Set output errors relative to ideals, all other errors to 0.
|
||||||
|
for (Neuron neuron : neuralNetwork.getNeurons()) {
|
||||||
|
neuron.setError(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
94
src/net/woodyfolsom/msproj/ann2/Connection.java
Normal file
94
src/net/woodyfolsom/msproj/ann2/Connection.java
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
|
import javax.xml.bind.annotation.XmlTransient;
|
||||||
|
|
||||||
|
public class Connection {
|
||||||
|
private int src;
|
||||||
|
private int dest;
|
||||||
|
private double weight;
|
||||||
|
private transient double lastDelta = 0.0;
|
||||||
|
|
||||||
|
public Connection() {
|
||||||
|
//no-arg constructor for JAXB
|
||||||
|
}
|
||||||
|
|
||||||
|
public Connection(int src, int dest, double weight) {
|
||||||
|
this.src = src;
|
||||||
|
this.dest = dest;
|
||||||
|
this.weight = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addDelta(double delta) {
|
||||||
|
this.weight += delta;
|
||||||
|
this.lastDelta = delta;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
|
public int getDest() {
|
||||||
|
return dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlTransient
|
||||||
|
public double getLastDelta() {
|
||||||
|
return lastDelta;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
|
public int getSrc() {
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
|
public double getWeight() {
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setDest(int dest) {
|
||||||
|
this.dest = dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSrc(int src) {
|
||||||
|
this.src = src;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWeight(double weight) {
|
||||||
|
this.weight = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime * result + dest;
|
||||||
|
result = prime * result + src;
|
||||||
|
long temp;
|
||||||
|
temp = Double.doubleToLongBits(weight);
|
||||||
|
result = prime * result + (int) (temp ^ (temp >>> 32));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
Connection other = (Connection) obj;
|
||||||
|
if (dest != other.dest)
|
||||||
|
return false;
|
||||||
|
if (src != other.src)
|
||||||
|
return false;
|
||||||
|
if (Double.doubleToLongBits(weight) != Double
|
||||||
|
.doubleToLongBits(other.weight))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Connection(" + src + ", " + dest +"), weight: " + weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
291
src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java
Normal file
291
src/net/woodyfolsom/msproj/ann2/FeedforwardNetwork.java
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.Linear;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||||
|
|
||||||
|
public abstract class FeedforwardNetwork {
|
||||||
|
private ActivationFunction activationFunction;
|
||||||
|
private boolean biased;
|
||||||
|
private List<Connection> connections;
|
||||||
|
private List<Neuron> neurons;
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private transient int biasNeuronId;
|
||||||
|
private transient Map<Integer, List<Connection>> connectionsFrom;
|
||||||
|
private transient Map<Integer, List<Connection>> connectionsTo;
|
||||||
|
|
||||||
|
public FeedforwardNetwork() {
|
||||||
|
this(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public FeedforwardNetwork(boolean biased) {
|
||||||
|
//No-arg constructor for JAXB
|
||||||
|
this.activationFunction = Sigmoid.function;
|
||||||
|
this.connections = new ArrayList<Connection>();
|
||||||
|
this.connectionsFrom = new HashMap<Integer,List<Connection>>();
|
||||||
|
this.connectionsTo = new HashMap<Integer,List<Connection>>();
|
||||||
|
this.neurons = new ArrayList<Neuron>();
|
||||||
|
this.name = "UNDEFINED";
|
||||||
|
this.biasNeuronId = -1;
|
||||||
|
setBiased(biased);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addConnection(Connection connection) {
|
||||||
|
connections.add(connection);
|
||||||
|
|
||||||
|
int src = connection.getSrc();
|
||||||
|
int dest = connection.getDest();
|
||||||
|
|
||||||
|
if (!connectionsFrom.containsKey(src)) {
|
||||||
|
connectionsFrom.put(src, new ArrayList<Connection>());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!connectionsTo.containsKey(dest)) {
|
||||||
|
connectionsTo.put(dest, new ArrayList<Connection>());
|
||||||
|
}
|
||||||
|
|
||||||
|
connectionsFrom.get(src).add(connection);
|
||||||
|
connectionsTo.get(dest).add(connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
public NNData compute(NNDataPair nnDataPair) {
|
||||||
|
NNData actual = new NNData(nnDataPair.getIdeal().getFields(),
|
||||||
|
compute(nnDataPair.getInput().getValues()));
|
||||||
|
return actual;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] compute(double[] input) {
|
||||||
|
zeroInputs();
|
||||||
|
setInput(input);
|
||||||
|
feedforward();
|
||||||
|
return getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
void createBiasConnection(int neuronId, double weight) {
|
||||||
|
if (!biased) {
|
||||||
|
throw new UnsupportedOperationException("Not a biased network");
|
||||||
|
}
|
||||||
|
addConnection(new Connection(biasNeuronId, neuronId, weight));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a new neuron with a unique id to this FeedforwardNetwork.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Neuron createNeuron(boolean input) {
|
||||||
|
Neuron neuron;
|
||||||
|
if (input) {
|
||||||
|
neuron = new Neuron(Linear.function, neurons.size());
|
||||||
|
} else {
|
||||||
|
neuron = new Neuron(activationFunction, neurons.size());
|
||||||
|
}
|
||||||
|
neurons.add(neuron);
|
||||||
|
return neuron;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
FeedforwardNetwork other = (FeedforwardNetwork) obj;
|
||||||
|
if (activationFunction == null) {
|
||||||
|
if (other.activationFunction != null)
|
||||||
|
return false;
|
||||||
|
} else if (!activationFunction.equals(other.activationFunction))
|
||||||
|
return false;
|
||||||
|
if (connections == null) {
|
||||||
|
if (other.connections != null)
|
||||||
|
return false;
|
||||||
|
} else if (!connections.equals(other.connections))
|
||||||
|
return false;
|
||||||
|
if (name == null) {
|
||||||
|
if (other.name != null)
|
||||||
|
return false;
|
||||||
|
} else if (!name.equals(other.name))
|
||||||
|
return false;
|
||||||
|
if (neurons == null) {
|
||||||
|
if (other.neurons != null)
|
||||||
|
return false;
|
||||||
|
} else if (!neurons.equals(other.neurons))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void feedforward() {
|
||||||
|
for (int i = 0; i < neurons.size(); i++) {
|
||||||
|
Neuron src = neurons.get(i);
|
||||||
|
for (Connection connection : getConnectionsFrom(src.getId())) {
|
||||||
|
Neuron dest = getNeuron(connection.getDest());
|
||||||
|
dest.addInput(src.getOutput() * connection.getWeight());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement(type = Sigmoid.class)
|
||||||
|
public ActivationFunction getActivationFunction() {
|
||||||
|
return activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract double[] getOutput();
|
||||||
|
protected abstract Neuron[] getOutputNeurons();
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Neuron getNeuron(int id) {
|
||||||
|
return neurons.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement
|
||||||
|
protected Connection[] getConnections() {
|
||||||
|
return connections.toArray(new Connection[connections.size()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Connection[] getConnectionsFrom(int neuronId) {
|
||||||
|
List<Connection> connList = connectionsFrom.get(neuronId);
|
||||||
|
|
||||||
|
if (connList == null) {
|
||||||
|
return new Connection[0];
|
||||||
|
} else {
|
||||||
|
return connList.toArray(new Connection[connList.size()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Connection[] getConnectionsTo(int neuronId) {
|
||||||
|
List<Connection> connList = connectionsTo.get(neuronId);
|
||||||
|
|
||||||
|
if (connList == null) {
|
||||||
|
return new Connection[0];
|
||||||
|
} else {
|
||||||
|
return connList.toArray(new Connection[connList.size()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
|
public boolean isBiased() {
|
||||||
|
return biased;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement
|
||||||
|
protected Neuron[] getNeurons() {
|
||||||
|
return neurons.toArray(new Neuron[neurons.size()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime
|
||||||
|
* result
|
||||||
|
+ ((activationFunction == null) ? 0 : activationFunction
|
||||||
|
.hashCode());
|
||||||
|
result = prime * result
|
||||||
|
+ ((connections == null) ? 0 : connections.hashCode());
|
||||||
|
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
||||||
|
result = prime * result + ((neurons == null) ? 0 : neurons.hashCode());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void initWeights() {
|
||||||
|
for (Connection connection : connections) {
|
||||||
|
connection.setWeight(1.0-Math.random()*2.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean load(InputStream is);
|
||||||
|
|
||||||
|
public abstract boolean save(OutputStream os);
|
||||||
|
|
||||||
|
public void setActivationFunction(ActivationFunction activationFunction) {
|
||||||
|
this.activationFunction = activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBiased(boolean biased) {
|
||||||
|
|
||||||
|
if (this.biased == biased) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.biased = biased;
|
||||||
|
|
||||||
|
if (biased) {
|
||||||
|
Neuron biasNeuron = createNeuron(true);
|
||||||
|
biasNeuron.setInput(1.0);
|
||||||
|
biasNeuronId = biasNeuron.getId();
|
||||||
|
} else {
|
||||||
|
//This is an inefficient but concise way to remove all connections involving the bias Neuron
|
||||||
|
//from the global
|
||||||
|
|
||||||
|
//Remove all connections from biasId from this index
|
||||||
|
List<Connection> connectionsFromBias = connectionsFrom.remove(biasNeuronId);
|
||||||
|
|
||||||
|
//Remove all connections to all nodes from biasId from this index
|
||||||
|
for (Map.Entry<Integer,List<Connection>> mapEntry : connectionsTo.entrySet()) {
|
||||||
|
mapEntry.getValue().removeAll(connectionsFromBias);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Finally, remove from the (serialized) list of non-indexed connections
|
||||||
|
connections.remove(connectionsFromBias);
|
||||||
|
|
||||||
|
biasNeuronId = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setConnections(Connection[] connections) {
|
||||||
|
this.connections.clear();
|
||||||
|
this.connectionsFrom.clear();
|
||||||
|
this.connectionsTo.clear();
|
||||||
|
for (Connection connection : connections) {
|
||||||
|
addConnection(connection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract void setInput(double[] input);
|
||||||
|
|
||||||
|
public void setName(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setNeurons(Neuron[] neurons) {
|
||||||
|
this.neurons.clear();
|
||||||
|
for (Neuron neuron : neurons) {
|
||||||
|
this.neurons.add(neuron);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWeights(double[] weights) {
|
||||||
|
if (weights.length != connections.size()) {
|
||||||
|
throw new IllegalArgumentException("# of weights must == # of connections");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < connections.size(); i++) {
|
||||||
|
connections.get(i).setWeight(weights[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void zeroInputs() {
|
||||||
|
for (Neuron neuron : neurons) {
|
||||||
|
if (neuron.getId() != biasNeuronId){
|
||||||
|
neuron.setInput(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -2,29 +2,48 @@ package net.woodyfolsom.msproj.ann2;
|
|||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
|
|
||||||
public class Layer {
|
public class Layer {
|
||||||
private Neuron[] neurons;
|
private int[] neuronIds;
|
||||||
|
|
||||||
public Layer() {
|
public Layer() {
|
||||||
//default constructor for JAXB
|
neuronIds = new int[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
public Layer(int numNeurons, int numWeights, ActivationFunction activationFunction) {
|
public Layer(int numNeurons) {
|
||||||
neurons = new Neuron[numNeurons];
|
neuronIds = new int[numNeurons];
|
||||||
for (int neuronIndex = 0; neuronIndex < numNeurons; neuronIndex++) {
|
|
||||||
neurons[neuronIndex] = new Neuron(activationFunction, numWeights);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int size() {
|
public int size() {
|
||||||
return neurons.length;
|
return neuronIds.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNeuronId(int index) {
|
||||||
|
return neuronIds[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement
|
||||||
|
public int[] getNeuronIds() {
|
||||||
|
int[] safeCopy = new int[neuronIds.length];
|
||||||
|
System.arraycopy(neuronIds, 0, safeCopy, 0, neuronIds.length);
|
||||||
|
return safeCopy;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNeuronId(int index, int id) {
|
||||||
|
neuronIds[index] = id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNeuronIds(int[] neuronIds) {
|
||||||
|
this.neuronIds = new int[neuronIds.length];
|
||||||
|
System.arraycopy(neuronIds, 0, this.neuronIds, 0, neuronIds.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
final int prime = 31;
|
final int prime = 31;
|
||||||
int result = 1;
|
int result = 1;
|
||||||
result = prime * result + Arrays.hashCode(neurons);
|
result = prime * result + Arrays.hashCode(neuronIds);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,17 +56,9 @@ public class Layer {
|
|||||||
if (getClass() != obj.getClass())
|
if (getClass() != obj.getClass())
|
||||||
return false;
|
return false;
|
||||||
Layer other = (Layer) obj;
|
Layer other = (Layer) obj;
|
||||||
if (!Arrays.equals(neurons, other.neurons))
|
if (!Arrays.equals(neuronIds, other.neuronIds))
|
||||||
return false;
|
return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Neuron[] getNeurons() {
|
|
||||||
return neurons;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNeurons(Neuron[] neurons) {
|
|
||||||
this.neurons = neurons;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -2,19 +2,16 @@ package net.woodyfolsom.msproj.ann2;
|
|||||||
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
import javax.xml.bind.JAXBContext;
|
import javax.xml.bind.JAXBContext;
|
||||||
import javax.xml.bind.JAXBException;
|
import javax.xml.bind.JAXBException;
|
||||||
import javax.xml.bind.Marshaller;
|
import javax.xml.bind.Marshaller;
|
||||||
import javax.xml.bind.Unmarshaller;
|
import javax.xml.bind.Unmarshaller;
|
||||||
import javax.xml.bind.annotation.XmlAttribute;
|
|
||||||
import javax.xml.bind.annotation.XmlElement;
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
import javax.xml.bind.annotation.XmlRootElement;
|
import javax.xml.bind.annotation.XmlRootElement;
|
||||||
|
|
||||||
@XmlRootElement
|
@XmlRootElement
|
||||||
public class MultiLayerPerceptron extends NeuralNetwork {
|
public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||||
private ActivationFunction activationFunction;
|
|
||||||
private boolean biased;
|
private boolean biased;
|
||||||
private Layer[] layers;
|
private Layer[] layers;
|
||||||
|
|
||||||
@@ -23,17 +20,15 @@ public class MultiLayerPerceptron extends NeuralNetwork {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public MultiLayerPerceptron(boolean biased, int... layerSizes) {
|
public MultiLayerPerceptron(boolean biased, int... layerSizes) {
|
||||||
|
super(biased);
|
||||||
|
|
||||||
int numLayers = layerSizes.length;
|
int numLayers = layerSizes.length;
|
||||||
|
|
||||||
if (numLayers < 2) {
|
if (numLayers < 2) {
|
||||||
throw new IllegalArgumentException("# of layers must be >= 2");
|
throw new IllegalArgumentException("# of layers must be >= 2");
|
||||||
}
|
}
|
||||||
|
|
||||||
this.activationFunction = Sigmoid.function;
|
|
||||||
this.biased = biased;
|
|
||||||
this.layers = new Layer[numLayers];
|
|
||||||
|
|
||||||
int numWeights;
|
this.layers = new Layer[numLayers];
|
||||||
|
|
||||||
for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) {
|
for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) {
|
||||||
int layerSize = layerSizes[layerIndex];
|
int layerSize = layerSizes[layerIndex];
|
||||||
@@ -41,66 +36,71 @@ public class MultiLayerPerceptron extends NeuralNetwork {
|
|||||||
if (layerSize < 1) {
|
if (layerSize < 1) {
|
||||||
throw new IllegalArgumentException("Layer size must be >= 1");
|
throw new IllegalArgumentException("Layer size must be >= 1");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (layerIndex == 0) {
|
|
||||||
numWeights = 0;
|
|
||||||
if (biased) {
|
|
||||||
layerSize++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
numWeights = layers[layerIndex - 1].size();
|
|
||||||
}
|
|
||||||
|
|
||||||
layers[layerIndex] = new Layer(layerSize, numWeights,
|
Layer newLayer = createNewLayer(layerIndex, layerSize);
|
||||||
activationFunction);
|
|
||||||
|
if (layerIndex > 0) {
|
||||||
|
Layer prevLayer = layers[layerIndex - 1];
|
||||||
|
for (int j = 0; j < newLayer.size(); j++) {
|
||||||
|
if (biased) {
|
||||||
|
createBiasConnection(newLayer.getNeuronId(j),0.0);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < prevLayer.size(); i++) {
|
||||||
|
addConnection(new Connection(prevLayer.getNeuronId(i),
|
||||||
|
newLayer.getNeuronId(j), 0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlElement(type=Sigmoid.class)
|
private Layer createNewLayer(int layerIndex, int layerSize) {
|
||||||
public ActivationFunction getActivationFunction() {
|
Layer layer = new Layer(layerSize);
|
||||||
return activationFunction;
|
layers[layerIndex] = layer;
|
||||||
|
for (int n = 0; n < layerSize; n++) {
|
||||||
|
Neuron neuron = createNeuron(layerIndex == 0);
|
||||||
|
layer.setNeuronId(n, neuron.getId());
|
||||||
|
}
|
||||||
|
return layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlElement
|
@XmlElement
|
||||||
public Layer[] getLayers() {
|
public Layer[] getLayers() {
|
||||||
return layers;
|
return layers;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected double[] getOutput() {
|
protected double[] getOutput() {
|
||||||
// TODO Auto-generated method stub
|
Layer outputLayer = layers[layers.length - 1];
|
||||||
return null;
|
double output[] = new double[outputLayer.size()];
|
||||||
|
for (int n = 0; n < outputLayer.size(); n++) {
|
||||||
|
output[n] = getNeuron(outputLayer.getNeuronId(n)).getOutput();
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Neuron[] getOutputNeurons() {
|
||||||
|
Layer outputLayer = layers[layers.length - 1];
|
||||||
|
Neuron[] outputNeurons = new Neuron[outputLayer.size()];
|
||||||
|
for (int i = 0; i < outputLayer.size(); i++) {
|
||||||
|
outputNeurons[i] = getNeuron(outputLayer.getNeuronId(i));
|
||||||
|
}
|
||||||
|
return outputNeurons;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Neuron[] getNeurons() {
|
|
||||||
// TODO Auto-generated method stub
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@XmlAttribute
|
|
||||||
public boolean isBiased() {
|
|
||||||
return biased;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setActivationFunction(ActivationFunction activationFunction) {
|
|
||||||
this.activationFunction = activationFunction;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void setInput(double[] input) {
|
protected void setInput(double[] input) {
|
||||||
// TODO Auto-generated method stub
|
Layer inputLayer = layers[0];
|
||||||
|
for (int n = 0; n < inputLayer.size(); n++) {
|
||||||
|
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setBiased(boolean biased) {
|
|
||||||
this.biased = biased;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLayers(Layer[] layers) {
|
public void setLayers(Layer[] layers) {
|
||||||
this.layers = layers;
|
this.layers = layers;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean load(InputStream is) {
|
public boolean load(InputStream is) {
|
||||||
try {
|
try {
|
||||||
@@ -111,7 +111,9 @@ public class MultiLayerPerceptron extends NeuralNetwork {
|
|||||||
Unmarshaller u = jc.createUnmarshaller();
|
Unmarshaller u = jc.createUnmarshaller();
|
||||||
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);
|
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);
|
||||||
|
|
||||||
this.activationFunction = mlp.activationFunction;
|
super.setActivationFunction(mlp.getActivationFunction());
|
||||||
|
super.setConnections(mlp.getConnections());
|
||||||
|
super.setNeurons(mlp.getNeurons());
|
||||||
this.biased = mlp.biased;
|
this.biased = mlp.biased;
|
||||||
this.layers = mlp.layers;
|
this.layers = mlp.layers;
|
||||||
|
|
||||||
@@ -138,38 +140,4 @@ public class MultiLayerPerceptron extends NeuralNetwork {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
final int prime = 31;
|
|
||||||
int result = 1;
|
|
||||||
result = prime
|
|
||||||
* result
|
|
||||||
+ ((activationFunction == null) ? 0 : activationFunction
|
|
||||||
.hashCode());
|
|
||||||
result = prime * result + (biased ? 1231 : 1237);
|
|
||||||
result = prime * result + Arrays.hashCode(layers);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object obj) {
|
|
||||||
if (this == obj)
|
|
||||||
return true;
|
|
||||||
if (obj == null)
|
|
||||||
return false;
|
|
||||||
if (getClass() != obj.getClass())
|
|
||||||
return false;
|
|
||||||
MultiLayerPerceptron other = (MultiLayerPerceptron) obj;
|
|
||||||
if (activationFunction == null) {
|
|
||||||
if (other.activationFunction != null)
|
|
||||||
return false;
|
|
||||||
} else if (!activationFunction.equals(other.activationFunction))
|
|
||||||
return false;
|
|
||||||
if (biased != other.biased)
|
|
||||||
return false;
|
|
||||||
if (!Arrays.equals(layers, other.layers))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -9,6 +9,15 @@ public class NNData {
|
|||||||
this.values = values;
|
this.values = values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public NNData(NNData that) {
|
||||||
|
this.fields = that.fields;
|
||||||
|
this.values = that.values;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String[] getFields() {
|
||||||
|
return fields;
|
||||||
|
}
|
||||||
|
|
||||||
public double[] getValues() {
|
public double[] getValues() {
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
public class NNDataPair {
|
public class NNDataPair {
|
||||||
private final NNData actual;
|
private final NNData input;
|
||||||
private final NNData ideal;
|
private final NNData ideal;
|
||||||
|
|
||||||
public NNDataPair(NNData actual, NNData ideal) {
|
public NNDataPair(NNData actual, NNData ideal) {
|
||||||
this.actual = actual;
|
this.input = actual;
|
||||||
this.ideal = ideal;
|
this.ideal = ideal;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NNData getActual() {
|
public NNData getInput() {
|
||||||
return actual;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NNData getIdeal() {
|
public NNData getIdeal() {
|
||||||
|
|||||||
25
src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java
Normal file
25
src/net/woodyfolsom/msproj/ann2/NeuralNetFilter.java
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface NeuralNetFilter {
|
||||||
|
int getActualTrainingEpochs();
|
||||||
|
|
||||||
|
int getInputSize();
|
||||||
|
|
||||||
|
int getMaxTrainingEpochs();
|
||||||
|
|
||||||
|
int getOutputSize();
|
||||||
|
|
||||||
|
boolean load(InputStream input);
|
||||||
|
|
||||||
|
boolean save(OutputStream output);
|
||||||
|
|
||||||
|
void setMaxTrainingEpochs(int max);
|
||||||
|
|
||||||
|
NNData compute(NNDataPair input);
|
||||||
|
|
||||||
|
void learn(List<NNDataPair> trainingSet);
|
||||||
|
}
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
|
|
||||||
import javax.xml.bind.JAXBException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A NeuralNetwork is simply an ordered set of Neurons.
|
|
||||||
*
|
|
||||||
* Functions which rely on knowledge of input neurons, output neurons and layers
|
|
||||||
* are delegated to MultiLayerPerception.
|
|
||||||
*
|
|
||||||
* The primary function implemented in this abstract class is feedfoward.
|
|
||||||
* This function depends only on getNeurons() returning Neurons in feedforward order
|
|
||||||
* and the returned Neurons must have the correct number of weights for the NeuralNetwork
|
|
||||||
* configuration.
|
|
||||||
*
|
|
||||||
* @author Woody
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public abstract class NeuralNetwork {
|
|
||||||
public NeuralNetwork() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] calculate(double[] input) {
|
|
||||||
zeroInputs();
|
|
||||||
setInput(input);
|
|
||||||
feedforward();
|
|
||||||
return getOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void feedforward() {
|
|
||||||
Neuron[] neurons = getNeurons();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract double[] getOutput();
|
|
||||||
|
|
||||||
protected abstract Neuron[] getNeurons();
|
|
||||||
|
|
||||||
public abstract boolean load(InputStream is);
|
|
||||||
public abstract boolean save(OutputStream os);
|
|
||||||
|
|
||||||
protected abstract void setInput(double[] input);
|
|
||||||
|
|
||||||
protected void zeroInputs() {
|
|
||||||
for (Neuron neuron : getNeurons()) {
|
|
||||||
neuron.setInput(0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,24 +1,29 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
|
|
||||||
import javax.xml.bind.Unmarshaller;
|
|
||||||
import javax.xml.bind.annotation.XmlElement;
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
import javax.xml.bind.annotation.XmlTransient;
|
import javax.xml.bind.annotation.XmlTransient;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||||
|
|
||||||
public class Neuron {
|
public class Neuron {
|
||||||
private ActivationFunction activationFunction;
|
private ActivationFunction activationFunction;
|
||||||
private double[] weights;
|
private int id;
|
||||||
|
|
||||||
private transient double input = 0.0;
|
private transient double input = 0.0;
|
||||||
|
private transient double error = 0.0;
|
||||||
|
|
||||||
public Neuron() {
|
public Neuron() {
|
||||||
//no-arg constructor for JAXB
|
//no-arg constructor for JAXB
|
||||||
}
|
}
|
||||||
|
|
||||||
public Neuron(ActivationFunction activationFunction, int numWeights) {
|
public Neuron(ActivationFunction activationFunction, int id) {
|
||||||
this.activationFunction = activationFunction;
|
this.activationFunction = activationFunction;
|
||||||
this.weights = new double[numWeights];
|
this.id = id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addInput(double value) {
|
||||||
|
input += value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlElement(type=Sigmoid.class)
|
@XmlElement(type=Sigmoid.class)
|
||||||
@@ -26,12 +31,15 @@ public class Neuron {
|
|||||||
return activationFunction;
|
return activationFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
void afterUnmarshal(Unmarshaller aUnmarshaller, Object aParent)
|
@XmlAttribute
|
||||||
{
|
public int getId() {
|
||||||
if (weights == null) {
|
return id;
|
||||||
weights = new double[0];
|
}
|
||||||
}
|
|
||||||
}
|
@XmlTransient
|
||||||
|
public double getError() {
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
@XmlTransient
|
@XmlTransient
|
||||||
public double getInput() {
|
public double getInput() {
|
||||||
@@ -42,9 +50,8 @@ public class Neuron {
|
|||||||
return activationFunction.calculate(input);
|
return activationFunction.calculate(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
@XmlElement
|
public void setError(double value) {
|
||||||
public double[] getWeights() {
|
this.error = value;
|
||||||
return weights;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setInput(double input) {
|
public void setInput(double input) {
|
||||||
@@ -59,7 +66,6 @@ public class Neuron {
|
|||||||
* result
|
* result
|
||||||
+ ((activationFunction == null) ? 0 : activationFunction
|
+ ((activationFunction == null) ? 0 : activationFunction
|
||||||
.hashCode());
|
.hashCode());
|
||||||
result = prime * result + Arrays.hashCode(weights);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,8 +83,6 @@ public class Neuron {
|
|||||||
return false;
|
return false;
|
||||||
} else if (!activationFunction.equals(other.activationFunction))
|
} else if (!activationFunction.equals(other.activationFunction))
|
||||||
return false;
|
return false;
|
||||||
if (!Arrays.equals(weights, other.weights))
|
|
||||||
return false;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +90,9 @@ public class Neuron {
|
|||||||
this.activationFunction = activationFunction;
|
this.activationFunction = activationFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setWeights(double[] weights) {
|
@Override
|
||||||
this.weights = weights;
|
public String toString() {
|
||||||
|
return "Neuron #" + id +", input: " + input + ", error: " + error;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
|
||||||
|
|
||||||
public class Tanh implements ActivationFunction{
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double calculate(double arg) {
|
|
||||||
return Math.tanh(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
19
src/net/woodyfolsom/msproj/ann2/TemporalDifference.java
Normal file
19
src/net/woodyfolsom/msproj/ann2/TemporalDifference.java
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class TemporalDifference implements TrainingMethod {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void iterate(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeError(FeedforwardNetwork neuralNetwork,
|
||||||
|
List<NNDataPair> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
10
src/net/woodyfolsom/msproj/ann2/TrainingMethod.java
Normal file
10
src/net/woodyfolsom/msproj/ann2/TrainingMethod.java
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface TrainingMethod {
|
||||||
|
|
||||||
|
void iterate(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
|
||||||
|
double computeError(FeedforwardNetwork neuralNetwork, List<NNDataPair> trainingSet);
|
||||||
|
|
||||||
|
}
|
||||||
44
src/net/woodyfolsom/msproj/ann2/XORFilter.java
Normal file
44
src/net/woodyfolsom/msproj/ann2/XORFilter.java
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Based on sample code from http://neuroph.sourceforge.net
|
||||||
|
*
|
||||||
|
* @author Woody
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||||
|
NeuralNetFilter {
|
||||||
|
|
||||||
|
private static final int INPUT_SIZE = 2;
|
||||||
|
private static final int OUTPUT_SIZE = 1;
|
||||||
|
|
||||||
|
public XORFilter() {
|
||||||
|
this(0.8,0.7);
|
||||||
|
}
|
||||||
|
|
||||||
|
public XORFilter(double learningRate, double momentum) {
|
||||||
|
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
|
||||||
|
new BackPropagation(learningRate, momentum), 1000, 0.01);
|
||||||
|
super.getNeuralNetwork().setName("XORFilter");
|
||||||
|
|
||||||
|
//TODO remove
|
||||||
|
//getNeuralNetwork().setWeights(new double[] {
|
||||||
|
// 0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
|
||||||
|
// -0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
|
||||||
|
// -0.993423, 0.164732, 0.752621}); //output
|
||||||
|
}
|
||||||
|
|
||||||
|
public double compute(double x, double y) {
|
||||||
|
return getNeuralNetwork().compute(new double[]{x,y})[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputSize() {
|
||||||
|
return INPUT_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOutputSize() {
|
||||||
|
return OUTPUT_SIZE;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,25 +1,29 @@
|
|||||||
package net.woodyfolsom.msproj.ann2;
|
package net.woodyfolsom.msproj.ann2.math;
|
||||||
|
|
||||||
public class Sigmoid implements ActivationFunction{
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
public static final Sigmoid function = new Sigmoid();
|
|
||||||
|
public abstract class ActivationFunction {
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
private Sigmoid() {
|
public abstract double calculate(double arg);
|
||||||
this.name = "Sigmoid";
|
public abstract double derivative(double arg);
|
||||||
|
|
||||||
|
public ActivationFunction() {
|
||||||
|
//no-arg constructor for JAXB
|
||||||
}
|
}
|
||||||
|
|
||||||
public double calculate(double arg) {
|
public ActivationFunction(String name) {
|
||||||
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg));
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setName(String name) {
|
public void setName(String name) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
final int prime = 31;
|
final int prime = 31;
|
||||||
@@ -27,7 +31,6 @@ public class Sigmoid implements ActivationFunction{
|
|||||||
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object obj) {
|
public boolean equals(Object obj) {
|
||||||
if (this == obj)
|
if (this == obj)
|
||||||
@@ -36,7 +39,7 @@ public class Sigmoid implements ActivationFunction{
|
|||||||
return false;
|
return false;
|
||||||
if (getClass() != obj.getClass())
|
if (getClass() != obj.getClass())
|
||||||
return false;
|
return false;
|
||||||
Sigmoid other = (Sigmoid) obj;
|
ActivationFunction other = (ActivationFunction) obj;
|
||||||
if (name == null) {
|
if (name == null) {
|
||||||
if (other.name != null)
|
if (other.name != null)
|
||||||
return false;
|
return false;
|
||||||
5
src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java
Normal file
5
src/net/woodyfolsom/msproj/ann2/math/ErrorFunction.java
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2.math;
|
||||||
|
|
||||||
|
public interface ErrorFunction {
|
||||||
|
double compute(double[] ideal, double[] actual);
|
||||||
|
}
|
||||||
17
src/net/woodyfolsom/msproj/ann2/math/Linear.java
Normal file
17
src/net/woodyfolsom/msproj/ann2/math/Linear.java
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2.math;
|
||||||
|
|
||||||
|
public class Linear extends ActivationFunction{
|
||||||
|
public static final Linear function = new Linear();
|
||||||
|
|
||||||
|
private Linear() {
|
||||||
|
super("Linear");
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculate(double arg) {
|
||||||
|
return arg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double derivative(double arg) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
22
src/net/woodyfolsom/msproj/ann2/math/MSSE.java
Normal file
22
src/net/woodyfolsom/msproj/ann2/math/MSSE.java
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2.math;
|
||||||
|
|
||||||
|
public class MSSE implements ErrorFunction{
|
||||||
|
public static final ErrorFunction function = new MSSE();
|
||||||
|
|
||||||
|
public double compute(double[] ideal, double[] actual) {
|
||||||
|
int idealSize = ideal.length;
|
||||||
|
int actualSize = actual.length;
|
||||||
|
|
||||||
|
if (idealSize != actualSize) {
|
||||||
|
throw new IllegalArgumentException("actualSize != idealSize");
|
||||||
|
}
|
||||||
|
|
||||||
|
double SSE = 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < idealSize; i++) {
|
||||||
|
SSE += Math.pow(ideal[i] - actual[i], 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return SSE / idealSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
20
src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java
Normal file
20
src/net/woodyfolsom/msproj/ann2/math/Sigmoid.java
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2.math;
|
||||||
|
|
||||||
|
public class Sigmoid extends ActivationFunction{
|
||||||
|
public static final Sigmoid function = new Sigmoid();
|
||||||
|
|
||||||
|
private Sigmoid() {
|
||||||
|
super("Sigmoid");
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculate(double arg) {
|
||||||
|
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg));
|
||||||
|
}
|
||||||
|
|
||||||
|
public double derivative(double arg) {
|
||||||
|
//lol wth?
|
||||||
|
//double eX = Math.exp(arg);
|
||||||
|
//return eX / (Math.pow((1+eX), 2));
|
||||||
|
return arg - Math.pow(arg,2);
|
||||||
|
}
|
||||||
|
}
|
||||||
21
src/net/woodyfolsom/msproj/ann2/math/Tanh.java
Normal file
21
src/net/woodyfolsom/msproj/ann2/math/Tanh.java
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2.math;
|
||||||
|
|
||||||
|
public class Tanh extends ActivationFunction{
|
||||||
|
public static final Tanh function = new Tanh();
|
||||||
|
|
||||||
|
public Tanh() {
|
||||||
|
super("Tanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double calculate(double arg) {
|
||||||
|
return Math.tanh(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double derivative(double arg) {
|
||||||
|
double tanh = Math.tanh(arg);
|
||||||
|
return 1 - Math.pow(tanh, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -16,7 +16,8 @@ import org.junit.Test;
|
|||||||
|
|
||||||
public class MultiLayerPerceptronTest {
|
public class MultiLayerPerceptronTest {
|
||||||
static final File TEST_FILE = new File("data/test/mlp.net");
|
static final File TEST_FILE = new File("data/test/mlp.net");
|
||||||
|
static final double EPS = 0.001;
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void setUp() {
|
public static void setUp() {
|
||||||
if (TEST_FILE.exists()) {
|
if (TEST_FILE.exists()) {
|
||||||
@@ -49,14 +50,47 @@ public class MultiLayerPerceptronTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPersistence() throws JAXBException, IOException {
|
public void testPersistence() throws JAXBException, IOException {
|
||||||
NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
||||||
FileOutputStream fos = new FileOutputStream(TEST_FILE);
|
FileOutputStream fos = new FileOutputStream(TEST_FILE);
|
||||||
assertTrue(mlp.save(fos));
|
assertTrue(mlp.save(fos));
|
||||||
fos.close();
|
fos.close();
|
||||||
FileInputStream fis = new FileInputStream(TEST_FILE);
|
FileInputStream fis = new FileInputStream(TEST_FILE);
|
||||||
NeuralNetwork mlp2 = new MultiLayerPerceptron();
|
FeedforwardNetwork mlp2 = new MultiLayerPerceptron();
|
||||||
assertTrue(mlp2.load(fis));
|
assertTrue(mlp2.load(fis));
|
||||||
assertEquals(mlp, mlp2);
|
assertEquals(mlp, mlp2);
|
||||||
fis.close();
|
fis.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCompute() {
|
||||||
|
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||||
|
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
|
||||||
|
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
|
||||||
|
NNData actualOutput = mlp.compute(actual);
|
||||||
|
assertEquals(expected.getIdeal(), actualOutput);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testXORnetwork() {
|
||||||
|
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||||
|
mlp.setWeights(new double[] {
|
||||||
|
0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
|
||||||
|
-0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
|
||||||
|
-0.993423, 0.164732, 0.752621}); //output
|
||||||
|
|
||||||
|
for (Connection connection : mlp.getConnections()) {
|
||||||
|
System.out.println(connection);
|
||||||
|
}
|
||||||
|
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610}));
|
||||||
|
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
|
||||||
|
NNData actualOutput = mlp.compute(actual);
|
||||||
|
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123
|
||||||
|
Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932
|
||||||
|
Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,16 +3,27 @@ package net.woodyfolsom.msproj.ann2;
|
|||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class SigmoidTest {
|
public class SigmoidTest {
|
||||||
|
static double EPS = 0.001;
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCalculate() {
|
public void testCalculate() {
|
||||||
double EPS = 0.001;
|
|
||||||
|
|
||||||
ActivationFunction sigmoid = Sigmoid.function;
|
ActivationFunction sigmoid = Sigmoid.function;
|
||||||
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
|
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
|
||||||
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
|
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
|
||||||
assertTrue(sigmoid.calculate(-9000.0) < EPS);
|
assertTrue(sigmoid.calculate(-9000.0) < EPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDerivative() {
|
||||||
|
ActivationFunction sigmoid = new Tanh();
|
||||||
|
assertEquals(1.0,sigmoid.derivative(0.0), EPS);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,16 +3,26 @@ package net.woodyfolsom.msproj.ann2;
|
|||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||||
|
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class TanhTest {
|
public class TanhTest {
|
||||||
|
static double EPS = 0.001;
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCalculate() {
|
public void testCalculate() {
|
||||||
double EPS = 0.001;
|
|
||||||
|
|
||||||
ActivationFunction sigmoid = new Tanh();
|
ActivationFunction tanh = new Tanh();
|
||||||
assertEquals(0.0,sigmoid.calculate(0.0),EPS);
|
assertEquals(0.0,tanh.calculate(0.0),EPS);
|
||||||
assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS);
|
assertTrue(tanh.calculate(100.0) > 0.5 - EPS);
|
||||||
assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS);
|
assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDerivative() {
|
||||||
|
ActivationFunction tanh = new Tanh();
|
||||||
|
assertEquals(1.0,tanh.derivative(0.0), EPS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
136
test/net/woodyfolsom/msproj/ann2/XORFilterTest.java
Normal file
136
test/net/woodyfolsom/msproj/ann2/XORFilterTest.java
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class XORFilterTest {
|
||||||
|
private static final String FILENAME = "xorPerceptron.net";
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void deleteNewNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void deleteSavedNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearn() throws IOException {
|
||||||
|
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
|
||||||
|
|
||||||
|
// create training set (logical XOR function)
|
||||||
|
int size = 1;
|
||||||
|
double[][] trainingInput = new double[4 * size][];
|
||||||
|
double[][] trainingOutput = new double[4 * size][];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||||
|
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||||
|
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||||
|
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||||
|
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||||
|
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||||
|
}
|
||||||
|
|
||||||
|
// create training data
|
||||||
|
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||||
|
String[] inputNames = new String[] {"x","y"};
|
||||||
|
String[] outputNames = new String[] {"XOR"};
|
||||||
|
for (int i = 0; i < 4*size; i++) {
|
||||||
|
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||||
|
}
|
||||||
|
|
||||||
|
nnLearner.setMaxTrainingEpochs(20000);
|
||||||
|
nnLearner.learn(trainingSet);
|
||||||
|
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||||
|
|
||||||
|
double[][] validationSet = new double[4][2];
|
||||||
|
|
||||||
|
validationSet[0] = new double[] { 0, 0 };
|
||||||
|
validationSet[1] = new double[] { 0, 1 };
|
||||||
|
validationSet[2] = new double[] { 1, 0 };
|
||||||
|
validationSet[3] = new double[] { 1, 1 };
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network):");
|
||||||
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearnSaveLoad() throws IOException {
|
||||||
|
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||||
|
|
||||||
|
// create training set (logical XOR function)
|
||||||
|
int size = 2;
|
||||||
|
double[][] trainingInput = new double[4 * size][];
|
||||||
|
double[][] trainingOutput = new double[4 * size][];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||||
|
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||||
|
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||||
|
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||||
|
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||||
|
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||||
|
}
|
||||||
|
|
||||||
|
// create training data
|
||||||
|
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||||
|
String[] inputNames = new String[] {"x","y"};
|
||||||
|
String[] outputNames = new String[] {"XOR"};
|
||||||
|
for (int i = 0; i < 4*size; i++) {
|
||||||
|
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||||
|
}
|
||||||
|
|
||||||
|
nnLearner.setMaxTrainingEpochs(1);
|
||||||
|
nnLearner.learn(trainingSet);
|
||||||
|
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||||
|
|
||||||
|
double[][] validationSet = new double[4][2];
|
||||||
|
|
||||||
|
validationSet[0] = new double[] { 0, 0 };
|
||||||
|
validationSet[1] = new double[] { 0, 1 };
|
||||||
|
validationSet[2] = new double[] { 1, 0 };
|
||||||
|
validationSet[3] = new double[] { 1, 1 };
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||||
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
|
|
||||||
|
FileOutputStream fos = new FileOutputStream(FILENAME);
|
||||||
|
assertTrue(nnLearner.save(fos));
|
||||||
|
fos.close();
|
||||||
|
|
||||||
|
FileInputStream fis = new FileInputStream(FILENAME);
|
||||||
|
assertTrue(nnLearner.load(fis));
|
||||||
|
fis.close();
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||||
|
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||||
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
|
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
||||||
|
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user