Beginning to implement neural net training. Fixed bugs in SGF/LaTeX export.

This commit is contained in:
2012-11-15 15:00:40 -05:00
parent e01060bc11
commit ca37280ed8
17 changed files with 470 additions and 156 deletions

BIN
lib/encog-engine-2.5.0.jar Normal file

Binary file not shown.

View File

@@ -77,6 +77,48 @@ public class Action {
return this == Action.RESIGN;
}
public static Action parseSGF(String sgfCoord, int boardSize) {
if (sgfCoord.length() == 0) {
return Action.PASS;
}
StringBuilder sb = new StringBuilder();
char column = sgfCoord.charAt(0);
if (column >= 'i') {
sb.append((char) (column + 1));
} else {
sb.append(column);
}
char row = sgfCoord.charAt(1);
sb.append(boardSize - row + 'a');
return Action.getInstance(sb.toString().toUpperCase());
}
public String toLatex(int boardSize) {
if (isPass() || isNone() || isResign()) {
throw new UnsupportedOperationException("Invalid Action for toLatex() call: " + this);
}
String latex = new String(new char[] {column}).toLowerCase();
latex += row;
return latex;
}
public String toSGF(int boardSize) {
if (isNone() || isResign()) {
throw new UnsupportedOperationException("Invalid Action for toLatex() call: " + this);
}
if (isPass()) {
return ""; // or 'tt' in old format
}
char[] latex = new char[2];
if (column >= 'J') {
latex[0] = (char) (column - 1);
} else {
latex[0] = column;
}
latex[1] = (char)('a' + (boardSize - row));
return new String(latex).toLowerCase();
}
@Override
public String toString() {
return move;

View File

@@ -44,6 +44,29 @@ public class GameResult {
return blackScore;
}
public String getFullText() {
double blackScore = getBlackScore();
double whiteScore = getWhiteScore();
switch (resultType) {
case BLACK_BY_RESIGNATION:
return "Black wins by resignation";
case WHITE_BY_RESIGNATION:
return "White wins by resignation";
case IN_PROGRESS:
case VOID:
return "Game in progress";
case SCORED:
if (blackScore > whiteScore) {
return "Black wins by " + (blackScore - whiteScore);
} else {
return "White wins by " + (whiteScore - blackScore);
}
default:
return "Unknown game result";
}
}
public int getNormalizedZeroScore() {
return normalizedZeroScore;
}

View File

@@ -0,0 +1,46 @@
package net.woodyfolsom.msproj;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
public class LatexWriter {
public static void write(OutputStream os, GameRecord gameRecord,
int turnNumber) throws IOException {
OutputStreamWriter writer = new OutputStreamWriter(os);
int boardSize = gameRecord.getGameConfig().getSize();
writer.write("\\begin{figure}[h]\n");
writer.write("\\black{");
boolean firstMove = true;
for (int turn = 1; turn <= gameRecord.getNumTurns(); turn+=2) {
if (!firstMove) {
writer.write(",");
}
writer.write(gameRecord.getMove(turn).toLatex(boardSize));
firstMove = false;
}
writer.write("}\n");
writer.write("\\white{");
firstMove = true;
for (int turn = 2; turn <= gameRecord.getNumTurns(); turn+=2) {
if (!firstMove) {
writer.write(",");
}
writer.write(gameRecord.getMove(turn).toLatex(boardSize));
firstMove = false;
}
writer.write("}\n");
writer.write("\\begin{center}\n");
writer.write("\\gobansize{" + boardSize + "}\n");
writer.write("\\shortstack{\\showfullgoban\\\\"
+ gameRecord.getResult().getFullText() + "}\n");
writer.write("\\end{center}\n");
writer.write("\\end{figure}\n");
writer.flush();
}
}

View File

@@ -3,11 +3,20 @@ package net.woodyfolsom.msproj;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Date;
import org.antlr.runtime.ANTLRInputStream;
import org.antlr.runtime.ANTLRStringStream;
import org.antlr.runtime.CommonTokenStream;
import org.antlr.runtime.RecognitionException;
import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.sgf.SGFLexer;
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
import net.woodyfolsom.msproj.sgf.SGFParser;
public class Referee {
private Policy blackPolicy;
@@ -35,6 +44,25 @@ public class Referee {
}
}
public static GameRecord replay(InputStream sgfInputStream) throws IOException, RecognitionException {
ANTLRStringStream in = new ANTLRInputStream(sgfInputStream);
SGFLexer lexer = new SGFLexer(in);
CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection = parser.collection();
//parse sgf header
GameConfig gameConfig = nodeCollection.getGameConfig();
GameRecord gameRecord = new GameRecord(gameConfig);
//replay sgf moves, throw exception if moves are illegal
for (Action action : nodeCollection.getMoves(gameConfig.getSize())) {
gameRecord.play(gameRecord.getPlayerToMove(), action);
}
return gameRecord;
}
public GameResult play(GameConfig gameConfig) {
GameRecord gameRecord = new GameRecord(gameConfig);

View File

@@ -33,9 +33,7 @@ public class SGFWriter {
if (action.isPass()) {
sgfCoord = "";
} else {
sgfCoord = action.toString().toLowerCase().substring(0,1);
char row = (char) ('a' + gameConfig.getSize() - Integer.valueOf(action.toString().substring(1)).intValue());
sgfCoord = sgfCoord + row;
sgfCoord = action.toSGF(gameConfig.getSize());
}
writer.write(sgfCoord + "]");
}

View File

@@ -1,6 +1,14 @@
package net.woodyfolsom.msproj;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import net.woodyfolsom.msproj.gui.Goban;
@@ -12,12 +20,15 @@ import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParallelization;
public class StandAloneGame {
private static final int DEFAULT_TURN_LENGTH = 2; // default turn is 2
// seconds;
private static final double DEFAULT_KOMI = 5.5;
private static final int DEFAULT_NUM_GAMES = 1;
private static final int DEFAULT_SIZE = 9;
enum PLAYER_TYPE {
HUMAN, HUMAN_GUI, ROOT_PAR, UCT_FAST, UCT_SLOW
HUMAN, HUMAN_GUI, ROOT_PAR, UCT_FAST, UCT_SLOW, RANDOM
};
public static void main(String[] args) {
@@ -30,7 +41,7 @@ public class StandAloneGame {
int nGames = DEFAULT_NUM_GAMES;
int size = DEFAULT_SIZE;
double komi = DEFAULT_KOMI;
switch (args.length) {
case 5:
nGames = Integer.valueOf(args[4]);
@@ -40,7 +51,14 @@ public class StandAloneGame {
size = Integer.valueOf(args[2]);
break;
default:
System.out.println("Arguments #3-5 not specified. Using default size=" + size +", komi = " + komi +", nGames=" + nGames +".");
System.out
.println("Arguments #3-5 not specified. Using default size="
+ size
+ ", komi = "
+ komi
+ ", nGames="
+ nGames
+ ".");
}
new StandAloneGame().playGame(parsePlayerType(args[0]),
parsePlayerType(args[1]), size, komi, nGames);
@@ -58,6 +76,8 @@ public class StandAloneGame {
return PLAYER_TYPE.HUMAN;
} else if ("HUMAN_GUI".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.HUMAN_GUI;
} else if ("RANDOM".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.RANDOM;
} else {
throw new RuntimeException("Unknown player type: " + playerTypeStr);
}
@@ -70,34 +90,90 @@ public class StandAloneGame {
gameConfig.setKomi(komi);
Referee referee = new Referee();
referee.setPolicy(Player.BLACK, getPolicy(playerType1, gameConfig, Player.BLACK));
referee.setPolicy(Player.WHITE, getPolicy(playerType2, gameConfig, Player.WHITE));
referee.setPolicy(Player.BLACK,
getPolicy(playerType1, gameConfig, Player.BLACK));
referee.setPolicy(Player.WHITE,
getPolicy(playerType2, gameConfig, Player.WHITE));
List<GameResult> round1results = new ArrayList<GameResult>();
List<GameResult> results = new ArrayList<GameResult>();
for (int round = 0; round < rounds; round++) {
results.add(referee.play(gameConfig));
round1results.add(referee.play(gameConfig));
}
System.out.println("Cumulative results for " + rounds + " games (BLACK="
+ playerType1 + ", WHITE=" + playerType2 + ")");
for (int i = 0; i < rounds; i++) {
System.out.println(i + ". " + results.get(i));
List<GameResult> round2results = new ArrayList<GameResult>();
referee.setPolicy(Player.BLACK,
getPolicy(playerType2, gameConfig, Player.BLACK));
referee.setPolicy(Player.WHITE,
getPolicy(playerType1, gameConfig, Player.WHITE));
for (int round = 0; round < rounds; round++) {
round2results.add(referee.play(gameConfig));
}
DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
try {
File txtFile = new File("gotournament-" + dateFormat.format(new Date())
+ ".txt");
FileWriter writer = new FileWriter(txtFile);
try {
logResults(writer, round1results, playerType1.toString(), playerType2.toString());
logResults(writer, round2results, playerType2.toString(), playerType1.toString());
System.out
.println("Game tournament saved as " + txtFile.getAbsolutePath());
} finally {
try {
writer.close();
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
} catch (IOException ioe) {
System.out.println("Unable to save game file due to IOException: "
+ ioe.getMessage());
}
}
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, Player player) {
private void logResults(FileWriter writer, List<GameResult> results, String player1, String player2) throws IOException {
String header = "Cumulative results for " + results.size()
+ " games (BLACK=" + player1 + ", WHITE=" + player2
+ ")";
System.out.println(header);
writer.write(header);
writer.write("\n");
for (int i = 0; i < results.size(); i++) {
String resultLine = (i+1) + ". " + results.get(i);
System.out.println(resultLine);
writer.write(resultLine);
writer.write("\n");
}
writer.flush();
}
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
Player player) {
switch (playerType) {
case HUMAN:
return new HumanKeyboardInput();
case HUMAN_GUI:
return new HumanGuiInput(new Goban(gameConfig, player));
case ROOT_PAR:
return new RootParallelization(4, 6000L);
return new RootParallelization(4, DEFAULT_TURN_LENGTH * 1000L * 3);
case UCT_SLOW:
return new MonteCarloUCT(new RandomMovePolicy(), 4000L);
return new MonteCarloUCT(new RandomMovePolicy(),
DEFAULT_TURN_LENGTH * 1000L * 3);
case UCT_FAST:
return new MonteCarloUCT(new RandomMovePolicy(), 1000L);
return new MonteCarloUCT(new RandomMovePolicy(),
DEFAULT_TURN_LENGTH * 1000L);
case RANDOM:
return new RandomMovePolicy();
default:
throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
+ playerType);

View File

@@ -0,0 +1,15 @@
package net.woodyfolsom.msproj.ann;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
public interface NeuralNetLearner {
void learn(TrainingSet<SupervisedTrainingElement> trainingSet);
void reset();
NeuralNetwork getNeuralNetwork();
void setNeuralNetwork(NeuralNetwork neuralNetwork);
}

View File

@@ -6,6 +6,8 @@ import java.io.FileNotFoundException;
import java.io.FilenameFilter;
import java.io.IOException;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.Referee;
import net.woodyfolsom.msproj.sgf.SGFLexer;
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
import net.woodyfolsom.msproj.sgf.SGFParser;
@@ -40,40 +42,18 @@ public class PassLearner {
}
public void parseSGF(File sgfFile) {
FileInputStream fis;
FileInputStream sgfInputStream;
try {
fis = new FileInputStream(sgfFile);
ANTLRStringStream in;
try {
in = new ANTLRInputStream(fis);
SGFLexer lexer = new SGFLexer(in);
CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection;
try {
nodeCollection = parser.collection();
System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF());
System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
} catch (RecognitionException re) {
re.printStackTrace();
}
} catch (IOException ioe) {
ioe.printStackTrace();
}
try {
fis.close();
} catch (IOException ioe) {
System.out.println("Error closing input stream for file" + sgfFile.getPath());
}
sgfInputStream = new FileInputStream(sgfFile);
GameRecord gameRecord = Referee.replay(sgfInputStream);
//...
} catch (FileNotFoundException fnfe) {
fnfe.printStackTrace();
} catch (RecognitionException re) {
re.printStackTrace();
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
}

View File

@@ -0,0 +1,42 @@
package net.woodyfolsom.msproj.ann;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;
/**
* Based on sample code from http://neuroph.sourceforge.net
*
* @author Woody
*
*/
public class XORLearner implements NeuralNetLearner {
private NeuralNetwork neuralNetwork;
public XORLearner() {
reset();
}
@Override
public NeuralNetwork getNeuralNetwork() {
return neuralNetwork;
}
@Override
public void learn(TrainingSet<SupervisedTrainingElement> trainingSet) {
this.neuralNetwork.learn(trainingSet);
}
@Override
public void reset() {
this.neuralNetwork = new MultiLayerPerceptron(
TransferFunctionType.TANH, 2, 3, 1);
}
@Override
public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
this.neuralNetwork = neuralNetwork;
}
}

View File

@@ -4,7 +4,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.Action;
public class SGFGameTree {
private List<SGFNode> nodeSequence = new ArrayList<SGFNode>();
@@ -30,60 +30,32 @@ public class SGFGameTree {
subTrees.add(subTree);
}
public List<Action> getMoves(int boardSize) {
List<Action> moves = new ArrayList<Action>();
for (SGFNode node : nodeSequence) {
SGFValue<?> sgfValue;
switch (node.getType()) {
case MOVE_BLACK :
sgfValue = node.getFirstValue(SGFIdentifier.MOVE_BLACK);
break;
case MOVE_WHITE :
sgfValue = node.getFirstValue(SGFIdentifier.MOVE_WHITE);
break;
default :
continue;
}
moves.add(Action.parseSGF(sgfValue.getText(), boardSize));
}
return moves;
}
public int getSubTreeCount() {
return subTrees.size();
}
public String toLateXmoves(Player player, int boardSize) {
StringBuilder latexSB = new StringBuilder();
SGFNode.TYPE nodeType;
SGFIdentifier sgfIdent;
if (player == Player.WHITE) {
nodeType = SGFNode.TYPE.MOVE_WHITE;
sgfIdent = SGFIdentifier.MOVE_WHITE;
latexSB.append("\\white{");
} else if (player == Player.BLACK) {
nodeType = SGFNode.TYPE.MOVE_BLACK;
sgfIdent = SGFIdentifier.MOVE_BLACK;
latexSB.append("\\black{");
} else {
throw new RuntimeException("Invalid player: " + player);
}
boolean firstMove = true;
int nMoves = 0;
for (SGFNode node : nodeSequence) {
if (node.getType() != nodeType) {
continue;
}
SGFValue<?> sgfValue = node.getFirstValue(sgfIdent);
if (sgfValue.isEmpty()) {
// TODO later this will be the LaTeX igo code for 'Pass'?
continue;
}
if (firstMove) {
firstMove = false;
} else {
latexSB.append(",");
}
SGFCoord sgfCoord = (SGFCoord) sgfValue.getValue();
char column = sgfCoord.getColumn();
if (column >= 'i') {
latexSB.append((char) (column + 1));
} else {
latexSB.append(column);
}
char row = sgfCoord.getRow();
latexSB.append(boardSize - row + 'a');
nMoves++;
}
if (nMoves == 0) {
return "";
}
latexSB.append("}\n");
return latexSB.toString();
}
public int getBoardSize() {
for (SGFNode node : nodeSequence) {
SGFNode.TYPE nodeType = node.getType();
@@ -97,35 +69,19 @@ public class SGFGameTree {
throw new RuntimeException("Cannot get board size: SGFNode with identifier SZ was not found in any nodeSequence of type ROOT.");
}
public String toLateX(int boardSize) {
StringBuilder latexSB = new StringBuilder();
// Somewhat convoluted logic here because the grammar does not require
// all root
// properties to be included in the same node in the tree's node
// sequence, although they should
// each be unique among all node sequences in the tree.
public double getKomi() {
for (SGFNode node : nodeSequence) {
SGFNode.TYPE nodeType = node.getType();
switch (nodeType) {
case ROOT:
latexSB.append("\\gobansize");
latexSB.append("{");
latexSB.append(boardSize);
latexSB.append("}\n");
latexSB.append("\\shortstack{\\showfullgoban\\\\");
SGFResult result = (SGFResult) node.getFirstValue(
SGFIdentifier.RESULT).getValue();
latexSB.append(result.getFullText());
latexSB.append("}\n");
break;
return Double.parseDouble(node.getFirstValue(SGFIdentifier.KOMI).getText());
default:
// ignore
}
}
return latexSB.toString();
throw new RuntimeException("Cannot get board size: SGFNode with identifier KM was not found in any nodeSequence of type ROOT.");
}
public String toSGF() {
StringBuilder sgfFormatString = new StringBuilder("(");
for (SGFNode node : nodeSequence) {

View File

@@ -3,7 +3,8 @@ package net.woodyfolsom.msproj.sgf;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
public class SGFNodeCollection {
private List<SGFGameTree> gameTrees = new ArrayList<SGFGameTree>();
@@ -16,22 +17,21 @@ public class SGFNodeCollection {
return gameTrees.get(index);
}
public GameConfig getGameConfig() {
SGFGameTree gameTree = gameTrees.get(0);
int boardSize = gameTree.getBoardSize();
double komi = gameTree.getKomi();
return new GameConfig(boardSize,komi);
}
public int getGameTreeCount() {
return gameTrees.size();
}
public String toLateX() {
StringBuilder latexFormatString = new StringBuilder("");
public List<Action> getMoves(int boardSize) {
SGFGameTree gameTree = gameTrees.get(0);
int boardSize = gameTree.getBoardSize();
latexFormatString.append(gameTree.toLateXmoves(Player.BLACK, boardSize));
latexFormatString.append(gameTree.toLateXmoves(Player.WHITE, boardSize));
latexFormatString.append("\\begin{center}\n");
latexFormatString.append(gameTree.toLateX(boardSize));
latexFormatString.append("\\end{center}");
return latexFormatString.toString();
return gameTree.getMoves(boardSize);
}
/**

View File

@@ -18,16 +18,6 @@ public class SGFResult {
}
}
public String getFullText() {
if (resignation == false && tie == false) {
return winner.getColor() + " wins by " + score;
} else if (resignation == true && tie == false) {
return winner.getColor() + " wins by resignation";
} else {
throw new UnsupportedOperationException("Not implemented");
}
}
@Override
public String toString() {
if (resignation == false && tie == false) {

View File

@@ -0,0 +1,42 @@
package net.woodyfolsom.msproj;
import static org.junit.Assert.assertEquals;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.antlr.runtime.RecognitionException;
import org.junit.Test;
public class RefereeTest {
//private static final String FILENAME = "data/tourney1/gogame-121115115720-0500.sgf";
private static final String FILENAME = "data/games/1334-gokifu-20120916-Gu_Li-Lee_Sedol.sgf";
@Test
public void testReplay() throws IOException, RecognitionException {
File sgfFile = new File(FILENAME);
InputStream fis = new FileInputStream(sgfFile);
GameRecord gameRecord = Referee.replay(fis);
fis.close();
//assertEquals(9, gameRecord.getGameConfig().getSize());
assertEquals(19, gameRecord.getGameConfig().getSize());
//assertEquals(5.5, gameRecord.getGameConfig().getKomi(), 0.1);
assertEquals(7.5, gameRecord.getGameConfig().getKomi(), 0.1);
//assertEquals(74, gameRecord.getNumTurns());
assertEquals(214, gameRecord.getNumTurns());
System.out.println("Final board state (LaTeX): ");
LatexWriter.write(System.out, gameRecord, 214);
System.out.println("Final board state (SGF): ");
SGFWriter.write(System.out, gameRecord);
}
}

View File

@@ -0,0 +1,86 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.util.Arrays;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
public class NeuralNetLearnerTest {
private static final String FILENAME = "myMlPerceptron.nnet";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearnSaveLoad() {
NeuralNetLearner nnLearner = new XORLearner();
// create training set (logical XOR function)
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
for (int x = 0; x < 1000; x++) {
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
}
nnLearner.learn(trainingSet);
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
nnet.save(FILENAME);
nnet = NeuralNetwork.load(FILENAME);
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
}
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
nnet.setInput(trainingElement.getInput());
nnet.calculate();
double[] networkOutput = nnet.getOutput();
System.out.print("Input: "
+ Arrays.toString(trainingElement.getInput()));
System.out.println(" Output: " + Arrays.toString(networkOutput));
}
}
}

View File

@@ -1,6 +1,7 @@
package net.woodyfolsom.msproj.sgf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import java.io.ByteArrayInputStream;
import java.io.IOException;
@@ -15,6 +16,7 @@ public class CollectionTest {
public static final String TEST_SGF =
"(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+
";W[ee];B[];W[])";
public static final String TEST_LATEX =
"\\white{e5}\n" +
"\\begin{center}\n" +
@@ -48,8 +50,6 @@ public class CollectionTest {
CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection = parser.collection();
String actualLaTeX = nodeCollection.toLateX();
assertEquals(TEST_LATEX, actualLaTeX);
assertNotNull(nodeCollection);
}
}

View File

@@ -25,11 +25,6 @@ public class SGFParserTest {
System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF());
System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
}
@Test
@@ -45,10 +40,5 @@ public class SGFParserTest {
System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF());
System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
}
}