Beginning to implement neural net training. Fixed bugs in SGF/LaTeX export.
This commit is contained in:
BIN
lib/encog-engine-2.5.0.jar
Normal file
BIN
lib/encog-engine-2.5.0.jar
Normal file
Binary file not shown.
@@ -77,6 +77,48 @@ public class Action {
|
|||||||
return this == Action.RESIGN;
|
return this == Action.RESIGN;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Action parseSGF(String sgfCoord, int boardSize) {
|
||||||
|
if (sgfCoord.length() == 0) {
|
||||||
|
return Action.PASS;
|
||||||
|
}
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
char column = sgfCoord.charAt(0);
|
||||||
|
if (column >= 'i') {
|
||||||
|
sb.append((char) (column + 1));
|
||||||
|
} else {
|
||||||
|
sb.append(column);
|
||||||
|
}
|
||||||
|
char row = sgfCoord.charAt(1);
|
||||||
|
sb.append(boardSize - row + 'a');
|
||||||
|
return Action.getInstance(sb.toString().toUpperCase());
|
||||||
|
}
|
||||||
|
|
||||||
|
public String toLatex(int boardSize) {
|
||||||
|
if (isPass() || isNone() || isResign()) {
|
||||||
|
throw new UnsupportedOperationException("Invalid Action for toLatex() call: " + this);
|
||||||
|
}
|
||||||
|
String latex = new String(new char[] {column}).toLowerCase();
|
||||||
|
latex += row;
|
||||||
|
return latex;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String toSGF(int boardSize) {
|
||||||
|
if (isNone() || isResign()) {
|
||||||
|
throw new UnsupportedOperationException("Invalid Action for toLatex() call: " + this);
|
||||||
|
}
|
||||||
|
if (isPass()) {
|
||||||
|
return ""; // or 'tt' in old format
|
||||||
|
}
|
||||||
|
char[] latex = new char[2];
|
||||||
|
if (column >= 'J') {
|
||||||
|
latex[0] = (char) (column - 1);
|
||||||
|
} else {
|
||||||
|
latex[0] = column;
|
||||||
|
}
|
||||||
|
latex[1] = (char)('a' + (boardSize - row));
|
||||||
|
return new String(latex).toLowerCase();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return move;
|
return move;
|
||||||
|
|||||||
@@ -44,6 +44,29 @@ public class GameResult {
|
|||||||
return blackScore;
|
return blackScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getFullText() {
|
||||||
|
double blackScore = getBlackScore();
|
||||||
|
double whiteScore = getWhiteScore();
|
||||||
|
|
||||||
|
switch (resultType) {
|
||||||
|
case BLACK_BY_RESIGNATION:
|
||||||
|
return "Black wins by resignation";
|
||||||
|
case WHITE_BY_RESIGNATION:
|
||||||
|
return "White wins by resignation";
|
||||||
|
case IN_PROGRESS:
|
||||||
|
case VOID:
|
||||||
|
return "Game in progress";
|
||||||
|
case SCORED:
|
||||||
|
if (blackScore > whiteScore) {
|
||||||
|
return "Black wins by " + (blackScore - whiteScore);
|
||||||
|
} else {
|
||||||
|
return "White wins by " + (whiteScore - blackScore);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "Unknown game result";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public int getNormalizedZeroScore() {
|
public int getNormalizedZeroScore() {
|
||||||
return normalizedZeroScore;
|
return normalizedZeroScore;
|
||||||
}
|
}
|
||||||
|
|||||||
46
src/net/woodyfolsom/msproj/LatexWriter.java
Normal file
46
src/net/woodyfolsom/msproj/LatexWriter.java
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package net.woodyfolsom.msproj;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.io.OutputStreamWriter;
|
||||||
|
|
||||||
|
public class LatexWriter {
|
||||||
|
|
||||||
|
public static void write(OutputStream os, GameRecord gameRecord,
|
||||||
|
int turnNumber) throws IOException {
|
||||||
|
OutputStreamWriter writer = new OutputStreamWriter(os);
|
||||||
|
|
||||||
|
int boardSize = gameRecord.getGameConfig().getSize();
|
||||||
|
|
||||||
|
writer.write("\\begin{figure}[h]\n");
|
||||||
|
writer.write("\\black{");
|
||||||
|
boolean firstMove = true;
|
||||||
|
for (int turn = 1; turn <= gameRecord.getNumTurns(); turn+=2) {
|
||||||
|
if (!firstMove) {
|
||||||
|
writer.write(",");
|
||||||
|
}
|
||||||
|
writer.write(gameRecord.getMove(turn).toLatex(boardSize));
|
||||||
|
firstMove = false;
|
||||||
|
}
|
||||||
|
writer.write("}\n");
|
||||||
|
|
||||||
|
writer.write("\\white{");
|
||||||
|
firstMove = true;
|
||||||
|
for (int turn = 2; turn <= gameRecord.getNumTurns(); turn+=2) {
|
||||||
|
if (!firstMove) {
|
||||||
|
writer.write(",");
|
||||||
|
}
|
||||||
|
writer.write(gameRecord.getMove(turn).toLatex(boardSize));
|
||||||
|
firstMove = false;
|
||||||
|
}
|
||||||
|
writer.write("}\n");
|
||||||
|
|
||||||
|
writer.write("\\begin{center}\n");
|
||||||
|
writer.write("\\gobansize{" + boardSize + "}\n");
|
||||||
|
writer.write("\\shortstack{\\showfullgoban\\\\"
|
||||||
|
+ gameRecord.getResult().getFullText() + "}\n");
|
||||||
|
writer.write("\\end{center}\n");
|
||||||
|
writer.write("\\end{figure}\n");
|
||||||
|
writer.flush();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,11 +3,20 @@ package net.woodyfolsom.msproj;
|
|||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
import java.text.DateFormat;
|
import java.text.DateFormat;
|
||||||
import java.text.SimpleDateFormat;
|
import java.text.SimpleDateFormat;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
|
||||||
|
import org.antlr.runtime.ANTLRInputStream;
|
||||||
|
import org.antlr.runtime.ANTLRStringStream;
|
||||||
|
import org.antlr.runtime.CommonTokenStream;
|
||||||
|
import org.antlr.runtime.RecognitionException;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.policy.Policy;
|
import net.woodyfolsom.msproj.policy.Policy;
|
||||||
|
import net.woodyfolsom.msproj.sgf.SGFLexer;
|
||||||
|
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
|
||||||
|
import net.woodyfolsom.msproj.sgf.SGFParser;
|
||||||
|
|
||||||
public class Referee {
|
public class Referee {
|
||||||
private Policy blackPolicy;
|
private Policy blackPolicy;
|
||||||
@@ -35,6 +44,25 @@ public class Referee {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static GameRecord replay(InputStream sgfInputStream) throws IOException, RecognitionException {
|
||||||
|
ANTLRStringStream in = new ANTLRInputStream(sgfInputStream);
|
||||||
|
SGFLexer lexer = new SGFLexer(in);
|
||||||
|
CommonTokenStream tokens = new CommonTokenStream(lexer);
|
||||||
|
SGFParser parser = new SGFParser(tokens);
|
||||||
|
SGFNodeCollection nodeCollection = parser.collection();
|
||||||
|
|
||||||
|
//parse sgf header
|
||||||
|
GameConfig gameConfig = nodeCollection.getGameConfig();
|
||||||
|
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||||
|
|
||||||
|
//replay sgf moves, throw exception if moves are illegal
|
||||||
|
for (Action action : nodeCollection.getMoves(gameConfig.getSize())) {
|
||||||
|
gameRecord.play(gameRecord.getPlayerToMove(), action);
|
||||||
|
}
|
||||||
|
|
||||||
|
return gameRecord;
|
||||||
|
}
|
||||||
|
|
||||||
public GameResult play(GameConfig gameConfig) {
|
public GameResult play(GameConfig gameConfig) {
|
||||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||||
|
|
||||||
|
|||||||
@@ -33,9 +33,7 @@ public class SGFWriter {
|
|||||||
if (action.isPass()) {
|
if (action.isPass()) {
|
||||||
sgfCoord = "";
|
sgfCoord = "";
|
||||||
} else {
|
} else {
|
||||||
sgfCoord = action.toString().toLowerCase().substring(0,1);
|
sgfCoord = action.toSGF(gameConfig.getSize());
|
||||||
char row = (char) ('a' + gameConfig.getSize() - Integer.valueOf(action.toString().substring(1)).intValue());
|
|
||||||
sgfCoord = sgfCoord + row;
|
|
||||||
}
|
}
|
||||||
writer.write(sgfCoord + "]");
|
writer.write(sgfCoord + "]");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
package net.woodyfolsom.msproj;
|
package net.woodyfolsom.msproj;
|
||||||
|
|
||||||
|
import java.io.BufferedWriter;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.FileWriter;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.text.DateFormat;
|
||||||
|
import java.text.SimpleDateFormat;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.gui.Goban;
|
import net.woodyfolsom.msproj.gui.Goban;
|
||||||
@@ -12,12 +20,15 @@ import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
|||||||
import net.woodyfolsom.msproj.policy.RootParallelization;
|
import net.woodyfolsom.msproj.policy.RootParallelization;
|
||||||
|
|
||||||
public class StandAloneGame {
|
public class StandAloneGame {
|
||||||
|
private static final int DEFAULT_TURN_LENGTH = 2; // default turn is 2
|
||||||
|
// seconds;
|
||||||
|
|
||||||
private static final double DEFAULT_KOMI = 5.5;
|
private static final double DEFAULT_KOMI = 5.5;
|
||||||
private static final int DEFAULT_NUM_GAMES = 1;
|
private static final int DEFAULT_NUM_GAMES = 1;
|
||||||
private static final int DEFAULT_SIZE = 9;
|
private static final int DEFAULT_SIZE = 9;
|
||||||
|
|
||||||
enum PLAYER_TYPE {
|
enum PLAYER_TYPE {
|
||||||
HUMAN, HUMAN_GUI, ROOT_PAR, UCT_FAST, UCT_SLOW
|
HUMAN, HUMAN_GUI, ROOT_PAR, UCT_FAST, UCT_SLOW, RANDOM
|
||||||
};
|
};
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
@@ -40,7 +51,14 @@ public class StandAloneGame {
|
|||||||
size = Integer.valueOf(args[2]);
|
size = Integer.valueOf(args[2]);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
System.out.println("Arguments #3-5 not specified. Using default size=" + size +", komi = " + komi +", nGames=" + nGames +".");
|
System.out
|
||||||
|
.println("Arguments #3-5 not specified. Using default size="
|
||||||
|
+ size
|
||||||
|
+ ", komi = "
|
||||||
|
+ komi
|
||||||
|
+ ", nGames="
|
||||||
|
+ nGames
|
||||||
|
+ ".");
|
||||||
}
|
}
|
||||||
new StandAloneGame().playGame(parsePlayerType(args[0]),
|
new StandAloneGame().playGame(parsePlayerType(args[0]),
|
||||||
parsePlayerType(args[1]), size, komi, nGames);
|
parsePlayerType(args[1]), size, komi, nGames);
|
||||||
@@ -58,6 +76,8 @@ public class StandAloneGame {
|
|||||||
return PLAYER_TYPE.HUMAN;
|
return PLAYER_TYPE.HUMAN;
|
||||||
} else if ("HUMAN_GUI".equalsIgnoreCase(playerTypeStr)) {
|
} else if ("HUMAN_GUI".equalsIgnoreCase(playerTypeStr)) {
|
||||||
return PLAYER_TYPE.HUMAN_GUI;
|
return PLAYER_TYPE.HUMAN_GUI;
|
||||||
|
} else if ("RANDOM".equalsIgnoreCase(playerTypeStr)) {
|
||||||
|
return PLAYER_TYPE.RANDOM;
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Unknown player type: " + playerTypeStr);
|
throw new RuntimeException("Unknown player type: " + playerTypeStr);
|
||||||
}
|
}
|
||||||
@@ -70,34 +90,90 @@ public class StandAloneGame {
|
|||||||
gameConfig.setKomi(komi);
|
gameConfig.setKomi(komi);
|
||||||
|
|
||||||
Referee referee = new Referee();
|
Referee referee = new Referee();
|
||||||
referee.setPolicy(Player.BLACK, getPolicy(playerType1, gameConfig, Player.BLACK));
|
referee.setPolicy(Player.BLACK,
|
||||||
referee.setPolicy(Player.WHITE, getPolicy(playerType2, gameConfig, Player.WHITE));
|
getPolicy(playerType1, gameConfig, Player.BLACK));
|
||||||
|
referee.setPolicy(Player.WHITE,
|
||||||
|
getPolicy(playerType2, gameConfig, Player.WHITE));
|
||||||
|
|
||||||
List<GameResult> results = new ArrayList<GameResult>();
|
List<GameResult> round1results = new ArrayList<GameResult>();
|
||||||
|
|
||||||
for (int round = 0; round < rounds; round++) {
|
for (int round = 0; round < rounds; round++) {
|
||||||
results.add(referee.play(gameConfig));
|
round1results.add(referee.play(gameConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Cumulative results for " + rounds + " games (BLACK="
|
List<GameResult> round2results = new ArrayList<GameResult>();
|
||||||
+ playerType1 + ", WHITE=" + playerType2 + ")");
|
|
||||||
for (int i = 0; i < rounds; i++) {
|
referee.setPolicy(Player.BLACK,
|
||||||
System.out.println(i + ". " + results.get(i));
|
getPolicy(playerType2, gameConfig, Player.BLACK));
|
||||||
|
referee.setPolicy(Player.WHITE,
|
||||||
|
getPolicy(playerType1, gameConfig, Player.WHITE));
|
||||||
|
for (int round = 0; round < rounds; round++) {
|
||||||
|
round2results.add(referee.play(gameConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
|
||||||
|
|
||||||
|
try {
|
||||||
|
|
||||||
|
File txtFile = new File("gotournament-" + dateFormat.format(new Date())
|
||||||
|
+ ".txt");
|
||||||
|
FileWriter writer = new FileWriter(txtFile);
|
||||||
|
|
||||||
|
try {
|
||||||
|
logResults(writer, round1results, playerType1.toString(), playerType2.toString());
|
||||||
|
logResults(writer, round2results, playerType2.toString(), playerType1.toString());
|
||||||
|
|
||||||
|
System.out
|
||||||
|
.println("Game tournament saved as " + txtFile.getAbsolutePath());
|
||||||
|
} finally {
|
||||||
|
try {
|
||||||
|
writer.close();
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
ioe.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
System.out.println("Unable to save game file due to IOException: "
|
||||||
|
+ ioe.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, Player player) {
|
private void logResults(FileWriter writer, List<GameResult> results, String player1, String player2) throws IOException {
|
||||||
|
String header = "Cumulative results for " + results.size()
|
||||||
|
+ " games (BLACK=" + player1 + ", WHITE=" + player2
|
||||||
|
+ ")";
|
||||||
|
|
||||||
|
System.out.println(header);
|
||||||
|
writer.write(header);
|
||||||
|
writer.write("\n");
|
||||||
|
|
||||||
|
for (int i = 0; i < results.size(); i++) {
|
||||||
|
String resultLine = (i+1) + ". " + results.get(i);
|
||||||
|
System.out.println(resultLine);
|
||||||
|
writer.write(resultLine);
|
||||||
|
writer.write("\n");
|
||||||
|
}
|
||||||
|
writer.flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
|
||||||
|
Player player) {
|
||||||
switch (playerType) {
|
switch (playerType) {
|
||||||
case HUMAN:
|
case HUMAN:
|
||||||
return new HumanKeyboardInput();
|
return new HumanKeyboardInput();
|
||||||
case HUMAN_GUI:
|
case HUMAN_GUI:
|
||||||
return new HumanGuiInput(new Goban(gameConfig, player));
|
return new HumanGuiInput(new Goban(gameConfig, player));
|
||||||
case ROOT_PAR:
|
case ROOT_PAR:
|
||||||
return new RootParallelization(4, 6000L);
|
return new RootParallelization(4, DEFAULT_TURN_LENGTH * 1000L * 3);
|
||||||
case UCT_SLOW:
|
case UCT_SLOW:
|
||||||
return new MonteCarloUCT(new RandomMovePolicy(), 4000L);
|
return new MonteCarloUCT(new RandomMovePolicy(),
|
||||||
|
DEFAULT_TURN_LENGTH * 1000L * 3);
|
||||||
case UCT_FAST:
|
case UCT_FAST:
|
||||||
return new MonteCarloUCT(new RandomMovePolicy(), 1000L);
|
return new MonteCarloUCT(new RandomMovePolicy(),
|
||||||
|
DEFAULT_TURN_LENGTH * 1000L);
|
||||||
|
case RANDOM:
|
||||||
|
return new RandomMovePolicy();
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
|
throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
|
||||||
+ playerType);
|
+ playerType);
|
||||||
|
|||||||
15
src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java
Normal file
15
src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import org.neuroph.core.NeuralNetwork;
|
||||||
|
import org.neuroph.core.learning.SupervisedTrainingElement;
|
||||||
|
import org.neuroph.core.learning.TrainingSet;
|
||||||
|
|
||||||
|
public interface NeuralNetLearner {
|
||||||
|
void learn(TrainingSet<SupervisedTrainingElement> trainingSet);
|
||||||
|
|
||||||
|
void reset();
|
||||||
|
|
||||||
|
NeuralNetwork getNeuralNetwork();
|
||||||
|
|
||||||
|
void setNeuralNetwork(NeuralNetwork neuralNetwork);
|
||||||
|
}
|
||||||
@@ -6,6 +6,8 @@ import java.io.FileNotFoundException;
|
|||||||
import java.io.FilenameFilter;
|
import java.io.FilenameFilter;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.GameRecord;
|
||||||
|
import net.woodyfolsom.msproj.Referee;
|
||||||
import net.woodyfolsom.msproj.sgf.SGFLexer;
|
import net.woodyfolsom.msproj.sgf.SGFLexer;
|
||||||
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
|
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
|
||||||
import net.woodyfolsom.msproj.sgf.SGFParser;
|
import net.woodyfolsom.msproj.sgf.SGFParser;
|
||||||
@@ -40,40 +42,18 @@ public class PassLearner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void parseSGF(File sgfFile) {
|
public void parseSGF(File sgfFile) {
|
||||||
FileInputStream fis;
|
FileInputStream sgfInputStream;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
fis = new FileInputStream(sgfFile);
|
sgfInputStream = new FileInputStream(sgfFile);
|
||||||
ANTLRStringStream in;
|
GameRecord gameRecord = Referee.replay(sgfInputStream);
|
||||||
try {
|
//...
|
||||||
in = new ANTLRInputStream(fis);
|
|
||||||
SGFLexer lexer = new SGFLexer(in);
|
|
||||||
CommonTokenStream tokens = new CommonTokenStream(lexer);
|
|
||||||
SGFParser parser = new SGFParser(tokens);
|
|
||||||
SGFNodeCollection nodeCollection;
|
|
||||||
try {
|
|
||||||
nodeCollection = parser.collection();
|
|
||||||
|
|
||||||
System.out.println("To SGF:");
|
|
||||||
System.out.println(nodeCollection.toSGF());
|
|
||||||
System.out.println("");
|
|
||||||
|
|
||||||
System.out.println("To LaTeX:");
|
|
||||||
System.out.println(nodeCollection.toLateX());
|
|
||||||
System.out.println("");
|
|
||||||
} catch (RecognitionException re) {
|
|
||||||
re.printStackTrace();
|
|
||||||
}
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
ioe.printStackTrace();
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
fis.close();
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
System.out.println("Error closing input stream for file" + sgfFile.getPath());
|
|
||||||
}
|
|
||||||
} catch (FileNotFoundException fnfe) {
|
} catch (FileNotFoundException fnfe) {
|
||||||
fnfe.printStackTrace();
|
fnfe.printStackTrace();
|
||||||
|
} catch (RecognitionException re) {
|
||||||
|
re.printStackTrace();
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
ioe.printStackTrace();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
42
src/net/woodyfolsom/msproj/ann/XORLearner.java
Normal file
42
src/net/woodyfolsom/msproj/ann/XORLearner.java
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import org.neuroph.core.NeuralNetwork;
|
||||||
|
import org.neuroph.core.learning.SupervisedTrainingElement;
|
||||||
|
import org.neuroph.core.learning.TrainingSet;
|
||||||
|
import org.neuroph.nnet.MultiLayerPerceptron;
|
||||||
|
import org.neuroph.util.TransferFunctionType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Based on sample code from http://neuroph.sourceforge.net
|
||||||
|
*
|
||||||
|
* @author Woody
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class XORLearner implements NeuralNetLearner {
|
||||||
|
private NeuralNetwork neuralNetwork;
|
||||||
|
|
||||||
|
public XORLearner() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork getNeuralNetwork() {
|
||||||
|
return neuralNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(TrainingSet<SupervisedTrainingElement> trainingSet) {
|
||||||
|
this.neuralNetwork.learn(trainingSet);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
this.neuralNetwork = new MultiLayerPerceptron(
|
||||||
|
TransferFunctionType.TANH, 2, 3, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
|
||||||
|
this.neuralNetwork = neuralNetwork;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ import java.util.ArrayList;
|
|||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.Player;
|
import net.woodyfolsom.msproj.Action;
|
||||||
|
|
||||||
public class SGFGameTree {
|
public class SGFGameTree {
|
||||||
private List<SGFNode> nodeSequence = new ArrayList<SGFNode>();
|
private List<SGFNode> nodeSequence = new ArrayList<SGFNode>();
|
||||||
@@ -30,58 +30,30 @@ public class SGFGameTree {
|
|||||||
subTrees.add(subTree);
|
subTrees.add(subTree);
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getSubTreeCount() {
|
public List<Action> getMoves(int boardSize) {
|
||||||
return subTrees.size();
|
List<Action> moves = new ArrayList<Action>();
|
||||||
|
|
||||||
|
for (SGFNode node : nodeSequence) {
|
||||||
|
SGFValue<?> sgfValue;
|
||||||
|
switch (node.getType()) {
|
||||||
|
case MOVE_BLACK :
|
||||||
|
sgfValue = node.getFirstValue(SGFIdentifier.MOVE_BLACK);
|
||||||
|
break;
|
||||||
|
case MOVE_WHITE :
|
||||||
|
sgfValue = node.getFirstValue(SGFIdentifier.MOVE_WHITE);
|
||||||
|
break;
|
||||||
|
default :
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
moves.add(Action.parseSGF(sgfValue.getText(), boardSize));
|
||||||
|
}
|
||||||
|
|
||||||
|
return moves;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toLateXmoves(Player player, int boardSize) {
|
public int getSubTreeCount() {
|
||||||
StringBuilder latexSB = new StringBuilder();
|
return subTrees.size();
|
||||||
SGFNode.TYPE nodeType;
|
|
||||||
SGFIdentifier sgfIdent;
|
|
||||||
if (player == Player.WHITE) {
|
|
||||||
nodeType = SGFNode.TYPE.MOVE_WHITE;
|
|
||||||
sgfIdent = SGFIdentifier.MOVE_WHITE;
|
|
||||||
latexSB.append("\\white{");
|
|
||||||
} else if (player == Player.BLACK) {
|
|
||||||
nodeType = SGFNode.TYPE.MOVE_BLACK;
|
|
||||||
sgfIdent = SGFIdentifier.MOVE_BLACK;
|
|
||||||
latexSB.append("\\black{");
|
|
||||||
} else {
|
|
||||||
throw new RuntimeException("Invalid player: " + player);
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean firstMove = true;
|
|
||||||
int nMoves = 0;
|
|
||||||
for (SGFNode node : nodeSequence) {
|
|
||||||
if (node.getType() != nodeType) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
SGFValue<?> sgfValue = node.getFirstValue(sgfIdent);
|
|
||||||
if (sgfValue.isEmpty()) {
|
|
||||||
// TODO later this will be the LaTeX igo code for 'Pass'?
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (firstMove) {
|
|
||||||
firstMove = false;
|
|
||||||
} else {
|
|
||||||
latexSB.append(",");
|
|
||||||
}
|
|
||||||
SGFCoord sgfCoord = (SGFCoord) sgfValue.getValue();
|
|
||||||
char column = sgfCoord.getColumn();
|
|
||||||
if (column >= 'i') {
|
|
||||||
latexSB.append((char) (column + 1));
|
|
||||||
} else {
|
|
||||||
latexSB.append(column);
|
|
||||||
}
|
|
||||||
char row = sgfCoord.getRow();
|
|
||||||
latexSB.append(boardSize - row + 'a');
|
|
||||||
nMoves++;
|
|
||||||
}
|
|
||||||
if (nMoves == 0) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
latexSB.append("}\n");
|
|
||||||
return latexSB.toString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getBoardSize() {
|
public int getBoardSize() {
|
||||||
@@ -97,33 +69,17 @@ public class SGFGameTree {
|
|||||||
throw new RuntimeException("Cannot get board size: SGFNode with identifier SZ was not found in any nodeSequence of type ROOT.");
|
throw new RuntimeException("Cannot get board size: SGFNode with identifier SZ was not found in any nodeSequence of type ROOT.");
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toLateX(int boardSize) {
|
public double getKomi() {
|
||||||
StringBuilder latexSB = new StringBuilder();
|
|
||||||
|
|
||||||
// Somewhat convoluted logic here because the grammar does not require
|
|
||||||
// all root
|
|
||||||
// properties to be included in the same node in the tree's node
|
|
||||||
// sequence, although they should
|
|
||||||
// each be unique among all node sequences in the tree.
|
|
||||||
for (SGFNode node : nodeSequence) {
|
for (SGFNode node : nodeSequence) {
|
||||||
SGFNode.TYPE nodeType = node.getType();
|
SGFNode.TYPE nodeType = node.getType();
|
||||||
switch (nodeType) {
|
switch (nodeType) {
|
||||||
case ROOT:
|
case ROOT:
|
||||||
latexSB.append("\\gobansize");
|
return Double.parseDouble(node.getFirstValue(SGFIdentifier.KOMI).getText());
|
||||||
latexSB.append("{");
|
|
||||||
latexSB.append(boardSize);
|
|
||||||
latexSB.append("}\n");
|
|
||||||
latexSB.append("\\shortstack{\\showfullgoban\\\\");
|
|
||||||
SGFResult result = (SGFResult) node.getFirstValue(
|
|
||||||
SGFIdentifier.RESULT).getValue();
|
|
||||||
latexSB.append(result.getFullText());
|
|
||||||
latexSB.append("}\n");
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
// ignore
|
// ignore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return latexSB.toString();
|
throw new RuntimeException("Cannot get board size: SGFNode with identifier KM was not found in any nodeSequence of type ROOT.");
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toSGF() {
|
public String toSGF() {
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ package net.woodyfolsom.msproj.sgf;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.Player;
|
import net.woodyfolsom.msproj.Action;
|
||||||
|
import net.woodyfolsom.msproj.GameConfig;
|
||||||
|
|
||||||
public class SGFNodeCollection {
|
public class SGFNodeCollection {
|
||||||
private List<SGFGameTree> gameTrees = new ArrayList<SGFGameTree>();
|
private List<SGFGameTree> gameTrees = new ArrayList<SGFGameTree>();
|
||||||
@@ -16,22 +17,21 @@ public class SGFNodeCollection {
|
|||||||
return gameTrees.get(index);
|
return gameTrees.get(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GameConfig getGameConfig() {
|
||||||
|
SGFGameTree gameTree = gameTrees.get(0);
|
||||||
|
int boardSize = gameTree.getBoardSize();
|
||||||
|
double komi = gameTree.getKomi();
|
||||||
|
|
||||||
|
return new GameConfig(boardSize,komi);
|
||||||
|
}
|
||||||
|
|
||||||
public int getGameTreeCount() {
|
public int getGameTreeCount() {
|
||||||
return gameTrees.size();
|
return gameTrees.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toLateX() {
|
public List<Action> getMoves(int boardSize) {
|
||||||
StringBuilder latexFormatString = new StringBuilder("");
|
|
||||||
SGFGameTree gameTree = gameTrees.get(0);
|
SGFGameTree gameTree = gameTrees.get(0);
|
||||||
|
return gameTree.getMoves(boardSize);
|
||||||
int boardSize = gameTree.getBoardSize();
|
|
||||||
latexFormatString.append(gameTree.toLateXmoves(Player.BLACK, boardSize));
|
|
||||||
latexFormatString.append(gameTree.toLateXmoves(Player.WHITE, boardSize));
|
|
||||||
latexFormatString.append("\\begin{center}\n");
|
|
||||||
latexFormatString.append(gameTree.toLateX(boardSize));
|
|
||||||
latexFormatString.append("\\end{center}");
|
|
||||||
|
|
||||||
return latexFormatString.toString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -18,16 +18,6 @@ public class SGFResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getFullText() {
|
|
||||||
if (resignation == false && tie == false) {
|
|
||||||
return winner.getColor() + " wins by " + score;
|
|
||||||
} else if (resignation == true && tie == false) {
|
|
||||||
return winner.getColor() + " wins by resignation";
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException("Not implemented");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
if (resignation == false && tie == false) {
|
if (resignation == false && tie == false) {
|
||||||
|
|||||||
42
test/net/woodyfolsom/msproj/RefereeTest.java
Normal file
42
test/net/woodyfolsom/msproj/RefereeTest.java
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package net.woodyfolsom.msproj;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
import org.antlr.runtime.RecognitionException;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class RefereeTest {
|
||||||
|
//private static final String FILENAME = "data/tourney1/gogame-121115115720-0500.sgf";
|
||||||
|
private static final String FILENAME = "data/games/1334-gokifu-20120916-Gu_Li-Lee_Sedol.sgf";
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReplay() throws IOException, RecognitionException {
|
||||||
|
File sgfFile = new File(FILENAME);
|
||||||
|
InputStream fis = new FileInputStream(sgfFile);
|
||||||
|
|
||||||
|
GameRecord gameRecord = Referee.replay(fis);
|
||||||
|
|
||||||
|
fis.close();
|
||||||
|
|
||||||
|
//assertEquals(9, gameRecord.getGameConfig().getSize());
|
||||||
|
assertEquals(19, gameRecord.getGameConfig().getSize());
|
||||||
|
|
||||||
|
//assertEquals(5.5, gameRecord.getGameConfig().getKomi(), 0.1);
|
||||||
|
assertEquals(7.5, gameRecord.getGameConfig().getKomi(), 0.1);
|
||||||
|
|
||||||
|
//assertEquals(74, gameRecord.getNumTurns());
|
||||||
|
assertEquals(214, gameRecord.getNumTurns());
|
||||||
|
|
||||||
|
System.out.println("Final board state (LaTeX): ");
|
||||||
|
LatexWriter.write(System.out, gameRecord, 214);
|
||||||
|
|
||||||
|
System.out.println("Final board state (SGF): ");
|
||||||
|
SGFWriter.write(System.out, gameRecord);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
86
test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java
Normal file
86
test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.neuroph.core.NeuralNetwork;
|
||||||
|
import org.neuroph.core.learning.SupervisedTrainingElement;
|
||||||
|
import org.neuroph.core.learning.TrainingSet;
|
||||||
|
|
||||||
|
public class NeuralNetLearnerTest {
|
||||||
|
private static final String FILENAME = "myMlPerceptron.nnet";
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void deleteNewNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void deleteSavedNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearnSaveLoad() {
|
||||||
|
NeuralNetLearner nnLearner = new XORLearner();
|
||||||
|
|
||||||
|
// create training set (logical XOR function)
|
||||||
|
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
|
||||||
|
2, 1);
|
||||||
|
for (int x = 0; x < 1000; x++) {
|
||||||
|
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||||
|
0 }, new double[] { 0 }));
|
||||||
|
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||||
|
1 }, new double[] { 1 }));
|
||||||
|
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||||
|
0 }, new double[] { 1 }));
|
||||||
|
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||||
|
1 }, new double[] { 0 }));
|
||||||
|
}
|
||||||
|
|
||||||
|
nnLearner.learn(trainingSet);
|
||||||
|
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
|
||||||
|
|
||||||
|
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
|
||||||
|
2, 1);
|
||||||
|
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||||
|
0 }, new double[] { 0 }));
|
||||||
|
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
||||||
|
1 }, new double[] { 1 }));
|
||||||
|
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||||
|
0 }, new double[] { 1 }));
|
||||||
|
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
||||||
|
1 }, new double[] { 0 }));
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network):");
|
||||||
|
testNetwork(nnet, valSet);
|
||||||
|
|
||||||
|
nnet.save(FILENAME);
|
||||||
|
nnet = NeuralNetwork.load(FILENAME);
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network):");
|
||||||
|
testNetwork(nnet, valSet);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
|
||||||
|
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
|
||||||
|
|
||||||
|
nnet.setInput(trainingElement.getInput());
|
||||||
|
nnet.calculate();
|
||||||
|
double[] networkOutput = nnet.getOutput();
|
||||||
|
System.out.print("Input: "
|
||||||
|
+ Arrays.toString(trainingElement.getInput()));
|
||||||
|
System.out.println(" Output: " + Arrays.toString(networkOutput));
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package net.woodyfolsom.msproj.sgf;
|
package net.woodyfolsom.msproj.sgf;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@@ -15,6 +16,7 @@ public class CollectionTest {
|
|||||||
public static final String TEST_SGF =
|
public static final String TEST_SGF =
|
||||||
"(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+
|
"(;FF[4]SZ[9]KM[5.5]RE[W+6.5]"+
|
||||||
";W[ee];B[];W[])";
|
";W[ee];B[];W[])";
|
||||||
|
|
||||||
public static final String TEST_LATEX =
|
public static final String TEST_LATEX =
|
||||||
"\\white{e5}\n" +
|
"\\white{e5}\n" +
|
||||||
"\\begin{center}\n" +
|
"\\begin{center}\n" +
|
||||||
@@ -48,8 +50,6 @@ public class CollectionTest {
|
|||||||
CommonTokenStream tokens = new CommonTokenStream(lexer);
|
CommonTokenStream tokens = new CommonTokenStream(lexer);
|
||||||
SGFParser parser = new SGFParser(tokens);
|
SGFParser parser = new SGFParser(tokens);
|
||||||
SGFNodeCollection nodeCollection = parser.collection();
|
SGFNodeCollection nodeCollection = parser.collection();
|
||||||
|
assertNotNull(nodeCollection);
|
||||||
String actualLaTeX = nodeCollection.toLateX();
|
|
||||||
assertEquals(TEST_LATEX, actualLaTeX);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,11 +25,6 @@ public class SGFParserTest {
|
|||||||
System.out.println("To SGF:");
|
System.out.println("To SGF:");
|
||||||
System.out.println(nodeCollection.toSGF());
|
System.out.println(nodeCollection.toSGF());
|
||||||
System.out.println("");
|
System.out.println("");
|
||||||
|
|
||||||
System.out.println("To LaTeX:");
|
|
||||||
System.out.println(nodeCollection.toLateX());
|
|
||||||
System.out.println("");
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -45,10 +40,5 @@ public class SGFParserTest {
|
|||||||
System.out.println("To SGF:");
|
System.out.println("To SGF:");
|
||||||
System.out.println(nodeCollection.toSGF());
|
System.out.println(nodeCollection.toSGF());
|
||||||
System.out.println("");
|
System.out.println("");
|
||||||
|
|
||||||
System.out.println("To LaTeX:");
|
|
||||||
System.out.println(nodeCollection.toLateX());
|
|
||||||
System.out.println("");
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user