Beginning to implement neural net training. Fixed bugs in SGF/LaTeX export.

This commit is contained in:
2012-11-15 15:00:40 -05:00
parent e01060bc11
commit ca37280ed8
17 changed files with 470 additions and 156 deletions

BIN
lib/encog-engine-2.5.0.jar Normal file

Binary file not shown.

View File

@@ -77,6 +77,48 @@ public class Action {
return this == Action.RESIGN; return this == Action.RESIGN;
} }
public static Action parseSGF(String sgfCoord, int boardSize) {
if (sgfCoord.length() == 0) {
return Action.PASS;
}
StringBuilder sb = new StringBuilder();
char column = sgfCoord.charAt(0);
if (column >= 'i') {
sb.append((char) (column + 1));
} else {
sb.append(column);
}
char row = sgfCoord.charAt(1);
sb.append(boardSize - row + 'a');
return Action.getInstance(sb.toString().toUpperCase());
}
public String toLatex(int boardSize) {
if (isPass() || isNone() || isResign()) {
throw new UnsupportedOperationException("Invalid Action for toLatex() call: " + this);
}
String latex = new String(new char[] {column}).toLowerCase();
latex += row;
return latex;
}
public String toSGF(int boardSize) {
if (isNone() || isResign()) {
throw new UnsupportedOperationException("Invalid Action for toLatex() call: " + this);
}
if (isPass()) {
return ""; // or 'tt' in old format
}
char[] latex = new char[2];
if (column >= 'J') {
latex[0] = (char) (column - 1);
} else {
latex[0] = column;
}
latex[1] = (char)('a' + (boardSize - row));
return new String(latex).toLowerCase();
}
@Override @Override
public String toString() { public String toString() {
return move; return move;

View File

@@ -44,6 +44,29 @@ public class GameResult {
return blackScore; return blackScore;
} }
public String getFullText() {
double blackScore = getBlackScore();
double whiteScore = getWhiteScore();
switch (resultType) {
case BLACK_BY_RESIGNATION:
return "Black wins by resignation";
case WHITE_BY_RESIGNATION:
return "White wins by resignation";
case IN_PROGRESS:
case VOID:
return "Game in progress";
case SCORED:
if (blackScore > whiteScore) {
return "Black wins by " + (blackScore - whiteScore);
} else {
return "White wins by " + (whiteScore - blackScore);
}
default:
return "Unknown game result";
}
}
public int getNormalizedZeroScore() { public int getNormalizedZeroScore() {
return normalizedZeroScore; return normalizedZeroScore;
} }

View File

@@ -0,0 +1,46 @@
package net.woodyfolsom.msproj;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
public class LatexWriter {
public static void write(OutputStream os, GameRecord gameRecord,
int turnNumber) throws IOException {
OutputStreamWriter writer = new OutputStreamWriter(os);
int boardSize = gameRecord.getGameConfig().getSize();
writer.write("\\begin{figure}[h]\n");
writer.write("\\black{");
boolean firstMove = true;
for (int turn = 1; turn <= gameRecord.getNumTurns(); turn+=2) {
if (!firstMove) {
writer.write(",");
}
writer.write(gameRecord.getMove(turn).toLatex(boardSize));
firstMove = false;
}
writer.write("}\n");
writer.write("\\white{");
firstMove = true;
for (int turn = 2; turn <= gameRecord.getNumTurns(); turn+=2) {
if (!firstMove) {
writer.write(",");
}
writer.write(gameRecord.getMove(turn).toLatex(boardSize));
firstMove = false;
}
writer.write("}\n");
writer.write("\\begin{center}\n");
writer.write("\\gobansize{" + boardSize + "}\n");
writer.write("\\shortstack{\\showfullgoban\\\\"
+ gameRecord.getResult().getFullText() + "}\n");
writer.write("\\end{center}\n");
writer.write("\\end{figure}\n");
writer.flush();
}
}

View File

@@ -3,11 +3,20 @@ package net.woodyfolsom.msproj;
import java.io.File; import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.text.DateFormat; import java.text.DateFormat;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.Date; import java.util.Date;
import org.antlr.runtime.ANTLRInputStream;
import org.antlr.runtime.ANTLRStringStream;
import org.antlr.runtime.CommonTokenStream;
import org.antlr.runtime.RecognitionException;
import net.woodyfolsom.msproj.policy.Policy; import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.sgf.SGFLexer;
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
import net.woodyfolsom.msproj.sgf.SGFParser;
public class Referee { public class Referee {
private Policy blackPolicy; private Policy blackPolicy;
@@ -35,6 +44,25 @@ public class Referee {
} }
} }
public static GameRecord replay(InputStream sgfInputStream) throws IOException, RecognitionException {
ANTLRStringStream in = new ANTLRInputStream(sgfInputStream);
SGFLexer lexer = new SGFLexer(in);
CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection = parser.collection();
//parse sgf header
GameConfig gameConfig = nodeCollection.getGameConfig();
GameRecord gameRecord = new GameRecord(gameConfig);
//replay sgf moves, throw exception if moves are illegal
for (Action action : nodeCollection.getMoves(gameConfig.getSize())) {
gameRecord.play(gameRecord.getPlayerToMove(), action);
}
return gameRecord;
}
public GameResult play(GameConfig gameConfig) { public GameResult play(GameConfig gameConfig) {
GameRecord gameRecord = new GameRecord(gameConfig); GameRecord gameRecord = new GameRecord(gameConfig);

View File

@@ -33,9 +33,7 @@ public class SGFWriter {
if (action.isPass()) { if (action.isPass()) {
sgfCoord = ""; sgfCoord = "";
} else { } else {
sgfCoord = action.toString().toLowerCase().substring(0,1); sgfCoord = action.toSGF(gameConfig.getSize());
char row = (char) ('a' + gameConfig.getSize() - Integer.valueOf(action.toString().substring(1)).intValue());
sgfCoord = sgfCoord + row;
} }
writer.write(sgfCoord + "]"); writer.write(sgfCoord + "]");
} }

View File

@@ -1,6 +1,14 @@
package net.woodyfolsom.msproj; package net.woodyfolsom.msproj;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Date;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.gui.Goban; import net.woodyfolsom.msproj.gui.Goban;
@@ -12,12 +20,15 @@ import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParallelization; import net.woodyfolsom.msproj.policy.RootParallelization;
public class StandAloneGame { public class StandAloneGame {
private static final int DEFAULT_TURN_LENGTH = 2; // default turn is 2
// seconds;
private static final double DEFAULT_KOMI = 5.5; private static final double DEFAULT_KOMI = 5.5;
private static final int DEFAULT_NUM_GAMES = 1; private static final int DEFAULT_NUM_GAMES = 1;
private static final int DEFAULT_SIZE = 9; private static final int DEFAULT_SIZE = 9;
enum PLAYER_TYPE { enum PLAYER_TYPE {
HUMAN, HUMAN_GUI, ROOT_PAR, UCT_FAST, UCT_SLOW HUMAN, HUMAN_GUI, ROOT_PAR, UCT_FAST, UCT_SLOW, RANDOM
}; };
public static void main(String[] args) { public static void main(String[] args) {
@@ -30,7 +41,7 @@ public class StandAloneGame {
int nGames = DEFAULT_NUM_GAMES; int nGames = DEFAULT_NUM_GAMES;
int size = DEFAULT_SIZE; int size = DEFAULT_SIZE;
double komi = DEFAULT_KOMI; double komi = DEFAULT_KOMI;
switch (args.length) { switch (args.length) {
case 5: case 5:
nGames = Integer.valueOf(args[4]); nGames = Integer.valueOf(args[4]);
@@ -40,7 +51,14 @@ public class StandAloneGame {
size = Integer.valueOf(args[2]); size = Integer.valueOf(args[2]);
break; break;
default: default:
System.out.println("Arguments #3-5 not specified. Using default size=" + size +", komi = " + komi +", nGames=" + nGames +"."); System.out
.println("Arguments #3-5 not specified. Using default size="
+ size
+ ", komi = "
+ komi
+ ", nGames="
+ nGames
+ ".");
} }
new StandAloneGame().playGame(parsePlayerType(args[0]), new StandAloneGame().playGame(parsePlayerType(args[0]),
parsePlayerType(args[1]), size, komi, nGames); parsePlayerType(args[1]), size, komi, nGames);
@@ -58,6 +76,8 @@ public class StandAloneGame {
return PLAYER_TYPE.HUMAN; return PLAYER_TYPE.HUMAN;
} else if ("HUMAN_GUI".equalsIgnoreCase(playerTypeStr)) { } else if ("HUMAN_GUI".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.HUMAN_GUI; return PLAYER_TYPE.HUMAN_GUI;
} else if ("RANDOM".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.RANDOM;
} else { } else {
throw new RuntimeException("Unknown player type: " + playerTypeStr); throw new RuntimeException("Unknown player type: " + playerTypeStr);
} }
@@ -70,34 +90,90 @@ public class StandAloneGame {
gameConfig.setKomi(komi); gameConfig.setKomi(komi);
Referee referee = new Referee(); Referee referee = new Referee();
referee.setPolicy(Player.BLACK, getPolicy(playerType1, gameConfig, Player.BLACK)); referee.setPolicy(Player.BLACK,
referee.setPolicy(Player.WHITE, getPolicy(playerType2, gameConfig, Player.WHITE)); getPolicy(playerType1, gameConfig, Player.BLACK));
referee.setPolicy(Player.WHITE,
getPolicy(playerType2, gameConfig, Player.WHITE));
List<GameResult> round1results = new ArrayList<GameResult>();
List<GameResult> results = new ArrayList<GameResult>();
for (int round = 0; round < rounds; round++) { for (int round = 0; round < rounds; round++) {
results.add(referee.play(gameConfig)); round1results.add(referee.play(gameConfig));
} }
System.out.println("Cumulative results for " + rounds + " games (BLACK=" List<GameResult> round2results = new ArrayList<GameResult>();
+ playerType1 + ", WHITE=" + playerType2 + ")");
for (int i = 0; i < rounds; i++) { referee.setPolicy(Player.BLACK,
System.out.println(i + ". " + results.get(i)); getPolicy(playerType2, gameConfig, Player.BLACK));
referee.setPolicy(Player.WHITE,
getPolicy(playerType1, gameConfig, Player.WHITE));
for (int round = 0; round < rounds; round++) {
round2results.add(referee.play(gameConfig));
} }
DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
try {
File txtFile = new File("gotournament-" + dateFormat.format(new Date())
+ ".txt");
FileWriter writer = new FileWriter(txtFile);
try {
logResults(writer, round1results, playerType1.toString(), playerType2.toString());
logResults(writer, round2results, playerType2.toString(), playerType1.toString());
System.out
.println("Game tournament saved as " + txtFile.getAbsolutePath());
} finally {
try {
writer.close();
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
} catch (IOException ioe) {
System.out.println("Unable to save game file due to IOException: "
+ ioe.getMessage());
}
} }
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, Player player) { private void logResults(FileWriter writer, List<GameResult> results, String player1, String player2) throws IOException {
String header = "Cumulative results for " + results.size()
+ " games (BLACK=" + player1 + ", WHITE=" + player2
+ ")";
System.out.println(header);
writer.write(header);
writer.write("\n");
for (int i = 0; i < results.size(); i++) {
String resultLine = (i+1) + ". " + results.get(i);
System.out.println(resultLine);
writer.write(resultLine);
writer.write("\n");
}
writer.flush();
}
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
Player player) {
switch (playerType) { switch (playerType) {
case HUMAN: case HUMAN:
return new HumanKeyboardInput(); return new HumanKeyboardInput();
case HUMAN_GUI: case HUMAN_GUI:
return new HumanGuiInput(new Goban(gameConfig, player)); return new HumanGuiInput(new Goban(gameConfig, player));
case ROOT_PAR: case ROOT_PAR:
return new RootParallelization(4, 6000L); return new RootParallelization(4, DEFAULT_TURN_LENGTH * 1000L * 3);
case UCT_SLOW: case UCT_SLOW:
return new MonteCarloUCT(new RandomMovePolicy(), 4000L); return new MonteCarloUCT(new RandomMovePolicy(),
DEFAULT_TURN_LENGTH * 1000L * 3);
case UCT_FAST: case UCT_FAST:
return new MonteCarloUCT(new RandomMovePolicy(), 1000L); return new MonteCarloUCT(new RandomMovePolicy(),
DEFAULT_TURN_LENGTH * 1000L);
case RANDOM:
return new RandomMovePolicy();
default: default:
throw new IllegalArgumentException("Invalid PLAYER_TYPE: " throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
+ playerType); + playerType);

View File

@@ -0,0 +1,15 @@
package net.woodyfolsom.msproj.ann;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
public interface NeuralNetLearner {
void learn(TrainingSet<SupervisedTrainingElement> trainingSet);
void reset();
NeuralNetwork getNeuralNetwork();
void setNeuralNetwork(NeuralNetwork neuralNetwork);
}

View File

@@ -6,6 +6,8 @@ import java.io.FileNotFoundException;
import java.io.FilenameFilter; import java.io.FilenameFilter;
import java.io.IOException; import java.io.IOException;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.Referee;
import net.woodyfolsom.msproj.sgf.SGFLexer; import net.woodyfolsom.msproj.sgf.SGFLexer;
import net.woodyfolsom.msproj.sgf.SGFNodeCollection; import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
import net.woodyfolsom.msproj.sgf.SGFParser; import net.woodyfolsom.msproj.sgf.SGFParser;
@@ -40,40 +42,18 @@ public class PassLearner {
} }
public void parseSGF(File sgfFile) { public void parseSGF(File sgfFile) {
FileInputStream fis; FileInputStream sgfInputStream;
try { try {
fis = new FileInputStream(sgfFile); sgfInputStream = new FileInputStream(sgfFile);
ANTLRStringStream in; GameRecord gameRecord = Referee.replay(sgfInputStream);
try { //...
in = new ANTLRInputStream(fis);
SGFLexer lexer = new SGFLexer(in);
CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection;
try {
nodeCollection = parser.collection();
System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF());
System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
} catch (RecognitionException re) {
re.printStackTrace();
}
} catch (IOException ioe) {
ioe.printStackTrace();
}
try {
fis.close();
} catch (IOException ioe) {
System.out.println("Error closing input stream for file" + sgfFile.getPath());
}
} catch (FileNotFoundException fnfe) { } catch (FileNotFoundException fnfe) {
fnfe.printStackTrace(); fnfe.printStackTrace();
} catch (RecognitionException re) {
re.printStackTrace();
} catch (IOException ioe) {
ioe.printStackTrace();
} }
} }
} }

View File

@@ -0,0 +1,42 @@
package net.woodyfolsom.msproj.ann;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;
/**
* Based on sample code from http://neuroph.sourceforge.net
*
* @author Woody
*
*/
public class XORLearner implements NeuralNetLearner {
private NeuralNetwork neuralNetwork;
public XORLearner() {
reset();
}
@Override
public NeuralNetwork getNeuralNetwork() {
return neuralNetwork;
}
@Override
public void learn(TrainingSet<SupervisedTrainingElement> trainingSet) {
this.neuralNetwork.learn(trainingSet);
}
@Override
public void reset() {
this.neuralNetwork = new MultiLayerPerceptron(
TransferFunctionType.TANH, 2, 3, 1);
}
@Override
public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
this.neuralNetwork = neuralNetwork;
}
}

View File

@@ -4,7 +4,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Action;
public class SGFGameTree { public class SGFGameTree {
private List<SGFNode> nodeSequence = new ArrayList<SGFNode>(); private List<SGFNode> nodeSequence = new ArrayList<SGFNode>();
@@ -30,60 +30,32 @@ public class SGFGameTree {
subTrees.add(subTree); subTrees.add(subTree);
} }
public List<Action> getMoves(int boardSize) {
List<Action> moves = new ArrayList<Action>();
for (SGFNode node : nodeSequence) {
SGFValue<?> sgfValue;
switch (node.getType()) {
case MOVE_BLACK :
sgfValue = node.getFirstValue(SGFIdentifier.MOVE_BLACK);
break;
case MOVE_WHITE :
sgfValue = node.getFirstValue(SGFIdentifier.MOVE_WHITE);
break;
default :
continue;
}
moves.add(Action.parseSGF(sgfValue.getText(), boardSize));
}
return moves;
}
public int getSubTreeCount() { public int getSubTreeCount() {
return subTrees.size(); return subTrees.size();
} }
public String toLateXmoves(Player player, int boardSize) {
StringBuilder latexSB = new StringBuilder();
SGFNode.TYPE nodeType;
SGFIdentifier sgfIdent;
if (player == Player.WHITE) {
nodeType = SGFNode.TYPE.MOVE_WHITE;
sgfIdent = SGFIdentifier.MOVE_WHITE;
latexSB.append("\\white{");
} else if (player == Player.BLACK) {
nodeType = SGFNode.TYPE.MOVE_BLACK;
sgfIdent = SGFIdentifier.MOVE_BLACK;
latexSB.append("\\black{");
} else {
throw new RuntimeException("Invalid player: " + player);
}
boolean firstMove = true;
int nMoves = 0;
for (SGFNode node : nodeSequence) {
if (node.getType() != nodeType) {
continue;
}
SGFValue<?> sgfValue = node.getFirstValue(sgfIdent);
if (sgfValue.isEmpty()) {
// TODO later this will be the LaTeX igo code for 'Pass'?
continue;
}
if (firstMove) {
firstMove = false;
} else {
latexSB.append(",");
}
SGFCoord sgfCoord = (SGFCoord) sgfValue.getValue();
char column = sgfCoord.getColumn();
if (column >= 'i') {
latexSB.append((char) (column + 1));
} else {
latexSB.append(column);
}
char row = sgfCoord.getRow();
latexSB.append(boardSize - row + 'a');
nMoves++;
}
if (nMoves == 0) {
return "";
}
latexSB.append("}\n");
return latexSB.toString();
}
public int getBoardSize() { public int getBoardSize() {
for (SGFNode node : nodeSequence) { for (SGFNode node : nodeSequence) {
SGFNode.TYPE nodeType = node.getType(); SGFNode.TYPE nodeType = node.getType();
@@ -97,35 +69,19 @@ public class SGFGameTree {
throw new RuntimeException("Cannot get board size: SGFNode with identifier SZ was not found in any nodeSequence of type ROOT."); throw new RuntimeException("Cannot get board size: SGFNode with identifier SZ was not found in any nodeSequence of type ROOT.");
} }
public String toLateX(int boardSize) { public double getKomi() {
StringBuilder latexSB = new StringBuilder();
// Somewhat convoluted logic here because the grammar does not require
// all root
// properties to be included in the same node in the tree's node
// sequence, although they should
// each be unique among all node sequences in the tree.
for (SGFNode node : nodeSequence) { for (SGFNode node : nodeSequence) {
SGFNode.TYPE nodeType = node.getType(); SGFNode.TYPE nodeType = node.getType();
switch (nodeType) { switch (nodeType) {
case ROOT: case ROOT:
latexSB.append("\\gobansize"); return Double.parseDouble(node.getFirstValue(SGFIdentifier.KOMI).getText());
latexSB.append("{");
latexSB.append(boardSize);
latexSB.append("}\n");
latexSB.append("\\shortstack{\\showfullgoban\\\\");
SGFResult result = (SGFResult) node.getFirstValue(
SGFIdentifier.RESULT).getValue();
latexSB.append(result.getFullText());
latexSB.append("}\n");
break;
default: default:
// ignore // ignore
} }
} }
return latexSB.toString(); throw new RuntimeException("Cannot get board size: SGFNode with identifier KM was not found in any nodeSequence of type ROOT.");
} }
public String toSGF() { public String toSGF() {
StringBuilder sgfFormatString = new StringBuilder("("); StringBuilder sgfFormatString = new StringBuilder("(");
for (SGFNode node : nodeSequence) { for (SGFNode node : nodeSequence) {

View File

@@ -3,7 +3,8 @@ package net.woodyfolsom.msproj.sgf;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
public class SGFNodeCollection { public class SGFNodeCollection {
private List<SGFGameTree> gameTrees = new ArrayList<SGFGameTree>(); private List<SGFGameTree> gameTrees = new ArrayList<SGFGameTree>();
@@ -16,22 +17,21 @@ public class SGFNodeCollection {
return gameTrees.get(index); return gameTrees.get(index);
} }
public GameConfig getGameConfig() {
SGFGameTree gameTree = gameTrees.get(0);
int boardSize = gameTree.getBoardSize();
double komi = gameTree.getKomi();
return new GameConfig(boardSize,komi);
}
public int getGameTreeCount() { public int getGameTreeCount() {
return gameTrees.size(); return gameTrees.size();
} }
public String toLateX() { public List<Action> getMoves(int boardSize) {
StringBuilder latexFormatString = new StringBuilder("");
SGFGameTree gameTree = gameTrees.get(0); SGFGameTree gameTree = gameTrees.get(0);
return gameTree.getMoves(boardSize);
int boardSize = gameTree.getBoardSize();
latexFormatString.append(gameTree.toLateXmoves(Player.BLACK, boardSize));
latexFormatString.append(gameTree.toLateXmoves(Player.WHITE, boardSize));
latexFormatString.append("\\begin{center}\n");
latexFormatString.append(gameTree.toLateX(boardSize));
latexFormatString.append("\\end{center}");
return latexFormatString.toString();
} }
/** /**

View File

@@ -18,16 +18,6 @@ public class SGFResult {
} }
} }
public String getFullText() {
if (resignation == false && tie == false) {
return winner.getColor() + " wins by " + score;
} else if (resignation == true && tie == false) {
return winner.getColor() + " wins by resignation";
} else {
throw new UnsupportedOperationException("Not implemented");
}
}
@Override @Override
public String toString() { public String toString() {
if (resignation == false && tie == false) { if (resignation == false && tie == false) {

View File

@@ -0,0 +1,42 @@
package net.woodyfolsom.msproj;
import static org.junit.Assert.assertEquals;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.antlr.runtime.RecognitionException;
import org.junit.Test;
public class RefereeTest {
//private static final String FILENAME = "data/tourney1/gogame-121115115720-0500.sgf";
private static final String FILENAME = "data/games/1334-gokifu-20120916-Gu_Li-Lee_Sedol.sgf";
@Test
public void testReplay() throws IOException, RecognitionException {
File sgfFile = new File(FILENAME);
InputStream fis = new FileInputStream(sgfFile);
GameRecord gameRecord = Referee.replay(fis);
fis.close();
//assertEquals(9, gameRecord.getGameConfig().getSize());
assertEquals(19, gameRecord.getGameConfig().getSize());
//assertEquals(5.5, gameRecord.getGameConfig().getKomi(), 0.1);
assertEquals(7.5, gameRecord.getGameConfig().getKomi(), 0.1);
//assertEquals(74, gameRecord.getNumTurns());
assertEquals(214, gameRecord.getNumTurns());
System.out.println("Final board state (LaTeX): ");
LatexWriter.write(System.out, gameRecord, 214);
System.out.println("Final board state (SGF): ");
SGFWriter.write(System.out, gameRecord);
}
}

View File

@@ -0,0 +1,86 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.util.Arrays;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
public class NeuralNetLearnerTest {
private static final String FILENAME = "myMlPerceptron.nnet";
@AfterClass
public static void deleteNewNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@BeforeClass
public static void deleteSavedNet() {
File file = new File(FILENAME);
if (file.exists()) {
file.delete();
}
}
@Test
public void testLearnSaveLoad() {
NeuralNetLearner nnLearner = new XORLearner();
// create training set (logical XOR function)
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
for (int x = 0; x < 1000; x++) {
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
}
nnLearner.learn(trainingSet);
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
2, 1);
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
0 }, new double[] { 0 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
1 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
0 }, new double[] { 1 }));
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
1 }, new double[] { 0 }));
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
nnet.save(FILENAME);
nnet = NeuralNetwork.load(FILENAME);
System.out.println("Output from eval set (learned network):");
testNetwork(nnet, valSet);
}
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
nnet.setInput(trainingElement.getInput());
nnet.calculate();
double[] networkOutput = nnet.getOutput();
System.out.print("Input: "
+ Arrays.toString(trainingElement.getInput()));
System.out.println(" Output: " + Arrays.toString(networkOutput));
}
}
}

View File

@@ -1,6 +1,7 @@
package net.woodyfolsom.msproj.sgf; package net.woodyfolsom.msproj.sgf;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
@@ -15,6 +16,7 @@ public class CollectionTest {
public static final String TEST_SGF = public static final String TEST_SGF =
"(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+ "(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+
";W[ee];B[];W[])"; ";W[ee];B[];W[])";
public static final String TEST_LATEX = public static final String TEST_LATEX =
"\\white{e5}\n" + "\\white{e5}\n" +
"\\begin{center}\n" + "\\begin{center}\n" +
@@ -48,8 +50,6 @@ public class CollectionTest {
CommonTokenStream tokens = new CommonTokenStream(lexer); CommonTokenStream tokens = new CommonTokenStream(lexer);
SGFParser parser = new SGFParser(tokens); SGFParser parser = new SGFParser(tokens);
SGFNodeCollection nodeCollection = parser.collection(); SGFNodeCollection nodeCollection = parser.collection();
assertNotNull(nodeCollection);
String actualLaTeX = nodeCollection.toLateX();
assertEquals(TEST_LATEX, actualLaTeX);
} }
} }

View File

@@ -25,11 +25,6 @@ public class SGFParserTest {
System.out.println("To SGF:"); System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF()); System.out.println(nodeCollection.toSGF());
System.out.println(""); System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
} }
@Test @Test
@@ -45,10 +40,5 @@ public class SGFParserTest {
System.out.println("To SGF:"); System.out.println("To SGF:");
System.out.println(nodeCollection.toSGF()); System.out.println(nodeCollection.toSGF());
System.out.println(""); System.out.println("");
System.out.println("To LaTeX:");
System.out.println(nodeCollection.toLateX());
System.out.println("");
} }
} }