From 70c7d7d9b1c35770120e83fe97ec1d80fd8c208e Mon Sep 17 00:00:00 2001 From: Woody Folsom Date: Tue, 30 Dec 2014 17:29:52 -0500 Subject: [PATCH] Fixes for ANN serialization and unit tests. --- .../msproj/ann/MultiLayerPerceptron.java | 1 - src/net/woodyfolsom/msproj/ann/Neuron.java | 11 ++++++++++- .../msproj/ann/math/ActivationFunction.java | 2 ++ src/net/woodyfolsom/msproj/ann/math/Linear.java | 2 ++ src/net/woodyfolsom/msproj/ann/math/Sigmoid.java | 4 ++-- src/net/woodyfolsom/msproj/ann/math/Tanh.java | 2 ++ .../msproj/tictactoe/NNDataSetFactory.java | 1 - .../msproj/ann/MultiLayerPerceptronTest.java | 13 +++---------- 8 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java b/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java index 1871a2f..d9ec824 100644 --- a/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java +++ b/src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java @@ -122,7 +122,6 @@ public class MultiLayerPerceptron extends FeedforwardNetwork { JAXBContext jc = JAXBContext .newInstance(MultiLayerPerceptron.class); - // unmarshal from foo.xml Unmarshaller u = jc.createUnmarshaller(); MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is); diff --git a/src/net/woodyfolsom/msproj/ann/Neuron.java b/src/net/woodyfolsom/msproj/ann/Neuron.java index c7f9c70..7d889f7 100644 --- a/src/net/woodyfolsom/msproj/ann/Neuron.java +++ b/src/net/woodyfolsom/msproj/ann/Neuron.java @@ -2,12 +2,16 @@ package net.woodyfolsom.msproj.ann; import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElements; import javax.xml.bind.annotation.XmlTransient; import net.woodyfolsom.msproj.ann.math.ActivationFunction; +import net.woodyfolsom.msproj.ann.math.Linear; import net.woodyfolsom.msproj.ann.math.Sigmoid; +import net.woodyfolsom.msproj.ann.math.Tanh; public class Neuron { + private ActivationFunction activationFunction; private int id; private transient double input = 0.0; @@ -26,7 +30,12 @@ public class Neuron { input += value; } - @XmlElement(type=Sigmoid.class) + + @XmlElements({ + @XmlElement(name="LinearActivationFunction",type=Linear.class), + @XmlElement(name="SigmoidActivationFunction",type=Sigmoid.class), + @XmlElement(name="TanhActivationFunction",type=Tanh.class) + }) public ActivationFunction getActivationFunction() { return activationFunction; } diff --git a/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java b/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java index c941739..55c71b6 100644 --- a/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java +++ b/src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java @@ -1,7 +1,9 @@ package net.woodyfolsom.msproj.ann.math; import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlTransient; +@XmlTransient public abstract class ActivationFunction { private String name; diff --git a/src/net/woodyfolsom/msproj/ann/math/Linear.java b/src/net/woodyfolsom/msproj/ann/math/Linear.java index c5de8e8..e652636 100644 --- a/src/net/woodyfolsom/msproj/ann/math/Linear.java +++ b/src/net/woodyfolsom/msproj/ann/math/Linear.java @@ -1,5 +1,7 @@ package net.woodyfolsom.msproj.ann.math; +import javax.xml.bind.annotation.XmlRootElement; + public class Linear extends ActivationFunction{ public static final Linear function = new Linear(); diff --git a/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java b/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java index de1e168..9400c5b 100644 --- a/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java +++ b/src/net/woodyfolsom/msproj/ann/math/Sigmoid.java @@ -1,5 +1,7 @@ package net.woodyfolsom.msproj.ann.math; +import javax.xml.bind.annotation.XmlRootElement; + public class Sigmoid extends ActivationFunction{ public static final Sigmoid function = new Sigmoid(); @@ -12,9 +14,7 @@ public class Sigmoid extends ActivationFunction{ } public double derivative(double arg) { - //lol wth? oh, the next derivative formula is a function of s(x), not x. double eX = Math.exp(arg); return eX / (Math.pow((1+eX), 2)); - //return arg - Math.pow(arg,2); } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/math/Tanh.java b/src/net/woodyfolsom/msproj/ann/math/Tanh.java index 2ab3ec4..3f8ba0b 100644 --- a/src/net/woodyfolsom/msproj/ann/math/Tanh.java +++ b/src/net/woodyfolsom/msproj/ann/math/Tanh.java @@ -1,5 +1,7 @@ package net.woodyfolsom.msproj.ann.math; +import javax.xml.bind.annotation.XmlRootElement; + public class Tanh extends ActivationFunction{ public static final Tanh function = new Tanh(); diff --git a/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java b/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java index 1bb4b15..55f3ad4 100644 --- a/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java +++ b/src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java @@ -4,7 +4,6 @@ import java.util.ArrayList; import java.util.List; import net.woodyfolsom.msproj.GameBoard; -import net.woodyfolsom.msproj.GameConfig; import net.woodyfolsom.msproj.GameResult; import net.woodyfolsom.msproj.GameState; import net.woodyfolsom.msproj.Player; diff --git a/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java b/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java index f6d840b..b58d3ec 100644 --- a/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java +++ b/test/net/woodyfolsom/msproj/ann/MultiLayerPerceptronTest.java @@ -70,10 +70,10 @@ public class MultiLayerPerceptronTest { @Test public void testCompute() { FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1); - NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0})); + NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5})); NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5})); NNData actualOutput = mlp.compute(actual); - assertEquals(expected.getIdeal(), actualOutput); + assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS); } @Test @@ -87,16 +87,9 @@ public class MultiLayerPerceptronTest { for (Connection connection : mlp.getConnections()) { System.out.println(connection); } - NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610})); + NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.263932})); NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0})); NNData actualOutput = mlp.compute(actual); assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS); } - /** - * -Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123 -Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932 -Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621 - - */ }