diff --git a/.classpath b/.classpath
index d25bc2e..6f91090 100644
--- a/.classpath
+++ b/.classpath
@@ -7,7 +7,6 @@
-
-
+
diff --git a/build.xml b/build.xml
index b641d96..19cde56 100644
--- a/build.xml
+++ b/build.xml
@@ -23,7 +23,7 @@
-
+
@@ -33,9 +33,25 @@
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -44,12 +60,12 @@
-
+
-
-
+
+
diff --git a/data/gogame.cfg b/data/gogame.cfg
index 3a82545..bfc04b1 100644
--- a/data/gogame.cfg
+++ b/data/gogame.cfg
@@ -1,7 +1,10 @@
-PlayerOne=ROOT_PAR
+PlayerOne=RANDOM
PlayerTwo=RANDOM
-GUIDelay=2000 //1 second
+GUIDelay=1000 //1 second
BoardSize=9
Komi=6.5
-NumGames=10 //Games for each player
-TurnTime=2000 //seconds per player per turn
\ No newline at end of file
+NumGames=1000 //Games for each color per player
+TurnTime=1000 //seconds per player per turn
+SpectatorBoardShown=false;
+WhiteMoveLogged=false;
+BlackMoveLogged=false;
\ No newline at end of file
diff --git a/data/networks/Pass2.nn b/data/networks/Pass2.nn
index e5ae0da..e3d9157 100644
Binary files a/data/networks/Pass2.nn and b/data/networks/Pass2.nn differ
diff --git a/lib/encog-engine-2.5.0.jar b/lib/encog-engine-2.5.0.jar
deleted file mode 100644
index 12d9d63..0000000
Binary files a/lib/encog-engine-2.5.0.jar and /dev/null differ
diff --git a/lib/encog-java-core-javadoc.jar b/lib/encog-java-core-javadoc.jar
new file mode 100644
index 0000000..2cf8d93
Binary files /dev/null and b/lib/encog-java-core-javadoc.jar differ
diff --git a/lib/encog-java-core-sources.jar b/lib/encog-java-core-sources.jar
new file mode 100644
index 0000000..0cb6b50
Binary files /dev/null and b/lib/encog-java-core-sources.jar differ
diff --git a/lib/encog-java-core.jar b/lib/encog-java-core.jar
new file mode 100644
index 0000000..f846d91
Binary files /dev/null and b/lib/encog-java-core.jar differ
diff --git a/lib/neuroph-2.6.jar b/lib/neuroph-2.6.jar
deleted file mode 100644
index 2009162..0000000
Binary files a/lib/neuroph-2.6.jar and /dev/null differ
diff --git a/src/net/woodyfolsom/msproj/GameRecord.java b/src/net/woodyfolsom/msproj/GameRecord.java
index cc5a234..8202d5e 100644
--- a/src/net/woodyfolsom/msproj/GameRecord.java
+++ b/src/net/woodyfolsom/msproj/GameRecord.java
@@ -23,6 +23,16 @@ public class GameRecord {
moves.add(Action.NONE);
}
+ public GameRecord(GameRecord that) {
+ for(GameState gameState : that.gameStates) {
+ gameStates.add(new GameState(gameState));
+ }
+ //initial 'move' of Action.NONE allows for a game that starts with a board setup
+ for (Action action : that.moves) {
+ moves.add(action);
+ }
+ }
+
/**
* Adds a comment for the current turn.
* @param comment
diff --git a/src/net/woodyfolsom/msproj/GameSettings.java b/src/net/woodyfolsom/msproj/GameSettings.java
index c9b98fa..f595b33 100644
--- a/src/net/woodyfolsom/msproj/GameSettings.java
+++ b/src/net/woodyfolsom/msproj/GameSettings.java
@@ -13,6 +13,9 @@ public class GameSettings {
private int boardSize = 9;
private double komi = 6.5;
private int numGames = 10;
+ private boolean spectatorBoardShown = false;
+ private boolean whiteMoveLogged = true;
+ private boolean blackMoveLogged = true;
private GameSettings() {
}
@@ -49,6 +52,12 @@ public class GameSettings {
gameSettings.setNumGames(Integer.parseInt(value));
} else if ("Komi".equals(name)) {
gameSettings.setKomi(Double.parseDouble(value));
+ } else if ("SpectatorBoardShown".equals(name)) {
+ gameSettings.setSpectatorBoardShown(Boolean.parseBoolean(value));
+ } else if ("WhiteMoveLogged".equals(name)) {
+ gameSettings.setWhiteMoveLogged(Boolean.parseBoolean(value));
+ } else if ("BlackMoveLogged".equals(name)) {
+ gameSettings.setBlackMoveLogged(Boolean.parseBoolean(value));
} else {
System.out.println("Ignoring game settings property with unrecognized name: " + name);
}
@@ -127,4 +136,29 @@ public class GameSettings {
sb.append(", GUIDelay=" + guiDelay);
return sb.toString();
}
+
+ public boolean isSpectatorBoardShown() {
+ return spectatorBoardShown;
+ }
+
+ private void setSpectatorBoardShown(boolean spectatorBoardShown) {
+ this.spectatorBoardShown = spectatorBoardShown;
+ }
+
+ public boolean isWhiteMoveLogged() {
+ return whiteMoveLogged;
+ }
+
+ private void setWhiteMoveLogged(boolean whiteMoveLogged) {
+ this.whiteMoveLogged = whiteMoveLogged;
+ }
+
+ public boolean isBlackMoveLogged() {
+ return blackMoveLogged;
+ }
+
+ private void setBlackMoveLogged(boolean blackMoveLogged) {
+ this.blackMoveLogged = blackMoveLogged;
+ }
+
}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/Referee.java b/src/net/woodyfolsom/msproj/Referee.java
index c2be1e2..6299d53 100644
--- a/src/net/woodyfolsom/msproj/Referee.java
+++ b/src/net/woodyfolsom/msproj/Referee.java
@@ -4,8 +4,6 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
-import java.text.DateFormat;
-import java.text.SimpleDateFormat;
import net.woodyfolsom.msproj.gui.Goban;
import net.woodyfolsom.msproj.policy.HumanGuiInput;
@@ -65,10 +63,11 @@ public class Referee {
return gameRecord;
}
- public GameResult play(GameConfig gameConfig, int gameNo) {
+ public GameResult play(GameConfig gameConfig, int gameNo,
+ boolean showSpectatorBoard, boolean logGameRecord) {
GameRecord gameRecord = new GameRecord(gameConfig);
- System.out.println("Game started.");
+ //System.out.println("Game started.");
GameState initialGameState = gameRecord.getGameState(gameRecord
.getNumTurns());
@@ -78,20 +77,21 @@ public class Referee {
whitePolicy.setState(initialGameState);
Goban spectatorBoard;
- if (blackPolicy instanceof HumanGuiInput
- || whitePolicy instanceof HumanGuiInput) {
+ if (blackPolicy instanceof HumanGuiInput || whitePolicy instanceof HumanGuiInput) {
System.out.println("Human is controlling the game board GUI.");
spectatorBoard = null;
- } else {
+ } else if (showSpectatorBoard){
System.out.println("Starting game board GUI in spectator mode.");
- spectatorBoard = new Goban(gameConfig, null);
+ spectatorBoard = new Goban(gameConfig, null, "Game #" + gameNo);
+ } else { // else showing spectator board is disabled
+ spectatorBoard = null;
}
-
+
try {
while (!gameRecord.isFinished()) {
GameState gameState = gameRecord.getGameState(gameRecord
.getNumTurns());
- //System.out.println(gameState);
+ // System.out.println(gameState);
Player playerToMove = gameRecord.getPlayerToMove();
Policy policy = getPolicy(playerToMove);
@@ -120,8 +120,9 @@ public class Referee {
System.out.println("Game over. Result: " + result);
- //DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
+ // DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
+ if (logGameRecord) {
try {
// File sgfFile = new File("gogame-" + dateFormat.format(new Date())
@@ -143,8 +144,9 @@ public class Referee {
System.out.println("Unable to save game file due to IOException: "
+ ioe.getMessage());
}
-
- System.out.println("Game finished.");
+ }
+
+ //System.out.println("Game finished.");
return result;
}
diff --git a/src/net/woodyfolsom/msproj/StandAloneGame.java b/src/net/woodyfolsom/msproj/StandAloneGame.java
index 1455db0..fe4ca42 100644
--- a/src/net/woodyfolsom/msproj/StandAloneGame.java
+++ b/src/net/woodyfolsom/msproj/StandAloneGame.java
@@ -19,8 +19,9 @@ import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParallelization;
public class StandAloneGame {
- private static final int EXIT_NOMINAL = 0;
- private static final int EXIT_IO_EXCEPTION = 1;
+ public static final int EXIT_USER_QUIT = 1;
+ public static final int EXIT_NOMINAL = 0;
+ public static final int EXIT_IO_EXCEPTION = -1;
private int gameNo = 0;
@@ -38,7 +39,9 @@ public class StandAloneGame {
parsePlayerType(gameSettings.getPlayerOne()),
parsePlayerType(gameSettings.getPlayerTwo()),
gameSettings.getBoardSize(), gameSettings.getKomi(),
- gameSettings.getNumGames(), gameSettings.getTurnTime());
+ gameSettings.getNumGames(), gameSettings.getTurnTime(),
+ gameSettings.isSpectatorBoardShown(),
+ gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
} catch (IOException ioe) {
ioe.printStackTrace();
System.exit(EXIT_IO_EXCEPTION);
@@ -65,7 +68,8 @@ public class StandAloneGame {
}
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
- int size, double komi, int rounds, long turnLength) {
+ int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard,
+ boolean blackMoveLogged, boolean whiteMoveLogged) {
long startTime = System.currentTimeMillis();
@@ -74,32 +78,31 @@ public class StandAloneGame {
Referee referee = new Referee();
referee.setPolicy(Player.BLACK,
- getPolicy(playerType1, gameConfig, Player.BLACK, turnLength));
+ getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
referee.setPolicy(Player.WHITE,
- getPolicy(playerType2, gameConfig, Player.WHITE, turnLength));
+ getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
List round1results = new ArrayList();
+ boolean logGameRecords = rounds <= 50;
for (int round = 0; round < rounds; round++) {
gameNo++;
- round1results.add(referee.play(gameConfig, gameNo));
+ round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
}
List round2results = new ArrayList();
referee.setPolicy(Player.BLACK,
- getPolicy(playerType2, gameConfig, Player.BLACK, turnLength));
+ getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
referee.setPolicy(Player.WHITE,
- getPolicy(playerType1, gameConfig, Player.WHITE, turnLength));
+ getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
for (int round = 0; round < rounds; round++) {
gameNo++;
- round2results.add(referee.play(gameConfig, gameNo));
+ round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
}
long endTime = System.currentTimeMillis();
-
DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmss");
-
try {
File txtFile = new File("gotournament-"
@@ -107,14 +110,16 @@ public class StandAloneGame {
FileWriter writer = new FileWriter(txtFile);
try {
+ if (!logGameRecords) {
+ System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
+ }
+
logResults(writer, round1results, playerType1.toString(),
- playerType2.toString());
+ playerType2.toString());
logResults(writer, round2results, playerType2.toString(),
- playerType1.toString());
-
+ playerType1.toString());
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
+ " seconds.");
-
System.out.println("Game tournament saved as "
+ txtFile.getAbsolutePath());
} finally {
@@ -149,19 +154,19 @@ public class StandAloneGame {
}
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
- Player player, long turnLength) {
+ Player player, long turnLength, boolean moveLogged) {
switch (playerType) {
case HUMAN:
return new HumanKeyboardInput();
case HUMAN_GUI:
- return new HumanGuiInput(new Goban(gameConfig, player));
+ return new HumanGuiInput(new Goban(gameConfig, player,""));
case ROOT_PAR:
return new RootParallelization(4, turnLength);
case UCT:
return new MonteCarloUCT(new RandomMovePolicy(), turnLength);
case RANDOM:
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
- randomMovePolicy.setLogging(true);
+ randomMovePolicy.setLogging(moveLogged);
return randomMovePolicy;
case RAVE:
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
diff --git a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
new file mode 100644
index 0000000..698197f
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
@@ -0,0 +1,54 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import org.encog.neural.networks.BasicNetwork;
+import org.encog.neural.networks.PersistBasicNetwork;
+
+public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
+ protected BasicNetwork neuralNetwork;
+ protected int actualTrainingEpochs = 0;
+ protected int maxTrainingEpochs = 1000;
+
+ public int getActualTrainingEpochs() {
+ return actualTrainingEpochs;
+ }
+
+ public int getMaxTrainingEpochs() {
+ return maxTrainingEpochs;
+ }
+
+ @Override
+ public BasicNetwork getNeuralNetwork() {
+ return neuralNetwork;
+ }
+
+ public void load(String filename) throws IOException {
+ FileInputStream fis = new FileInputStream(new File(filename));
+ neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis);
+ fis.close();
+ }
+
+ @Override
+ public void reset() {
+ neuralNetwork.reset();
+ }
+
+ @Override
+ public void reset(int seed) {
+ neuralNetwork.reset(seed);
+ }
+
+ public void save(String filename) throws IOException {
+ FileOutputStream fos = new FileOutputStream(new File(filename));
+ new PersistBasicNetwork().save(fos, getNeuralNetwork());
+ fos.close();
+ }
+
+ public void setMaxTrainingEpochs(int max) {
+ this.maxTrainingEpochs = max;
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/ann/DoublePair.java b/src/net/woodyfolsom/msproj/ann/DoublePair.java
new file mode 100644
index 0000000..49f7eeb
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/DoublePair.java
@@ -0,0 +1,17 @@
+package net.woodyfolsom.msproj.ann;
+
+import org.encog.ml.data.basic.BasicMLData;
+
+public class DoublePair extends BasicMLData {
+ // private final double x;
+ // private final double y;
+
+ /**
+ *
+ */
+ private static final long serialVersionUID = 1L;
+
+ public DoublePair(double x, double y) {
+ super(new double[] { x, y });
+ }
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java b/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java
new file mode 100644
index 0000000..7823a8f
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java
@@ -0,0 +1,95 @@
+package net.woodyfolsom.msproj.ann;
+
+import org.encog.mathutil.error.ErrorCalculationMode;
+
+/*
+ Initial erison of this class was a verbatim copy from Encog framework.
+ */
+
+public class ErrorCalculation {
+
+ private static ErrorCalculationMode mode = ErrorCalculationMode.MSE;
+
+ public static ErrorCalculationMode getMode() {
+ return ErrorCalculation.mode;
+ }
+
+ public static void setMode(final ErrorCalculationMode theMode) {
+ ErrorCalculation.mode = theMode;
+ }
+
+ private double globalError;
+
+ private int setSize;
+
+ public final double calculate() {
+ if (this.setSize == 0) {
+ return 0;
+ }
+
+ switch (ErrorCalculation.getMode()) {
+ case RMS:
+ return calculateRMS();
+ case MSE:
+ return calculateMSE();
+ case ESS:
+ return calculateESS();
+ default:
+ return calculateMSE();
+ }
+
+ }
+
+ public final double calculateMSE() {
+ if (this.setSize == 0) {
+ return 0;
+ }
+ final double err = this.globalError / this.setSize;
+ return err;
+
+ }
+
+ public final double calculateESS() {
+ if (this.setSize == 0) {
+ return 0;
+ }
+ final double err = this.globalError / 2;
+ return err;
+
+ }
+
+ public final double calculateRMS() {
+ if (this.setSize == 0) {
+ return 0;
+ }
+ final double err = Math.sqrt(this.globalError / this.setSize);
+ return err;
+ }
+
+ public final void reset() {
+ this.globalError = 0;
+ this.setSize = 0;
+ }
+
+ public final void updateError(final double actual, final double ideal) {
+
+ double delta = ideal - actual;
+
+ this.globalError += delta * delta;
+
+ this.setSize++;
+
+ }
+
+ public final void updateError(final double[] actual, final double[] ideal,
+ final double significance) {
+ for (int i = 0; i < actual.length; i++) {
+ double delta = (ideal[i] - actual[i]) * significance;
+
+ this.globalError += delta * delta;
+ }
+
+ this.setSize += ideal.length;
+ }
+
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLData.java b/src/net/woodyfolsom/msproj/ann/GameStateMLData.java
new file mode 100644
index 0000000..47d588c
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/GameStateMLData.java
@@ -0,0 +1,25 @@
+package net.woodyfolsom.msproj.ann;
+
+import net.woodyfolsom.msproj.GameState;
+
+import org.encog.ml.data.basic.BasicMLData;
+
+public class GameStateMLData extends BasicMLData {
+
+ /**
+ *
+ */
+ private static final long serialVersionUID = 1L;
+
+ private GameState gameState;
+
+ public GameStateMLData(double[] d, GameState gameState) {
+ super(d);
+ // TODO Auto-generated constructor stub
+ this.gameState = gameState;
+ }
+
+ public GameState getGameState() {
+ return gameState;
+ }
+}
diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java
new file mode 100644
index 0000000..8f41126
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java
@@ -0,0 +1,121 @@
+package net.woodyfolsom.msproj.ann;
+
+import net.woodyfolsom.msproj.GameResult;
+import net.woodyfolsom.msproj.GameState;
+import net.woodyfolsom.msproj.Player;
+
+import org.encog.ml.data.MLData;
+import org.encog.ml.data.MLDataPair;
+import org.encog.ml.data.basic.BasicMLData;
+import org.encog.ml.data.basic.BasicMLDataPair;
+import org.encog.util.kmeans.Centroid;
+
+public class GameStateMLDataPair implements MLDataPair {
+ //private final String[] inputs = { "BlackScore", "WhiteScore" };
+ //private final String[] outputs = { "BlackWins", "WhiteWins" };
+
+ private BasicMLDataPair mlDataPairDelegate;
+ private GameState gameState;
+
+ public GameStateMLDataPair(GameState gameState) {
+ this.gameState = gameState;
+ mlDataPairDelegate = new BasicMLDataPair(
+ new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal()));
+ }
+
+ public GameStateMLDataPair(GameStateMLDataPair that) {
+ this.gameState = new GameState(that.gameState);
+ mlDataPairDelegate = new BasicMLDataPair(
+ that.mlDataPairDelegate.getInput(),
+ that.mlDataPairDelegate.getIdeal());
+ }
+
+ @Override
+ public MLDataPair clone() {
+ return new GameStateMLDataPair(this);
+ }
+
+ @Override
+ public Centroid createCentroid() {
+ return mlDataPairDelegate.createCentroid();
+ }
+
+ /**
+ * Creates a vector of normalized scores from GameState.
+ *
+ * @return
+ */
+ private double[] createInput() {
+
+ GameResult result = gameState.getResult();
+
+ double maxScore = gameState.getGameConfig().getSize()
+ * gameState.getGameConfig().getSize();
+
+ double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
+ double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
+
+ return new double[] { blackScore, whiteScore };
+ }
+
+ /**
+ * Creates a vector of values indicating strength of black/white win output
+ * from network.
+ *
+ * @return
+ */
+ private double[] createIdeal() {
+ GameResult result = gameState.getResult();
+
+ double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
+ double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
+
+ return new double[] { blackWinner, whiteWinner };
+ }
+
+ @Override
+ public MLData getIdeal() {
+ return mlDataPairDelegate.getIdeal();
+ }
+
+ @Override
+ public double[] getIdealArray() {
+ return mlDataPairDelegate.getIdealArray();
+ }
+
+ @Override
+ public MLData getInput() {
+ return mlDataPairDelegate.getInput();
+ }
+
+ @Override
+ public double[] getInputArray() {
+ return mlDataPairDelegate.getInputArray();
+ }
+
+ @Override
+ public double getSignificance() {
+ return mlDataPairDelegate.getSignificance();
+ }
+
+ @Override
+ public boolean isSupervised() {
+ return mlDataPairDelegate.isSupervised();
+ }
+
+ @Override
+ public void setIdealArray(double[] arg0) {
+ mlDataPairDelegate.setIdealArray(arg0);
+ }
+
+ @Override
+ public void setInputArray(double[] arg0) {
+ mlDataPairDelegate.setInputArray(arg0);
+ }
+
+ @Override
+ public void setSignificance(double arg0) {
+ mlDataPairDelegate.setSignificance(arg0);
+ }
+
+}
diff --git a/src/net/woodyfolsom/msproj/ann/GradientWorker.java b/src/net/woodyfolsom/msproj/ann/GradientWorker.java
new file mode 100644
index 0000000..0678e4e
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/GradientWorker.java
@@ -0,0 +1,172 @@
+package net.woodyfolsom.msproj.ann;
+/*
+ * Class copied verbatim from Encog framework due to dependency on Propagation
+ * implementation.
+ *
+ * Encog(tm) Core v3.2 - Java Version
+ * http://www.heatonresearch.com/encog/
+ * http://code.google.com/p/encog-java/
+
+ * Copyright 2008-2012 Heaton Research, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * For more information on Heaton Research copyrights, licenses
+ * and trademarks visit:
+ * http://www.heatonresearch.com/copyright
+ */
+
+import java.util.List;
+import java.util.Set;
+
+import org.encog.engine.network.activation.ActivationFunction;
+import org.encog.ml.data.MLDataPair;
+import org.encog.ml.data.MLDataSet;
+import org.encog.ml.data.basic.BasicMLDataPair;
+import org.encog.neural.error.ErrorFunction;
+import org.encog.neural.flat.FlatNetwork;
+import org.encog.util.EngineArray;
+import org.encog.util.concurrency.EngineTask;
+
+public class GradientWorker implements EngineTask {
+
+ private final FlatNetwork network;
+ private final ErrorCalculation errorCalculation = new ErrorCalculation();
+ private final double[] actual;
+ private final double[] layerDelta;
+ private final int[] layerCounts;
+ private final int[] layerFeedCounts;
+ private final int[] layerIndex;
+ private final int[] weightIndex;
+ private final double[] layerOutput;
+ private final double[] layerSums;
+ private final double[] gradients;
+ private final double[] weights;
+ private final MLDataPair pair;
+ private final Set> training;
+ private final int low;
+ private final int high;
+ private final TemporalDifferenceLearning owner;
+ private double[] flatSpot;
+ private final ErrorFunction errorFunction;
+
+ public GradientWorker(final FlatNetwork theNetwork,
+ final TemporalDifferenceLearning theOwner,
+ final Set> theTraining, final int theLow,
+ final int theHigh, final double[] flatSpot,
+ ErrorFunction ef) {
+ this.network = theNetwork;
+ this.training = theTraining;
+ this.low = theLow;
+ this.high = theHigh;
+ this.owner = theOwner;
+ this.flatSpot = flatSpot;
+ this.errorFunction = ef;
+
+ this.layerDelta = new double[network.getLayerOutput().length];
+ this.gradients = new double[network.getWeights().length];
+ this.actual = new double[network.getOutputCount()];
+
+ this.weights = network.getWeights();
+ this.layerIndex = network.getLayerIndex();
+ this.layerCounts = network.getLayerCounts();
+ this.weightIndex = network.getWeightIndex();
+ this.layerOutput = network.getLayerOutput();
+ this.layerSums = network.getLayerSums();
+ this.layerFeedCounts = network.getLayerFeedCounts();
+
+ this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
+ .getOutputCount());
+ }
+
+ public FlatNetwork getNetwork() {
+ return this.network;
+ }
+
+ public double[] getWeights() {
+ return this.weights;
+ }
+
+ private void process(final double[] input, final double[] ideal, double s) {
+ this.network.compute(input, this.actual);
+
+ this.errorCalculation.updateError(this.actual, ideal, s);
+ this.errorFunction.calculateError(ideal, actual, this.layerDelta);
+
+ for (int i = 0; i < this.actual.length; i++) {
+
+ this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
+ .derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0]))
+ * (this.layerDelta[i] * s);
+ }
+
+ for (int i = this.network.getBeginTraining(); i < this.network
+ .getEndTraining(); i++) {
+ processLevel(i);
+ }
+ }
+
+ private void processLevel(final int currentLevel) {
+ final int fromLayerIndex = this.layerIndex[currentLevel + 1];
+ final int toLayerIndex = this.layerIndex[currentLevel];
+ final int fromLayerSize = this.layerCounts[currentLevel + 1];
+ final int toLayerSize = this.layerFeedCounts[currentLevel];
+
+ final int index = this.weightIndex[currentLevel];
+ final ActivationFunction activation = this.network
+ .getActivationFunctions()[currentLevel];
+ final double currentFlatSpot = this.flatSpot[currentLevel + 1];
+
+ // handle weights
+ int yi = fromLayerIndex;
+ for (int y = 0; y < fromLayerSize; y++) {
+ final double output = this.layerOutput[yi];
+ double sum = 0;
+ int xi = toLayerIndex;
+ int wi = index + y;
+ for (int x = 0; x < toLayerSize; x++) {
+ this.gradients[wi] += output * this.layerDelta[xi];
+ sum += this.weights[wi] * this.layerDelta[xi];
+ wi += fromLayerSize;
+ xi++;
+ }
+
+ this.layerDelta[yi] = sum
+ * (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot);
+ yi++;
+ }
+ }
+
+ public final void run() {
+ try {
+ this.errorCalculation.reset();
+ //for (int i = this.low; i <= this.high; i++) {
+ for (List trainingSequence : training) {
+ MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1);
+ this.pair.setInputArray(mldp.getInputArray());
+ if (this.pair.getIdealArray() != null) {
+ this.pair.setIdealArray(mldp.getIdealArray());
+ }
+ //this.training.getRecord(i, this.pair);
+ process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance());
+ }
+ //}
+ final double error = this.errorCalculation.calculate();
+ this.owner.report(this.gradients, error, null);
+ EngineArray.fill(this.gradients, 0);
+ } catch (final Throwable ex) {
+ this.owner.report(null, 0, ex);
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
new file mode 100644
index 0000000..64d40c7
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
@@ -0,0 +1,31 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Set;
+
+import org.encog.ml.data.MLData;
+import org.encog.ml.data.MLDataPair;
+import org.encog.ml.data.MLDataSet;
+import org.encog.neural.networks.BasicNetwork;
+
+public interface NeuralNetFilter {
+ BasicNetwork getNeuralNetwork();
+
+ public int getActualTrainingEpochs();
+ public int getInputSize();
+ public int getMaxTrainingEpochs();
+ public int getOutputSize();
+
+ public double computeValue(MLData input);
+ public double[] computeVector(MLData input);
+
+ public void learn(MLDataSet trainingSet);
+ public void learn(Set> trainingSet);
+
+ public void load(String fileName) throws IOException;
+ public void reset();
+ public void reset(int seed);
+ public void save(String fileName) throws IOException;
+ public void setMaxTrainingEpochs(int max);
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java b/src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java
deleted file mode 100644
index 42e9e59..0000000
--- a/src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java
+++ /dev/null
@@ -1,15 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import org.neuroph.core.NeuralNetwork;
-import org.neuroph.core.learning.SupervisedTrainingElement;
-import org.neuroph.core.learning.TrainingSet;
-
-public interface NeuralNetLearner {
- void learn(TrainingSet trainingSet);
-
- void reset();
-
- NeuralNetwork getNeuralNetwork();
-
- void setNeuralNetwork(NeuralNetwork neuralNetwork);
-}
diff --git a/src/net/woodyfolsom/msproj/ann/PassData.java b/src/net/woodyfolsom/msproj/ann/PassData.java
deleted file mode 100644
index be601b3..0000000
--- a/src/net/woodyfolsom/msproj/ann/PassData.java
+++ /dev/null
@@ -1,98 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import java.util.Arrays;
-
-import net.woodyfolsom.msproj.GameRecord;
-import net.woodyfolsom.msproj.GameResult;
-import net.woodyfolsom.msproj.GameState;
-import net.woodyfolsom.msproj.Player;
-
-import org.neuroph.core.learning.SupervisedTrainingElement;
-import org.neuroph.core.learning.TrainingSet;
-
-public class PassData {
- public enum DATA_TYPE { TRAINING, TEST, VALIDATION };
-
- public String[] inputs = { "BlackScore", "WhiteScore" };
- public String[] outputs = { "BlackWins", "WhiteWins" };
-
- private TrainingSet testSet;
- private TrainingSet trainingSet;
- private TrainingSet valSet;
-
- public PassData() {
- testSet = new TrainingSet(inputs.length, outputs.length);
- trainingSet = new TrainingSet(inputs.length, outputs.length);
- valSet = new TrainingSet(inputs.length, outputs.length);
- }
-
- public void addData(DATA_TYPE dataType, GameRecord gameRecord) {
- GameState finalState = gameRecord.getGameState(gameRecord.getNumTurns());
- GameResult result = finalState.getResult();
- double maxScore = finalState.getGameConfig().getSize() * finalState.getGameConfig().getSize();
-
- double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
- double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
-
- double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
- double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
-
- addData(dataType, blackScore, whiteScore, blackWinner, whiteWinner);
- }
-
- public void addData(DATA_TYPE dataType, double...data ) {
- double[] desiredInput = Arrays.copyOfRange(data,0,inputs.length);
- double[] desiredOutput = Arrays.copyOfRange(data, inputs.length, data.length);
-
- switch (dataType) {
- case TEST :
- testSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput));
- break;
- case TRAINING :
- trainingSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput));
- System.out.println("Added training input data: " + getInput(desiredInput) + ", output data: " + getOutput(desiredOutput));
- break;
- case VALIDATION :
- valSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput));
- break;
- default :
- throw new UnsupportedOperationException("invalid dataType " + dataType);
- }
- }
-
- public String getInput(double... inputValues) {
- StringBuilder sbuilder = new StringBuilder();
- boolean first = true;
- for (int i = 0; i < outputs.length; i++) {
- if (first) {
- first = false;
- } else {
- sbuilder.append(",");
- }
- sbuilder.append(inputs[i]);
- sbuilder.append(": ");
- sbuilder.append(inputValues[i]);
- }
- return sbuilder.toString();
- }
-
- public String getOutput(double... outputValues) {
- StringBuilder sbuilder = new StringBuilder();
- boolean first = true;
- for (int i = 0; i < outputs.length; i++) {
- if (first) {
- first = false;
- } else {
- sbuilder.append(",");
- }
- sbuilder.append(outputs[i]);
- sbuilder.append(": ");
- sbuilder.append(outputValues[i]);
- }
- return sbuilder.toString();
- }
-
- public TrainingSet getTrainingSet() {
- return trainingSet;
- }
-}
diff --git a/src/net/woodyfolsom/msproj/ann/PassLearner.java b/src/net/woodyfolsom/msproj/ann/PassLearner.java
deleted file mode 100644
index 1d99127..0000000
--- a/src/net/woodyfolsom/msproj/ann/PassLearner.java
+++ /dev/null
@@ -1,127 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FilenameFilter;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-import net.woodyfolsom.msproj.Action;
-import net.woodyfolsom.msproj.GameRecord;
-import net.woodyfolsom.msproj.Referee;
-import net.woodyfolsom.msproj.ann.PassData.DATA_TYPE;
-
-import org.antlr.runtime.RecognitionException;
-import org.neuroph.core.NeuralNetwork;
-import org.neuroph.core.learning.SupervisedTrainingElement;
-import org.neuroph.core.learning.TrainingSet;
-import org.neuroph.nnet.MultiLayerPerceptron;
-import org.neuroph.util.TransferFunctionType;
-
-public class PassLearner implements NeuralNetLearner {
- private NeuralNetwork neuralNetwork;
-
- public PassLearner() {
- reset();
- }
-
- private File[] getDataFiles(String dirName) {
- File file = new File(dirName);
- return file.listFiles(new FilenameFilter() {
- @Override
- public boolean accept(File dir, String name) {
- return name.toLowerCase().endsWith(".sgf");
- }
- });
- }
-
- public static void main(String[] args) {
- new PassLearner().learnANN();
- }
-
- private void learnANN() {
- List parsedRecords = new ArrayList();
-
- for (File sgfFile : getDataFiles("data/games/random_vs_random")) {
- System.out.println("Parsing " + sgfFile.getPath() + "...");
- try {
- GameRecord gameRecord = parseSGF(sgfFile);
- while (!gameRecord.isFinished()) {
- System.out.println("Game is not finished, passing as player to move");
- gameRecord.play(gameRecord.getPlayerToMove(), Action.PASS);
- }
- parsedRecords.add(gameRecord);
- } catch (RecognitionException re) {
- re.printStackTrace();
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
- }
-
- PassData passData = new PassData();
-
- for (GameRecord gameRecord : parsedRecords) {
- System.out.println(gameRecord.getResult().getFullText());
- passData.addData(DATA_TYPE.TRAINING, gameRecord);
- }
-
- System.out.println("PassData: ");
- System.out.println(passData);
-
- learn(passData.getTrainingSet());
-
- getNeuralNetwork().setInput(0.75,0.25);
- System.out.println("Output of ann(0.75,0.25): " + passData.getOutput(getNeuralNetwork().getOutput()));
-
- getNeuralNetwork().setInput(0.25,0.50);
- System.out.println("Output of ann(0.50,0.99): " + passData.getOutput(getNeuralNetwork().getOutput()));
-
- getNeuralNetwork().save("data/networks/Pass2.nn");
-
- testNetwork(getNeuralNetwork(), passData.getTrainingSet());
- }
-
- public GameRecord parseSGF(File sgfFile) throws IOException,
- RecognitionException {
- FileInputStream sgfInputStream;
-
- sgfInputStream = new FileInputStream(sgfFile);
- return Referee.replay(sgfInputStream);
- }
-
- @Override
- public NeuralNetwork getNeuralNetwork() {
- return neuralNetwork;
- }
-
- @Override
- public void learn(TrainingSet trainingSet) {
- this.neuralNetwork.learn(trainingSet);
- }
-
- @Override
- public void reset() {
- this.neuralNetwork = new MultiLayerPerceptron(
- TransferFunctionType.TANH, 2, 3, 2);
- }
-
- @Override
- public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
- this.neuralNetwork = neuralNetwork;
- }
-
-private void testNetwork(NeuralNetwork nnet, TrainingSet trainingSet) {
- for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
-
- nnet.setInput(trainingElement.getInput());
- nnet.calculate();
- double[] networkOutput = nnet.getOutput();
- System.out.print("Input: "
- + Arrays.toString(trainingElement.getInput()));
- System.out.println(" Output: " + Arrays.toString(networkOutput));
-
- }
-}
-}
diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java b/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java
new file mode 100644
index 0000000..fc3018b
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java
@@ -0,0 +1,484 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.encog.EncogError;
+import org.encog.engine.network.activation.ActivationFunction;
+import org.encog.engine.network.activation.ActivationSigmoid;
+import org.encog.mathutil.IntRange;
+import org.encog.ml.MLMethod;
+import org.encog.ml.TrainingImplementationType;
+import org.encog.ml.data.MLDataPair;
+import org.encog.ml.data.MLDataSet;
+import org.encog.ml.train.MLTrain;
+import org.encog.ml.train.strategy.Strategy;
+import org.encog.ml.train.strategy.end.EndTrainingStrategy;
+import org.encog.neural.error.ErrorFunction;
+import org.encog.neural.error.LinearErrorFunction;
+import org.encog.neural.flat.FlatNetwork;
+import org.encog.neural.networks.ContainsFlat;
+import org.encog.neural.networks.training.LearningRate;
+import org.encog.neural.networks.training.Momentum;
+import org.encog.neural.networks.training.Train;
+import org.encog.neural.networks.training.TrainingError;
+import org.encog.neural.networks.training.propagation.TrainingContinuation;
+import org.encog.neural.networks.training.propagation.back.Backpropagation;
+import org.encog.neural.networks.training.strategy.SmartLearningRate;
+import org.encog.neural.networks.training.strategy.SmartMomentum;
+import org.encog.util.EncogValidate;
+import org.encog.util.EngineArray;
+import org.encog.util.concurrency.DetermineWorkload;
+import org.encog.util.concurrency.EngineConcurrency;
+import org.encog.util.concurrency.MultiThreadable;
+import org.encog.util.concurrency.TaskGroup;
+import org.encog.util.logging.EncogLogging;
+
+/**
+ * This class started as a verbatim copy of BackPropagation from the open-source
+ * Encog framework. It was merged with its super-classes to access protected
+ * fields without resorting to reflection.
+ */
+public class TemporalDifferenceLearning implements MLTrain, Momentum,
+ LearningRate, Train, MultiThreadable {
+ // New fields for TD(lambda)
+ private final double lambda;
+ // end new fields
+
+ // BackProp
+ public static final String LAST_DELTA = "LAST_DELTA";
+ private double learningRate;
+ private double momentum;
+ private double[] lastDelta;
+ // End BackProp
+
+ // Propagation
+ private FlatNetwork currentFlatNetwork;
+ private int numThreads;
+ protected double[] gradients;
+ private double[] lastGradient;
+ protected ContainsFlat network;
+ // private MLDataSet indexable;
+ private Set> indexable;
+ private GradientWorker[] workers;
+ private double totalError;
+ protected double lastError;
+ private Throwable reportedException;
+ private double[] flatSpot;
+ private boolean shouldFixFlatSpot;
+ private ErrorFunction ef = new LinearErrorFunction();
+ // End Propagation
+
+ // BasicTraining
+ private final List strategies = new ArrayList();
+ private Set> training;
+ private double error;
+ private int iteration;
+ private TrainingImplementationType implementationType;
+
+ // End BasicTraining
+
+ public TemporalDifferenceLearning(final ContainsFlat network,
+ final Set> training, double lambda) {
+ this(network, training, 0, 0, lambda);
+ addStrategy(new SmartLearningRate());
+ addStrategy(new SmartMomentum());
+ }
+
+ public TemporalDifferenceLearning(final ContainsFlat network,
+ Set> training, final double theLearnRate,
+ final double theMomentum, double lambda) {
+ initPropagation(network, training);
+ // TODO consider how to re-implement validation
+ // ValidateNetwork.validateMethodToData(network, training);
+ this.momentum = theMomentum;
+ this.learningRate = theLearnRate;
+ this.lastDelta = new double[network.getFlat().getWeights().length];
+ this.lambda = lambda;
+ }
+
+ private void initPropagation(final ContainsFlat network,
+ final Set> training) {
+ initBasicTraining(TrainingImplementationType.Iterative);
+ this.network = network;
+ this.currentFlatNetwork = network.getFlat();
+ setTraining(training);
+
+ this.gradients = new double[this.currentFlatNetwork.getWeights().length];
+ this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
+
+ this.indexable = training;
+ this.numThreads = 0;
+ this.reportedException = null;
+ this.shouldFixFlatSpot = true;
+ }
+
+ private void initBasicTraining(TrainingImplementationType implementationType) {
+ this.implementationType = implementationType;
+ }
+
+ // Methods from BackPropagation
+ @Override
+ public boolean canContinue() {
+ return false;
+ }
+
+ public double[] getLastDelta() {
+ return this.lastDelta;
+ }
+
+ @Override
+ public double getLearningRate() {
+ return this.learningRate;
+ }
+
+ @Override
+ public double getMomentum() {
+ return this.momentum;
+ }
+
+ public boolean isValidResume(final TrainingContinuation state) {
+ if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) {
+ return false;
+ }
+
+ if (!state.getTrainingType().equals(getClass().getSimpleName())) {
+ return false;
+ }
+
+ final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA);
+ return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
+ }
+
+ @Override
+ public TrainingContinuation pause() {
+ final TrainingContinuation result = new TrainingContinuation();
+ result.setTrainingType(this.getClass().getSimpleName());
+ result.set(Backpropagation.LAST_DELTA, this.lastDelta);
+ return result;
+ }
+
+ @Override
+ public void resume(final TrainingContinuation state) {
+ if (!isValidResume(state)) {
+ throw new TrainingError("Invalid training resume data length");
+ }
+
+ this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA));
+ }
+
+ @Override
+ public void setLearningRate(final double rate) {
+ this.learningRate = rate;
+ }
+
+ @Override
+ public void setMomentum(final double m) {
+ this.momentum = m;
+ }
+
+ public double updateWeight(final double[] gradients,
+ final double[] lastGradient, final int index) {
+ final double delta = (gradients[index] * this.learningRate)
+ + (this.lastDelta[index] * this.momentum);
+ this.lastDelta[index] = delta;
+
+ System.out.println("Updating weights for connection: " + index
+ + " with lambda: " + lambda);
+
+ return delta;
+ }
+
+ public void initOthers() {
+ }
+
+ // End methods from BackPropagation
+
+ // Methods from Propagation
+ public void finishTraining() {
+ basicFinishTraining();
+ }
+
+ public FlatNetwork getCurrentFlatNetwork() {
+ return this.currentFlatNetwork;
+ }
+
+ public MLMethod getMethod() {
+ return this.network;
+ }
+
+ public void iteration() {
+ iteration(1);
+ }
+
+ public void rollIteration() {
+ this.iteration++;
+ }
+
+ public void iteration(final int count) {
+
+ try {
+ for (int i = 0; i < count; i++) {
+
+ preIteration();
+
+ rollIteration();
+
+ calculateGradients();
+
+ if (this.currentFlatNetwork.isLimited()) {
+ learnLimited();
+ } else {
+ learn();
+ }
+
+ this.lastError = this.getError();
+
+ for (final GradientWorker worker : this.workers) {
+ EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(),
+ 0, worker.getWeights(), 0,
+ this.currentFlatNetwork.getWeights().length);
+ }
+
+ if (this.currentFlatNetwork.getHasContext()) {
+ copyContexts();
+ }
+
+ if (this.reportedException != null) {
+ throw (new EncogError(this.reportedException));
+ }
+
+ postIteration();
+
+ EncogLogging.log(EncogLogging.LEVEL_INFO,
+ "Training iteration done, error: " + getError());
+
+ }
+ } catch (final ArrayIndexOutOfBoundsException ex) {
+ EncogValidate.validateNetworkForTraining(this.network,
+ getTraining());
+ throw new EncogError(ex);
+ }
+ }
+
+ public void setThreadCount(final int numThreads) {
+ this.numThreads = numThreads;
+ }
+
+ @Override
+ public int getThreadCount() {
+ return this.numThreads;
+ }
+
+ public void fixFlatSpot(boolean b) {
+ this.shouldFixFlatSpot = b;
+ }
+
+ public void setErrorFunction(ErrorFunction ef) {
+ this.ef = ef;
+ }
+
+ public void calculateGradients() {
+ if (this.workers == null) {
+ init();
+ }
+
+ if (this.currentFlatNetwork.getHasContext()) {
+ this.workers[0].getNetwork().clearContext();
+ }
+
+ this.totalError = 0;
+
+ if (this.workers.length > 1) {
+
+ final TaskGroup group = EngineConcurrency.getInstance()
+ .createTaskGroup();
+
+ for (final GradientWorker worker : this.workers) {
+ EngineConcurrency.getInstance().processTask(worker, group);
+ }
+
+ group.waitForComplete();
+ } else {
+ this.workers[0].run();
+ }
+
+ this.setError(this.totalError / this.workers.length);
+
+ }
+
+ /**
+ * Copy the contexts to keep them consistent with multithreaded training.
+ */
+ private void copyContexts() {
+
+ // copy the contexts(layer outputO from each group to the next group
+ for (int i = 0; i < (this.workers.length - 1); i++) {
+ final double[] src = this.workers[i].getNetwork().getLayerOutput();
+ final double[] dst = this.workers[i + 1].getNetwork()
+ .getLayerOutput();
+ EngineArray.arrayCopy(src, dst);
+ }
+
+ // copy the contexts from the final group to the real network
+ EngineArray.arrayCopy(this.workers[this.workers.length - 1]
+ .getNetwork().getLayerOutput(), this.currentFlatNetwork
+ .getLayerOutput());
+ }
+
+ private void init() {
+ // fix flat spot, if needed
+ this.flatSpot = new double[this.currentFlatNetwork
+ .getActivationFunctions().length];
+
+ if (this.shouldFixFlatSpot) {
+ for (int i = 0; i < this.currentFlatNetwork
+ .getActivationFunctions().length; i++) {
+ final ActivationFunction af = this.currentFlatNetwork
+ .getActivationFunctions()[i];
+
+ if (af instanceof ActivationSigmoid) {
+ this.flatSpot[i] = 0.1;
+ } else {
+ this.flatSpot[i] = 0.0;
+ }
+ }
+ } else {
+ EngineArray.fill(this.flatSpot, 0.0);
+ }
+
+ // setup workers
+ final DetermineWorkload determine = new DetermineWorkload(
+ this.numThreads, (int) this.indexable.size());
+ // this.numThreads, (int) this.indexable.getRecordCount());
+
+ this.workers = new GradientWorker[determine.getThreadCount()];
+
+ int index = 0;
+
+ // handle CPU
+ for (final IntRange r : determine.calculateWorkers()) {
+ this.workers[index++] = new GradientWorker(
+ this.currentFlatNetwork.clone(), this, new HashSet(
+ this.indexable), r.getLow(), r.getHigh(),
+ this.flatSpot, this.ef);
+ }
+
+ initOthers();
+ }
+
+ public void report(final double[] gradients, final double error,
+ final Throwable ex) {
+ synchronized (this) {
+ if (ex == null) {
+
+ for (int i = 0; i < gradients.length; i++) {
+ this.gradients[i] += gradients[i];
+ }
+ this.totalError += error;
+ } else {
+ this.reportedException = ex;
+ }
+ }
+ }
+
+ protected void learn() {
+ final double[] weights = this.currentFlatNetwork.getWeights();
+ for (int i = 0; i < this.gradients.length; i++) {
+ weights[i] += updateWeight(this.gradients, this.lastGradient, i);
+ this.gradients[i] = 0;
+ }
+ }
+
+ protected void learnLimited() {
+ final double limit = this.currentFlatNetwork.getConnectionLimit();
+ final double[] weights = this.currentFlatNetwork.getWeights();
+ for (int i = 0; i < this.gradients.length; i++) {
+ if (Math.abs(weights[i]) < limit) {
+ weights[i] = 0;
+ } else {
+ weights[i] += updateWeight(this.gradients, this.lastGradient, i);
+ }
+ this.gradients[i] = 0;
+ }
+ }
+
+ public double[] getLastGradient() {
+ return lastGradient;
+ }
+
+ // End methods from Propagation
+
+ // Methods from BasicTraining/
+ public void addStrategy(final Strategy strategy) {
+ strategy.init(this);
+ this.strategies.add(strategy);
+ }
+
+ public void basicFinishTraining() {
+ }
+
+ public double getError() {
+ return this.error;
+ }
+
+ public int getIteration() {
+ return this.iteration;
+ }
+
+ public List getStrategies() {
+ return this.strategies;
+ }
+
+ public MLDataSet getTraining() {
+ throw new UnsupportedOperationException(
+ "This learning method operates on Set>, not MLDataSet");
+ }
+
+ public boolean isTrainingDone() {
+ for (Strategy strategy : this.strategies) {
+ if (strategy instanceof EndTrainingStrategy) {
+ EndTrainingStrategy end = (EndTrainingStrategy) strategy;
+ if (end.shouldStop()) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ public void postIteration() {
+ for (final Strategy strategy : this.strategies) {
+ strategy.postIteration();
+ }
+ }
+
+ public void preIteration() {
+
+ this.iteration++;
+
+ for (final Strategy strategy : this.strategies) {
+ strategy.preIteration();
+ }
+ }
+
+ public void setError(final double error) {
+ this.error = error;
+ }
+
+ public void setIteration(final int iteration) {
+ this.iteration = iteration;
+ }
+
+ public void setTraining(final Set> training) {
+ this.training = training;
+ }
+
+ public TrainingImplementationType getImplementationType() {
+ return this.implementationType;
+ }
+ // End Methods from BasicTraining
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/WinFilter.java b/src/net/woodyfolsom/msproj/ann/WinFilter.java
new file mode 100644
index 0000000..f2671f6
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/WinFilter.java
@@ -0,0 +1,112 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.util.List;
+import java.util.Set;
+
+import net.woodyfolsom.msproj.GameState;
+import net.woodyfolsom.msproj.Player;
+
+import org.encog.engine.network.activation.ActivationSigmoid;
+import org.encog.ml.data.MLData;
+import org.encog.ml.data.MLDataPair;
+import org.encog.ml.data.MLDataSet;
+import org.encog.ml.train.MLTrain;
+import org.encog.neural.networks.BasicNetwork;
+import org.encog.neural.networks.layers.BasicLayer;
+
+public class WinFilter extends AbstractNeuralNetFilter implements
+ NeuralNetFilter {
+
+ public WinFilter() {
+ // create a neural network, without using a factory
+ BasicNetwork network = new BasicNetwork();
+ network.addLayer(new BasicLayer(null, false, 2));
+ network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
+ network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
+ network.getStructure().finalizeStructure();
+ network.reset();
+
+ this.neuralNetwork = network;
+ }
+
+ @Override
+ public double computeValue(MLData input) {
+ if (input instanceof GameStateMLData) {
+ double[] idealVector = computeVector(input);
+ GameState gameState = ((GameStateMLData) input).getGameState();
+ Player playerToMove = gameState.getPlayerToMove();
+ if (playerToMove == Player.BLACK) {
+ return idealVector[0];
+ } else if (playerToMove == Player.WHITE) {
+ return idealVector[1];
+ } else {
+ throw new RuntimeException("Invalid GameState.playerToMove: "
+ + playerToMove);
+ }
+ } else {
+ throw new UnsupportedOperationException(
+ "This NeuralNetFilter only accepts GameStates as input.");
+ }
+ }
+
+ @Override
+ public double[] computeVector(MLData input) {
+ if (input instanceof GameStateMLData) {
+ return neuralNetwork.compute(input).getData();
+ } else {
+ throw new UnsupportedOperationException(
+ "This NeuralNetFilter only accepts GameStates as input.");
+ }
+ }
+
+ @Override
+ public void learn(MLDataSet trainingData) {
+ throw new UnsupportedOperationException("This filter learns a Set>, not an MLDataSet");
+ }
+
+ @Override
+ public void learn(Set> trainingSet) {
+
+ // train the neural network
+ final MLTrain train = new TemporalDifferenceLearning(neuralNetwork,
+ trainingSet, 0.7, 0.8, 0.25);
+
+ actualTrainingEpochs = 0;
+
+ do {
+ train.iteration();
+ System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ + train.getError());
+ actualTrainingEpochs++;
+ } while (train.getError() > 0.01
+ && actualTrainingEpochs <= maxTrainingEpochs);
+ }
+
+ @Override
+ public void reset() {
+ neuralNetwork.reset();
+ }
+
+ @Override
+ public void reset(int seed) {
+ neuralNetwork.reset(seed);
+ }
+
+ @Override
+ public BasicNetwork getNeuralNetwork() {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
+ @Override
+ public int getInputSize() {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public int getOutputSize() {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/XORFilter.java b/src/net/woodyfolsom/msproj/ann/XORFilter.java
new file mode 100644
index 0000000..e6be979
--- /dev/null
+++ b/src/net/woodyfolsom/msproj/ann/XORFilter.java
@@ -0,0 +1,83 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.util.List;
+import java.util.Set;
+
+import org.encog.engine.network.activation.ActivationSigmoid;
+import org.encog.ml.data.MLData;
+import org.encog.ml.data.MLDataPair;
+import org.encog.ml.data.MLDataSet;
+import org.encog.ml.data.basic.BasicMLDataSet;
+import org.encog.ml.train.MLTrain;
+import org.encog.neural.networks.BasicNetwork;
+import org.encog.neural.networks.layers.BasicLayer;
+import org.encog.neural.networks.training.propagation.back.Backpropagation;
+
+/**
+ * Based on sample code from http://neuroph.sourceforge.net
+ *
+ * @author Woody
+ *
+ */
+public class XORFilter extends AbstractNeuralNetFilter implements
+ NeuralNetFilter {
+
+ public XORFilter() {
+ // create a neural network, without using a factory
+ BasicNetwork network = new BasicNetwork();
+ network.addLayer(new BasicLayer(null, false, 2));
+ network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
+ network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
+ network.getStructure().finalizeStructure();
+ network.reset();
+
+ this.neuralNetwork = network;
+ }
+
+ @Override
+ public void learn(MLDataSet trainingSet) {
+
+ // train the neural network
+ final MLTrain train = new Backpropagation(neuralNetwork,
+ trainingSet, 0.7, 0.8);
+
+ actualTrainingEpochs = 0;
+
+ do {
+ train.iteration();
+ System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ + train.getError());
+ actualTrainingEpochs++;
+ } while (train.getError() > 0.01
+ && actualTrainingEpochs <= maxTrainingEpochs);
+ }
+
+ @Override
+ public double[] computeVector(MLData mlData) {
+ MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() },
+ new double[][] { new double[getOutputSize()] });
+ MLData output = neuralNetwork.compute(dataset.get(0).getInput());
+ return output.getData();
+ }
+
+ @Override
+ public int getInputSize() {
+ return 2;
+ }
+
+ @Override
+ public int getOutputSize() {
+ // TODO Auto-generated method stub
+ return 1;
+ }
+
+ @Override
+ public double computeValue(MLData input) {
+ return computeVector(input)[0];
+ }
+
+ @Override
+ public void learn(Set> trainingSet) {
+ throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set>.");
+ }
+}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/ann/XORLearner.java b/src/net/woodyfolsom/msproj/ann/XORLearner.java
deleted file mode 100644
index 64fcde9..0000000
--- a/src/net/woodyfolsom/msproj/ann/XORLearner.java
+++ /dev/null
@@ -1,42 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import org.neuroph.core.NeuralNetwork;
-import org.neuroph.core.learning.SupervisedTrainingElement;
-import org.neuroph.core.learning.TrainingSet;
-import org.neuroph.nnet.MultiLayerPerceptron;
-import org.neuroph.util.TransferFunctionType;
-
-/**
- * Based on sample code from http://neuroph.sourceforge.net
- *
- * @author Woody
- *
- */
-public class XORLearner implements NeuralNetLearner {
- private NeuralNetwork neuralNetwork;
-
- public XORLearner() {
- reset();
- }
-
- @Override
- public NeuralNetwork getNeuralNetwork() {
- return neuralNetwork;
- }
-
- @Override
- public void learn(TrainingSet trainingSet) {
- this.neuralNetwork.learn(trainingSet);
- }
-
- @Override
- public void reset() {
- this.neuralNetwork = new MultiLayerPerceptron(
- TransferFunctionType.TANH, 2, 3, 1);
- }
-
- @Override
- public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
- this.neuralNetwork = neuralNetwork;
- }
-}
\ No newline at end of file
diff --git a/src/net/woodyfolsom/msproj/gui/Goban.java b/src/net/woodyfolsom/msproj/gui/Goban.java
index 54041cb..e206977 100644
--- a/src/net/woodyfolsom/msproj/gui/Goban.java
+++ b/src/net/woodyfolsom/msproj/gui/Goban.java
@@ -5,6 +5,7 @@ import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
+import java.util.concurrent.atomic.AtomicInteger;
import javax.swing.JButton;
import javax.swing.JFrame;
@@ -14,22 +15,31 @@ import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
+import net.woodyfolsom.msproj.StandAloneGame;
import net.woodyfolsom.msproj.sfx.SfxPlayer;
public class Goban extends JFrame {
private static final long serialVersionUID = 1L;
+ private static AtomicInteger openWindows = new AtomicInteger(0);
+
private GridPanel gridPanel;
private SfxPlayer sfxPlayer;
- public Goban(GameConfig gameConfig, Player guiPlayer) {
+ public Goban(GameConfig gameConfig, Player guiPlayer, String gameName) {
+ super(gameName);
+
setLayout(new BorderLayout());
addWindowListener(new WindowAdapter() {
@Override
public void windowClosing(WindowEvent e) {
sfxPlayer.cleanup();
+ int windowsLeftOpen = openWindows.addAndGet(-1);
+ if (windowsLeftOpen < 1) {
+ System.exit(StandAloneGame.EXIT_USER_QUIT);
+ }
}
});
@@ -42,9 +52,6 @@ public class Goban extends JFrame {
this.gridPanel = new GridPanel(gameConfig, guiPlayer, sfxPlayer);
add(gridPanel,BorderLayout.CENTER);
- setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
- setVisible(true);
-
JButton passBtn = new JButton("Pass");
JButton resignBtn = new JButton("Resign");
@@ -67,6 +74,9 @@ public class Goban extends JFrame {
bottomPanel.add(resignBtn);
add(bottomPanel, BorderLayout.SOUTH);
+
+ setVisible(true);
+ openWindows.addAndGet(1);
pack();
}
diff --git a/src/net/woodyfolsom/msproj/gui/GridPanel.java b/src/net/woodyfolsom/msproj/gui/GridPanel.java
index 9ec3d82..08b7570 100644
--- a/src/net/woodyfolsom/msproj/gui/GridPanel.java
+++ b/src/net/woodyfolsom/msproj/gui/GridPanel.java
@@ -229,7 +229,9 @@ public class GridPanel extends JPanel implements MouseListener,
@Override
public void run() {
- sfxPlayer.play();
+ if (sfxPlayer != null) {
+ sfxPlayer.play();
+ }
}}).start();
}
diff --git a/src/net/woodyfolsom/msproj/policy/MonteCarlo.java b/src/net/woodyfolsom/msproj/policy/MonteCarlo.java
index d38602f..fb4d35a 100644
--- a/src/net/woodyfolsom/msproj/policy/MonteCarlo.java
+++ b/src/net/woodyfolsom/msproj/policy/MonteCarlo.java
@@ -13,7 +13,7 @@ import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public abstract class MonteCarlo implements Policy {
- protected static final int ROLLOUT_DEPTH_LIMIT = 150;
+ protected static final int ROLLOUT_DEPTH_LIMIT = 250;
protected int numStateEvaluations = 0;
protected Policy movePolicy;
diff --git a/src/net/woodyfolsom/msproj/policy/RootParallelization.java b/src/net/woodyfolsom/msproj/policy/RootParallelization.java
index b648df1..c9c9e27 100644
--- a/src/net/woodyfolsom/msproj/policy/RootParallelization.java
+++ b/src/net/woodyfolsom/msproj/policy/RootParallelization.java
@@ -101,8 +101,8 @@ public class RootParallelization implements Policy {
System.out.println("It won "
+ bestWins + " out of " + bestSims
+ " rollouts among " + totalRollouts
- + " total rollouts (" + totalReward.keySet()
- + " possible actions) from the current state.");
+ + " total rollouts (" + totalReward.size()
+ + " possible moves evaluated) from the current state.");
return bestAction;
}
diff --git a/test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java b/test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java
deleted file mode 100644
index aa4d260..0000000
--- a/test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java
+++ /dev/null
@@ -1,86 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import java.io.File;
-import java.util.Arrays;
-
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.neuroph.core.NeuralNetwork;
-import org.neuroph.core.learning.SupervisedTrainingElement;
-import org.neuroph.core.learning.TrainingSet;
-
-public class NeuralNetLearnerTest {
- private static final String FILENAME = "myMlPerceptron.nnet";
-
- @AfterClass
- public static void deleteNewNet() {
- File file = new File(FILENAME);
- if (file.exists()) {
- file.delete();
- }
- }
-
- @BeforeClass
- public static void deleteSavedNet() {
- File file = new File(FILENAME);
- if (file.exists()) {
- file.delete();
- }
- }
-
- @Test
- public void testLearnSaveLoad() {
- NeuralNetLearner nnLearner = new XORLearner();
-
- // create training set (logical XOR function)
- TrainingSet trainingSet = new TrainingSet(
- 2, 1);
- for (int x = 0; x < 1000; x++) {
- trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
- 0 }, new double[] { 0 }));
- trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
- 1 }, new double[] { 1 }));
- trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
- 0 }, new double[] { 1 }));
- trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
- 1 }, new double[] { 0 }));
- }
-
- nnLearner.learn(trainingSet);
- NeuralNetwork nnet = nnLearner.getNeuralNetwork();
-
- TrainingSet valSet = new TrainingSet(
- 2, 1);
- valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
- 0 }, new double[] { 0 }));
- valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
- 1 }, new double[] { 1 }));
- valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
- 0 }, new double[] { 1 }));
- valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
- 1 }, new double[] { 0 }));
-
- System.out.println("Output from eval set (learned network):");
- testNetwork(nnet, valSet);
-
- nnet.save(FILENAME);
- nnet = NeuralNetwork.load(FILENAME);
-
- System.out.println("Output from eval set (learned network):");
- testNetwork(nnet, valSet);
- }
-
- private void testNetwork(NeuralNetwork nnet, TrainingSet trainingSet) {
- for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
-
- nnet.setInput(trainingElement.getInput());
- nnet.calculate();
- double[] networkOutput = nnet.getOutput();
- System.out.print("Input: "
- + Arrays.toString(trainingElement.getInput()));
- System.out.println(" Output: " + Arrays.toString(networkOutput));
-
- }
- }
-}
\ No newline at end of file
diff --git a/test/net/woodyfolsom/msproj/ann/PassNetworkTest.java b/test/net/woodyfolsom/msproj/ann/PassNetworkTest.java
deleted file mode 100644
index 134a8ac..0000000
--- a/test/net/woodyfolsom/msproj/ann/PassNetworkTest.java
+++ /dev/null
@@ -1,51 +0,0 @@
-package net.woodyfolsom.msproj.ann;
-
-import static org.junit.Assert.assertTrue;
-
-import org.junit.Test;
-import org.neuroph.core.NeuralNetwork;
-
-public class PassNetworkTest {
-
- @Test
- public void testSavedNetwork1() {
- NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn");
- passFilter.setInput(0.75,0.25);
- passFilter.calculate();
-
- PassData passData = new PassData();
- double[] output = passFilter.getOutput();
- System.out.println("Output: " + passData.getOutput(output));
-
- assertTrue(output[0] > 0.50);
- assertTrue(output[1] < 0.50);
-
- passFilter.setInput(0.25,0.50);
- passFilter.calculate();
- output = passFilter.getOutput();
- System.out.println("Output: " + passData.getOutput(output));
- assertTrue(output[0] < 0.50);
- assertTrue(output[1] > 0.50);
- }
-
- @Test
- public void testSavedNetwork2() {
- NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass2.nn");
- passFilter.setInput(0.75,0.25);
- passFilter.calculate();
-
- PassData passData = new PassData();
- double[] output = passFilter.getOutput();
- System.out.println("Output: " + passData.getOutput(output));
-
- assertTrue(output[0] > 0.50);
- assertTrue(output[1] < 0.50);
-
- passFilter.setInput(0.45,0.55);
- passFilter.calculate();
- output = passFilter.getOutput();
- System.out.println("Output: " + passData.getOutput(output));
- assertTrue(output[0] < 0.50);
- assertTrue(output[1] > 0.50);
- }
-}
diff --git a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java
new file mode 100644
index 0000000..52eaf09
--- /dev/null
+++ b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java
@@ -0,0 +1,66 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.io.File;
+import java.io.FileFilter;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import net.woodyfolsom.msproj.GameRecord;
+import net.woodyfolsom.msproj.Referee;
+
+import org.antlr.runtime.RecognitionException;
+import org.encog.ml.data.MLData;
+import org.encog.ml.data.MLDataPair;
+import org.junit.Test;
+
+public class WinFilterTest {
+
+ @Test
+ public void testLearnSaveLoad() throws IOException, RecognitionException {
+ File[] sgfFiles = new File("data/games/random_vs_random")
+ .listFiles(new FileFilter() {
+ @Override
+ public boolean accept(File pathname) {
+ return pathname.getName().endsWith(".sgf");
+ }
+ });
+
+ Set> trainingData = new HashSet>();
+
+ for (File file : sgfFiles) {
+ FileInputStream fis = new FileInputStream(file);
+ GameRecord gameRecord = Referee.replay(fis);
+
+ List gameData = new ArrayList();
+ for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
+ gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
+ }
+
+ trainingData.add(gameData);
+
+ fis.close();
+ }
+
+ WinFilter winFilter = new WinFilter();
+
+ winFilter.learn(trainingData);
+
+ for (List trainingSequence : trainingData) {
+ //for (MLDataPair mlDataPair : trainingSequence) {
+ for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
+ if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
+ continue;
+ }
+ MLData input = trainingSequence.get(stateIndex).getInput();
+
+ System.out.println("Turn " + stateIndex + ": " + input + " => "
+ + winFilter.computeValue(input));
+ }
+ //}
+ }
+ }
+}
diff --git a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java
new file mode 100644
index 0000000..e0c2977
--- /dev/null
+++ b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java
@@ -0,0 +1,79 @@
+package net.woodyfolsom.msproj.ann;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.encog.ml.data.MLDataSet;
+import org.encog.ml.data.basic.BasicMLDataSet;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class XORFilterTest {
+ private static final String FILENAME = "xorPerceptron.net";
+
+ @AfterClass
+ public static void deleteNewNet() {
+ File file = new File(FILENAME);
+ if (file.exists()) {
+ file.delete();
+ }
+ }
+
+ @BeforeClass
+ public static void deleteSavedNet() {
+ File file = new File(FILENAME);
+ if (file.exists()) {
+ file.delete();
+ }
+ }
+
+ @Test
+ public void testLearnSaveLoad() throws IOException {
+ NeuralNetFilter nnLearner = new XORFilter();
+ System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
+
+ // create training set (logical XOR function)
+ int size = 1;
+ double[][] trainingInput = new double[4 * size][];
+ double[][] trainingOutput = new double[4 * size][];
+ for (int i = 0; i < size; i++) {
+ trainingInput[i * 4 + 0] = new double[] { 0, 0 };
+ trainingInput[i * 4 + 1] = new double[] { 0, 1 };
+ trainingInput[i * 4 + 2] = new double[] { 1, 0 };
+ trainingInput[i * 4 + 3] = new double[] { 1, 1 };
+ trainingOutput[i * 4 + 0] = new double[] { 0 };
+ trainingOutput[i * 4 + 1] = new double[] { 1 };
+ trainingOutput[i * 4 + 2] = new double[] { 1 };
+ trainingOutput[i * 4 + 3] = new double[] { 0 };
+ }
+
+ // create training data
+ MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
+
+ nnLearner.learn(trainingSet);
+
+ double[][] validationSet = new double[4][2];
+
+ validationSet[0] = new double[] { 0, 0 };
+ validationSet[1] = new double[] { 0, 1 };
+ validationSet[2] = new double[] { 1, 0 };
+ validationSet[3] = new double[] { 1, 1 };
+
+ System.out.println("Output from eval set (learned network, pre-serialization):");
+ testNetwork(nnLearner, validationSet);
+
+ nnLearner.save(FILENAME);
+ nnLearner.load(FILENAME);
+
+ System.out.println("Output from eval set (learned network, post-serialization):");
+ testNetwork(nnLearner, validationSet);
+ }
+
+ private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
+ for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
+ DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
+ System.out.println(dp + " => " + nnLearner.computeValue(dp));
+ }
+ }
+}
\ No newline at end of file