Fixes for ANN serialization and unit tests.

This commit is contained in:
2014-12-30 17:29:52 -05:00
parent 14bc769493
commit 70c7d7d9b1
8 changed files with 21 additions and 15 deletions

View File

@@ -122,7 +122,6 @@ public class MultiLayerPerceptron extends FeedforwardNetwork {
JAXBContext jc = JAXBContext JAXBContext jc = JAXBContext
.newInstance(MultiLayerPerceptron.class); .newInstance(MultiLayerPerceptron.class);
// unmarshal from foo.xml
Unmarshaller u = jc.createUnmarshaller(); Unmarshaller u = jc.createUnmarshaller();
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is); MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);

View File

@@ -2,12 +2,16 @@ package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement; import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlElements;
import javax.xml.bind.annotation.XmlTransient; import javax.xml.bind.annotation.XmlTransient;
import net.woodyfolsom.msproj.ann.math.ActivationFunction; import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Linear;
import net.woodyfolsom.msproj.ann.math.Sigmoid; import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann.math.Tanh;
public class Neuron { public class Neuron {
private ActivationFunction activationFunction; private ActivationFunction activationFunction;
private int id; private int id;
private transient double input = 0.0; private transient double input = 0.0;
@@ -26,7 +30,12 @@ public class Neuron {
input += value; input += value;
} }
@XmlElement(type=Sigmoid.class)
@XmlElements({
@XmlElement(name="LinearActivationFunction",type=Linear.class),
@XmlElement(name="SigmoidActivationFunction",type=Sigmoid.class),
@XmlElement(name="TanhActivationFunction",type=Tanh.class)
})
public ActivationFunction getActivationFunction() { public ActivationFunction getActivationFunction() {
return activationFunction; return activationFunction;
} }

View File

@@ -1,7 +1,9 @@
package net.woodyfolsom.msproj.ann.math; package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlAttribute; import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlTransient;
@XmlTransient
public abstract class ActivationFunction { public abstract class ActivationFunction {
private String name; private String name;

View File

@@ -1,5 +1,7 @@
package net.woodyfolsom.msproj.ann.math; package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlRootElement;
public class Linear extends ActivationFunction{ public class Linear extends ActivationFunction{
public static final Linear function = new Linear(); public static final Linear function = new Linear();

View File

@@ -1,5 +1,7 @@
package net.woodyfolsom.msproj.ann.math; package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlRootElement;
public class Sigmoid extends ActivationFunction{ public class Sigmoid extends ActivationFunction{
public static final Sigmoid function = new Sigmoid(); public static final Sigmoid function = new Sigmoid();
@@ -12,9 +14,7 @@ public class Sigmoid extends ActivationFunction{
} }
public double derivative(double arg) { public double derivative(double arg) {
//lol wth? oh, the next derivative formula is a function of s(x), not x.
double eX = Math.exp(arg); double eX = Math.exp(arg);
return eX / (Math.pow((1+eX), 2)); return eX / (Math.pow((1+eX), 2));
//return arg - Math.pow(arg,2);
} }
} }

View File

@@ -1,5 +1,7 @@
package net.woodyfolsom.msproj.ann.math; package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlRootElement;
public class Tanh extends ActivationFunction{ public class Tanh extends ActivationFunction{
public static final Tanh function = new Tanh(); public static final Tanh function = new Tanh();

View File

@@ -4,7 +4,6 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.GameBoard; import net.woodyfolsom.msproj.GameBoard;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameResult; import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.GameState; import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Player;

View File

@@ -70,10 +70,10 @@ public class MultiLayerPerceptronTest {
@Test @Test
public void testCompute() { public void testCompute() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1); FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0})); NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5})); NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
NNData actualOutput = mlp.compute(actual); NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal(), actualOutput); assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
} }
@Test @Test
@@ -87,16 +87,9 @@ public class MultiLayerPerceptronTest {
for (Connection connection : mlp.getConnections()) { for (Connection connection : mlp.getConnections()) {
System.out.println(connection); System.out.println(connection);
} }
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.367610})); NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.263932}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0})); NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
NNData actualOutput = mlp.compute(actual); NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS); assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
} }
/**
*
Hidden Neuron 1: w2(0,1) = 0.341232 w2(1,1) = 0.129952 w2(2,1) =-0.923123
Hidden Neuron 2: w2(0,2) =-0.115223 w2(1,2) = 0.570345 w2(2,2) =-0.328932
Output Neuron: w3(0,1) =-0.993423 w3(1,1) = 0.164732 w3(2,1) = 0.752621
*/
} }