Implementing temporal difference learning based heavily on Encog framework.
Not functional yet - incremental update.
This commit is contained in:
@@ -7,7 +7,6 @@
|
|||||||
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
|
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/kgsGtp.jar"/>
|
<classpathentry kind="lib" path="lib/kgsGtp.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
|
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/neuroph-2.6.jar"/>
|
<classpathentry kind="lib" path="lib/encog-java-core.jar" sourcepath="lib/encog-java-core-sources.jar"/>
|
||||||
<classpathentry kind="lib" path="lib/encog-engine-2.5.0.jar"/>
|
|
||||||
<classpathentry kind="output" path="bin"/>
|
<classpathentry kind="output" path="bin"/>
|
||||||
</classpath>
|
</classpath>
|
||||||
|
|||||||
26
build.xml
26
build.xml
@@ -23,7 +23,7 @@
|
|||||||
|
|
||||||
<target name="compile" depends="init" description="compile the source ">
|
<target name="compile" depends="init" description="compile the source ">
|
||||||
<!-- Compile the java code from ${src} into ${build} -->
|
<!-- Compile the java code from ${src} into ${build} -->
|
||||||
<javac srcdir="${src}" destdir="${build}" classpathref="build.classpath" debug="true" source="1.6" target="1.6"/>
|
<javac includeantruntime="false" srcdir="${src}" destdir="${build}" classpathref="build.classpath" debug="true"/>
|
||||||
</target>
|
</target>
|
||||||
|
|
||||||
<target name="compile-test" depends="compile">
|
<target name="compile-test" depends="compile">
|
||||||
@@ -33,9 +33,25 @@
|
|||||||
</target>
|
</target>
|
||||||
|
|
||||||
<target name="copy-resources">
|
<target name="copy-resources">
|
||||||
<copy todir="${dist}">
|
<copy todir="${dist}/data">
|
||||||
<fileset dir="data" />
|
<fileset dir="data" />
|
||||||
</copy>
|
</copy>
|
||||||
|
<copy todir="${build}/net/woodyfolsom/msproj/gui">
|
||||||
|
<fileset dir="${src}/net/woodyfolsom/msproj/gui">
|
||||||
|
<exclude name="**/*.java"/>
|
||||||
|
</fileset>
|
||||||
|
</copy>
|
||||||
|
<copy todir="${build}/net/woodyfolsom/msproj/sfx">
|
||||||
|
<fileset dir="${src}/net/woodyfolsom/msproj/sfx">
|
||||||
|
<exclude name="**/*.java"/>
|
||||||
|
</fileset>
|
||||||
|
</copy>
|
||||||
|
</target>
|
||||||
|
|
||||||
|
<target name="copy-libs">
|
||||||
|
<copy todir="${dist}/lib">
|
||||||
|
<fileset dir="lib" />
|
||||||
|
</copy>
|
||||||
</target>
|
</target>
|
||||||
|
|
||||||
<target name="clean" description="clean up">
|
<target name="clean" description="clean up">
|
||||||
@@ -44,12 +60,12 @@
|
|||||||
<delete dir="${dist}" />
|
<delete dir="${dist}" />
|
||||||
</target>
|
</target>
|
||||||
|
|
||||||
<target name="dist" depends="compile,copy-resources" description="generate the distribution">
|
<target name="dist" depends="compile,copy-resources,copy-libs" description="generate the distribution">
|
||||||
<jar jarfile="${dist}/GoGame.jar">
|
<jar jarfile="${dist}/GoGame.jar">
|
||||||
<fileset dir="${build}" excludes="**/*Test.class" />
|
<fileset dir="${build}" excludes="**/*Test.class" />
|
||||||
<manifest>
|
<manifest>
|
||||||
<attribute name="Main-Class" value="net.woodyfolsom.msproj.GoGame" />
|
<attribute name="Main-Class" value="net.woodyfolsom.msproj.StandAloneGame" />
|
||||||
<attribute name="Class-Path" value="kgsGtp.jar log4j-1.2.16.jar"/>
|
<attribute name="Class-Path" value="lib/kgsGtp.jar lib/log4j-1.2.16.jar lib/antlrworks-1.4.3.jar lib/encog-engine-2.5.0.jar lib/neuroph-2.6.jar"/>
|
||||||
</manifest>
|
</manifest>
|
||||||
</jar>
|
</jar>
|
||||||
</target>
|
</target>
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
PlayerOne=ROOT_PAR
|
PlayerOne=RANDOM
|
||||||
PlayerTwo=RANDOM
|
PlayerTwo=RANDOM
|
||||||
GUIDelay=2000 //1 second
|
GUIDelay=1000 //1 second
|
||||||
BoardSize=9
|
BoardSize=9
|
||||||
Komi=6.5
|
Komi=6.5
|
||||||
NumGames=10 //Games for each player
|
NumGames=1000 //Games for each color per player
|
||||||
TurnTime=2000 //seconds per player per turn
|
TurnTime=1000 //seconds per player per turn
|
||||||
|
SpectatorBoardShown=false;
|
||||||
|
WhiteMoveLogged=false;
|
||||||
|
BlackMoveLogged=false;
|
||||||
Binary file not shown.
Binary file not shown.
BIN
lib/encog-java-core-javadoc.jar
Normal file
BIN
lib/encog-java-core-javadoc.jar
Normal file
Binary file not shown.
BIN
lib/encog-java-core-sources.jar
Normal file
BIN
lib/encog-java-core-sources.jar
Normal file
Binary file not shown.
BIN
lib/encog-java-core.jar
Normal file
BIN
lib/encog-java-core.jar
Normal file
Binary file not shown.
Binary file not shown.
@@ -23,6 +23,16 @@ public class GameRecord {
|
|||||||
moves.add(Action.NONE);
|
moves.add(Action.NONE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GameRecord(GameRecord that) {
|
||||||
|
for(GameState gameState : that.gameStates) {
|
||||||
|
gameStates.add(new GameState(gameState));
|
||||||
|
}
|
||||||
|
//initial 'move' of Action.NONE allows for a game that starts with a board setup
|
||||||
|
for (Action action : that.moves) {
|
||||||
|
moves.add(action);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds a comment for the current turn.
|
* Adds a comment for the current turn.
|
||||||
* @param comment
|
* @param comment
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ public class GameSettings {
|
|||||||
private int boardSize = 9;
|
private int boardSize = 9;
|
||||||
private double komi = 6.5;
|
private double komi = 6.5;
|
||||||
private int numGames = 10;
|
private int numGames = 10;
|
||||||
|
private boolean spectatorBoardShown = false;
|
||||||
|
private boolean whiteMoveLogged = true;
|
||||||
|
private boolean blackMoveLogged = true;
|
||||||
|
|
||||||
private GameSettings() {
|
private GameSettings() {
|
||||||
}
|
}
|
||||||
@@ -49,6 +52,12 @@ public class GameSettings {
|
|||||||
gameSettings.setNumGames(Integer.parseInt(value));
|
gameSettings.setNumGames(Integer.parseInt(value));
|
||||||
} else if ("Komi".equals(name)) {
|
} else if ("Komi".equals(name)) {
|
||||||
gameSettings.setKomi(Double.parseDouble(value));
|
gameSettings.setKomi(Double.parseDouble(value));
|
||||||
|
} else if ("SpectatorBoardShown".equals(name)) {
|
||||||
|
gameSettings.setSpectatorBoardShown(Boolean.parseBoolean(value));
|
||||||
|
} else if ("WhiteMoveLogged".equals(name)) {
|
||||||
|
gameSettings.setWhiteMoveLogged(Boolean.parseBoolean(value));
|
||||||
|
} else if ("BlackMoveLogged".equals(name)) {
|
||||||
|
gameSettings.setBlackMoveLogged(Boolean.parseBoolean(value));
|
||||||
} else {
|
} else {
|
||||||
System.out.println("Ignoring game settings property with unrecognized name: " + name);
|
System.out.println("Ignoring game settings property with unrecognized name: " + name);
|
||||||
}
|
}
|
||||||
@@ -127,4 +136,29 @@ public class GameSettings {
|
|||||||
sb.append(", GUIDelay=" + guiDelay);
|
sb.append(", GUIDelay=" + guiDelay);
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isSpectatorBoardShown() {
|
||||||
|
return spectatorBoardShown;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setSpectatorBoardShown(boolean spectatorBoardShown) {
|
||||||
|
this.spectatorBoardShown = spectatorBoardShown;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isWhiteMoveLogged() {
|
||||||
|
return whiteMoveLogged;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setWhiteMoveLogged(boolean whiteMoveLogged) {
|
||||||
|
this.whiteMoveLogged = whiteMoveLogged;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isBlackMoveLogged() {
|
||||||
|
return blackMoveLogged;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setBlackMoveLogged(boolean blackMoveLogged) {
|
||||||
|
this.blackMoveLogged = blackMoveLogged;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -4,8 +4,6 @@ import java.io.File;
|
|||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.text.DateFormat;
|
|
||||||
import java.text.SimpleDateFormat;
|
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.gui.Goban;
|
import net.woodyfolsom.msproj.gui.Goban;
|
||||||
import net.woodyfolsom.msproj.policy.HumanGuiInput;
|
import net.woodyfolsom.msproj.policy.HumanGuiInput;
|
||||||
@@ -65,10 +63,11 @@ public class Referee {
|
|||||||
return gameRecord;
|
return gameRecord;
|
||||||
}
|
}
|
||||||
|
|
||||||
public GameResult play(GameConfig gameConfig, int gameNo) {
|
public GameResult play(GameConfig gameConfig, int gameNo,
|
||||||
|
boolean showSpectatorBoard, boolean logGameRecord) {
|
||||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||||
|
|
||||||
System.out.println("Game started.");
|
//System.out.println("Game started.");
|
||||||
|
|
||||||
GameState initialGameState = gameRecord.getGameState(gameRecord
|
GameState initialGameState = gameRecord.getGameState(gameRecord
|
||||||
.getNumTurns());
|
.getNumTurns());
|
||||||
@@ -78,20 +77,21 @@ public class Referee {
|
|||||||
whitePolicy.setState(initialGameState);
|
whitePolicy.setState(initialGameState);
|
||||||
|
|
||||||
Goban spectatorBoard;
|
Goban spectatorBoard;
|
||||||
if (blackPolicy instanceof HumanGuiInput
|
if (blackPolicy instanceof HumanGuiInput || whitePolicy instanceof HumanGuiInput) {
|
||||||
|| whitePolicy instanceof HumanGuiInput) {
|
|
||||||
System.out.println("Human is controlling the game board GUI.");
|
System.out.println("Human is controlling the game board GUI.");
|
||||||
spectatorBoard = null;
|
spectatorBoard = null;
|
||||||
} else {
|
} else if (showSpectatorBoard){
|
||||||
System.out.println("Starting game board GUI in spectator mode.");
|
System.out.println("Starting game board GUI in spectator mode.");
|
||||||
spectatorBoard = new Goban(gameConfig, null);
|
spectatorBoard = new Goban(gameConfig, null, "Game #" + gameNo);
|
||||||
|
} else { // else showing spectator board is disabled
|
||||||
|
spectatorBoard = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (!gameRecord.isFinished()) {
|
while (!gameRecord.isFinished()) {
|
||||||
GameState gameState = gameRecord.getGameState(gameRecord
|
GameState gameState = gameRecord.getGameState(gameRecord
|
||||||
.getNumTurns());
|
.getNumTurns());
|
||||||
//System.out.println(gameState);
|
// System.out.println(gameState);
|
||||||
|
|
||||||
Player playerToMove = gameRecord.getPlayerToMove();
|
Player playerToMove = gameRecord.getPlayerToMove();
|
||||||
Policy policy = getPolicy(playerToMove);
|
Policy policy = getPolicy(playerToMove);
|
||||||
@@ -120,8 +120,9 @@ public class Referee {
|
|||||||
|
|
||||||
System.out.println("Game over. Result: " + result);
|
System.out.println("Game over. Result: " + result);
|
||||||
|
|
||||||
//DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
|
// DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmssZ");
|
||||||
|
|
||||||
|
if (logGameRecord) {
|
||||||
try {
|
try {
|
||||||
|
|
||||||
// File sgfFile = new File("gogame-" + dateFormat.format(new Date())
|
// File sgfFile = new File("gogame-" + dateFormat.format(new Date())
|
||||||
@@ -143,8 +144,9 @@ public class Referee {
|
|||||||
System.out.println("Unable to save game file due to IOException: "
|
System.out.println("Unable to save game file due to IOException: "
|
||||||
+ ioe.getMessage());
|
+ ioe.getMessage());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
System.out.println("Game finished.");
|
|
||||||
|
//System.out.println("Game finished.");
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
|||||||
import net.woodyfolsom.msproj.policy.RootParallelization;
|
import net.woodyfolsom.msproj.policy.RootParallelization;
|
||||||
|
|
||||||
public class StandAloneGame {
|
public class StandAloneGame {
|
||||||
private static final int EXIT_NOMINAL = 0;
|
public static final int EXIT_USER_QUIT = 1;
|
||||||
private static final int EXIT_IO_EXCEPTION = 1;
|
public static final int EXIT_NOMINAL = 0;
|
||||||
|
public static final int EXIT_IO_EXCEPTION = -1;
|
||||||
|
|
||||||
private int gameNo = 0;
|
private int gameNo = 0;
|
||||||
|
|
||||||
@@ -38,7 +39,9 @@ public class StandAloneGame {
|
|||||||
parsePlayerType(gameSettings.getPlayerOne()),
|
parsePlayerType(gameSettings.getPlayerOne()),
|
||||||
parsePlayerType(gameSettings.getPlayerTwo()),
|
parsePlayerType(gameSettings.getPlayerTwo()),
|
||||||
gameSettings.getBoardSize(), gameSettings.getKomi(),
|
gameSettings.getBoardSize(), gameSettings.getKomi(),
|
||||||
gameSettings.getNumGames(), gameSettings.getTurnTime());
|
gameSettings.getNumGames(), gameSettings.getTurnTime(),
|
||||||
|
gameSettings.isSpectatorBoardShown(),
|
||||||
|
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
ioe.printStackTrace();
|
ioe.printStackTrace();
|
||||||
System.exit(EXIT_IO_EXCEPTION);
|
System.exit(EXIT_IO_EXCEPTION);
|
||||||
@@ -65,7 +68,8 @@ public class StandAloneGame {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
|
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
|
||||||
int size, double komi, int rounds, long turnLength) {
|
int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard,
|
||||||
|
boolean blackMoveLogged, boolean whiteMoveLogged) {
|
||||||
|
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
@@ -74,32 +78,31 @@ public class StandAloneGame {
|
|||||||
|
|
||||||
Referee referee = new Referee();
|
Referee referee = new Referee();
|
||||||
referee.setPolicy(Player.BLACK,
|
referee.setPolicy(Player.BLACK,
|
||||||
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength));
|
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
|
||||||
referee.setPolicy(Player.WHITE,
|
referee.setPolicy(Player.WHITE,
|
||||||
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength));
|
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
|
||||||
|
|
||||||
List<GameResult> round1results = new ArrayList<GameResult>();
|
List<GameResult> round1results = new ArrayList<GameResult>();
|
||||||
|
|
||||||
|
boolean logGameRecords = rounds <= 50;
|
||||||
for (int round = 0; round < rounds; round++) {
|
for (int round = 0; round < rounds; round++) {
|
||||||
gameNo++;
|
gameNo++;
|
||||||
round1results.add(referee.play(gameConfig, gameNo));
|
round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<GameResult> round2results = new ArrayList<GameResult>();
|
List<GameResult> round2results = new ArrayList<GameResult>();
|
||||||
|
|
||||||
referee.setPolicy(Player.BLACK,
|
referee.setPolicy(Player.BLACK,
|
||||||
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength));
|
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
|
||||||
referee.setPolicy(Player.WHITE,
|
referee.setPolicy(Player.WHITE,
|
||||||
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength));
|
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
|
||||||
for (int round = 0; round < rounds; round++) {
|
for (int round = 0; round < rounds; round++) {
|
||||||
gameNo++;
|
gameNo++;
|
||||||
round2results.add(referee.play(gameConfig, gameNo));
|
round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
|
||||||
}
|
}
|
||||||
|
|
||||||
long endTime = System.currentTimeMillis();
|
long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmss");
|
DateFormat dateFormat = new SimpleDateFormat("yyMMddHHmmss");
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
||||||
File txtFile = new File("gotournament-"
|
File txtFile = new File("gotournament-"
|
||||||
@@ -107,14 +110,16 @@ public class StandAloneGame {
|
|||||||
FileWriter writer = new FileWriter(txtFile);
|
FileWriter writer = new FileWriter(txtFile);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
if (!logGameRecords) {
|
||||||
|
System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
|
||||||
|
}
|
||||||
|
|
||||||
logResults(writer, round1results, playerType1.toString(),
|
logResults(writer, round1results, playerType1.toString(),
|
||||||
playerType2.toString());
|
playerType2.toString());
|
||||||
logResults(writer, round2results, playerType2.toString(),
|
logResults(writer, round2results, playerType2.toString(),
|
||||||
playerType1.toString());
|
playerType1.toString());
|
||||||
|
|
||||||
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
|
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
|
||||||
+ " seconds.");
|
+ " seconds.");
|
||||||
|
|
||||||
System.out.println("Game tournament saved as "
|
System.out.println("Game tournament saved as "
|
||||||
+ txtFile.getAbsolutePath());
|
+ txtFile.getAbsolutePath());
|
||||||
} finally {
|
} finally {
|
||||||
@@ -149,19 +154,19 @@ public class StandAloneGame {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
|
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
|
||||||
Player player, long turnLength) {
|
Player player, long turnLength, boolean moveLogged) {
|
||||||
switch (playerType) {
|
switch (playerType) {
|
||||||
case HUMAN:
|
case HUMAN:
|
||||||
return new HumanKeyboardInput();
|
return new HumanKeyboardInput();
|
||||||
case HUMAN_GUI:
|
case HUMAN_GUI:
|
||||||
return new HumanGuiInput(new Goban(gameConfig, player));
|
return new HumanGuiInput(new Goban(gameConfig, player,""));
|
||||||
case ROOT_PAR:
|
case ROOT_PAR:
|
||||||
return new RootParallelization(4, turnLength);
|
return new RootParallelization(4, turnLength);
|
||||||
case UCT:
|
case UCT:
|
||||||
return new MonteCarloUCT(new RandomMovePolicy(), turnLength);
|
return new MonteCarloUCT(new RandomMovePolicy(), turnLength);
|
||||||
case RANDOM:
|
case RANDOM:
|
||||||
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
|
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
|
||||||
randomMovePolicy.setLogging(true);
|
randomMovePolicy.setLogging(moveLogged);
|
||||||
return randomMovePolicy;
|
return randomMovePolicy;
|
||||||
case RAVE:
|
case RAVE:
|
||||||
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
|
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
|
||||||
|
|||||||
54
src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
Normal file
54
src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
|
import org.encog.neural.networks.PersistBasicNetwork;
|
||||||
|
|
||||||
|
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||||
|
protected BasicNetwork neuralNetwork;
|
||||||
|
protected int actualTrainingEpochs = 0;
|
||||||
|
protected int maxTrainingEpochs = 1000;
|
||||||
|
|
||||||
|
public int getActualTrainingEpochs() {
|
||||||
|
return actualTrainingEpochs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMaxTrainingEpochs() {
|
||||||
|
return maxTrainingEpochs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BasicNetwork getNeuralNetwork() {
|
||||||
|
return neuralNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void load(String filename) throws IOException {
|
||||||
|
FileInputStream fis = new FileInputStream(new File(filename));
|
||||||
|
neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis);
|
||||||
|
fis.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
neuralNetwork.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset(int seed) {
|
||||||
|
neuralNetwork.reset(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
FileOutputStream fos = new FileOutputStream(new File(filename));
|
||||||
|
new PersistBasicNetwork().save(fos, getNeuralNetwork());
|
||||||
|
fos.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxTrainingEpochs(int max) {
|
||||||
|
this.maxTrainingEpochs = max;
|
||||||
|
}
|
||||||
|
}
|
||||||
17
src/net/woodyfolsom/msproj/ann/DoublePair.java
Normal file
17
src/net/woodyfolsom/msproj/ann/DoublePair.java
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import org.encog.ml.data.basic.BasicMLData;
|
||||||
|
|
||||||
|
public class DoublePair extends BasicMLData {
|
||||||
|
// private final double x;
|
||||||
|
// private final double y;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public DoublePair(double x, double y) {
|
||||||
|
super(new double[] { x, y });
|
||||||
|
}
|
||||||
|
}
|
||||||
95
src/net/woodyfolsom/msproj/ann/ErrorCalculation.java
Normal file
95
src/net/woodyfolsom/msproj/ann/ErrorCalculation.java
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import org.encog.mathutil.error.ErrorCalculationMode;
|
||||||
|
|
||||||
|
/*
|
||||||
|
Initial erison of this class was a verbatim copy from Encog framework.
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class ErrorCalculation {
|
||||||
|
|
||||||
|
private static ErrorCalculationMode mode = ErrorCalculationMode.MSE;
|
||||||
|
|
||||||
|
public static ErrorCalculationMode getMode() {
|
||||||
|
return ErrorCalculation.mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setMode(final ErrorCalculationMode theMode) {
|
||||||
|
ErrorCalculation.mode = theMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
private double globalError;
|
||||||
|
|
||||||
|
private int setSize;
|
||||||
|
|
||||||
|
public final double calculate() {
|
||||||
|
if (this.setSize == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (ErrorCalculation.getMode()) {
|
||||||
|
case RMS:
|
||||||
|
return calculateRMS();
|
||||||
|
case MSE:
|
||||||
|
return calculateMSE();
|
||||||
|
case ESS:
|
||||||
|
return calculateESS();
|
||||||
|
default:
|
||||||
|
return calculateMSE();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public final double calculateMSE() {
|
||||||
|
if (this.setSize == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
final double err = this.globalError / this.setSize;
|
||||||
|
return err;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public final double calculateESS() {
|
||||||
|
if (this.setSize == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
final double err = this.globalError / 2;
|
||||||
|
return err;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public final double calculateRMS() {
|
||||||
|
if (this.setSize == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
final double err = Math.sqrt(this.globalError / this.setSize);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void reset() {
|
||||||
|
this.globalError = 0;
|
||||||
|
this.setSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void updateError(final double actual, final double ideal) {
|
||||||
|
|
||||||
|
double delta = ideal - actual;
|
||||||
|
|
||||||
|
this.globalError += delta * delta;
|
||||||
|
|
||||||
|
this.setSize++;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void updateError(final double[] actual, final double[] ideal,
|
||||||
|
final double significance) {
|
||||||
|
for (int i = 0; i < actual.length; i++) {
|
||||||
|
double delta = (ideal[i] - actual[i]) * significance;
|
||||||
|
|
||||||
|
this.globalError += delta * delta;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setSize += ideal.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
25
src/net/woodyfolsom/msproj/ann/GameStateMLData.java
Normal file
25
src/net/woodyfolsom/msproj/ann/GameStateMLData.java
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.GameState;
|
||||||
|
|
||||||
|
import org.encog.ml.data.basic.BasicMLData;
|
||||||
|
|
||||||
|
public class GameStateMLData extends BasicMLData {
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private GameState gameState;
|
||||||
|
|
||||||
|
public GameStateMLData(double[] d, GameState gameState) {
|
||||||
|
super(d);
|
||||||
|
// TODO Auto-generated constructor stub
|
||||||
|
this.gameState = gameState;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GameState getGameState() {
|
||||||
|
return gameState;
|
||||||
|
}
|
||||||
|
}
|
||||||
121
src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java
Normal file
121
src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.GameResult;
|
||||||
|
import net.woodyfolsom.msproj.GameState;
|
||||||
|
import net.woodyfolsom.msproj.Player;
|
||||||
|
|
||||||
|
import org.encog.ml.data.MLData;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.encog.ml.data.basic.BasicMLData;
|
||||||
|
import org.encog.ml.data.basic.BasicMLDataPair;
|
||||||
|
import org.encog.util.kmeans.Centroid;
|
||||||
|
|
||||||
|
public class GameStateMLDataPair implements MLDataPair {
|
||||||
|
//private final String[] inputs = { "BlackScore", "WhiteScore" };
|
||||||
|
//private final String[] outputs = { "BlackWins", "WhiteWins" };
|
||||||
|
|
||||||
|
private BasicMLDataPair mlDataPairDelegate;
|
||||||
|
private GameState gameState;
|
||||||
|
|
||||||
|
public GameStateMLDataPair(GameState gameState) {
|
||||||
|
this.gameState = gameState;
|
||||||
|
mlDataPairDelegate = new BasicMLDataPair(
|
||||||
|
new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public GameStateMLDataPair(GameStateMLDataPair that) {
|
||||||
|
this.gameState = new GameState(that.gameState);
|
||||||
|
mlDataPairDelegate = new BasicMLDataPair(
|
||||||
|
that.mlDataPairDelegate.getInput(),
|
||||||
|
that.mlDataPairDelegate.getIdeal());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLDataPair clone() {
|
||||||
|
return new GameStateMLDataPair(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Centroid<MLDataPair> createCentroid() {
|
||||||
|
return mlDataPairDelegate.createCentroid();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a vector of normalized scores from GameState.
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private double[] createInput() {
|
||||||
|
|
||||||
|
GameResult result = gameState.getResult();
|
||||||
|
|
||||||
|
double maxScore = gameState.getGameConfig().getSize()
|
||||||
|
* gameState.getGameConfig().getSize();
|
||||||
|
|
||||||
|
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
|
||||||
|
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
|
||||||
|
|
||||||
|
return new double[] { blackScore, whiteScore };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a vector of values indicating strength of black/white win output
|
||||||
|
* from network.
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private double[] createIdeal() {
|
||||||
|
GameResult result = gameState.getResult();
|
||||||
|
|
||||||
|
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
|
||||||
|
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
|
||||||
|
|
||||||
|
return new double[] { blackWinner, whiteWinner };
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLData getIdeal() {
|
||||||
|
return mlDataPairDelegate.getIdeal();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] getIdealArray() {
|
||||||
|
return mlDataPairDelegate.getIdealArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLData getInput() {
|
||||||
|
return mlDataPairDelegate.getInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] getInputArray() {
|
||||||
|
return mlDataPairDelegate.getInputArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getSignificance() {
|
||||||
|
return mlDataPairDelegate.getSignificance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSupervised() {
|
||||||
|
return mlDataPairDelegate.isSupervised();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setIdealArray(double[] arg0) {
|
||||||
|
mlDataPairDelegate.setIdealArray(arg0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setInputArray(double[] arg0) {
|
||||||
|
mlDataPairDelegate.setInputArray(arg0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setSignificance(double arg0) {
|
||||||
|
mlDataPairDelegate.setSignificance(arg0);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
172
src/net/woodyfolsom/msproj/ann/GradientWorker.java
Normal file
172
src/net/woodyfolsom/msproj/ann/GradientWorker.java
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
/*
|
||||||
|
* Class copied verbatim from Encog framework due to dependency on Propagation
|
||||||
|
* implementation.
|
||||||
|
*
|
||||||
|
* Encog(tm) Core v3.2 - Java Version
|
||||||
|
* http://www.heatonresearch.com/encog/
|
||||||
|
* http://code.google.com/p/encog-java/
|
||||||
|
|
||||||
|
* Copyright 2008-2012 Heaton Research, Inc.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
* For more information on Heaton Research copyrights, licenses
|
||||||
|
* and trademarks visit:
|
||||||
|
* http://www.heatonresearch.com/copyright
|
||||||
|
*/
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.encog.engine.network.activation.ActivationFunction;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.ml.data.basic.BasicMLDataPair;
|
||||||
|
import org.encog.neural.error.ErrorFunction;
|
||||||
|
import org.encog.neural.flat.FlatNetwork;
|
||||||
|
import org.encog.util.EngineArray;
|
||||||
|
import org.encog.util.concurrency.EngineTask;
|
||||||
|
|
||||||
|
public class GradientWorker implements EngineTask {
|
||||||
|
|
||||||
|
private final FlatNetwork network;
|
||||||
|
private final ErrorCalculation errorCalculation = new ErrorCalculation();
|
||||||
|
private final double[] actual;
|
||||||
|
private final double[] layerDelta;
|
||||||
|
private final int[] layerCounts;
|
||||||
|
private final int[] layerFeedCounts;
|
||||||
|
private final int[] layerIndex;
|
||||||
|
private final int[] weightIndex;
|
||||||
|
private final double[] layerOutput;
|
||||||
|
private final double[] layerSums;
|
||||||
|
private final double[] gradients;
|
||||||
|
private final double[] weights;
|
||||||
|
private final MLDataPair pair;
|
||||||
|
private final Set<List<MLDataPair>> training;
|
||||||
|
private final int low;
|
||||||
|
private final int high;
|
||||||
|
private final TemporalDifferenceLearning owner;
|
||||||
|
private double[] flatSpot;
|
||||||
|
private final ErrorFunction errorFunction;
|
||||||
|
|
||||||
|
public GradientWorker(final FlatNetwork theNetwork,
|
||||||
|
final TemporalDifferenceLearning theOwner,
|
||||||
|
final Set<List<MLDataPair>> theTraining, final int theLow,
|
||||||
|
final int theHigh, final double[] flatSpot,
|
||||||
|
ErrorFunction ef) {
|
||||||
|
this.network = theNetwork;
|
||||||
|
this.training = theTraining;
|
||||||
|
this.low = theLow;
|
||||||
|
this.high = theHigh;
|
||||||
|
this.owner = theOwner;
|
||||||
|
this.flatSpot = flatSpot;
|
||||||
|
this.errorFunction = ef;
|
||||||
|
|
||||||
|
this.layerDelta = new double[network.getLayerOutput().length];
|
||||||
|
this.gradients = new double[network.getWeights().length];
|
||||||
|
this.actual = new double[network.getOutputCount()];
|
||||||
|
|
||||||
|
this.weights = network.getWeights();
|
||||||
|
this.layerIndex = network.getLayerIndex();
|
||||||
|
this.layerCounts = network.getLayerCounts();
|
||||||
|
this.weightIndex = network.getWeightIndex();
|
||||||
|
this.layerOutput = network.getLayerOutput();
|
||||||
|
this.layerSums = network.getLayerSums();
|
||||||
|
this.layerFeedCounts = network.getLayerFeedCounts();
|
||||||
|
|
||||||
|
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
|
||||||
|
.getOutputCount());
|
||||||
|
}
|
||||||
|
|
||||||
|
public FlatNetwork getNetwork() {
|
||||||
|
return this.network;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getWeights() {
|
||||||
|
return this.weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void process(final double[] input, final double[] ideal, double s) {
|
||||||
|
this.network.compute(input, this.actual);
|
||||||
|
|
||||||
|
this.errorCalculation.updateError(this.actual, ideal, s);
|
||||||
|
this.errorFunction.calculateError(ideal, actual, this.layerDelta);
|
||||||
|
|
||||||
|
for (int i = 0; i < this.actual.length; i++) {
|
||||||
|
|
||||||
|
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
|
||||||
|
.derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0]))
|
||||||
|
* (this.layerDelta[i] * s);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = this.network.getBeginTraining(); i < this.network
|
||||||
|
.getEndTraining(); i++) {
|
||||||
|
processLevel(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void processLevel(final int currentLevel) {
|
||||||
|
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
|
||||||
|
final int toLayerIndex = this.layerIndex[currentLevel];
|
||||||
|
final int fromLayerSize = this.layerCounts[currentLevel + 1];
|
||||||
|
final int toLayerSize = this.layerFeedCounts[currentLevel];
|
||||||
|
|
||||||
|
final int index = this.weightIndex[currentLevel];
|
||||||
|
final ActivationFunction activation = this.network
|
||||||
|
.getActivationFunctions()[currentLevel];
|
||||||
|
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
|
||||||
|
|
||||||
|
// handle weights
|
||||||
|
int yi = fromLayerIndex;
|
||||||
|
for (int y = 0; y < fromLayerSize; y++) {
|
||||||
|
final double output = this.layerOutput[yi];
|
||||||
|
double sum = 0;
|
||||||
|
int xi = toLayerIndex;
|
||||||
|
int wi = index + y;
|
||||||
|
for (int x = 0; x < toLayerSize; x++) {
|
||||||
|
this.gradients[wi] += output * this.layerDelta[xi];
|
||||||
|
sum += this.weights[wi] * this.layerDelta[xi];
|
||||||
|
wi += fromLayerSize;
|
||||||
|
xi++;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.layerDelta[yi] = sum
|
||||||
|
* (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot);
|
||||||
|
yi++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void run() {
|
||||||
|
try {
|
||||||
|
this.errorCalculation.reset();
|
||||||
|
//for (int i = this.low; i <= this.high; i++) {
|
||||||
|
for (List<MLDataPair> trainingSequence : training) {
|
||||||
|
MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1);
|
||||||
|
this.pair.setInputArray(mldp.getInputArray());
|
||||||
|
if (this.pair.getIdealArray() != null) {
|
||||||
|
this.pair.setIdealArray(mldp.getIdealArray());
|
||||||
|
}
|
||||||
|
//this.training.getRecord(i, this.pair);
|
||||||
|
process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance());
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
final double error = this.errorCalculation.calculate();
|
||||||
|
this.owner.report(this.gradients, error, null);
|
||||||
|
EngineArray.fill(this.gradients, 0);
|
||||||
|
} catch (final Throwable ex) {
|
||||||
|
this.owner.report(null, 0, ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
31
src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
Normal file
31
src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.encog.ml.data.MLData;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
|
|
||||||
|
public interface NeuralNetFilter {
|
||||||
|
BasicNetwork getNeuralNetwork();
|
||||||
|
|
||||||
|
public int getActualTrainingEpochs();
|
||||||
|
public int getInputSize();
|
||||||
|
public int getMaxTrainingEpochs();
|
||||||
|
public int getOutputSize();
|
||||||
|
|
||||||
|
public double computeValue(MLData input);
|
||||||
|
public double[] computeVector(MLData input);
|
||||||
|
|
||||||
|
public void learn(MLDataSet trainingSet);
|
||||||
|
public void learn(Set<List<MLDataPair>> trainingSet);
|
||||||
|
|
||||||
|
public void load(String fileName) throws IOException;
|
||||||
|
public void reset();
|
||||||
|
public void reset(int seed);
|
||||||
|
public void save(String fileName) throws IOException;
|
||||||
|
public void setMaxTrainingEpochs(int max);
|
||||||
|
}
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import org.neuroph.core.NeuralNetwork;
|
|
||||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
|
||||||
import org.neuroph.core.learning.TrainingSet;
|
|
||||||
|
|
||||||
public interface NeuralNetLearner {
|
|
||||||
void learn(TrainingSet<SupervisedTrainingElement> trainingSet);
|
|
||||||
|
|
||||||
void reset();
|
|
||||||
|
|
||||||
NeuralNetwork getNeuralNetwork();
|
|
||||||
|
|
||||||
void setNeuralNetwork(NeuralNetwork neuralNetwork);
|
|
||||||
}
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.GameRecord;
|
|
||||||
import net.woodyfolsom.msproj.GameResult;
|
|
||||||
import net.woodyfolsom.msproj.GameState;
|
|
||||||
import net.woodyfolsom.msproj.Player;
|
|
||||||
|
|
||||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
|
||||||
import org.neuroph.core.learning.TrainingSet;
|
|
||||||
|
|
||||||
public class PassData {
|
|
||||||
public enum DATA_TYPE { TRAINING, TEST, VALIDATION };
|
|
||||||
|
|
||||||
public String[] inputs = { "BlackScore", "WhiteScore" };
|
|
||||||
public String[] outputs = { "BlackWins", "WhiteWins" };
|
|
||||||
|
|
||||||
private TrainingSet<SupervisedTrainingElement> testSet;
|
|
||||||
private TrainingSet<SupervisedTrainingElement> trainingSet;
|
|
||||||
private TrainingSet<SupervisedTrainingElement> valSet;
|
|
||||||
|
|
||||||
public PassData() {
|
|
||||||
testSet = new TrainingSet<SupervisedTrainingElement>(inputs.length, outputs.length);
|
|
||||||
trainingSet = new TrainingSet<SupervisedTrainingElement>(inputs.length, outputs.length);
|
|
||||||
valSet = new TrainingSet<SupervisedTrainingElement>(inputs.length, outputs.length);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addData(DATA_TYPE dataType, GameRecord gameRecord) {
|
|
||||||
GameState finalState = gameRecord.getGameState(gameRecord.getNumTurns());
|
|
||||||
GameResult result = finalState.getResult();
|
|
||||||
double maxScore = finalState.getGameConfig().getSize() * finalState.getGameConfig().getSize();
|
|
||||||
|
|
||||||
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
|
|
||||||
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
|
|
||||||
|
|
||||||
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
|
|
||||||
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
|
|
||||||
|
|
||||||
addData(dataType, blackScore, whiteScore, blackWinner, whiteWinner);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addData(DATA_TYPE dataType, double...data ) {
|
|
||||||
double[] desiredInput = Arrays.copyOfRange(data,0,inputs.length);
|
|
||||||
double[] desiredOutput = Arrays.copyOfRange(data, inputs.length, data.length);
|
|
||||||
|
|
||||||
switch (dataType) {
|
|
||||||
case TEST :
|
|
||||||
testSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput));
|
|
||||||
break;
|
|
||||||
case TRAINING :
|
|
||||||
trainingSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput));
|
|
||||||
System.out.println("Added training input data: " + getInput(desiredInput) + ", output data: " + getOutput(desiredOutput));
|
|
||||||
break;
|
|
||||||
case VALIDATION :
|
|
||||||
valSet.addElement(new SupervisedTrainingElement(desiredInput, desiredOutput));
|
|
||||||
break;
|
|
||||||
default :
|
|
||||||
throw new UnsupportedOperationException("invalid dataType " + dataType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getInput(double... inputValues) {
|
|
||||||
StringBuilder sbuilder = new StringBuilder();
|
|
||||||
boolean first = true;
|
|
||||||
for (int i = 0; i < outputs.length; i++) {
|
|
||||||
if (first) {
|
|
||||||
first = false;
|
|
||||||
} else {
|
|
||||||
sbuilder.append(",");
|
|
||||||
}
|
|
||||||
sbuilder.append(inputs[i]);
|
|
||||||
sbuilder.append(": ");
|
|
||||||
sbuilder.append(inputValues[i]);
|
|
||||||
}
|
|
||||||
return sbuilder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getOutput(double... outputValues) {
|
|
||||||
StringBuilder sbuilder = new StringBuilder();
|
|
||||||
boolean first = true;
|
|
||||||
for (int i = 0; i < outputs.length; i++) {
|
|
||||||
if (first) {
|
|
||||||
first = false;
|
|
||||||
} else {
|
|
||||||
sbuilder.append(",");
|
|
||||||
}
|
|
||||||
sbuilder.append(outputs[i]);
|
|
||||||
sbuilder.append(": ");
|
|
||||||
sbuilder.append(outputValues[i]);
|
|
||||||
}
|
|
||||||
return sbuilder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
public TrainingSet<SupervisedTrainingElement> getTrainingSet() {
|
|
||||||
return trainingSet;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FilenameFilter;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.Action;
|
|
||||||
import net.woodyfolsom.msproj.GameRecord;
|
|
||||||
import net.woodyfolsom.msproj.Referee;
|
|
||||||
import net.woodyfolsom.msproj.ann.PassData.DATA_TYPE;
|
|
||||||
|
|
||||||
import org.antlr.runtime.RecognitionException;
|
|
||||||
import org.neuroph.core.NeuralNetwork;
|
|
||||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
|
||||||
import org.neuroph.core.learning.TrainingSet;
|
|
||||||
import org.neuroph.nnet.MultiLayerPerceptron;
|
|
||||||
import org.neuroph.util.TransferFunctionType;
|
|
||||||
|
|
||||||
public class PassLearner implements NeuralNetLearner {
|
|
||||||
private NeuralNetwork neuralNetwork;
|
|
||||||
|
|
||||||
public PassLearner() {
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
private File[] getDataFiles(String dirName) {
|
|
||||||
File file = new File(dirName);
|
|
||||||
return file.listFiles(new FilenameFilter() {
|
|
||||||
@Override
|
|
||||||
public boolean accept(File dir, String name) {
|
|
||||||
return name.toLowerCase().endsWith(".sgf");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) {
|
|
||||||
new PassLearner().learnANN();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void learnANN() {
|
|
||||||
List<GameRecord> parsedRecords = new ArrayList<GameRecord>();
|
|
||||||
|
|
||||||
for (File sgfFile : getDataFiles("data/games/random_vs_random")) {
|
|
||||||
System.out.println("Parsing " + sgfFile.getPath() + "...");
|
|
||||||
try {
|
|
||||||
GameRecord gameRecord = parseSGF(sgfFile);
|
|
||||||
while (!gameRecord.isFinished()) {
|
|
||||||
System.out.println("Game is not finished, passing as player to move");
|
|
||||||
gameRecord.play(gameRecord.getPlayerToMove(), Action.PASS);
|
|
||||||
}
|
|
||||||
parsedRecords.add(gameRecord);
|
|
||||||
} catch (RecognitionException re) {
|
|
||||||
re.printStackTrace();
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
ioe.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PassData passData = new PassData();
|
|
||||||
|
|
||||||
for (GameRecord gameRecord : parsedRecords) {
|
|
||||||
System.out.println(gameRecord.getResult().getFullText());
|
|
||||||
passData.addData(DATA_TYPE.TRAINING, gameRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
System.out.println("PassData: ");
|
|
||||||
System.out.println(passData);
|
|
||||||
|
|
||||||
learn(passData.getTrainingSet());
|
|
||||||
|
|
||||||
getNeuralNetwork().setInput(0.75,0.25);
|
|
||||||
System.out.println("Output of ann(0.75,0.25): " + passData.getOutput(getNeuralNetwork().getOutput()));
|
|
||||||
|
|
||||||
getNeuralNetwork().setInput(0.25,0.50);
|
|
||||||
System.out.println("Output of ann(0.50,0.99): " + passData.getOutput(getNeuralNetwork().getOutput()));
|
|
||||||
|
|
||||||
getNeuralNetwork().save("data/networks/Pass2.nn");
|
|
||||||
|
|
||||||
testNetwork(getNeuralNetwork(), passData.getTrainingSet());
|
|
||||||
}
|
|
||||||
|
|
||||||
public GameRecord parseSGF(File sgfFile) throws IOException,
|
|
||||||
RecognitionException {
|
|
||||||
FileInputStream sgfInputStream;
|
|
||||||
|
|
||||||
sgfInputStream = new FileInputStream(sgfFile);
|
|
||||||
return Referee.replay(sgfInputStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNetwork getNeuralNetwork() {
|
|
||||||
return neuralNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(TrainingSet<SupervisedTrainingElement> trainingSet) {
|
|
||||||
this.neuralNetwork.learn(trainingSet);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
this.neuralNetwork = new MultiLayerPerceptron(
|
|
||||||
TransferFunctionType.TANH, 2, 3, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
|
|
||||||
this.neuralNetwork = neuralNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
|
|
||||||
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
|
|
||||||
|
|
||||||
nnet.setInput(trainingElement.getInput());
|
|
||||||
nnet.calculate();
|
|
||||||
double[] networkOutput = nnet.getOutput();
|
|
||||||
System.out.print("Input: "
|
|
||||||
+ Arrays.toString(trainingElement.getInput()));
|
|
||||||
System.out.println(" Output: " + Arrays.toString(networkOutput));
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
484
src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java
Normal file
484
src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.encog.EncogError;
|
||||||
|
import org.encog.engine.network.activation.ActivationFunction;
|
||||||
|
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||||
|
import org.encog.mathutil.IntRange;
|
||||||
|
import org.encog.ml.MLMethod;
|
||||||
|
import org.encog.ml.TrainingImplementationType;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.ml.train.MLTrain;
|
||||||
|
import org.encog.ml.train.strategy.Strategy;
|
||||||
|
import org.encog.ml.train.strategy.end.EndTrainingStrategy;
|
||||||
|
import org.encog.neural.error.ErrorFunction;
|
||||||
|
import org.encog.neural.error.LinearErrorFunction;
|
||||||
|
import org.encog.neural.flat.FlatNetwork;
|
||||||
|
import org.encog.neural.networks.ContainsFlat;
|
||||||
|
import org.encog.neural.networks.training.LearningRate;
|
||||||
|
import org.encog.neural.networks.training.Momentum;
|
||||||
|
import org.encog.neural.networks.training.Train;
|
||||||
|
import org.encog.neural.networks.training.TrainingError;
|
||||||
|
import org.encog.neural.networks.training.propagation.TrainingContinuation;
|
||||||
|
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||||
|
import org.encog.neural.networks.training.strategy.SmartLearningRate;
|
||||||
|
import org.encog.neural.networks.training.strategy.SmartMomentum;
|
||||||
|
import org.encog.util.EncogValidate;
|
||||||
|
import org.encog.util.EngineArray;
|
||||||
|
import org.encog.util.concurrency.DetermineWorkload;
|
||||||
|
import org.encog.util.concurrency.EngineConcurrency;
|
||||||
|
import org.encog.util.concurrency.MultiThreadable;
|
||||||
|
import org.encog.util.concurrency.TaskGroup;
|
||||||
|
import org.encog.util.logging.EncogLogging;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class started as a verbatim copy of BackPropagation from the open-source
|
||||||
|
* Encog framework. It was merged with its super-classes to access protected
|
||||||
|
* fields without resorting to reflection.
|
||||||
|
*/
|
||||||
|
public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
||||||
|
LearningRate, Train, MultiThreadable {
|
||||||
|
// New fields for TD(lambda)
|
||||||
|
private final double lambda;
|
||||||
|
// end new fields
|
||||||
|
|
||||||
|
// BackProp
|
||||||
|
public static final String LAST_DELTA = "LAST_DELTA";
|
||||||
|
private double learningRate;
|
||||||
|
private double momentum;
|
||||||
|
private double[] lastDelta;
|
||||||
|
// End BackProp
|
||||||
|
|
||||||
|
// Propagation
|
||||||
|
private FlatNetwork currentFlatNetwork;
|
||||||
|
private int numThreads;
|
||||||
|
protected double[] gradients;
|
||||||
|
private double[] lastGradient;
|
||||||
|
protected ContainsFlat network;
|
||||||
|
// private MLDataSet indexable;
|
||||||
|
private Set<List<MLDataPair>> indexable;
|
||||||
|
private GradientWorker[] workers;
|
||||||
|
private double totalError;
|
||||||
|
protected double lastError;
|
||||||
|
private Throwable reportedException;
|
||||||
|
private double[] flatSpot;
|
||||||
|
private boolean shouldFixFlatSpot;
|
||||||
|
private ErrorFunction ef = new LinearErrorFunction();
|
||||||
|
// End Propagation
|
||||||
|
|
||||||
|
// BasicTraining
|
||||||
|
private final List<Strategy> strategies = new ArrayList<Strategy>();
|
||||||
|
private Set<List<MLDataPair>> training;
|
||||||
|
private double error;
|
||||||
|
private int iteration;
|
||||||
|
private TrainingImplementationType implementationType;
|
||||||
|
|
||||||
|
// End BasicTraining
|
||||||
|
|
||||||
|
public TemporalDifferenceLearning(final ContainsFlat network,
|
||||||
|
final Set<List<MLDataPair>> training, double lambda) {
|
||||||
|
this(network, training, 0, 0, lambda);
|
||||||
|
addStrategy(new SmartLearningRate());
|
||||||
|
addStrategy(new SmartMomentum());
|
||||||
|
}
|
||||||
|
|
||||||
|
public TemporalDifferenceLearning(final ContainsFlat network,
|
||||||
|
Set<List<MLDataPair>> training, final double theLearnRate,
|
||||||
|
final double theMomentum, double lambda) {
|
||||||
|
initPropagation(network, training);
|
||||||
|
// TODO consider how to re-implement validation
|
||||||
|
// ValidateNetwork.validateMethodToData(network, training);
|
||||||
|
this.momentum = theMomentum;
|
||||||
|
this.learningRate = theLearnRate;
|
||||||
|
this.lastDelta = new double[network.getFlat().getWeights().length];
|
||||||
|
this.lambda = lambda;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initPropagation(final ContainsFlat network,
|
||||||
|
final Set<List<MLDataPair>> training) {
|
||||||
|
initBasicTraining(TrainingImplementationType.Iterative);
|
||||||
|
this.network = network;
|
||||||
|
this.currentFlatNetwork = network.getFlat();
|
||||||
|
setTraining(training);
|
||||||
|
|
||||||
|
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
|
||||||
|
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
|
||||||
|
|
||||||
|
this.indexable = training;
|
||||||
|
this.numThreads = 0;
|
||||||
|
this.reportedException = null;
|
||||||
|
this.shouldFixFlatSpot = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initBasicTraining(TrainingImplementationType implementationType) {
|
||||||
|
this.implementationType = implementationType;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Methods from BackPropagation
|
||||||
|
@Override
|
||||||
|
public boolean canContinue() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getLastDelta() {
|
||||||
|
return this.lastDelta;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLearningRate() {
|
||||||
|
return this.learningRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getMomentum() {
|
||||||
|
return this.momentum;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isValidResume(final TrainingContinuation state) {
|
||||||
|
if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!state.getTrainingType().equals(getClass().getSimpleName())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA);
|
||||||
|
return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TrainingContinuation pause() {
|
||||||
|
final TrainingContinuation result = new TrainingContinuation();
|
||||||
|
result.setTrainingType(this.getClass().getSimpleName());
|
||||||
|
result.set(Backpropagation.LAST_DELTA, this.lastDelta);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void resume(final TrainingContinuation state) {
|
||||||
|
if (!isValidResume(state)) {
|
||||||
|
throw new TrainingError("Invalid training resume data length");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setLearningRate(final double rate) {
|
||||||
|
this.learningRate = rate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setMomentum(final double m) {
|
||||||
|
this.momentum = m;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double updateWeight(final double[] gradients,
|
||||||
|
final double[] lastGradient, final int index) {
|
||||||
|
final double delta = (gradients[index] * this.learningRate)
|
||||||
|
+ (this.lastDelta[index] * this.momentum);
|
||||||
|
this.lastDelta[index] = delta;
|
||||||
|
|
||||||
|
System.out.println("Updating weights for connection: " + index
|
||||||
|
+ " with lambda: " + lambda);
|
||||||
|
|
||||||
|
return delta;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void initOthers() {
|
||||||
|
}
|
||||||
|
|
||||||
|
// End methods from BackPropagation
|
||||||
|
|
||||||
|
// Methods from Propagation
|
||||||
|
public void finishTraining() {
|
||||||
|
basicFinishTraining();
|
||||||
|
}
|
||||||
|
|
||||||
|
public FlatNetwork getCurrentFlatNetwork() {
|
||||||
|
return this.currentFlatNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MLMethod getMethod() {
|
||||||
|
return this.network;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void iteration() {
|
||||||
|
iteration(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void rollIteration() {
|
||||||
|
this.iteration++;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void iteration(final int count) {
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
|
||||||
|
preIteration();
|
||||||
|
|
||||||
|
rollIteration();
|
||||||
|
|
||||||
|
calculateGradients();
|
||||||
|
|
||||||
|
if (this.currentFlatNetwork.isLimited()) {
|
||||||
|
learnLimited();
|
||||||
|
} else {
|
||||||
|
learn();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.lastError = this.getError();
|
||||||
|
|
||||||
|
for (final GradientWorker worker : this.workers) {
|
||||||
|
EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(),
|
||||||
|
0, worker.getWeights(), 0,
|
||||||
|
this.currentFlatNetwork.getWeights().length);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.currentFlatNetwork.getHasContext()) {
|
||||||
|
copyContexts();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.reportedException != null) {
|
||||||
|
throw (new EncogError(this.reportedException));
|
||||||
|
}
|
||||||
|
|
||||||
|
postIteration();
|
||||||
|
|
||||||
|
EncogLogging.log(EncogLogging.LEVEL_INFO,
|
||||||
|
"Training iteration done, error: " + getError());
|
||||||
|
|
||||||
|
}
|
||||||
|
} catch (final ArrayIndexOutOfBoundsException ex) {
|
||||||
|
EncogValidate.validateNetworkForTraining(this.network,
|
||||||
|
getTraining());
|
||||||
|
throw new EncogError(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setThreadCount(final int numThreads) {
|
||||||
|
this.numThreads = numThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getThreadCount() {
|
||||||
|
return this.numThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void fixFlatSpot(boolean b) {
|
||||||
|
this.shouldFixFlatSpot = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setErrorFunction(ErrorFunction ef) {
|
||||||
|
this.ef = ef;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void calculateGradients() {
|
||||||
|
if (this.workers == null) {
|
||||||
|
init();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.currentFlatNetwork.getHasContext()) {
|
||||||
|
this.workers[0].getNetwork().clearContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.totalError = 0;
|
||||||
|
|
||||||
|
if (this.workers.length > 1) {
|
||||||
|
|
||||||
|
final TaskGroup group = EngineConcurrency.getInstance()
|
||||||
|
.createTaskGroup();
|
||||||
|
|
||||||
|
for (final GradientWorker worker : this.workers) {
|
||||||
|
EngineConcurrency.getInstance().processTask(worker, group);
|
||||||
|
}
|
||||||
|
|
||||||
|
group.waitForComplete();
|
||||||
|
} else {
|
||||||
|
this.workers[0].run();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setError(this.totalError / this.workers.length);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy the contexts to keep them consistent with multithreaded training.
|
||||||
|
*/
|
||||||
|
private void copyContexts() {
|
||||||
|
|
||||||
|
// copy the contexts(layer outputO from each group to the next group
|
||||||
|
for (int i = 0; i < (this.workers.length - 1); i++) {
|
||||||
|
final double[] src = this.workers[i].getNetwork().getLayerOutput();
|
||||||
|
final double[] dst = this.workers[i + 1].getNetwork()
|
||||||
|
.getLayerOutput();
|
||||||
|
EngineArray.arrayCopy(src, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the contexts from the final group to the real network
|
||||||
|
EngineArray.arrayCopy(this.workers[this.workers.length - 1]
|
||||||
|
.getNetwork().getLayerOutput(), this.currentFlatNetwork
|
||||||
|
.getLayerOutput());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void init() {
|
||||||
|
// fix flat spot, if needed
|
||||||
|
this.flatSpot = new double[this.currentFlatNetwork
|
||||||
|
.getActivationFunctions().length];
|
||||||
|
|
||||||
|
if (this.shouldFixFlatSpot) {
|
||||||
|
for (int i = 0; i < this.currentFlatNetwork
|
||||||
|
.getActivationFunctions().length; i++) {
|
||||||
|
final ActivationFunction af = this.currentFlatNetwork
|
||||||
|
.getActivationFunctions()[i];
|
||||||
|
|
||||||
|
if (af instanceof ActivationSigmoid) {
|
||||||
|
this.flatSpot[i] = 0.1;
|
||||||
|
} else {
|
||||||
|
this.flatSpot[i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
EngineArray.fill(this.flatSpot, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup workers
|
||||||
|
final DetermineWorkload determine = new DetermineWorkload(
|
||||||
|
this.numThreads, (int) this.indexable.size());
|
||||||
|
// this.numThreads, (int) this.indexable.getRecordCount());
|
||||||
|
|
||||||
|
this.workers = new GradientWorker[determine.getThreadCount()];
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
|
||||||
|
// handle CPU
|
||||||
|
for (final IntRange r : determine.calculateWorkers()) {
|
||||||
|
this.workers[index++] = new GradientWorker(
|
||||||
|
this.currentFlatNetwork.clone(), this, new HashSet(
|
||||||
|
this.indexable), r.getLow(), r.getHigh(),
|
||||||
|
this.flatSpot, this.ef);
|
||||||
|
}
|
||||||
|
|
||||||
|
initOthers();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void report(final double[] gradients, final double error,
|
||||||
|
final Throwable ex) {
|
||||||
|
synchronized (this) {
|
||||||
|
if (ex == null) {
|
||||||
|
|
||||||
|
for (int i = 0; i < gradients.length; i++) {
|
||||||
|
this.gradients[i] += gradients[i];
|
||||||
|
}
|
||||||
|
this.totalError += error;
|
||||||
|
} else {
|
||||||
|
this.reportedException = ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void learn() {
|
||||||
|
final double[] weights = this.currentFlatNetwork.getWeights();
|
||||||
|
for (int i = 0; i < this.gradients.length; i++) {
|
||||||
|
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
|
||||||
|
this.gradients[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void learnLimited() {
|
||||||
|
final double limit = this.currentFlatNetwork.getConnectionLimit();
|
||||||
|
final double[] weights = this.currentFlatNetwork.getWeights();
|
||||||
|
for (int i = 0; i < this.gradients.length; i++) {
|
||||||
|
if (Math.abs(weights[i]) < limit) {
|
||||||
|
weights[i] = 0;
|
||||||
|
} else {
|
||||||
|
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
|
||||||
|
}
|
||||||
|
this.gradients[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getLastGradient() {
|
||||||
|
return lastGradient;
|
||||||
|
}
|
||||||
|
|
||||||
|
// End methods from Propagation
|
||||||
|
|
||||||
|
// Methods from BasicTraining/
|
||||||
|
public void addStrategy(final Strategy strategy) {
|
||||||
|
strategy.init(this);
|
||||||
|
this.strategies.add(strategy);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void basicFinishTraining() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getError() {
|
||||||
|
return this.error;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getIteration() {
|
||||||
|
return this.iteration;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Strategy> getStrategies() {
|
||||||
|
return this.strategies;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MLDataSet getTraining() {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"This learning method operates on Set<List<MLData>>, not MLDataSet");
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isTrainingDone() {
|
||||||
|
for (Strategy strategy : this.strategies) {
|
||||||
|
if (strategy instanceof EndTrainingStrategy) {
|
||||||
|
EndTrainingStrategy end = (EndTrainingStrategy) strategy;
|
||||||
|
if (end.shouldStop()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void postIteration() {
|
||||||
|
for (final Strategy strategy : this.strategies) {
|
||||||
|
strategy.postIteration();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void preIteration() {
|
||||||
|
|
||||||
|
this.iteration++;
|
||||||
|
|
||||||
|
for (final Strategy strategy : this.strategies) {
|
||||||
|
strategy.preIteration();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setError(final double error) {
|
||||||
|
this.error = error;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setIteration(final int iteration) {
|
||||||
|
this.iteration = iteration;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTraining(final Set<List<MLDataPair>> training) {
|
||||||
|
this.training = training;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainingImplementationType getImplementationType() {
|
||||||
|
return this.implementationType;
|
||||||
|
}
|
||||||
|
// End Methods from BasicTraining
|
||||||
|
}
|
||||||
112
src/net/woodyfolsom/msproj/ann/WinFilter.java
Normal file
112
src/net/woodyfolsom/msproj/ann/WinFilter.java
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.GameState;
|
||||||
|
import net.woodyfolsom.msproj.Player;
|
||||||
|
|
||||||
|
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||||
|
import org.encog.ml.data.MLData;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.ml.train.MLTrain;
|
||||||
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
|
import org.encog.neural.networks.layers.BasicLayer;
|
||||||
|
|
||||||
|
public class WinFilter extends AbstractNeuralNetFilter implements
|
||||||
|
NeuralNetFilter {
|
||||||
|
|
||||||
|
public WinFilter() {
|
||||||
|
// create a neural network, without using a factory
|
||||||
|
BasicNetwork network = new BasicNetwork();
|
||||||
|
network.addLayer(new BasicLayer(null, false, 2));
|
||||||
|
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
|
||||||
|
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
|
||||||
|
network.getStructure().finalizeStructure();
|
||||||
|
network.reset();
|
||||||
|
|
||||||
|
this.neuralNetwork = network;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeValue(MLData input) {
|
||||||
|
if (input instanceof GameStateMLData) {
|
||||||
|
double[] idealVector = computeVector(input);
|
||||||
|
GameState gameState = ((GameStateMLData) input).getGameState();
|
||||||
|
Player playerToMove = gameState.getPlayerToMove();
|
||||||
|
if (playerToMove == Player.BLACK) {
|
||||||
|
return idealVector[0];
|
||||||
|
} else if (playerToMove == Player.WHITE) {
|
||||||
|
return idealVector[1];
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Invalid GameState.playerToMove: "
|
||||||
|
+ playerToMove);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"This NeuralNetFilter only accepts GameStates as input.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] computeVector(MLData input) {
|
||||||
|
if (input instanceof GameStateMLData) {
|
||||||
|
return neuralNetwork.compute(input).getData();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"This NeuralNetFilter only accepts GameStates as input.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(MLDataSet trainingData) {
|
||||||
|
throw new UnsupportedOperationException("This filter learns a Set<List<MLData>>, not an MLDataSet");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||||
|
|
||||||
|
// train the neural network
|
||||||
|
final MLTrain train = new TemporalDifferenceLearning(neuralNetwork,
|
||||||
|
trainingSet, 0.7, 0.8, 0.25);
|
||||||
|
|
||||||
|
actualTrainingEpochs = 0;
|
||||||
|
|
||||||
|
do {
|
||||||
|
train.iteration();
|
||||||
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
|
+ train.getError());
|
||||||
|
actualTrainingEpochs++;
|
||||||
|
} while (train.getError() > 0.01
|
||||||
|
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
neuralNetwork.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset(int seed) {
|
||||||
|
neuralNetwork.reset(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BasicNetwork getNeuralNetwork() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputSize() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOutputSize() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
83
src/net/woodyfolsom/msproj/ann/XORFilter.java
Normal file
83
src/net/woodyfolsom/msproj/ann/XORFilter.java
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||||
|
import org.encog.ml.data.MLData;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||||
|
import org.encog.ml.train.MLTrain;
|
||||||
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
|
import org.encog.neural.networks.layers.BasicLayer;
|
||||||
|
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Based on sample code from http://neuroph.sourceforge.net
|
||||||
|
*
|
||||||
|
* @author Woody
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||||
|
NeuralNetFilter {
|
||||||
|
|
||||||
|
public XORFilter() {
|
||||||
|
// create a neural network, without using a factory
|
||||||
|
BasicNetwork network = new BasicNetwork();
|
||||||
|
network.addLayer(new BasicLayer(null, false, 2));
|
||||||
|
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
|
||||||
|
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
|
||||||
|
network.getStructure().finalizeStructure();
|
||||||
|
network.reset();
|
||||||
|
|
||||||
|
this.neuralNetwork = network;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(MLDataSet trainingSet) {
|
||||||
|
|
||||||
|
// train the neural network
|
||||||
|
final MLTrain train = new Backpropagation(neuralNetwork,
|
||||||
|
trainingSet, 0.7, 0.8);
|
||||||
|
|
||||||
|
actualTrainingEpochs = 0;
|
||||||
|
|
||||||
|
do {
|
||||||
|
train.iteration();
|
||||||
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
|
+ train.getError());
|
||||||
|
actualTrainingEpochs++;
|
||||||
|
} while (train.getError() > 0.01
|
||||||
|
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] computeVector(MLData mlData) {
|
||||||
|
MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() },
|
||||||
|
new double[][] { new double[getOutputSize()] });
|
||||||
|
MLData output = neuralNetwork.compute(dataset.get(0).getInput());
|
||||||
|
return output.getData();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputSize() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOutputSize() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeValue(MLData input) {
|
||||||
|
return computeVector(input)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||||
|
throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set<List<MLData>>.");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import org.neuroph.core.NeuralNetwork;
|
|
||||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
|
||||||
import org.neuroph.core.learning.TrainingSet;
|
|
||||||
import org.neuroph.nnet.MultiLayerPerceptron;
|
|
||||||
import org.neuroph.util.TransferFunctionType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Based on sample code from http://neuroph.sourceforge.net
|
|
||||||
*
|
|
||||||
* @author Woody
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class XORLearner implements NeuralNetLearner {
|
|
||||||
private NeuralNetwork neuralNetwork;
|
|
||||||
|
|
||||||
public XORLearner() {
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNetwork getNeuralNetwork() {
|
|
||||||
return neuralNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(TrainingSet<SupervisedTrainingElement> trainingSet) {
|
|
||||||
this.neuralNetwork.learn(trainingSet);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
this.neuralNetwork = new MultiLayerPerceptron(
|
|
||||||
TransferFunctionType.TANH, 2, 3, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
|
|
||||||
this.neuralNetwork = neuralNetwork;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -5,6 +5,7 @@ import java.awt.event.ActionEvent;
|
|||||||
import java.awt.event.ActionListener;
|
import java.awt.event.ActionListener;
|
||||||
import java.awt.event.WindowAdapter;
|
import java.awt.event.WindowAdapter;
|
||||||
import java.awt.event.WindowEvent;
|
import java.awt.event.WindowEvent;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
import javax.swing.JButton;
|
import javax.swing.JButton;
|
||||||
import javax.swing.JFrame;
|
import javax.swing.JFrame;
|
||||||
@@ -14,22 +15,31 @@ import net.woodyfolsom.msproj.Action;
|
|||||||
import net.woodyfolsom.msproj.GameConfig;
|
import net.woodyfolsom.msproj.GameConfig;
|
||||||
import net.woodyfolsom.msproj.GameState;
|
import net.woodyfolsom.msproj.GameState;
|
||||||
import net.woodyfolsom.msproj.Player;
|
import net.woodyfolsom.msproj.Player;
|
||||||
|
import net.woodyfolsom.msproj.StandAloneGame;
|
||||||
import net.woodyfolsom.msproj.sfx.SfxPlayer;
|
import net.woodyfolsom.msproj.sfx.SfxPlayer;
|
||||||
|
|
||||||
public class Goban extends JFrame {
|
public class Goban extends JFrame {
|
||||||
|
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private static AtomicInteger openWindows = new AtomicInteger(0);
|
||||||
|
|
||||||
private GridPanel gridPanel;
|
private GridPanel gridPanel;
|
||||||
private SfxPlayer sfxPlayer;
|
private SfxPlayer sfxPlayer;
|
||||||
|
|
||||||
public Goban(GameConfig gameConfig, Player guiPlayer) {
|
public Goban(GameConfig gameConfig, Player guiPlayer, String gameName) {
|
||||||
|
super(gameName);
|
||||||
|
|
||||||
setLayout(new BorderLayout());
|
setLayout(new BorderLayout());
|
||||||
|
|
||||||
addWindowListener(new WindowAdapter() {
|
addWindowListener(new WindowAdapter() {
|
||||||
@Override
|
@Override
|
||||||
public void windowClosing(WindowEvent e) {
|
public void windowClosing(WindowEvent e) {
|
||||||
sfxPlayer.cleanup();
|
sfxPlayer.cleanup();
|
||||||
|
int windowsLeftOpen = openWindows.addAndGet(-1);
|
||||||
|
if (windowsLeftOpen < 1) {
|
||||||
|
System.exit(StandAloneGame.EXIT_USER_QUIT);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -42,9 +52,6 @@ public class Goban extends JFrame {
|
|||||||
this.gridPanel = new GridPanel(gameConfig, guiPlayer, sfxPlayer);
|
this.gridPanel = new GridPanel(gameConfig, guiPlayer, sfxPlayer);
|
||||||
add(gridPanel,BorderLayout.CENTER);
|
add(gridPanel,BorderLayout.CENTER);
|
||||||
|
|
||||||
setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
|
||||||
setVisible(true);
|
|
||||||
|
|
||||||
JButton passBtn = new JButton("Pass");
|
JButton passBtn = new JButton("Pass");
|
||||||
JButton resignBtn = new JButton("Resign");
|
JButton resignBtn = new JButton("Resign");
|
||||||
|
|
||||||
@@ -67,6 +74,9 @@ public class Goban extends JFrame {
|
|||||||
bottomPanel.add(resignBtn);
|
bottomPanel.add(resignBtn);
|
||||||
|
|
||||||
add(bottomPanel, BorderLayout.SOUTH);
|
add(bottomPanel, BorderLayout.SOUTH);
|
||||||
|
|
||||||
|
setVisible(true);
|
||||||
|
openWindows.addAndGet(1);
|
||||||
pack();
|
pack();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -229,7 +229,9 @@ public class GridPanel extends JPanel implements MouseListener,
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
sfxPlayer.play();
|
if (sfxPlayer != null) {
|
||||||
|
sfxPlayer.play();
|
||||||
|
}
|
||||||
}}).start();
|
}}).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import net.woodyfolsom.msproj.tree.GameTreeNode;
|
|||||||
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||||
|
|
||||||
public abstract class MonteCarlo implements Policy {
|
public abstract class MonteCarlo implements Policy {
|
||||||
protected static final int ROLLOUT_DEPTH_LIMIT = 150;
|
protected static final int ROLLOUT_DEPTH_LIMIT = 250;
|
||||||
|
|
||||||
protected int numStateEvaluations = 0;
|
protected int numStateEvaluations = 0;
|
||||||
protected Policy movePolicy;
|
protected Policy movePolicy;
|
||||||
|
|||||||
@@ -101,8 +101,8 @@ public class RootParallelization implements Policy {
|
|||||||
System.out.println("It won "
|
System.out.println("It won "
|
||||||
+ bestWins + " out of " + bestSims
|
+ bestWins + " out of " + bestSims
|
||||||
+ " rollouts among " + totalRollouts
|
+ " rollouts among " + totalRollouts
|
||||||
+ " total rollouts (" + totalReward.keySet()
|
+ " total rollouts (" + totalReward.size()
|
||||||
+ " possible actions) from the current state.");
|
+ " possible moves evaluated) from the current state.");
|
||||||
|
|
||||||
return bestAction;
|
return bestAction;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.neuroph.core.NeuralNetwork;
|
|
||||||
import org.neuroph.core.learning.SupervisedTrainingElement;
|
|
||||||
import org.neuroph.core.learning.TrainingSet;
|
|
||||||
|
|
||||||
public class NeuralNetLearnerTest {
|
|
||||||
private static final String FILENAME = "myMlPerceptron.nnet";
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void deleteNewNet() {
|
|
||||||
File file = new File(FILENAME);
|
|
||||||
if (file.exists()) {
|
|
||||||
file.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void deleteSavedNet() {
|
|
||||||
File file = new File(FILENAME);
|
|
||||||
if (file.exists()) {
|
|
||||||
file.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLearnSaveLoad() {
|
|
||||||
NeuralNetLearner nnLearner = new XORLearner();
|
|
||||||
|
|
||||||
// create training set (logical XOR function)
|
|
||||||
TrainingSet<SupervisedTrainingElement> trainingSet = new TrainingSet<SupervisedTrainingElement>(
|
|
||||||
2, 1);
|
|
||||||
for (int x = 0; x < 1000; x++) {
|
|
||||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
|
||||||
0 }, new double[] { 0 }));
|
|
||||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
|
||||||
1 }, new double[] { 1 }));
|
|
||||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
|
||||||
0 }, new double[] { 1 }));
|
|
||||||
trainingSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
|
||||||
1 }, new double[] { 0 }));
|
|
||||||
}
|
|
||||||
|
|
||||||
nnLearner.learn(trainingSet);
|
|
||||||
NeuralNetwork nnet = nnLearner.getNeuralNetwork();
|
|
||||||
|
|
||||||
TrainingSet<SupervisedTrainingElement> valSet = new TrainingSet<SupervisedTrainingElement>(
|
|
||||||
2, 1);
|
|
||||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
|
||||||
0 }, new double[] { 0 }));
|
|
||||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 0,
|
|
||||||
1 }, new double[] { 1 }));
|
|
||||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
|
||||||
0 }, new double[] { 1 }));
|
|
||||||
valSet.addElement(new SupervisedTrainingElement(new double[] { 1,
|
|
||||||
1 }, new double[] { 0 }));
|
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network):");
|
|
||||||
testNetwork(nnet, valSet);
|
|
||||||
|
|
||||||
nnet.save(FILENAME);
|
|
||||||
nnet = NeuralNetwork.load(FILENAME);
|
|
||||||
|
|
||||||
System.out.println("Output from eval set (learned network):");
|
|
||||||
testNetwork(nnet, valSet);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void testNetwork(NeuralNetwork nnet, TrainingSet<SupervisedTrainingElement> trainingSet) {
|
|
||||||
for (SupervisedTrainingElement trainingElement : trainingSet.elements()) {
|
|
||||||
|
|
||||||
nnet.setInput(trainingElement.getInput());
|
|
||||||
nnet.calculate();
|
|
||||||
double[] networkOutput = nnet.getOutput();
|
|
||||||
System.out.print("Input: "
|
|
||||||
+ Arrays.toString(trainingElement.getInput()));
|
|
||||||
System.out.println(" Output: " + Arrays.toString(networkOutput));
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.neuroph.core.NeuralNetwork;
|
|
||||||
|
|
||||||
public class PassNetworkTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSavedNetwork1() {
|
|
||||||
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn");
|
|
||||||
passFilter.setInput(0.75,0.25);
|
|
||||||
passFilter.calculate();
|
|
||||||
|
|
||||||
PassData passData = new PassData();
|
|
||||||
double[] output = passFilter.getOutput();
|
|
||||||
System.out.println("Output: " + passData.getOutput(output));
|
|
||||||
|
|
||||||
assertTrue(output[0] > 0.50);
|
|
||||||
assertTrue(output[1] < 0.50);
|
|
||||||
|
|
||||||
passFilter.setInput(0.25,0.50);
|
|
||||||
passFilter.calculate();
|
|
||||||
output = passFilter.getOutput();
|
|
||||||
System.out.println("Output: " + passData.getOutput(output));
|
|
||||||
assertTrue(output[0] < 0.50);
|
|
||||||
assertTrue(output[1] > 0.50);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSavedNetwork2() {
|
|
||||||
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass2.nn");
|
|
||||||
passFilter.setInput(0.75,0.25);
|
|
||||||
passFilter.calculate();
|
|
||||||
|
|
||||||
PassData passData = new PassData();
|
|
||||||
double[] output = passFilter.getOutput();
|
|
||||||
System.out.println("Output: " + passData.getOutput(output));
|
|
||||||
|
|
||||||
assertTrue(output[0] > 0.50);
|
|
||||||
assertTrue(output[1] < 0.50);
|
|
||||||
|
|
||||||
passFilter.setInput(0.45,0.55);
|
|
||||||
passFilter.calculate();
|
|
||||||
output = passFilter.getOutput();
|
|
||||||
System.out.println("Output: " + passData.getOutput(output));
|
|
||||||
assertTrue(output[0] < 0.50);
|
|
||||||
assertTrue(output[1] > 0.50);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
66
test/net/woodyfolsom/msproj/ann/WinFilterTest.java
Normal file
66
test/net/woodyfolsom/msproj/ann/WinFilterTest.java
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileFilter;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.GameRecord;
|
||||||
|
import net.woodyfolsom.msproj.Referee;
|
||||||
|
|
||||||
|
import org.antlr.runtime.RecognitionException;
|
||||||
|
import org.encog.ml.data.MLData;
|
||||||
|
import org.encog.ml.data.MLDataPair;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class WinFilterTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearnSaveLoad() throws IOException, RecognitionException {
|
||||||
|
File[] sgfFiles = new File("data/games/random_vs_random")
|
||||||
|
.listFiles(new FileFilter() {
|
||||||
|
@Override
|
||||||
|
public boolean accept(File pathname) {
|
||||||
|
return pathname.getName().endsWith(".sgf");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
|
||||||
|
|
||||||
|
for (File file : sgfFiles) {
|
||||||
|
FileInputStream fis = new FileInputStream(file);
|
||||||
|
GameRecord gameRecord = Referee.replay(fis);
|
||||||
|
|
||||||
|
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
|
||||||
|
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
|
||||||
|
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
trainingData.add(gameData);
|
||||||
|
|
||||||
|
fis.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
WinFilter winFilter = new WinFilter();
|
||||||
|
|
||||||
|
winFilter.learn(trainingData);
|
||||||
|
|
||||||
|
for (List<MLDataPair> trainingSequence : trainingData) {
|
||||||
|
//for (MLDataPair mlDataPair : trainingSequence) {
|
||||||
|
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
||||||
|
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
MLData input = trainingSequence.get(stateIndex).getInput();
|
||||||
|
|
||||||
|
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
||||||
|
+ winFilter.computeValue(input));
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
79
test/net/woodyfolsom/msproj/ann/XORFilterTest.java
Normal file
79
test/net/woodyfolsom/msproj/ann/XORFilterTest.java
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class XORFilterTest {
|
||||||
|
private static final String FILENAME = "xorPerceptron.net";
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void deleteNewNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void deleteSavedNet() {
|
||||||
|
File file = new File(FILENAME);
|
||||||
|
if (file.exists()) {
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLearnSaveLoad() throws IOException {
|
||||||
|
NeuralNetFilter nnLearner = new XORFilter();
|
||||||
|
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||||
|
|
||||||
|
// create training set (logical XOR function)
|
||||||
|
int size = 1;
|
||||||
|
double[][] trainingInput = new double[4 * size][];
|
||||||
|
double[][] trainingOutput = new double[4 * size][];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||||
|
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||||
|
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||||
|
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||||
|
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||||
|
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||||
|
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||||
|
}
|
||||||
|
|
||||||
|
// create training data
|
||||||
|
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
|
||||||
|
|
||||||
|
nnLearner.learn(trainingSet);
|
||||||
|
|
||||||
|
double[][] validationSet = new double[4][2];
|
||||||
|
|
||||||
|
validationSet[0] = new double[] { 0, 0 };
|
||||||
|
validationSet[1] = new double[] { 0, 1 };
|
||||||
|
validationSet[2] = new double[] { 1, 0 };
|
||||||
|
validationSet[3] = new double[] { 1, 1 };
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||||
|
testNetwork(nnLearner, validationSet);
|
||||||
|
|
||||||
|
nnLearner.save(FILENAME);
|
||||||
|
nnLearner.load(FILENAME);
|
||||||
|
|
||||||
|
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||||
|
testNetwork(nnLearner, validationSet);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
||||||
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
|
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
||||||
|
System.out.println(dp + " => " + nnLearner.computeValue(dp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user