Starting my own ANN implementation.
This commit is contained in:
@@ -50,7 +50,6 @@ public class WinFilterTest {
|
||||
winFilter.learn(trainingData);
|
||||
|
||||
for (List<MLDataPair> trainingSequence : trainingData) {
|
||||
//for (MLDataPair mlDataPair : trainingSequence) {
|
||||
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
||||
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
||||
continue;
|
||||
@@ -58,9 +57,8 @@ public class WinFilterTest {
|
||||
MLData input = trainingSequence.get(stateIndex).getInput();
|
||||
|
||||
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
||||
+ winFilter.computeValue(input));
|
||||
+ winFilter.compute(input));
|
||||
}
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ public class XORFilterTest {
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
||||
System.out.println(dp + " => " + nnLearner.computeValue(dp));
|
||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import javax.xml.bind.JAXBException;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MultiLayerPerceptronTest {
|
||||
static final File TEST_FILE = new File("data/test/mlp.net");
|
||||
|
||||
@BeforeClass
|
||||
public static void setUp() {
|
||||
if (TEST_FILE.exists()) {
|
||||
TEST_FILE.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void tearDown() {
|
||||
if (TEST_FILE.exists()) {
|
||||
TEST_FILE.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConstructor() {
|
||||
new MultiLayerPerceptron(true, 2, 4, 1);
|
||||
new MultiLayerPerceptron(false, 2, 1);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testConstructorTooFewLayers() {
|
||||
new MultiLayerPerceptron(true, 2);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testConstructorTooFewNeurons() {
|
||||
new MultiLayerPerceptron(true, 2, 4, 0, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistence() throws JAXBException, IOException {
|
||||
NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
||||
FileOutputStream fos = new FileOutputStream(TEST_FILE);
|
||||
assertTrue(mlp.save(fos));
|
||||
fos.close();
|
||||
FileInputStream fis = new FileInputStream(TEST_FILE);
|
||||
NeuralNetwork mlp2 = new MultiLayerPerceptron();
|
||||
assertTrue(mlp2.load(fis));
|
||||
assertEquals(mlp, mlp2);
|
||||
fis.close();
|
||||
}
|
||||
}
|
||||
18
test/net/woodyfolsom/msproj/ann2/SigmoidTest.java
Normal file
18
test/net/woodyfolsom/msproj/ann2/SigmoidTest.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class SigmoidTest {
|
||||
@Test
|
||||
public void testCalculate() {
|
||||
double EPS = 0.001;
|
||||
|
||||
ActivationFunction sigmoid = Sigmoid.function;
|
||||
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
|
||||
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
|
||||
assertTrue(sigmoid.calculate(-9000.0) < EPS);
|
||||
}
|
||||
}
|
||||
18
test/net/woodyfolsom/msproj/ann2/TanhTest.java
Normal file
18
test/net/woodyfolsom/msproj/ann2/TanhTest.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package net.woodyfolsom.msproj.ann2;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class TanhTest {
|
||||
@Test
|
||||
public void testCalculate() {
|
||||
double EPS = 0.001;
|
||||
|
||||
ActivationFunction sigmoid = new Tanh();
|
||||
assertEquals(0.0,sigmoid.calculate(0.0),EPS);
|
||||
assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS);
|
||||
assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user