Implementing temporal difference learning based heavily on Encog framework.
Not functional yet - incremental update.
This commit is contained in:
@@ -1,86 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.neuroph.core.NeuralNetwork;
|
||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
||||
import org.neuroph.core.learning.TrainingSet;
|
||||
|
||||
public class NeuralNetLearnerTest {
|
||||
private static final String FILENAME = "myMlPerceptron.nnet";
|
||||
|
||||
@AfterClass
|
||||
public static void deleteNewNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void deleteSavedNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() {
|
||||
NeuralNetLearner nnLearner = new XORLearner();
|
||||
|
||||
// create training set (logical XOR function)
|
||||
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
|
||||
2, 1);
|
||||
for (int x = 0; x < 1000; x++) {
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
0 }, new double[] { 0 }));
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
1 }, new double[] { 1 }));
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
0 }, new double[] { 1 }));
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
1 }, new double[] { 0 }));
|
||||
}
|
||||
|
||||
nnLearner.learn(trainingSet);
|
||||
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
|
||||
|
||||
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
|
||||
2, 1);
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
0 }, new double[] { 0 }));
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
1 }, new double[] { 1 }));
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
0 }, new double[] { 1 }));
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
1 }, new double[] { 0 }));
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnet, valSet);
|
||||
|
||||
nnet.save(FILENAME);
|
||||
nnet = NeuralNetwork.load(FILENAME);
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnet, valSet);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
|
||||
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
|
||||
|
||||
nnet.setInput(trainingElement.getInput());
|
||||
nnet.calculate();
|
||||
double[] networkOutput = nnet.getOutput();
|
||||
System.out.print("Input: "
|
||||
+ Arrays.toString(trainingElement.getInput()));
|
||||
System.out.println(" Output: " + Arrays.toString(networkOutput));
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.neuroph.core.NeuralNetwork;
|
||||
|
||||
public class PassNetworkTest {
|
||||
|
||||
@Test
|
||||
public void testSavedNetwork1() {
|
||||
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn");
|
||||
passFilter.setInput(0.75,0.25);
|
||||
passFilter.calculate();
|
||||
|
||||
PassData passData = new PassData();
|
||||
double[] output = passFilter.getOutput();
|
||||
System.out.println("Output: " + passData.getOutput(output));
|
||||
|
||||
assertTrue(output[0] > 0.50);
|
||||
assertTrue(output[1] < 0.50);
|
||||
|
||||
passFilter.setInput(0.25,0.50);
|
||||
passFilter.calculate();
|
||||
output = passFilter.getOutput();
|
||||
System.out.println("Output: " + passData.getOutput(output));
|
||||
assertTrue(output[0] < 0.50);
|
||||
assertTrue(output[1] > 0.50);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSavedNetwork2() {
|
||||
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass2.nn");
|
||||
passFilter.setInput(0.75,0.25);
|
||||
passFilter.calculate();
|
||||
|
||||
PassData passData = new PassData();
|
||||
double[] output = passFilter.getOutput();
|
||||
System.out.println("Output: " + passData.getOutput(output));
|
||||
|
||||
assertTrue(output[0] > 0.50);
|
||||
assertTrue(output[1] < 0.50);
|
||||
|
||||
passFilter.setInput(0.45,0.55);
|
||||
passFilter.calculate();
|
||||
output = passFilter.getOutput();
|
||||
System.out.println("Output: " + passData.getOutput(output));
|
||||
assertTrue(output[0] < 0.50);
|
||||
assertTrue(output[1] > 0.50);
|
||||
}
|
||||
}
|
||||
66
test/net/woodyfolsom/msproj/ann/WinFilterTest.java
Normal file
66
test/net/woodyfolsom/msproj/ann/WinFilterTest.java
Normal file
@@ -0,0 +1,66 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileFilter;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import net.woodyfolsom.msproj.GameRecord;
|
||||
import net.woodyfolsom.msproj.Referee;
|
||||
|
||||
import org.antlr.runtime.RecognitionException;
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.junit.Test;
|
||||
|
||||
public class WinFilterTest {
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException, RecognitionException {
|
||||
File[] sgfFiles = new File("data/games/random_vs_random")
|
||||
.listFiles(new FileFilter() {
|
||||
@Override
|
||||
public boolean accept(File pathname) {
|
||||
return pathname.getName().endsWith(".sgf");
|
||||
}
|
||||
});
|
||||
|
||||
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
|
||||
|
||||
for (File file : sgfFiles) {
|
||||
FileInputStream fis = new FileInputStream(file);
|
||||
GameRecord gameRecord = Referee.replay(fis);
|
||||
|
||||
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
|
||||
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
|
||||
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
|
||||
}
|
||||
|
||||
trainingData.add(gameData);
|
||||
|
||||
fis.close();
|
||||
}
|
||||
|
||||
WinFilter winFilter = new WinFilter();
|
||||
|
||||
winFilter.learn(trainingData);
|
||||
|
||||
for (List<MLDataPair> trainingSequence : trainingData) {
|
||||
//for (MLDataPair mlDataPair : trainingSequence) {
|
||||
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
||||
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
||||
continue;
|
||||
}
|
||||
MLData input = trainingSequence.get(stateIndex).getInput();
|
||||
|
||||
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
||||
+ winFilter.computeValue(input));
|
||||
}
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
79
test/net/woodyfolsom/msproj/ann/XORFilterTest.java
Normal file
79
test/net/woodyfolsom/msproj/ann/XORFilterTest.java
Normal file
@@ -0,0 +1,79 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class XORFilterTest {
|
||||
private static final String FILENAME = "xorPerceptron.net";
|
||||
|
||||
@AfterClass
|
||||
public static void deleteNewNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void deleteSavedNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter();
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 1;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
|
||||
|
||||
nnLearner.learn(trainingSet);
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||
testNetwork(nnLearner, validationSet);
|
||||
|
||||
nnLearner.save(FILENAME);
|
||||
nnLearner.load(FILENAME);
|
||||
|
||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||
testNetwork(nnLearner, validationSet);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
||||
System.out.println(dp + " => " + nnLearner.computeValue(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user