Beginning to implement neural net training. Fixed bugs in SGF/LaTeX export.
This commit is contained in:
42
test/net/woodyfolsom/msproj/RefereeTest.java
Normal file
42
test/net/woodyfolsom/msproj/RefereeTest.java
Normal file
@@ -0,0 +1,42 @@
|
||||
package net.woodyfolsom.msproj;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
import org.antlr.runtime.RecognitionException;
|
||||
import org.junit.Test;
|
||||
|
||||
public class RefereeTest {
|
||||
//private static final String FILENAME = "data/tourney1/gogame-121115115720-0500.sgf";
|
||||
private static final String FILENAME = "data/games/1334-gokifu-20120916-Gu_Li-Lee_Sedol.sgf";
|
||||
|
||||
@Test
|
||||
public void testReplay() throws IOException, RecognitionException {
|
||||
File sgfFile = new File(FILENAME);
|
||||
InputStream fis = new FileInputStream(sgfFile);
|
||||
|
||||
GameRecord gameRecord = Referee.replay(fis);
|
||||
|
||||
fis.close();
|
||||
|
||||
//assertEquals(9, gameRecord.getGameConfig().getSize());
|
||||
assertEquals(19, gameRecord.getGameConfig().getSize());
|
||||
|
||||
//assertEquals(5.5, gameRecord.getGameConfig().getKomi(), 0.1);
|
||||
assertEquals(7.5, gameRecord.getGameConfig().getKomi(), 0.1);
|
||||
|
||||
//assertEquals(74, gameRecord.getNumTurns());
|
||||
assertEquals(214, gameRecord.getNumTurns());
|
||||
|
||||
System.out.println("Final board state (LaTeX): ");
|
||||
LatexWriter.write(System.out, gameRecord, 214);
|
||||
|
||||
System.out.println("Final board state (SGF): ");
|
||||
SGFWriter.write(System.out, gameRecord);
|
||||
|
||||
}
|
||||
}
|
||||
86
test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java
Normal file
86
test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java
Normal file
@@ -0,0 +1,86 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.neuroph.core.NeuralNetwork;
|
||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
||||
import org.neuroph.core.learning.TrainingSet;
|
||||
|
||||
public class NeuralNetLearnerTest {
|
||||
private static final String FILENAME = "myMlPerceptron.nnet";
|
||||
|
||||
@AfterClass
|
||||
public static void deleteNewNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void deleteSavedNet() {
|
||||
File file = new File(FILENAME);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() {
|
||||
NeuralNetLearner nnLearner = new XORLearner();
|
||||
|
||||
// create training set (logical XOR function)
|
||||
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
|
||||
2, 1);
|
||||
for (int x = 0; x < 1000; x++) {
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
0 }, new double[] { 0 }));
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
1 }, new double[] { 1 }));
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
0 }, new double[] { 1 }));
|
||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
1 }, new double[] { 0 }));
|
||||
}
|
||||
|
||||
nnLearner.learn(trainingSet);
|
||||
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
|
||||
|
||||
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
|
||||
2, 1);
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
0 }, new double[] { 0 }));
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||
1 }, new double[] { 1 }));
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
0 }, new double[] { 1 }));
|
||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||
1 }, new double[] { 0 }));
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnet, valSet);
|
||||
|
||||
nnet.save(FILENAME);
|
||||
nnet = NeuralNetwork.load(FILENAME);
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnet, valSet);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
|
||||
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
|
||||
|
||||
nnet.setInput(trainingElement.getInput());
|
||||
nnet.calculate();
|
||||
double[] networkOutput = nnet.getOutput();
|
||||
System.out.print("Input: "
|
||||
+ Arrays.toString(trainingElement.getInput()));
|
||||
System.out.println(" Output: " + Arrays.toString(networkOutput));
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package net.woodyfolsom.msproj.sgf;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
@@ -15,6 +16,7 @@ public class CollectionTest {
|
||||
public static final String TEST_SGF =
|
||||
"(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+
|
||||
";W[ee];B[];W[])";
|
||||
|
||||
public static final String TEST_LATEX =
|
||||
"\\white{e5}\n" +
|
||||
"\\begin{center}\n" +
|
||||
@@ -48,8 +50,6 @@ public class CollectionTest {
|
||||
CommonTokenStream tokens = new CommonTokenStream(lexer);
|
||||
SGFParser parser = new SGFParser(tokens);
|
||||
SGFNodeCollection nodeCollection = parser.collection();
|
||||
|
||||
String actualLaTeX = nodeCollection.toLateX();
|
||||
assertEquals(TEST_LATEX, actualLaTeX);
|
||||
assertNotNull(nodeCollection);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,11 +25,6 @@ public class SGFParserTest {
|
||||
System.out.println("To SGF:");
|
||||
System.out.println(nodeCollection.toSGF());
|
||||
System.out.println("");
|
||||
|
||||
System.out.println("To LaTeX:");
|
||||
System.out.println(nodeCollection.toLateX());
|
||||
System.out.println("");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -45,10 +40,5 @@ public class SGFParserTest {
|
||||
System.out.println("To SGF:");
|
||||
System.out.println(nodeCollection.toSGF());
|
||||
System.out.println("");
|
||||
|
||||
System.out.println("To LaTeX:");
|
||||
System.out.println(nodeCollection.toLateX());
|
||||
System.out.println("");
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user