Beginning to implement neural net training. Fixed bugs in SGF/LaTeX export.

This commit is contained in:
2012-11-15 15:00:40 -05:00
parent e01060bc11
commit ca37280ed8
17 changed files with 470 additions and 156 deletions

View File

@@ -0,0 +1,42 @@
package net.woodyfolsom.msproj;
import static org.junit.Assert.assertEquals;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.antlr.runtime.RecognitionException;
import org.junit.Test;
public class RefereeTest {
//private static final String FILENAME = "data/tourney1/gogame-121115115720-0500.sgf";
private static final String FILENAME = "data/games/1334-gokifu-20120916-Gu_Li-Lee_Sedol.sgf";
@Test
public void testReplay() throws IOException, RecognitionException {
File sgfFile = new File(FILENAME);
InputStream fis = new FileInputStream(sgfFile);
GameRecord gameRecord = Referee.replay(fis);
fis.close();
//assertEquals(9, gameRecord.getGameConfig().getSize());
assertEquals(19, gameRecord.getGameConfig().getSize());
//assertEquals(5.5, gameRecord.getGameConfig().getKomi(), 0.1);
assertEquals(7.5, gameRecord.getGameConfig().getKomi(), 0.1);
//assertEquals(74, gameRecord.getNumTurns());
assertEquals(214, gameRecord.getNumTurns());
System.out.println("Final board state (LaTeX): ");
LatexWriter.write(System.out, gameRecord, 214);
System.out.println("Final board state (SGF): ");
SGFWriter.write(System.out, gameRecord);
}
}

View File

@@ -0,0 +1,86 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.util.Arrays;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
public class NeuralNetLearnerTest {
private static final String FILENAME = "myMlPerceptron.nnet";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearnSaveLoad() {
NeuralNetLearner nnLearner = new XORLearner();
// create training set (logical XOR function)
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
for (int x = 0; x < 1000; x++) {
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
}
nnLearner.learn(trainingSet);
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
nnet.save(FILENAME);
nnet = NeuralNetwork.load(FILENAME);
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
}
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
nnet.setInput(trainingElement.getInput());
nnet.calculate();
double[] networkOutput = nnet.getOutput();
System.out.print("Input: "
+ Arrays.toString(trainingElement.getInput()));
System.out.println(" Output: " + Arrays.toString(networkOutput));
}
}
}

View File

@@ -1,6 +1,7 @@
package net.woodyfolsom.msproj.sgf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import java.io.ByteArrayInputStream;
import java.io.IOException;
@@ -15,6 +16,7 @@ public class CollectionTest {
public static final String TEST_SGF =
"(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+
";W[ee];B[];W[])";
public static final String TEST_LATEX =
"\\white{e5}\n" +
"\\begin{center}\n" +
@@ -48,8 +50,6 @@ public class CollectionTest {
CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection = parser.collection();
String actualLaTeX = nodeCollection.toLateX();
assertEquals(TEST_LATEX, actualLaTeX);
assertNotNull(nodeCollection);
}
}

View File

@@ -25,11 +25,6 @@ public class SGFParserTest {
System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF());
System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
}
@Test
@@ -45,10 +40,5 @@ public class SGFParserTest {
System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF());
System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
}
}