Implementing temporal difference learning based heavily on Encog framework.

Not functional yet - incremental update.
This commit is contained in:
2012-11-21 10:03:56 -05:00
parent 49d3b2c242
commit b723e2666e
35 changed files with 1471 additions and 470 deletions

View File

@@ -1,86 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.util.Arrays;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
public class NeuralNetLearnerTest {
private static final String FILENAME = "myMlPerceptron.nnet";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearnSaveLoad() {
NeuralNetLearner nnLearner = new XORLearner();
// create training set (logical XOR function)
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
for (int x = 0; x < 1000; x++) {
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
}
nnLearner.learn(trainingSet);
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
nnet.save(FILENAME);
nnet = NeuralNetwork.load(FILENAME);
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
}
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
nnet.setInput(trainingElement.getInput());
nnet.calculate();
double[] networkOutput = nnet.getOutput();
System.out.print("Input: "
+ Arrays.toString(trainingElement.getInput()));
System.out.println(" Output: " + Arrays.toString(networkOutput));
}
}
}

View File

@@ -1,51 +0,0 @@
package net.woodyfolsom.msproj.ann;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
public class PassNetworkTest {
@Test
public void testSavedNetwork1() {
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn");
passFilter.setInput(0.75,0.25);
passFilter.calculate();
PassData passData = new PassData();
double[] output = passFilter.getOutput();
System.out.println("Output: " + passData.getOutput(output));
assertTrue(output[0] > 0.50);
assertTrue(output[1] < 0.50);
passFilter.setInput(0.25,0.50);
passFilter.calculate();
output = passFilter.getOutput();
System.out.println("Output: " + passData.getOutput(output));
assertTrue(output[0] < 0.50);
assertTrue(output[1] > 0.50);
}
@Test
public void testSavedNetwork2() {
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass2.nn");
passFilter.setInput(0.75,0.25);
passFilter.calculate();
PassData passData = new PassData();
double[] output = passFilter.getOutput();
System.out.println("Output: " + passData.getOutput(output));
assertTrue(output[0] > 0.50);
assertTrue(output[1] < 0.50);
passFilter.setInput(0.45,0.55);
passFilter.calculate();
output = passFilter.getOutput();
System.out.println("Output: " + passData.getOutput(output));
assertTrue(output[0] < 0.50);
assertTrue(output[1] > 0.50);
}
}

View File

@@ -0,0 +1,66 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.Referee;
import org.antlr.runtime.RecognitionException;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.junit.Test;
public class WinFilterTest {
@Test
public void testLearnSaveLoad() throws IOException, RecognitionException {
File[] sgfFiles = new File("data/games/random_vs_random")
.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
return pathname.getName().endsWith(".sgf");
}
});
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
for (File file : sgfFiles) {
FileInputStream fis = new FileInputStream(file);
GameRecord gameRecord = Referee.replay(fis);
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
}
trainingData.add(gameData);
fis.close();
}
WinFilter winFilter = new WinFilter();
winFilter.learn(trainingData);
for (List<MLDataPair> trainingSequence : trainingData) {
//for (MLDataPair mlDataPair : trainingSequence) {
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
continue;
}
MLData input = trainingSequence.get(stateIndex).getInput();
System.out.println("Turn " + stateIndex + ": " + input + " => "
+ winFilter.computeValue(input));
}
//}
}
}
}

View File

@@ -0,0 +1,79 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.IOException;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class XORFilterTest {
private static final String FILENAME = "xorPerceptron.net";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter();
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
nnLearner.learn(trainingSet);
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet);
nnLearner.save(FILENAME);
nnLearner.load(FILENAME);
System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet);
}
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
System.out.println(dp + " => " + nnLearner.computeValue(dp));
}
}
}