diff --git a/.classpath b/.classpath index d25bc2e..6f91090 100644 --- a/.classpath +++ b/.classpath @@ -7,7 +7,6 @@ - - + diff --git a/build.xml b/build.xml index b641d96..19cde56 100644 --- a/build.xml +++ b/build.xml @@ -23,7 +23,7 @@ - + @@ -33,9 +33,25 @@ - + + + + + + + + + + + + + + + + + @@ -44,12 +60,12 @@ - + - - + + diff --git a/data/gogame.cfg b/data/gogame.cfg index 3a82545..bfc04b1 100644 --- a/data/gogame.cfg +++ b/data/gogame.cfg @@ -1,7 +1,10 @@ -PlayerOne=ROOT_PAR +PlayerOne=RANDOM PlayerTwo=RANDOM -GUIDelay=2000 //1 second +GUIDelay=1000 //1 second BoardSize=9 Komi=6.5 -NumGames=10 //Games for each player -TurnTime=2000 //seconds per player per turn \ No newline at end of file +NumGames=1000 //Games for each color per player +TurnTime=1000 //seconds per player per turn +SpectatorBoardShown=false; +WhiteMoveLogged=false; +BlackMoveLogged=false; \ No newline at end of file diff --git a/data/networks/Pass2.nn b/data/networks/Pass2.nn index e5ae0da..e3d9157 100644 Binary files a/data/networks/Pass2.nn and b/data/networks/Pass2.nn differ diff --git a/lib/encog-engine-2.5.0.jar b/lib/encog-engine-2.5.0.jar deleted file mode 100644 index 12d9d63..0000000 Binary files a/lib/encog-engine-2.5.0.jar and /dev/null differ diff --git a/lib/encog-java-core-javadoc.jar b/lib/encog-java-core-javadoc.jar new file mode 100644 index 0000000..2cf8d93 Binary files /dev/null and b/lib/encog-java-core-javadoc.jar differ diff --git a/lib/encog-java-core-sources.jar b/lib/encog-java-core-sources.jar new file mode 100644 index 0000000..0cb6b50 Binary files /dev/null and b/lib/encog-java-core-sources.jar differ diff --git a/lib/encog-java-core.jar b/lib/encog-java-core.jar new file mode 100644 index 0000000..f846d91 Binary files /dev/null and b/lib/encog-java-core.jar differ diff --git a/lib/neuroph-2.6.jar b/lib/neuroph-2.6.jar deleted file mode 100644 index 2009162..0000000 Binary files a/lib/neuroph-2.6.jar and /dev/null differ diff --git a/src/net/woodyfolsom/msproj/GameRecord.java b/src/net/woodyfolsom/msproj/GameRecord.java index cc5a234..8202d5e 100644 --- a/src/net/woodyfolsom/msproj/GameRecord.java +++ b/src/net/woodyfolsom/msproj/GameRecord.java @@ -23,6 +23,16 @@ public class GameRecord { moves.add(Action.NONE); } + public GameRecord(GameRecord that) { + for(GameState gameState : that.gameStates) { + gameStates.add(new GameState(gameState)); + } + //initial 'move' of Action.NONE allows for a game that starts with a board setup + for (Action action : that.moves) { + moves.add(action); + } + } + /** * Adds a comment for the current turn. * @param comment diff --git a/src/net/woodyfolsom/msproj/GameSettings.java b/src/net/woodyfolsom/msproj/GameSettings.java index c9b98fa..f595b33 100644 --- a/src/net/woodyfolsom/msproj/GameSettings.java +++ b/src/net/woodyfolsom/msproj/GameSettings.java @@ -13,6 +13,9 @@ public class GameSettings { private int boardSize = 9; private double komi = 6.5; private int numGames = 10; + private boolean spectatorBoardShown = false; + private boolean whiteMoveLogged = true; + private boolean blackMoveLogged = true; private GameSettings() { } @@ -49,6 +52,12 @@ public class GameSettings { gameSettings.setNumGames(Integer.parseInt(value)); } else if ("Komi".equals(name)) { gameSettings.setKomi(Double.parseDouble(value)); + } else if ("SpectatorBoardShown".equals(name)) { + gameSettings.setSpectatorBoardShown(Boolean.parseBoolean(value)); + } else if ("WhiteMoveLogged".equals(name)) { + gameSettings.setWhiteMoveLogged(Boolean.parseBoolean(value)); + } else if ("BlackMoveLogged".equals(name)) { + gameSettings.setBlackMoveLogged(Boolean.parseBoolean(value)); } else { System.out.println("Ignoring game settings property with unrecognized name: " + name); } @@ -127,4 +136,29 @@ public class GameSettings { sb.append(", GUIDelay=" + guiDelay); return sb.toString(); } + + public boolean isSpectatorBoardShown() { + return spectatorBoardShown; + } + + private void setSpectatorBoardShown(boolean spectatorBoardShown) { + this.spectatorBoardShown = spectatorBoardShown; + } + + public boolean isWhiteMoveLogged() { + return whiteMoveLogged; + } + + private void setWhiteMoveLogged(boolean whiteMoveLogged) { + this.whiteMoveLogged = whiteMoveLogged; + } + + public boolean isBlackMoveLogged() { + return blackMoveLogged; + } + + private void setBlackMoveLogged(boolean blackMoveLogged) { + this.blackMoveLogged = blackMoveLogged; + } + } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/Referee.java b/src/net/woodyfolsom/msproj/Referee.java index c2be1e2..6299d53 100644 --- a/src/net/woodyfolsom/msproj/Referee.java +++ b/src/net/woodyfolsom/msproj/Referee.java @@ -4,8 +4,6 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; -import java.text.DateFormat; -import java.text.SimpleDateFormat; import net.woodyfolsom.msproj.gui.Goban; import net.woodyfolsom.msproj.policy.HumanGuiInput; @@ -65,10 +63,11 @@ public class Referee { return gameRecord; } - public GameResult play(GameConfig gameConfig, int gameNo) { + public GameResult play(GameConfig gameConfig, int gameNo, + boolean showSpectatorBoard, boolean logGameRecord) { GameRecord gameRecord = new GameRecord(gameConfig); - System.out.println("Game started."); + //System.out.println("Game started."); GameState initialGameState = gameRecord.getGameState(gameRecord .getNumTurns()); @@ -78,20 +77,21 @@ public class Referee { whitePolicy.setState(initialGameState); Goban spectatorBoard; - if (blackPolicy instanceof HumanGuiInput - || whitePolicy instanceof HumanGuiInput) { + if (blackPolicy instanceof HumanGuiInput || whitePolicy instanceof HumanGuiInput) { System.out.println("Human is controlling the game board GUI."); spectatorBoard = null; - } else { + } else if (showSpectatorBoard){ System.out.println("Starting game board GUI in spectator mode."); - spectatorBoard = new Goban(gameConfig, null); + spectatorBoard = new Goban(gameConfig, null, "Game #" + gameNo); + } else { // else showing spectator board is disabled + spectatorBoard = null; } - + try { while (!gameRecord.isFinished()) { GameState gameState = gameRecord.getGameState(gameRecord .getNumTurns()); - //System.out.println(gameState); + // System.out.println(gameState); Player playerToMove = gameRecord.getPlayerToMove(); Policy policy = getPolicy(playerToMove); @@ -120,8 +120,9 @@ public class Referee { System.out.println("Game over. Result: " + result); - //DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ"); + // DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ"); + if (logGameRecord) { try { // File sgfFile = new File("gogame-" + dateFormat.format(new Date()) @@ -143,8 +144,9 @@ public class Referee { System.out.println("Unable to save game file due to IOException: " + ioe.getMessage()); } - - System.out.println("Game finished."); + } + + //System.out.println("Game finished."); return result; } diff --git a/src/net/woodyfolsom/msproj/StandAloneGame.java b/src/net/woodyfolsom/msproj/StandAloneGame.java index 1455db0..fe4ca42 100644 --- a/src/net/woodyfolsom/msproj/StandAloneGame.java +++ b/src/net/woodyfolsom/msproj/StandAloneGame.java @@ -19,8 +19,9 @@ import net.woodyfolsom.msproj.policy.RandomMovePolicy; import net.woodyfolsom.msproj.policy.RootParallelization; public class StandAloneGame { - private static final int EXIT_NOMINAL = 0; - private static final int EXIT_IO_EXCEPTION = 1; + public static final int EXIT_USER_QUIT = 1; + public static final int EXIT_NOMINAL = 0; + public static final int EXIT_IO_EXCEPTION = -1; private int gameNo = 0; @@ -38,7 +39,9 @@ public class StandAloneGame { parsePlayerType(gameSettings.getPlayerOne()), parsePlayerType(gameSettings.getPlayerTwo()), gameSettings.getBoardSize(), gameSettings.getKomi(), - gameSettings.getNumGames(), gameSettings.getTurnTime()); + gameSettings.getNumGames(), gameSettings.getTurnTime(), + gameSettings.isSpectatorBoardShown(), + gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged()); } catch (IOException ioe) { ioe.printStackTrace(); System.exit(EXIT_IO_EXCEPTION); @@ -65,7 +68,8 @@ public class StandAloneGame { } public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2, - int size, double komi, int rounds, long turnLength) { + int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard, + boolean blackMoveLogged, boolean whiteMoveLogged) { long startTime = System.currentTimeMillis(); @@ -74,32 +78,31 @@ public class StandAloneGame { Referee referee = new Referee(); referee.setPolicy(Player.BLACK, - getPolicy(playerType1, gameConfig, Player.BLACK, turnLength)); + getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); referee.setPolicy(Player.WHITE, - getPolicy(playerType2, gameConfig, Player.WHITE, turnLength)); + getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); List round1results = new ArrayList(); + boolean logGameRecords = rounds <= 50; for (int round = 0; round < rounds; round++) { gameNo++; - round1results.add(referee.play(gameConfig, gameNo)); + round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); } List round2results = new ArrayList(); referee.setPolicy(Player.BLACK, - getPolicy(playerType2, gameConfig, Player.BLACK, turnLength)); + getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); referee.setPolicy(Player.WHITE, - getPolicy(playerType1, gameConfig, Player.WHITE, turnLength)); + getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); for (int round = 0; round < rounds; round++) { gameNo++; - round2results.add(referee.play(gameConfig, gameNo)); + round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); } long endTime = System.currentTimeMillis(); - DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmss"); - try { File txtFile = new File("gotournament-" @@ -107,14 +110,16 @@ public class StandAloneGame { FileWriter writer = new FileWriter(txtFile); try { + if (!logGameRecords) { + System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output."); + } + logResults(writer, round1results, playerType1.toString(), - playerType2.toString()); + playerType2.toString()); logResults(writer, round2results, playerType2.toString(), - playerType1.toString()); - + playerType1.toString()); writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0 + " seconds."); - System.out.println("Game tournament saved as " + txtFile.getAbsolutePath()); } finally { @@ -149,19 +154,19 @@ public class StandAloneGame { } private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, - Player player, long turnLength) { + Player player, long turnLength, boolean moveLogged) { switch (playerType) { case HUMAN: return new HumanKeyboardInput(); case HUMAN_GUI: - return new HumanGuiInput(new Goban(gameConfig, player)); + return new HumanGuiInput(new Goban(gameConfig, player,"")); case ROOT_PAR: return new RootParallelization(4, turnLength); case UCT: return new MonteCarloUCT(new RandomMovePolicy(), turnLength); case RANDOM: RandomMovePolicy randomMovePolicy = new RandomMovePolicy(); - randomMovePolicy.setLogging(true); + randomMovePolicy.setLogging(moveLogged); return randomMovePolicy; case RAVE: return new MonteCarloAMAF(new RandomMovePolicy(), turnLength); diff --git a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java new file mode 100644 index 0000000..698197f --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java @@ -0,0 +1,54 @@ +package net.woodyfolsom.msproj.ann; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import org.encog.neural.networks.BasicNetwork; +import org.encog.neural.networks.PersistBasicNetwork; + +public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { + protected BasicNetwork neuralNetwork; + protected int actualTrainingEpochs = 0; + protected int maxTrainingEpochs = 1000; + + public int getActualTrainingEpochs() { + return actualTrainingEpochs; + } + + public int getMaxTrainingEpochs() { + return maxTrainingEpochs; + } + + @Override + public BasicNetwork getNeuralNetwork() { + return neuralNetwork; + } + + public void load(String filename) throws IOException { + FileInputStream fis = new FileInputStream(new File(filename)); + neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis); + fis.close(); + } + + @Override + public void reset() { + neuralNetwork.reset(); + } + + @Override + public void reset(int seed) { + neuralNetwork.reset(seed); + } + + public void save(String filename) throws IOException { + FileOutputStream fos = new FileOutputStream(new File(filename)); + new PersistBasicNetwork().save(fos, getNeuralNetwork()); + fos.close(); + } + + public void setMaxTrainingEpochs(int max) { + this.maxTrainingEpochs = max; + } +} diff --git a/src/net/woodyfolsom/msproj/ann/DoublePair.java b/src/net/woodyfolsom/msproj/ann/DoublePair.java new file mode 100644 index 0000000..49f7eeb --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/DoublePair.java @@ -0,0 +1,17 @@ +package net.woodyfolsom.msproj.ann; + +import org.encog.ml.data.basic.BasicMLData; + +public class DoublePair extends BasicMLData { + // private final double x; + // private final double y; + + /** + * + */ + private static final long serialVersionUID = 1L; + + public DoublePair(double x, double y) { + super(new double[] { x, y }); + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java b/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java new file mode 100644 index 0000000..7823a8f --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java @@ -0,0 +1,95 @@ +package net.woodyfolsom.msproj.ann; + +import org.encog.mathutil.error.ErrorCalculationMode; + +/* + Initial erison of this class was a verbatim copy from Encog framework. + */ + +public class ErrorCalculation { + + private static ErrorCalculationMode mode = ErrorCalculationMode.MSE; + + public static ErrorCalculationMode getMode() { + return ErrorCalculation.mode; + } + + public static void setMode(final ErrorCalculationMode theMode) { + ErrorCalculation.mode = theMode; + } + + private double globalError; + + private int setSize; + + public final double calculate() { + if (this.setSize == 0) { + return 0; + } + + switch (ErrorCalculation.getMode()) { + case RMS: + return calculateRMS(); + case MSE: + return calculateMSE(); + case ESS: + return calculateESS(); + default: + return calculateMSE(); + } + + } + + public final double calculateMSE() { + if (this.setSize == 0) { + return 0; + } + final double err = this.globalError / this.setSize; + return err; + + } + + public final double calculateESS() { + if (this.setSize == 0) { + return 0; + } + final double err = this.globalError / 2; + return err; + + } + + public final double calculateRMS() { + if (this.setSize == 0) { + return 0; + } + final double err = Math.sqrt(this.globalError / this.setSize); + return err; + } + + public final void reset() { + this.globalError = 0; + this.setSize = 0; + } + + public final void updateError(final double actual, final double ideal) { + + double delta = ideal - actual; + + this.globalError += delta * delta; + + this.setSize++; + + } + + public final void updateError(final double[] actual, final double[] ideal, + final double significance) { + for (int i = 0; i < actual.length; i++) { + double delta = (ideal[i] - actual[i]) * significance; + + this.globalError += delta * delta; + } + + this.setSize += ideal.length; + } + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLData.java b/src/net/woodyfolsom/msproj/ann/GameStateMLData.java new file mode 100644 index 0000000..47d588c --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/GameStateMLData.java @@ -0,0 +1,25 @@ +package net.woodyfolsom.msproj.ann; + +import net.woodyfolsom.msproj.GameState; + +import org.encog.ml.data.basic.BasicMLData; + +public class GameStateMLData extends BasicMLData { + + /** + * + */ + private static final long serialVersionUID = 1L; + + private GameState gameState; + + public GameStateMLData(double[] d, GameState gameState) { + super(d); + // TODO Auto-generated constructor stub + this.gameState = gameState; + } + + public GameState getGameState() { + return gameState; + } +} diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java new file mode 100644 index 0000000..8f41126 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java @@ -0,0 +1,121 @@ +package net.woodyfolsom.msproj.ann; + +import net.woodyfolsom.msproj.GameResult; +import net.woodyfolsom.msproj.GameState; +import net.woodyfolsom.msproj.Player; + +import org.encog.ml.data.MLData; +import org.encog.ml.data.MLDataPair; +import org.encog.ml.data.basic.BasicMLData; +import org.encog.ml.data.basic.BasicMLDataPair; +import org.encog.util.kmeans.Centroid; + +public class GameStateMLDataPair implements MLDataPair { + //private final String[] inputs = { "BlackScore", "WhiteScore" }; + //private final String[] outputs = { "BlackWins", "WhiteWins" }; + + private BasicMLDataPair mlDataPairDelegate; + private GameState gameState; + + public GameStateMLDataPair(GameState gameState) { + this.gameState = gameState; + mlDataPairDelegate = new BasicMLDataPair( + new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal())); + } + + public GameStateMLDataPair(GameStateMLDataPair that) { + this.gameState = new GameState(that.gameState); + mlDataPairDelegate = new BasicMLDataPair( + that.mlDataPairDelegate.getInput(), + that.mlDataPairDelegate.getIdeal()); + } + + @Override + public MLDataPair clone() { + return new GameStateMLDataPair(this); + } + + @Override + public Centroid createCentroid() { + return mlDataPairDelegate.createCentroid(); + } + + /** + * Creates a vector of normalized scores from GameState. + * + * @return + */ + private double[] createInput() { + + GameResult result = gameState.getResult(); + + double maxScore = gameState.getGameConfig().getSize() + * gameState.getGameConfig().getSize(); + + double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore); + double blackScore = Math.min(1.0, result.getBlackScore() / maxScore); + + return new double[] { blackScore, whiteScore }; + } + + /** + * Creates a vector of values indicating strength of black/white win output + * from network. + * + * @return + */ + private double[] createIdeal() { + GameResult result = gameState.getResult(); + + double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0; + double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0; + + return new double[] { blackWinner, whiteWinner }; + } + + @Override + public MLData getIdeal() { + return mlDataPairDelegate.getIdeal(); + } + + @Override + public double[] getIdealArray() { + return mlDataPairDelegate.getIdealArray(); + } + + @Override + public MLData getInput() { + return mlDataPairDelegate.getInput(); + } + + @Override + public double[] getInputArray() { + return mlDataPairDelegate.getInputArray(); + } + + @Override + public double getSignificance() { + return mlDataPairDelegate.getSignificance(); + } + + @Override + public boolean isSupervised() { + return mlDataPairDelegate.isSupervised(); + } + + @Override + public void setIdealArray(double[] arg0) { + mlDataPairDelegate.setIdealArray(arg0); + } + + @Override + public void setInputArray(double[] arg0) { + mlDataPairDelegate.setInputArray(arg0); + } + + @Override + public void setSignificance(double arg0) { + mlDataPairDelegate.setSignificance(arg0); + } + +} diff --git a/src/net/woodyfolsom/msproj/ann/GradientWorker.java b/src/net/woodyfolsom/msproj/ann/GradientWorker.java new file mode 100644 index 0000000..0678e4e --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/GradientWorker.java @@ -0,0 +1,172 @@ +package net.woodyfolsom.msproj.ann; +/* + * Class copied verbatim from Encog framework due to dependency on Propagation + * implementation. + * + * Encog(tm) Core v3.2 - Java Version + * http://www.heatonresearch.com/encog/ + * http://code.google.com/p/encog-java/ + + * Copyright 2008-2012 Heaton Research, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * For more information on Heaton Research copyrights, licenses + * and trademarks visit: + * http://www.heatonresearch.com/copyright + */ + +import java.util.List; +import java.util.Set; + +import org.encog.engine.network.activation.ActivationFunction; +import org.encog.ml.data.MLDataPair; +import org.encog.ml.data.MLDataSet; +import org.encog.ml.data.basic.BasicMLDataPair; +import org.encog.neural.error.ErrorFunction; +import org.encog.neural.flat.FlatNetwork; +import org.encog.util.EngineArray; +import org.encog.util.concurrency.EngineTask; + +public class GradientWorker implements EngineTask { + + private final FlatNetwork network; + private final ErrorCalculation errorCalculation = new ErrorCalculation(); + private final double[] actual; + private final double[] layerDelta; + private final int[] layerCounts; + private final int[] layerFeedCounts; + private final int[] layerIndex; + private final int[] weightIndex; + private final double[] layerOutput; + private final double[] layerSums; + private final double[] gradients; + private final double[] weights; + private final MLDataPair pair; + private final Set> training; + private final int low; + private final int high; + private final TemporalDifferenceLearning owner; + private double[] flatSpot; + private final ErrorFunction errorFunction; + + public GradientWorker(final FlatNetwork theNetwork, + final TemporalDifferenceLearning theOwner, + final Set> theTraining, final int theLow, + final int theHigh, final double[] flatSpot, + ErrorFunction ef) { + this.network = theNetwork; + this.training = theTraining; + this.low = theLow; + this.high = theHigh; + this.owner = theOwner; + this.flatSpot = flatSpot; + this.errorFunction = ef; + + this.layerDelta = new double[network.getLayerOutput().length]; + this.gradients = new double[network.getWeights().length]; + this.actual = new double[network.getOutputCount()]; + + this.weights = network.getWeights(); + this.layerIndex = network.getLayerIndex(); + this.layerCounts = network.getLayerCounts(); + this.weightIndex = network.getWeightIndex(); + this.layerOutput = network.getLayerOutput(); + this.layerSums = network.getLayerSums(); + this.layerFeedCounts = network.getLayerFeedCounts(); + + this.pair = BasicMLDataPair.createPair(network.getInputCount(), network + .getOutputCount()); + } + + public FlatNetwork getNetwork() { + return this.network; + } + + public double[] getWeights() { + return this.weights; + } + + private void process(final double[] input, final double[] ideal, double s) { + this.network.compute(input, this.actual); + + this.errorCalculation.updateError(this.actual, ideal, s); + this.errorFunction.calculateError(ideal, actual, this.layerDelta); + + for (int i = 0; i < this.actual.length; i++) { + + this.layerDelta[i] = ((this.network.getActivationFunctions()[0] + .derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0])) + * (this.layerDelta[i] * s); + } + + for (int i = this.network.getBeginTraining(); i < this.network + .getEndTraining(); i++) { + processLevel(i); + } + } + + private void processLevel(final int currentLevel) { + final int fromLayerIndex = this.layerIndex[currentLevel + 1]; + final int toLayerIndex = this.layerIndex[currentLevel]; + final int fromLayerSize = this.layerCounts[currentLevel + 1]; + final int toLayerSize = this.layerFeedCounts[currentLevel]; + + final int index = this.weightIndex[currentLevel]; + final ActivationFunction activation = this.network + .getActivationFunctions()[currentLevel]; + final double currentFlatSpot = this.flatSpot[currentLevel + 1]; + + // handle weights + int yi = fromLayerIndex; + for (int y = 0; y < fromLayerSize; y++) { + final double output = this.layerOutput[yi]; + double sum = 0; + int xi = toLayerIndex; + int wi = index + y; + for (int x = 0; x < toLayerSize; x++) { + this.gradients[wi] += output * this.layerDelta[xi]; + sum += this.weights[wi] * this.layerDelta[xi]; + wi += fromLayerSize; + xi++; + } + + this.layerDelta[yi] = sum + * (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot); + yi++; + } + } + + public final void run() { + try { + this.errorCalculation.reset(); + //for (int i = this.low; i <= this.high; i++) { + for (List trainingSequence : training) { + MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1); + this.pair.setInputArray(mldp.getInputArray()); + if (this.pair.getIdealArray() != null) { + this.pair.setIdealArray(mldp.getIdealArray()); + } + //this.training.getRecord(i, this.pair); + process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance()); + } + //} + final double error = this.errorCalculation.calculate(); + this.owner.report(this.gradients, error, null); + EngineArray.fill(this.gradients, 0); + } catch (final Throwable ex) { + this.owner.report(null, 0, ex); + } + } + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java new file mode 100644 index 0000000..64d40c7 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java @@ -0,0 +1,31 @@ +package net.woodyfolsom.msproj.ann; + +import java.io.IOException; +import java.util.List; +import java.util.Set; + +import org.encog.ml.data.MLData; +import org.encog.ml.data.MLDataPair; +import org.encog.ml.data.MLDataSet; +import org.encog.neural.networks.BasicNetwork; + +public interface NeuralNetFilter { + BasicNetwork getNeuralNetwork(); + + public int getActualTrainingEpochs(); + public int getInputSize(); + public int getMaxTrainingEpochs(); + public int getOutputSize(); + + public double computeValue(MLData input); + public double[] computeVector(MLData input); + + public void learn(MLDataSet trainingSet); + public void learn(Set> trainingSet); + + public void load(String fileName) throws IOException; + public void reset(); + public void reset(int seed); + public void save(String fileName) throws IOException; + public void setMaxTrainingEpochs(int max); +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java b/src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java deleted file mode 100644 index 42e9e59..0000000 --- a/src/net/woodyfolsom/msproj/ann/NeuralNetLearner.java +++ /dev/null @@ -1,15 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import org.neuroph.core.NeuralNetwork; -import org.neuroph.core.learning.SupervisedTrainingElement; -import org.neuroph.core.learning.TrainingSet; - -public interface NeuralNetLearner { - void learn(TrainingSet trainingSet); - - void reset(); - - NeuralNetwork getNeuralNetwork(); - - void setNeuralNetwork(NeuralNetwork neuralNetwork); -} diff --git a/src/net/woodyfolsom/msproj/ann/PassData.java b/src/net/woodyfolsom/msproj/ann/PassData.java deleted file mode 100644 index be601b3..0000000 --- a/src/net/woodyfolsom/msproj/ann/PassData.java +++ /dev/null @@ -1,98 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import java.util.Arrays; - -import net.woodyfolsom.msproj.GameRecord; -import net.woodyfolsom.msproj.GameResult; -import net.woodyfolsom.msproj.GameState; -import net.woodyfolsom.msproj.Player; - -import org.neuroph.core.learning.SupervisedTrainingElement; -import org.neuroph.core.learning.TrainingSet; - -public class PassData { - public enum DATA_TYPE { TRAINING, TEST, VALIDATION }; - - public String[] inputs = { "BlackScore", "WhiteScore" }; - public String[] outputs = { "BlackWins", "WhiteWins" }; - - private TrainingSet testSet; - private TrainingSet trainingSet; - private TrainingSet valSet; - - public PassData() { - testSet = new TrainingSet(inputs.length, outputs.length); - trainingSet = new TrainingSet(inputs.length, outputs.length); - valSet = new TrainingSet(inputs.length, outputs.length); - } - - public void addData(DATA_TYPE dataType, GameRecord gameRecord) { - GameState finalState = gameRecord.getGameState(gameRecord.getNumTurns()); - GameResult result = finalState.getResult(); - double maxScore = finalState.getGameConfig().getSize() * finalState.getGameConfig().getSize(); - - double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore); - double blackScore = Math.min(1.0, result.getBlackScore() / maxScore); - - double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0; - double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0; - - addData(dataType, blackScore, whiteScore, blackWinner, whiteWinner); - } - - public void addData(DATA_TYPE dataType, double...data ) { - double[] desiredInput = Arrays.copyOfRange(data,0,inputs.length); - double[] desiredOutput = Arrays.copyOfRange(data, inputs.length, data.length); - - switch (dataType) { - case TEST : - testSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput)); - break; - case TRAINING : - trainingSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput)); - System.out.println("Added training input data: " + getInput(desiredInput) + ", output data: " + getOutput(desiredOutput)); - break; - case VALIDATION : - valSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput)); - break; - default : - throw new UnsupportedOperationException("invalid dataType " + dataType); - } - } - - public String getInput(double... inputValues) { - StringBuilder sbuilder = new StringBuilder(); - boolean first = true; - for (int i = 0; i < outputs.length; i++) { - if (first) { - first = false; - } else { - sbuilder.append(","); - } - sbuilder.append(inputs[i]); - sbuilder.append(": "); - sbuilder.append(inputValues[i]); - } - return sbuilder.toString(); - } - - public String getOutput(double... outputValues) { - StringBuilder sbuilder = new StringBuilder(); - boolean first = true; - for (int i = 0; i < outputs.length; i++) { - if (first) { - first = false; - } else { - sbuilder.append(","); - } - sbuilder.append(outputs[i]); - sbuilder.append(": "); - sbuilder.append(outputValues[i]); - } - return sbuilder.toString(); - } - - public TrainingSet getTrainingSet() { - return trainingSet; - } -} diff --git a/src/net/woodyfolsom/msproj/ann/PassLearner.java b/src/net/woodyfolsom/msproj/ann/PassLearner.java deleted file mode 100644 index 1d99127..0000000 --- a/src/net/woodyfolsom/msproj/ann/PassLearner.java +++ /dev/null @@ -1,127 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import java.io.File; -import java.io.FileInputStream; -import java.io.FilenameFilter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import net.woodyfolsom.msproj.Action; -import net.woodyfolsom.msproj.GameRecord; -import net.woodyfolsom.msproj.Referee; -import net.woodyfolsom.msproj.ann.PassData.DATA_TYPE; - -import org.antlr.runtime.RecognitionException; -import org.neuroph.core.NeuralNetwork; -import org.neuroph.core.learning.SupervisedTrainingElement; -import org.neuroph.core.learning.TrainingSet; -import org.neuroph.nnet.MultiLayerPerceptron; -import org.neuroph.util.TransferFunctionType; - -public class PassLearner implements NeuralNetLearner { - private NeuralNetwork neuralNetwork; - - public PassLearner() { - reset(); - } - - private File[] getDataFiles(String dirName) { - File file = new File(dirName); - return file.listFiles(new FilenameFilter() { - @Override - public boolean accept(File dir, String name) { - return name.toLowerCase().endsWith(".sgf"); - } - }); - } - - public static void main(String[] args) { - new PassLearner().learnANN(); - } - - private void learnANN() { - List parsedRecords = new ArrayList(); - - for (File sgfFile : getDataFiles("data/games/random_vs_random")) { - System.out.println("Parsing " + sgfFile.getPath() + "..."); - try { - GameRecord gameRecord = parseSGF(sgfFile); - while (!gameRecord.isFinished()) { - System.out.println("Game is not finished, passing as player to move"); - gameRecord.play(gameRecord.getPlayerToMove(), Action.PASS); - } - parsedRecords.add(gameRecord); - } catch (RecognitionException re) { - re.printStackTrace(); - } catch (IOException ioe) { - ioe.printStackTrace(); - } - } - - PassData passData = new PassData(); - - for (GameRecord gameRecord : parsedRecords) { - System.out.println(gameRecord.getResult().getFullText()); - passData.addData(DATA_TYPE.TRAINING, gameRecord); - } - - System.out.println("PassData: "); - System.out.println(passData); - - learn(passData.getTrainingSet()); - - getNeuralNetwork().setInput(0.75,0.25); - System.out.println("Output of ann(0.75,0.25): " + passData.getOutput(getNeuralNetwork().getOutput())); - - getNeuralNetwork().setInput(0.25,0.50); - System.out.println("Output of ann(0.50,0.99): " + passData.getOutput(getNeuralNetwork().getOutput())); - - getNeuralNetwork().save("data/networks/Pass2.nn"); - - testNetwork(getNeuralNetwork(), passData.getTrainingSet()); - } - - public GameRecord parseSGF(File sgfFile) throws IOException, - RecognitionException { - FileInputStream sgfInputStream; - - sgfInputStream = new FileInputStream(sgfFile); - return Referee.replay(sgfInputStream); - } - - @Override - public NeuralNetwork getNeuralNetwork() { - return neuralNetwork; - } - - @Override - public void learn(TrainingSet trainingSet) { - this.neuralNetwork.learn(trainingSet); - } - - @Override - public void reset() { - this.neuralNetwork = new MultiLayerPerceptron( - TransferFunctionType.TANH, 2, 3, 2); - } - - @Override - public void setNeuralNetwork(NeuralNetwork neuralNetwork) { - this.neuralNetwork = neuralNetwork; - } - -private void testNetwork(NeuralNetwork nnet, TrainingSet trainingSet) { - for (SupervisedTrainingElement trainingElement : trainingSet.elements()) { - - nnet.setInput(trainingElement.getInput()); - nnet.calculate(); - double[] networkOutput = nnet.getOutput(); - System.out.print("Input: " - + Arrays.toString(trainingElement.getInput())); - System.out.println(" Output: " + Arrays.toString(networkOutput)); - - } -} -} diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java b/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java new file mode 100644 index 0000000..fc3018b --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java @@ -0,0 +1,484 @@ +package net.woodyfolsom.msproj.ann; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.encog.EncogError; +import org.encog.engine.network.activation.ActivationFunction; +import org.encog.engine.network.activation.ActivationSigmoid; +import org.encog.mathutil.IntRange; +import org.encog.ml.MLMethod; +import org.encog.ml.TrainingImplementationType; +import org.encog.ml.data.MLDataPair; +import org.encog.ml.data.MLDataSet; +import org.encog.ml.train.MLTrain; +import org.encog.ml.train.strategy.Strategy; +import org.encog.ml.train.strategy.end.EndTrainingStrategy; +import org.encog.neural.error.ErrorFunction; +import org.encog.neural.error.LinearErrorFunction; +import org.encog.neural.flat.FlatNetwork; +import org.encog.neural.networks.ContainsFlat; +import org.encog.neural.networks.training.LearningRate; +import org.encog.neural.networks.training.Momentum; +import org.encog.neural.networks.training.Train; +import org.encog.neural.networks.training.TrainingError; +import org.encog.neural.networks.training.propagation.TrainingContinuation; +import org.encog.neural.networks.training.propagation.back.Backpropagation; +import org.encog.neural.networks.training.strategy.SmartLearningRate; +import org.encog.neural.networks.training.strategy.SmartMomentum; +import org.encog.util.EncogValidate; +import org.encog.util.EngineArray; +import org.encog.util.concurrency.DetermineWorkload; +import org.encog.util.concurrency.EngineConcurrency; +import org.encog.util.concurrency.MultiThreadable; +import org.encog.util.concurrency.TaskGroup; +import org.encog.util.logging.EncogLogging; + +/** + * This class started as a verbatim copy of BackPropagation from the open-source + * Encog framework. It was merged with its super-classes to access protected + * fields without resorting to reflection. + */ +public class TemporalDifferenceLearning implements MLTrain, Momentum, + LearningRate, Train, MultiThreadable { + // New fields for TD(lambda) + private final double lambda; + // end new fields + + // BackProp + public static final String LAST_DELTA = "LAST_DELTA"; + private double learningRate; + private double momentum; + private double[] lastDelta; + // End BackProp + + // Propagation + private FlatNetwork currentFlatNetwork; + private int numThreads; + protected double[] gradients; + private double[] lastGradient; + protected ContainsFlat network; + // private MLDataSet indexable; + private Set> indexable; + private GradientWorker[] workers; + private double totalError; + protected double lastError; + private Throwable reportedException; + private double[] flatSpot; + private boolean shouldFixFlatSpot; + private ErrorFunction ef = new LinearErrorFunction(); + // End Propagation + + // BasicTraining + private final List strategies = new ArrayList(); + private Set> training; + private double error; + private int iteration; + private TrainingImplementationType implementationType; + + // End BasicTraining + + public TemporalDifferenceLearning(final ContainsFlat network, + final Set> training, double lambda) { + this(network, training, 0, 0, lambda); + addStrategy(new SmartLearningRate()); + addStrategy(new SmartMomentum()); + } + + public TemporalDifferenceLearning(final ContainsFlat network, + Set> training, final double theLearnRate, + final double theMomentum, double lambda) { + initPropagation(network, training); + // TODO consider how to re-implement validation + // ValidateNetwork.validateMethodToData(network, training); + this.momentum = theMomentum; + this.learningRate = theLearnRate; + this.lastDelta = new double[network.getFlat().getWeights().length]; + this.lambda = lambda; + } + + private void initPropagation(final ContainsFlat network, + final Set> training) { + initBasicTraining(TrainingImplementationType.Iterative); + this.network = network; + this.currentFlatNetwork = network.getFlat(); + setTraining(training); + + this.gradients = new double[this.currentFlatNetwork.getWeights().length]; + this.lastGradient = new double[this.currentFlatNetwork.getWeights().length]; + + this.indexable = training; + this.numThreads = 0; + this.reportedException = null; + this.shouldFixFlatSpot = true; + } + + private void initBasicTraining(TrainingImplementationType implementationType) { + this.implementationType = implementationType; + } + + // Methods from BackPropagation + @Override + public boolean canContinue() { + return false; + } + + public double[] getLastDelta() { + return this.lastDelta; + } + + @Override + public double getLearningRate() { + return this.learningRate; + } + + @Override + public double getMomentum() { + return this.momentum; + } + + public boolean isValidResume(final TrainingContinuation state) { + if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) { + return false; + } + + if (!state.getTrainingType().equals(getClass().getSimpleName())) { + return false; + } + + final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA); + return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length; + } + + @Override + public TrainingContinuation pause() { + final TrainingContinuation result = new TrainingContinuation(); + result.setTrainingType(this.getClass().getSimpleName()); + result.set(Backpropagation.LAST_DELTA, this.lastDelta); + return result; + } + + @Override + public void resume(final TrainingContinuation state) { + if (!isValidResume(state)) { + throw new TrainingError("Invalid training resume data length"); + } + + this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA)); + } + + @Override + public void setLearningRate(final double rate) { + this.learningRate = rate; + } + + @Override + public void setMomentum(final double m) { + this.momentum = m; + } + + public double updateWeight(final double[] gradients, + final double[] lastGradient, final int index) { + final double delta = (gradients[index] * this.learningRate) + + (this.lastDelta[index] * this.momentum); + this.lastDelta[index] = delta; + + System.out.println("Updating weights for connection: " + index + + " with lambda: " + lambda); + + return delta; + } + + public void initOthers() { + } + + // End methods from BackPropagation + + // Methods from Propagation + public void finishTraining() { + basicFinishTraining(); + } + + public FlatNetwork getCurrentFlatNetwork() { + return this.currentFlatNetwork; + } + + public MLMethod getMethod() { + return this.network; + } + + public void iteration() { + iteration(1); + } + + public void rollIteration() { + this.iteration++; + } + + public void iteration(final int count) { + + try { + for (int i = 0; i < count; i++) { + + preIteration(); + + rollIteration(); + + calculateGradients(); + + if (this.currentFlatNetwork.isLimited()) { + learnLimited(); + } else { + learn(); + } + + this.lastError = this.getError(); + + for (final GradientWorker worker : this.workers) { + EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(), + 0, worker.getWeights(), 0, + this.currentFlatNetwork.getWeights().length); + } + + if (this.currentFlatNetwork.getHasContext()) { + copyContexts(); + } + + if (this.reportedException != null) { + throw (new EncogError(this.reportedException)); + } + + postIteration(); + + EncogLogging.log(EncogLogging.LEVEL_INFO, + "Training iteration done, error: " + getError()); + + } + } catch (final ArrayIndexOutOfBoundsException ex) { + EncogValidate.validateNetworkForTraining(this.network, + getTraining()); + throw new EncogError(ex); + } + } + + public void setThreadCount(final int numThreads) { + this.numThreads = numThreads; + } + + @Override + public int getThreadCount() { + return this.numThreads; + } + + public void fixFlatSpot(boolean b) { + this.shouldFixFlatSpot = b; + } + + public void setErrorFunction(ErrorFunction ef) { + this.ef = ef; + } + + public void calculateGradients() { + if (this.workers == null) { + init(); + } + + if (this.currentFlatNetwork.getHasContext()) { + this.workers[0].getNetwork().clearContext(); + } + + this.totalError = 0; + + if (this.workers.length > 1) { + + final TaskGroup group = EngineConcurrency.getInstance() + .createTaskGroup(); + + for (final GradientWorker worker : this.workers) { + EngineConcurrency.getInstance().processTask(worker, group); + } + + group.waitForComplete(); + } else { + this.workers[0].run(); + } + + this.setError(this.totalError / this.workers.length); + + } + + /** + * Copy the contexts to keep them consistent with multithreaded training. + */ + private void copyContexts() { + + // copy the contexts(layer outputO from each group to the next group + for (int i = 0; i < (this.workers.length - 1); i++) { + final double[] src = this.workers[i].getNetwork().getLayerOutput(); + final double[] dst = this.workers[i + 1].getNetwork() + .getLayerOutput(); + EngineArray.arrayCopy(src, dst); + } + + // copy the contexts from the final group to the real network + EngineArray.arrayCopy(this.workers[this.workers.length - 1] + .getNetwork().getLayerOutput(), this.currentFlatNetwork + .getLayerOutput()); + } + + private void init() { + // fix flat spot, if needed + this.flatSpot = new double[this.currentFlatNetwork + .getActivationFunctions().length]; + + if (this.shouldFixFlatSpot) { + for (int i = 0; i < this.currentFlatNetwork + .getActivationFunctions().length; i++) { + final ActivationFunction af = this.currentFlatNetwork + .getActivationFunctions()[i]; + + if (af instanceof ActivationSigmoid) { + this.flatSpot[i] = 0.1; + } else { + this.flatSpot[i] = 0.0; + } + } + } else { + EngineArray.fill(this.flatSpot, 0.0); + } + + // setup workers + final DetermineWorkload determine = new DetermineWorkload( + this.numThreads, (int) this.indexable.size()); + // this.numThreads, (int) this.indexable.getRecordCount()); + + this.workers = new GradientWorker[determine.getThreadCount()]; + + int index = 0; + + // handle CPU + for (final IntRange r : determine.calculateWorkers()) { + this.workers[index++] = new GradientWorker( + this.currentFlatNetwork.clone(), this, new HashSet( + this.indexable), r.getLow(), r.getHigh(), + this.flatSpot, this.ef); + } + + initOthers(); + } + + public void report(final double[] gradients, final double error, + final Throwable ex) { + synchronized (this) { + if (ex == null) { + + for (int i = 0; i < gradients.length; i++) { + this.gradients[i] += gradients[i]; + } + this.totalError += error; + } else { + this.reportedException = ex; + } + } + } + + protected void learn() { + final double[] weights = this.currentFlatNetwork.getWeights(); + for (int i = 0; i < this.gradients.length; i++) { + weights[i] += updateWeight(this.gradients, this.lastGradient, i); + this.gradients[i] = 0; + } + } + + protected void learnLimited() { + final double limit = this.currentFlatNetwork.getConnectionLimit(); + final double[] weights = this.currentFlatNetwork.getWeights(); + for (int i = 0; i < this.gradients.length; i++) { + if (Math.abs(weights[i]) < limit) { + weights[i] = 0; + } else { + weights[i] += updateWeight(this.gradients, this.lastGradient, i); + } + this.gradients[i] = 0; + } + } + + public double[] getLastGradient() { + return lastGradient; + } + + // End methods from Propagation + + // Methods from BasicTraining/ + public void addStrategy(final Strategy strategy) { + strategy.init(this); + this.strategies.add(strategy); + } + + public void basicFinishTraining() { + } + + public double getError() { + return this.error; + } + + public int getIteration() { + return this.iteration; + } + + public List getStrategies() { + return this.strategies; + } + + public MLDataSet getTraining() { + throw new UnsupportedOperationException( + "This learning method operates on Set>, not MLDataSet"); + } + + public boolean isTrainingDone() { + for (Strategy strategy : this.strategies) { + if (strategy instanceof EndTrainingStrategy) { + EndTrainingStrategy end = (EndTrainingStrategy) strategy; + if (end.shouldStop()) { + return true; + } + } + } + + return false; + } + + public void postIteration() { + for (final Strategy strategy : this.strategies) { + strategy.postIteration(); + } + } + + public void preIteration() { + + this.iteration++; + + for (final Strategy strategy : this.strategies) { + strategy.preIteration(); + } + } + + public void setError(final double error) { + this.error = error; + } + + public void setIteration(final int iteration) { + this.iteration = iteration; + } + + public void setTraining(final Set> training) { + this.training = training; + } + + public TrainingImplementationType getImplementationType() { + return this.implementationType; + } + // End Methods from BasicTraining +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/WinFilter.java b/src/net/woodyfolsom/msproj/ann/WinFilter.java new file mode 100644 index 0000000..f2671f6 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/WinFilter.java @@ -0,0 +1,112 @@ +package net.woodyfolsom.msproj.ann; + +import java.util.List; +import java.util.Set; + +import net.woodyfolsom.msproj.GameState; +import net.woodyfolsom.msproj.Player; + +import org.encog.engine.network.activation.ActivationSigmoid; +import org.encog.ml.data.MLData; +import org.encog.ml.data.MLDataPair; +import org.encog.ml.data.MLDataSet; +import org.encog.ml.train.MLTrain; +import org.encog.neural.networks.BasicNetwork; +import org.encog.neural.networks.layers.BasicLayer; + +public class WinFilter extends AbstractNeuralNetFilter implements + NeuralNetFilter { + + public WinFilter() { + // create a neural network, without using a factory + BasicNetwork network = new BasicNetwork(); + network.addLayer(new BasicLayer(null, false, 2)); + network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4)); + network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2)); + network.getStructure().finalizeStructure(); + network.reset(); + + this.neuralNetwork = network; + } + + @Override + public double computeValue(MLData input) { + if (input instanceof GameStateMLData) { + double[] idealVector = computeVector(input); + GameState gameState = ((GameStateMLData) input).getGameState(); + Player playerToMove = gameState.getPlayerToMove(); + if (playerToMove == Player.BLACK) { + return idealVector[0]; + } else if (playerToMove == Player.WHITE) { + return idealVector[1]; + } else { + throw new RuntimeException("Invalid GameState.playerToMove: " + + playerToMove); + } + } else { + throw new UnsupportedOperationException( + "This NeuralNetFilter only accepts GameStates as input."); + } + } + + @Override + public double[] computeVector(MLData input) { + if (input instanceof GameStateMLData) { + return neuralNetwork.compute(input).getData(); + } else { + throw new UnsupportedOperationException( + "This NeuralNetFilter only accepts GameStates as input."); + } + } + + @Override + public void learn(MLDataSet trainingData) { + throw new UnsupportedOperationException("This filter learns a Set>, not an MLDataSet"); + } + + @Override + public void learn(Set> trainingSet) { + + // train the neural network + final MLTrain train = new TemporalDifferenceLearning(neuralNetwork, + trainingSet, 0.7, 0.8, 0.25); + + actualTrainingEpochs = 0; + + do { + train.iteration(); + System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + + train.getError()); + actualTrainingEpochs++; + } while (train.getError() > 0.01 + && actualTrainingEpochs <= maxTrainingEpochs); + } + + @Override + public void reset() { + neuralNetwork.reset(); + } + + @Override + public void reset(int seed) { + neuralNetwork.reset(seed); + } + + @Override + public BasicNetwork getNeuralNetwork() { + // TODO Auto-generated method stub + return null; + } + + @Override + public int getInputSize() { + // TODO Auto-generated method stub + return 0; + } + + @Override + public int getOutputSize() { + // TODO Auto-generated method stub + return 0; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/XORFilter.java b/src/net/woodyfolsom/msproj/ann/XORFilter.java new file mode 100644 index 0000000..e6be979 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/XORFilter.java @@ -0,0 +1,83 @@ +package net.woodyfolsom.msproj.ann; + +import java.util.List; +import java.util.Set; + +import org.encog.engine.network.activation.ActivationSigmoid; +import org.encog.ml.data.MLData; +import org.encog.ml.data.MLDataPair; +import org.encog.ml.data.MLDataSet; +import org.encog.ml.data.basic.BasicMLDataSet; +import org.encog.ml.train.MLTrain; +import org.encog.neural.networks.BasicNetwork; +import org.encog.neural.networks.layers.BasicLayer; +import org.encog.neural.networks.training.propagation.back.Backpropagation; + +/** + * Based on sample code from http://neuroph.sourceforge.net + * + * @author Woody + * + */ +public class XORFilter extends AbstractNeuralNetFilter implements + NeuralNetFilter { + + public XORFilter() { + // create a neural network, without using a factory + BasicNetwork network = new BasicNetwork(); + network.addLayer(new BasicLayer(null, false, 2)); + network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3)); + network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1)); + network.getStructure().finalizeStructure(); + network.reset(); + + this.neuralNetwork = network; + } + + @Override + public void learn(MLDataSet trainingSet) { + + // train the neural network + final MLTrain train = new Backpropagation(neuralNetwork, + trainingSet, 0.7, 0.8); + + actualTrainingEpochs = 0; + + do { + train.iteration(); + System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + + train.getError()); + actualTrainingEpochs++; + } while (train.getError() > 0.01 + && actualTrainingEpochs <= maxTrainingEpochs); + } + + @Override + public double[] computeVector(MLData mlData) { + MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() }, + new double[][] { new double[getOutputSize()] }); + MLData output = neuralNetwork.compute(dataset.get(0).getInput()); + return output.getData(); + } + + @Override + public int getInputSize() { + return 2; + } + + @Override + public int getOutputSize() { + // TODO Auto-generated method stub + return 1; + } + + @Override + public double computeValue(MLData input) { + return computeVector(input)[0]; + } + + @Override + public void learn(Set> trainingSet) { + throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set>."); + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/XORLearner.java b/src/net/woodyfolsom/msproj/ann/XORLearner.java deleted file mode 100644 index 64fcde9..0000000 --- a/src/net/woodyfolsom/msproj/ann/XORLearner.java +++ /dev/null @@ -1,42 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import org.neuroph.core.NeuralNetwork; -import org.neuroph.core.learning.SupervisedTrainingElement; -import org.neuroph.core.learning.TrainingSet; -import org.neuroph.nnet.MultiLayerPerceptron; -import org.neuroph.util.TransferFunctionType; - -/** - * Based on sample code from http://neuroph.sourceforge.net - * - * @author Woody - * - */ -public class XORLearner implements NeuralNetLearner { - private NeuralNetwork neuralNetwork; - - public XORLearner() { - reset(); - } - - @Override - public NeuralNetwork getNeuralNetwork() { - return neuralNetwork; - } - - @Override - public void learn(TrainingSet trainingSet) { - this.neuralNetwork.learn(trainingSet); - } - - @Override - public void reset() { - this.neuralNetwork = new MultiLayerPerceptron( - TransferFunctionType.TANH, 2, 3, 1); - } - - @Override - public void setNeuralNetwork(NeuralNetwork neuralNetwork) { - this.neuralNetwork = neuralNetwork; - } -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/gui/Goban.java b/src/net/woodyfolsom/msproj/gui/Goban.java index 54041cb..e206977 100644 --- a/src/net/woodyfolsom/msproj/gui/Goban.java +++ b/src/net/woodyfolsom/msproj/gui/Goban.java @@ -5,6 +5,7 @@ import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.WindowAdapter; import java.awt.event.WindowEvent; +import java.util.concurrent.atomic.AtomicInteger; import javax.swing.JButton; import javax.swing.JFrame; @@ -14,22 +15,31 @@ import net.woodyfolsom.msproj.Action; import net.woodyfolsom.msproj.GameConfig; import net.woodyfolsom.msproj.GameState; import net.woodyfolsom.msproj.Player; +import net.woodyfolsom.msproj.StandAloneGame; import net.woodyfolsom.msproj.sfx.SfxPlayer; public class Goban extends JFrame { private static final long serialVersionUID = 1L; + private static AtomicInteger openWindows = new AtomicInteger(0); + private GridPanel gridPanel; private SfxPlayer sfxPlayer; - public Goban(GameConfig gameConfig, Player guiPlayer) { + public Goban(GameConfig gameConfig, Player guiPlayer, String gameName) { + super(gameName); + setLayout(new BorderLayout()); addWindowListener(new WindowAdapter() { @Override public void windowClosing(WindowEvent e) { sfxPlayer.cleanup(); + int windowsLeftOpen = openWindows.addAndGet(-1); + if (windowsLeftOpen < 1) { + System.exit(StandAloneGame.EXIT_USER_QUIT); + } } }); @@ -42,9 +52,6 @@ public class Goban extends JFrame { this.gridPanel = new GridPanel(gameConfig, guiPlayer, sfxPlayer); add(gridPanel,BorderLayout.CENTER); - setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); - setVisible(true); - JButton passBtn = new JButton("Pass"); JButton resignBtn = new JButton("Resign"); @@ -67,6 +74,9 @@ public class Goban extends JFrame { bottomPanel.add(resignBtn); add(bottomPanel, BorderLayout.SOUTH); + + setVisible(true); + openWindows.addAndGet(1); pack(); } diff --git a/src/net/woodyfolsom/msproj/gui/GridPanel.java b/src/net/woodyfolsom/msproj/gui/GridPanel.java index 9ec3d82..08b7570 100644 --- a/src/net/woodyfolsom/msproj/gui/GridPanel.java +++ b/src/net/woodyfolsom/msproj/gui/GridPanel.java @@ -229,7 +229,9 @@ public class GridPanel extends JPanel implements MouseListener, @Override public void run() { - sfxPlayer.play(); + if (sfxPlayer != null) { + sfxPlayer.play(); + } }}).start(); } diff --git a/src/net/woodyfolsom/msproj/policy/MonteCarlo.java b/src/net/woodyfolsom/msproj/policy/MonteCarlo.java index d38602f..fb4d35a 100644 --- a/src/net/woodyfolsom/msproj/policy/MonteCarlo.java +++ b/src/net/woodyfolsom/msproj/policy/MonteCarlo.java @@ -13,7 +13,7 @@ import net.woodyfolsom.msproj.tree.GameTreeNode; import net.woodyfolsom.msproj.tree.MonteCarloProperties; public abstract class MonteCarlo implements Policy { - protected static final int ROLLOUT_DEPTH_LIMIT = 150; + protected static final int ROLLOUT_DEPTH_LIMIT = 250; protected int numStateEvaluations = 0; protected Policy movePolicy; diff --git a/src/net/woodyfolsom/msproj/policy/RootParallelization.java b/src/net/woodyfolsom/msproj/policy/RootParallelization.java index b648df1..c9c9e27 100644 --- a/src/net/woodyfolsom/msproj/policy/RootParallelization.java +++ b/src/net/woodyfolsom/msproj/policy/RootParallelization.java @@ -101,8 +101,8 @@ public class RootParallelization implements Policy { System.out.println("It won " + bestWins + " out of " + bestSims + " rollouts among " + totalRollouts - + " total rollouts (" + totalReward.keySet() - + " possible actions) from the current state."); + + " total rollouts (" + totalReward.size() + + " possible moves evaluated) from the current state."); return bestAction; } diff --git a/test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java b/test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java deleted file mode 100644 index aa4d260..0000000 --- a/test/net/woodyfolsom/msproj/ann/NeuralNetLearnerTest.java +++ /dev/null @@ -1,86 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import java.io.File; -import java.util.Arrays; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.neuroph.core.NeuralNetwork; -import org.neuroph.core.learning.SupervisedTrainingElement; -import org.neuroph.core.learning.TrainingSet; - -public class NeuralNetLearnerTest { - private static final String FILENAME = "myMlPerceptron.nnet"; - - @AfterClass - public static void deleteNewNet() { - File file = new File(FILENAME); - if (file.exists()) { - file.delete(); - } - } - - @BeforeClass - public static void deleteSavedNet() { - File file = new File(FILENAME); - if (file.exists()) { - file.delete(); - } - } - - @Test - public void testLearnSaveLoad() { - NeuralNetLearner nnLearner = new XORLearner(); - - // create training set (logical XOR function) - TrainingSet trainingSet = new TrainingSet( - 2, 1); - for (int x = 0; x < 1000; x++) { - trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0, - 0 }, new double[] { 0 })); - trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0, - 1 }, new double[] { 1 })); - trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1, - 0 }, new double[] { 1 })); - trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1, - 1 }, new double[] { 0 })); - } - - nnLearner.learn(trainingSet); - NeuralNetwork nnet = nnLearner.getNeuralNetwork(); - - TrainingSet valSet = new TrainingSet( - 2, 1); - valSet.addElement(new SupervisedTrainingElement(new double[] { 0, - 0 }, new double[] { 0 })); - valSet.addElement(new SupervisedTrainingElement(new double[] { 0, - 1 }, new double[] { 1 })); - valSet.addElement(new SupervisedTrainingElement(new double[] { 1, - 0 }, new double[] { 1 })); - valSet.addElement(new SupervisedTrainingElement(new double[] { 1, - 1 }, new double[] { 0 })); - - System.out.println("Output from eval set (learned network):"); - testNetwork(nnet, valSet); - - nnet.save(FILENAME); - nnet = NeuralNetwork.load(FILENAME); - - System.out.println("Output from eval set (learned network):"); - testNetwork(nnet, valSet); - } - - private void testNetwork(NeuralNetwork nnet, TrainingSet trainingSet) { - for (SupervisedTrainingElement trainingElement : trainingSet.elements()) { - - nnet.setInput(trainingElement.getInput()); - nnet.calculate(); - double[] networkOutput = nnet.getOutput(); - System.out.print("Input: " - + Arrays.toString(trainingElement.getInput())); - System.out.println(" Output: " + Arrays.toString(networkOutput)); - - } - } -} \ No newline at end of file diff --git a/test/net/woodyfolsom/msproj/ann/PassNetworkTest.java b/test/net/woodyfolsom/msproj/ann/PassNetworkTest.java deleted file mode 100644 index 134a8ac..0000000 --- a/test/net/woodyfolsom/msproj/ann/PassNetworkTest.java +++ /dev/null @@ -1,51 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import static org.junit.Assert.assertTrue; - -import org.junit.Test; -import org.neuroph.core.NeuralNetwork; - -public class PassNetworkTest { - - @Test - public void testSavedNetwork1() { - NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn"); - passFilter.setInput(0.75,0.25); - passFilter.calculate(); - - PassData passData = new PassData(); - double[] output = passFilter.getOutput(); - System.out.println("Output: " + passData.getOutput(output)); - - assertTrue(output[0] > 0.50); - assertTrue(output[1] < 0.50); - - passFilter.setInput(0.25,0.50); - passFilter.calculate(); - output = passFilter.getOutput(); - System.out.println("Output: " + passData.getOutput(output)); - assertTrue(output[0] < 0.50); - assertTrue(output[1] > 0.50); - } - - @Test - public void testSavedNetwork2() { - NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass2.nn"); - passFilter.setInput(0.75,0.25); - passFilter.calculate(); - - PassData passData = new PassData(); - double[] output = passFilter.getOutput(); - System.out.println("Output: " + passData.getOutput(output)); - - assertTrue(output[0] > 0.50); - assertTrue(output[1] < 0.50); - - passFilter.setInput(0.45,0.55); - passFilter.calculate(); - output = passFilter.getOutput(); - System.out.println("Output: " + passData.getOutput(output)); - assertTrue(output[0] < 0.50); - assertTrue(output[1] > 0.50); - } -} diff --git a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java new file mode 100644 index 0000000..52eaf09 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java @@ -0,0 +1,66 @@ +package net.woodyfolsom.msproj.ann; + +import java.io.File; +import java.io.FileFilter; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import net.woodyfolsom.msproj.GameRecord; +import net.woodyfolsom.msproj.Referee; + +import org.antlr.runtime.RecognitionException; +import org.encog.ml.data.MLData; +import org.encog.ml.data.MLDataPair; +import org.junit.Test; + +public class WinFilterTest { + + @Test + public void testLearnSaveLoad() throws IOException, RecognitionException { + File[] sgfFiles = new File("data/games/random_vs_random") + .listFiles(new FileFilter() { + @Override + public boolean accept(File pathname) { + return pathname.getName().endsWith(".sgf"); + } + }); + + Set> trainingData = new HashSet>(); + + for (File file : sgfFiles) { + FileInputStream fis = new FileInputStream(file); + GameRecord gameRecord = Referee.replay(fis); + + List gameData = new ArrayList(); + for (int i = 0; i <= gameRecord.getNumTurns(); i++) { + gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i))); + } + + trainingData.add(gameData); + + fis.close(); + } + + WinFilter winFilter = new WinFilter(); + + winFilter.learn(trainingData); + + for (List trainingSequence : trainingData) { + //for (MLDataPair mlDataPair : trainingSequence) { + for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) { + if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) { + continue; + } + MLData input = trainingSequence.get(stateIndex).getInput(); + + System.out.println("Turn " + stateIndex + ": " + input + " => " + + winFilter.computeValue(input)); + } + //} + } + } +} diff --git a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java new file mode 100644 index 0000000..e0c2977 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java @@ -0,0 +1,79 @@ +package net.woodyfolsom.msproj.ann; + +import java.io.File; +import java.io.IOException; + +import org.encog.ml.data.MLDataSet; +import org.encog.ml.data.basic.BasicMLDataSet; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class XORFilterTest { + private static final String FILENAME = "xorPerceptron.net"; + + @AfterClass + public static void deleteNewNet() { + File file = new File(FILENAME); + if (file.exists()) { + file.delete(); + } + } + + @BeforeClass + public static void deleteSavedNet() { + File file = new File(FILENAME); + if (file.exists()) { + file.delete(); + } + } + + @Test + public void testLearnSaveLoad() throws IOException { + NeuralNetFilter nnLearner = new XORFilter(); + System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); + + // create training set (logical XOR function) + int size = 1; + double[][] trainingInput = new double[4 * size][]; + double[][] trainingOutput = new double[4 * size][]; + for (int i = 0; i < size; i++) { + trainingInput[i * 4 + 0] = new double[] { 0, 0 }; + trainingInput[i * 4 + 1] = new double[] { 0, 1 }; + trainingInput[i * 4 + 2] = new double[] { 1, 0 }; + trainingInput[i * 4 + 3] = new double[] { 1, 1 }; + trainingOutput[i * 4 + 0] = new double[] { 0 }; + trainingOutput[i * 4 + 1] = new double[] { 1 }; + trainingOutput[i * 4 + 2] = new double[] { 1 }; + trainingOutput[i * 4 + 3] = new double[] { 0 }; + } + + // create training data + MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput); + + nnLearner.learn(trainingSet); + + double[][] validationSet = new double[4][2]; + + validationSet[0] = new double[] { 0, 0 }; + validationSet[1] = new double[] { 0, 1 }; + validationSet[2] = new double[] { 1, 0 }; + validationSet[3] = new double[] { 1, 1 }; + + System.out.println("Output from eval set (learned network, pre-serialization):"); + testNetwork(nnLearner, validationSet); + + nnLearner.save(FILENAME); + nnLearner.load(FILENAME); + + System.out.println("Output from eval set (learned network, post-serialization):"); + testNetwork(nnLearner, validationSet); + } + + private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) { + for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { + DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]); + System.out.println(dp + " => " + nnLearner.computeValue(dp)); + } + } +} \ No newline at end of file