Functional MLP for XOR toy problem.

This commit is contained in:
2012-11-24 22:20:41 -05:00
parent 874847f41b
commit 790b5666a8
26 changed files with 1109 additions and 217 deletions

View File

@@ -16,7 +16,8 @@ import org.junit.Test;
public class MultiLayerPerceptronTest {
static final File TEST_FILE = new File("data/test/mlp.net");
static final double EPS = 0.001;
@BeforeClass
public static void setUp() {
if (TEST_FILE.exists()) {
@@ -49,14 +50,47 @@ public class MultiLayerPerceptronTest {
@Test
public void testPersistence() throws JAXBException, IOException {
NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
FileOutputStream fos = new FileOutputStream(TEST_FILE);
assertTrue(mlp.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(TEST_FILE);
NeuralNetwork mlp2 = new MultiLayerPerceptron();
FeedforwardNetwork mlp2 = new MultiLayerPerceptron();
assertTrue(mlp2.load(fis));
assertEquals(mlp, mlp2);
fis.close();
}
@Test
public void testCompute() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal(), actualOutput);
}
@Test
public void testXORnetwork() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
mlp.setWeights(new double[] {
0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
-0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
-0.993423, 0.164732, 0.752621}); //output
for (Connection connection : mlp.getConnections()) {
System.out.println(connection);
}
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
}
/**
*
Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123
Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932
Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621
*/
}

View File

@@ -3,16 +3,27 @@ package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
import net.woodyfolsom.msproj.ann2.math.Tanh;
import org.junit.Test;
public class SigmoidTest {
static double EPS = 0.001;
@Test
public void testCalculate() {
double EPS = 0.001;
ActivationFunction sigmoid = Sigmoid.function;
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
assertTrue(sigmoid.calculate(-9000.0) < EPS);
}
@Test
public void testDerivative() {
ActivationFunction sigmoid = new Tanh();
assertEquals(1.0,sigmoid.derivative(0.0), EPS);
}
}

View File

@@ -3,16 +3,26 @@ package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Tanh;
import org.junit.Test;
public class TanhTest {
static double EPS = 0.001;
@Test
public void testCalculate() {
double EPS = 0.001;
ActivationFunction sigmoid = new Tanh();
assertEquals(0.0,sigmoid.calculate(0.0),EPS);
assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS);
assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS);
ActivationFunction tanh = new Tanh();
assertEquals(0.0,tanh.calculate(0.0),EPS);
assertTrue(tanh.calculate(100.0) > 0.5 - EPS);
assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS);
}
@Test
public void testDerivative() {
ActivationFunction tanh = new Tanh();
assertEquals(1.0,tanh.derivative(0.0), EPS);
}
}

View File

@@ -0,0 +1,136 @@
package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class XORFilterTest {
private static final String FILENAME = "xorPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 2;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(1);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
FileOutputStream fos = new FileOutputStream(FILENAME);
assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
}