Functional MLP for XOR toy problem.
This commit is contained in:
@@ -16,7 +16,8 @@ import org.junit.Test;
|
||||
|
||||
public class MultiLayerPerceptronTest {
|
||||
static final File TEST_FILE = new File("data/test/mlp.net");
|
||||
|
||||
static final double EPS = 0.001;
|
||||
|
||||
@BeforeClass
|
||||
public static void setUp() {
|
||||
if (TEST_FILE.exists()) {
|
||||
@@ -49,14 +50,47 @@ public class MultiLayerPerceptronTest {
|
||||
|
||||
@Test
|
||||
public void testPersistence() throws JAXBException, IOException {
|
||||
NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
||||
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
||||
FileOutputStream fos = new FileOutputStream(TEST_FILE);
|
||||
assertTrue(mlp.save(fos));
|
||||
fos.close();
|
||||
FileInputStream fis = new FileInputStream(TEST_FILE);
|
||||
NeuralNetwork mlp2 = new MultiLayerPerceptron();
|
||||
FeedforwardNetwork mlp2 = new MultiLayerPerceptron();
|
||||
assertTrue(mlp2.load(fis));
|
||||
assertEquals(mlp, mlp2);
|
||||
fis.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompute() {
|
||||
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
|
||||
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
|
||||
NNData actualOutput = mlp.compute(actual);
|
||||
assertEquals(expected.getIdeal(), actualOutput);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testXORnetwork() {
|
||||
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||
mlp.setWeights(new double[] {
|
||||
0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
|
||||
-0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
|
||||
-0.993423, 0.164732, 0.752621}); //output
|
||||
|
||||
for (Connection connection : mlp.getConnections()) {
|
||||
System.out.println(connection);
|
||||
}
|
||||
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610}));
|
||||
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
|
||||
NNData actualOutput = mlp.compute(actual);
|
||||
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
|
||||
}
|
||||
/**
|
||||
*
|
||||
Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123
|
||||
Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932
|
||||
Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
@@ -3,16 +3,27 @@ package net.woodyfolsom.msproj.ann2;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class SigmoidTest {
|
||||
static double EPS = 0.001;
|
||||
|
||||
@Test
|
||||
public void testCalculate() {
|
||||
double EPS = 0.001;
|
||||
|
||||
ActivationFunction sigmoid = Sigmoid.function;
|
||||
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
|
||||
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
|
||||
assertTrue(sigmoid.calculate(-9000.0) < EPS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDerivative() {
|
||||
ActivationFunction sigmoid = new Tanh();
|
||||
assertEquals(1.0,sigmoid.derivative(0.0), EPS);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,26 @@ package net.woodyfolsom.msproj.ann2;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann2.math.Tanh;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class TanhTest {
|
||||
static double EPS = 0.001;
|
||||
|
||||
@Test
|
||||
public void testCalculate() {
|
||||
double EPS = 0.001;
|
||||
|
||||
ActivationFunction sigmoid = new Tanh();
|
||||
assertEquals(0.0,sigmoid.calculate(0.0),EPS);
|
||||
assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS);
|
||||
assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS);
|
||||
ActivationFunction tanh = new Tanh();
|
||||
assertEquals(0.0,tanh.calculate(0.0),EPS);
|
||||
assertTrue(tanh.calculate(100.0) > 0.5 - EPS);
|
||||
assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDerivative() {
|
||||
ActivationFunction tanh = new Tanh();
|
||||
assertEquals(1.0,tanh.derivative(0.0), EPS);
|
||||
}
|
||||
}
|
||||
|
||||
136
test/net/woodyfolsom/msproj/ann2/XORFilterTest.java
Normal file
136
test/net/woodyfolsom/msproj/ann2/XORFilterTest.java
Normal file
@@ -0,0 +1,136 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class XORFilterTest {
|
||||
private static final String FILENAME = "xorPerceptron.net";
|
||||
|
||||
@AfterClass
|
||||
public static void deleteNewNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void deleteSavedNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearn() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 1;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(20000);
|
||||
nnLearner.learn(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 2;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(1);
|
||||
nnLearner.learn(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(FILENAME);
|
||||
assertTrue(nnLearner.save(fos));
|
||||
fos.close();
|
||||
|
||||
FileInputStream fis = new FileInputStream(FILENAME);
|
||||
assertTrue(nnLearner.load(fis));
|
||||
fis.close();
|
||||
|
||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user