Lots of neural network stuff.

This commit is contained in:
2012-11-27 17:33:09 -05:00
parent 790b5666a8
commit 214bdcd032
55 changed files with 1507 additions and 821 deletions

View File

@@ -1,4 +1,4 @@
package net.woodyfolsom.msproj.ann2;
package net.woodyfolsom.msproj.ann;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -10,6 +10,12 @@ import java.io.IOException;
import javax.xml.bind.JAXBException;
import net.woodyfolsom.msproj.ann.Connection;
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

View File

@@ -0,0 +1,100 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.IOException;
import java.util.List;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
import net.woodyfolsom.msproj.ann.TTTFilter;
import net.woodyfolsom.msproj.tictactoe.GameRecord;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
import net.woodyfolsom.msproj.tictactoe.Referee;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TTTFilterTest {
private static final String FILENAME = "tttPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearn() throws IOException {
double alpha = 0.5;
double lambda = 0.0;
int maxEpochs = 1000;
NeuralNetFilter nnLearner = new TTTFilter(alpha, lambda, maxEpochs);
// Create trainingSet from a tournament of random games.
// Future iterations will use Epsilon-greedy play from a policy based on
// this network to generate additional datasets.
List<GameRecord> tournament = new Referee().play(1);
List<List<NNDataPair>> trainingSet = NNDataSetFactory
.createDataSet(tournament);
System.out.println("Generated " + trainingSet.size()
+ " datasets from random self-play.");
nnLearner.learnSequences(trainingSet);
System.out.println("Learned network after "
+ nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[7][];
// empty board
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// center
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0 };
// top edge
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// left edge
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// corner
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// win
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
-1.0, 0.0 };
// loss
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
0.0, -1.0 };
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
"12", "20", "21", "22" };
String[] outputNames = new String[] { "values" };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner,
double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames,
validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
}

View File

@@ -1,64 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.Referee;
import org.antlr.runtime.RecognitionException;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.junit.Test;
public class WinFilterTest {
@Test
public void testLearnSaveLoad() throws IOException, RecognitionException {
File[] sgfFiles = new File("data/games/random_vs_random")
.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
return pathname.getName().endsWith(".sgf");
}
});
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
for (File file : sgfFiles) {
FileInputStream fis = new FileInputStream(file);
GameRecord gameRecord = Referee.replay(fis);
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
}
trainingData.add(gameData);
fis.close();
}
WinFilter winFilter = new WinFilter();
winFilter.learn(trainingData);
for (List<MLDataPair> trainingSequence : trainingData) {
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
continue;
}
MLData input = trainingSequence.get(stateIndex).getInput();
System.out.println("Turn " + stateIndex + ": " + input + " => "
+ winFilter.compute(input));
}
}
}
}

View File

@@ -1,10 +1,19 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.IOException;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
import net.woodyfolsom.msproj.ann.XORFilter;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -29,14 +38,55 @@ public class XORFilterTest {
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter();
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learnPatterns(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 2;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
@@ -49,10 +99,17 @@ public class XORFilterTest {
}
// create training data
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(1);
nnLearner.learnPatterns(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
nnLearner.learn(trainingSet);
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
@@ -61,18 +118,23 @@ public class XORFilterTest {
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet);
nnLearner.save(FILENAME);
nnLearner.load(FILENAME);
testNetwork(nnLearner, validationSet, inputNames, outputNames);
FileOutputStream fos = new FileOutputStream(FILENAME);
assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet);
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}

View File

@@ -1,11 +1,11 @@
package net.woodyfolsom.msproj.ann2;
package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Sigmoid;
import net.woodyfolsom.msproj.ann2.math.Tanh;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test;

View File

@@ -1,10 +1,10 @@
package net.woodyfolsom.msproj.ann2;
package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann2.math.ActivationFunction;
import net.woodyfolsom.msproj.ann2.math.Tanh;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test;

View File

@@ -1,136 +0,0 @@
package net.woodyfolsom.msproj.ann2;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class XORFilterTest {
private static final String FILENAME = "xorPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 2;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(1);
nnLearner.learn(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
FileOutputStream fos = new FileOutputStream(FILENAME);
assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.compute(dp));
}
}
}

View File

@@ -0,0 +1,73 @@
package net.woodyfolsom.msproj.tictactoe;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class GameRecordTest {
@Test
public void testGetResultXwins() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertTrue(finalState.isWinner(PLAYER.X));
assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
}
@Test
public void testGetResultOwins() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertTrue(finalState.isWinner(PLAYER.O));
assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
}
@Test
public void testGetResultTieGame() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertFalse(finalState.isWinner(PLAYER.X));
assertFalse(finalState.isWinner(PLAYER.O));
assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
}
}

View File

@@ -0,0 +1,12 @@
package net.woodyfolsom.msproj.tictactoe;
import org.junit.Test;
public class RefereeTest {
@Test
public void testPlay100Games() {
new Referee().play(100);
}
}