Fixed AMAF, SMAF algorithms.
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
PlayerOne=ROOT_PAR
|
PlayerOne=SMAF
|
||||||
PlayerTwo=RANDOM
|
PlayerTwo=RANDOM
|
||||||
GUIDelay=1000 //1 second
|
GUIDelay=1000 //1 second
|
||||||
BoardSize=9
|
BoardSize=9
|
||||||
@@ -7,4 +7,4 @@ NumGames=1 //Games for each color per player
|
|||||||
TurnTime=2000 //seconds per player per turn
|
TurnTime=2000 //seconds per player per turn
|
||||||
SpectatorBoardShown=true
|
SpectatorBoardShown=true
|
||||||
WhiteMoveLogged=false
|
WhiteMoveLogged=false
|
||||||
BlackMoveLogged=false
|
BlackMoveLogged=true
|
||||||
@@ -91,7 +91,6 @@ public class Referee {
|
|||||||
while (!gameRecord.isFinished()) {
|
while (!gameRecord.isFinished()) {
|
||||||
GameState gameState = gameRecord.getGameState(gameRecord
|
GameState gameState = gameRecord.getGameState(gameRecord
|
||||||
.getNumTurns());
|
.getNumTurns());
|
||||||
// System.out.println(gameState);
|
|
||||||
|
|
||||||
Player playerToMove = gameRecord.getPlayerToMove();
|
Player playerToMove = gameRecord.getPlayerToMove();
|
||||||
Policy policy = getPolicy(playerToMove);
|
Policy policy = getPolicy(playerToMove);
|
||||||
@@ -108,6 +107,11 @@ public class Referee {
|
|||||||
} else {
|
} else {
|
||||||
System.out.println("Move rejected - try again.");
|
System.out.println("Move rejected - try again.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (policy.isLogging()) {
|
||||||
|
System.out.println(gameState);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
System.out
|
System.out
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import net.woodyfolsom.msproj.gui.Goban;
|
|||||||
import net.woodyfolsom.msproj.policy.HumanGuiInput;
|
import net.woodyfolsom.msproj.policy.HumanGuiInput;
|
||||||
import net.woodyfolsom.msproj.policy.HumanKeyboardInput;
|
import net.woodyfolsom.msproj.policy.HumanKeyboardInput;
|
||||||
import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
|
import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
|
||||||
|
import net.woodyfolsom.msproj.policy.MonteCarloSMAF;
|
||||||
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
|
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
|
||||||
import net.woodyfolsom.msproj.policy.Policy;
|
import net.woodyfolsom.msproj.policy.Policy;
|
||||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||||
@@ -26,7 +27,7 @@ public class StandAloneGame {
|
|||||||
private int gameNo = 0;
|
private int gameNo = 0;
|
||||||
|
|
||||||
enum PLAYER_TYPE {
|
enum PLAYER_TYPE {
|
||||||
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE
|
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF
|
||||||
};
|
};
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException {
|
public static void main(String[] args) throws IOException {
|
||||||
@@ -41,7 +42,8 @@ public class StandAloneGame {
|
|||||||
gameSettings.getBoardSize(), gameSettings.getKomi(),
|
gameSettings.getBoardSize(), gameSettings.getKomi(),
|
||||||
gameSettings.getNumGames(), gameSettings.getTurnTime(),
|
gameSettings.getNumGames(), gameSettings.getTurnTime(),
|
||||||
gameSettings.isSpectatorBoardShown(),
|
gameSettings.isSpectatorBoardShown(),
|
||||||
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
|
gameSettings.isBlackMoveLogged(),
|
||||||
|
gameSettings.isWhiteMoveLogged());
|
||||||
System.out.println("Press <Enter> or CTRL-C to exit");
|
System.out.println("Press <Enter> or CTRL-C to exit");
|
||||||
System.in.read(new byte[80]);
|
System.in.read(new byte[80]);
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
@@ -64,14 +66,17 @@ public class StandAloneGame {
|
|||||||
return PLAYER_TYPE.RANDOM;
|
return PLAYER_TYPE.RANDOM;
|
||||||
} else if ("RAVE".equalsIgnoreCase(playerTypeStr)) {
|
} else if ("RAVE".equalsIgnoreCase(playerTypeStr)) {
|
||||||
return PLAYER_TYPE.RAVE;
|
return PLAYER_TYPE.RAVE;
|
||||||
|
} else if ("SMAF".equalsIgnoreCase(playerTypeStr)) {
|
||||||
|
return PLAYER_TYPE.SMAF;
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Unknown player type: " + playerTypeStr);
|
throw new RuntimeException("Unknown player type: " + playerTypeStr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
|
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
|
||||||
int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard,
|
int size, double komi, int rounds, long turnLength,
|
||||||
boolean blackMoveLogged, boolean whiteMoveLogged) {
|
boolean showSpectatorBoard, boolean blackMoveLogged,
|
||||||
|
boolean whiteMoveLogged) {
|
||||||
|
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
@@ -79,28 +84,38 @@ public class StandAloneGame {
|
|||||||
gameConfig.setKomi(komi);
|
gameConfig.setKomi(komi);
|
||||||
|
|
||||||
Referee referee = new Referee();
|
Referee referee = new Referee();
|
||||||
referee.setPolicy(Player.BLACK,
|
referee.setPolicy(
|
||||||
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
|
Player.BLACK,
|
||||||
referee.setPolicy(Player.WHITE,
|
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength,
|
||||||
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
|
blackMoveLogged));
|
||||||
|
referee.setPolicy(
|
||||||
|
Player.WHITE,
|
||||||
|
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength,
|
||||||
|
whiteMoveLogged));
|
||||||
|
|
||||||
List<GameResult> round1results = new ArrayList<GameResult>();
|
List<GameResult> round1results = new ArrayList<GameResult>();
|
||||||
|
|
||||||
boolean logGameRecords = rounds <= 50;
|
boolean logGameRecords = rounds <= 50;
|
||||||
for (int round = 0; round < rounds; round++) {
|
for (int round = 0; round < rounds; round++) {
|
||||||
gameNo++;
|
gameNo++;
|
||||||
round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
|
round1results.add(referee.play(gameConfig, gameNo,
|
||||||
|
showSpectatorBoard, logGameRecords));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<GameResult> round2results = new ArrayList<GameResult>();
|
List<GameResult> round2results = new ArrayList<GameResult>();
|
||||||
|
|
||||||
referee.setPolicy(Player.BLACK,
|
referee.setPolicy(
|
||||||
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
|
Player.BLACK,
|
||||||
referee.setPolicy(Player.WHITE,
|
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength,
|
||||||
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
|
blackMoveLogged));
|
||||||
|
referee.setPolicy(
|
||||||
|
Player.WHITE,
|
||||||
|
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength,
|
||||||
|
whiteMoveLogged));
|
||||||
for (int round = 0; round < rounds; round++) {
|
for (int round = 0; round < rounds; round++) {
|
||||||
gameNo++;
|
gameNo++;
|
||||||
round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
|
round2results.add(referee.play(gameConfig, gameNo,
|
||||||
|
showSpectatorBoard, logGameRecords));
|
||||||
}
|
}
|
||||||
|
|
||||||
long endTime = System.currentTimeMillis();
|
long endTime = System.currentTimeMillis();
|
||||||
@@ -113,13 +128,14 @@ public class StandAloneGame {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
if (!logGameRecords) {
|
if (!logGameRecords) {
|
||||||
System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
|
System.out
|
||||||
|
.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
|
||||||
}
|
}
|
||||||
|
|
||||||
logResults(writer, round1results, playerType1.toString(),
|
logResults(writer, round1results, playerType1.toString(),
|
||||||
playerType2.toString());
|
playerType2.toString());
|
||||||
logResults(writer, round2results, playerType2.toString(),
|
logResults(writer, round2results, playerType2.toString(),
|
||||||
playerType1.toString());
|
playerType1.toString());
|
||||||
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
|
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
|
||||||
+ " seconds.");
|
+ " seconds.");
|
||||||
System.out.println("Game tournament saved as "
|
System.out.println("Game tournament saved as "
|
||||||
@@ -157,25 +173,38 @@ public class StandAloneGame {
|
|||||||
|
|
||||||
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
|
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
|
||||||
Player player, long turnLength, boolean moveLogged) {
|
Player player, long turnLength, boolean moveLogged) {
|
||||||
|
|
||||||
|
Policy policy;
|
||||||
|
|
||||||
switch (playerType) {
|
switch (playerType) {
|
||||||
case HUMAN:
|
case HUMAN:
|
||||||
return new HumanKeyboardInput();
|
policy = new HumanKeyboardInput();
|
||||||
|
break;
|
||||||
case HUMAN_GUI:
|
case HUMAN_GUI:
|
||||||
return new HumanGuiInput(new Goban(gameConfig, player,""));
|
policy = new HumanGuiInput(new Goban(gameConfig, player, ""));
|
||||||
|
break;
|
||||||
case ROOT_PAR:
|
case ROOT_PAR:
|
||||||
return new RootParallelization(4, turnLength);
|
policy = new RootParallelization(4, turnLength);
|
||||||
|
break;
|
||||||
case UCT:
|
case UCT:
|
||||||
return new MonteCarloUCT(new RandomMovePolicy(), turnLength);
|
policy = new MonteCarloUCT(new RandomMovePolicy(), turnLength);
|
||||||
|
break;
|
||||||
|
case SMAF:
|
||||||
|
policy = new MonteCarloSMAF(new RandomMovePolicy(), turnLength, 0);
|
||||||
|
break;
|
||||||
case RANDOM:
|
case RANDOM:
|
||||||
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
|
policy = new RandomMovePolicy();
|
||||||
randomMovePolicy.setLogging(moveLogged);
|
break;
|
||||||
return randomMovePolicy;
|
|
||||||
case RAVE:
|
case RAVE:
|
||||||
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
|
policy = new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
|
throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
|
||||||
+ playerType);
|
+ playerType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy.setLogging(moveLogged);
|
||||||
|
return policy;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
import java.io.FileNotFoundException;
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -16,66 +19,99 @@ import net.woodyfolsom.msproj.tictactoe.State;
|
|||||||
public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
||||||
// version of NeuralNetFilter
|
// version of NeuralNetFilter
|
||||||
|
|
||||||
public static void main(String[] args) throws FileNotFoundException {
|
private boolean training = true;
|
||||||
double alpha = 0.15;
|
|
||||||
double lambda = .95;
|
public static void main(String[] args) throws IOException {
|
||||||
int maxGames = 1000;
|
double alpha = 0.50;
|
||||||
|
double lambda = 0.90;
|
||||||
|
int maxGames = 100000;
|
||||||
|
|
||||||
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void trainNetwork(double alpha, double lambda, int maxGames)
|
public void trainNetwork(double alpha, double lambda, int maxGames)
|
||||||
throws FileNotFoundException {
|
throws IOException {
|
||||||
|
|
||||||
FeedforwardNetwork neuralNetwork = new MultiLayerPerceptron(true, 9, 6,
|
FeedforwardNetwork neuralNetwork;
|
||||||
1);
|
if (training) {
|
||||||
neuralNetwork.setName("TicTacToe");
|
neuralNetwork = new MultiLayerPerceptron(true, 9, 9, 1);
|
||||||
neuralNetwork.initWeights();
|
neuralNetwork.setName("TicTacToe");
|
||||||
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
|
neuralNetwork.initWeights();
|
||||||
|
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
|
||||||
|
|
||||||
System.out.println("Playing untrained games.");
|
System.out.println("Playing untrained games.");
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
System.out.println("" + (i + 1) + ". "
|
System.out.println("" + (i + 1) + ". "
|
||||||
+ playOptimal(neuralNetwork).getResult());
|
+ playOptimal(neuralNetwork).getResult());
|
||||||
}
|
|
||||||
|
|
||||||
System.out.println("Learning from " + maxGames
|
|
||||||
+ " games of random self-play");
|
|
||||||
|
|
||||||
int gamesPlayed = 0;
|
|
||||||
List<RESULT> results = new ArrayList<RESULT>();
|
|
||||||
do {
|
|
||||||
GameRecord gameRecord = playEpsilonGreedy(0.90, neuralNetwork,
|
|
||||||
trainer);
|
|
||||||
System.out.println("Winner: " + gameRecord.getResult());
|
|
||||||
gamesPlayed++;
|
|
||||||
results.add(gameRecord.getResult());
|
|
||||||
} while (gamesPlayed < maxGames);
|
|
||||||
|
|
||||||
System.out.println("Learned network after " + maxGames
|
|
||||||
+ " training games.");
|
|
||||||
|
|
||||||
for (int i = 0; i < results.size(); i++) {
|
|
||||||
if (i % 10 == 0) {
|
|
||||||
System.out.println("" + (i + 1) + ". " + results.get(i));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
System.out.println("Learning from " + maxGames
|
||||||
|
+ " games of random self-play");
|
||||||
|
|
||||||
|
int gamesPlayed = 0;
|
||||||
|
List<RESULT> results = new ArrayList<RESULT>();
|
||||||
|
do {
|
||||||
|
GameRecord gameRecord = playEpsilonGreedy(0.50, neuralNetwork,
|
||||||
|
trainer);
|
||||||
|
System.out.println("Winner: " + gameRecord.getResult());
|
||||||
|
gamesPlayed++;
|
||||||
|
results.add(gameRecord.getResult());
|
||||||
|
} while (gamesPlayed < maxGames);
|
||||||
|
|
||||||
|
System.out.println("Results of every 10th training game:");
|
||||||
|
|
||||||
|
for (int i = 0; i < results.size(); i++) {
|
||||||
|
if (i % 10 == 0) {
|
||||||
|
System.out.println("" + (i + 1) + ". " + results.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Learned network after " + maxGames
|
||||||
|
+ " training games.");
|
||||||
|
} else {
|
||||||
|
System.out.println("Loading TicTacToe network from file.");
|
||||||
|
neuralNetwork = new MultiLayerPerceptron();
|
||||||
|
FileInputStream fis = new FileInputStream(new File("ttt.net"));
|
||||||
|
if (!new MultiLayerPerceptron().load(fis)) {
|
||||||
|
System.out.println("Error loading ttt.net from file.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
fis.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
evalTestCases(neuralNetwork);
|
evalTestCases(neuralNetwork);
|
||||||
|
|
||||||
System.out.println("Playing optimal games.");
|
System.out.println("Playing optimal games.");
|
||||||
|
List<RESULT> gameResults = new ArrayList<RESULT>();
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
System.out.println("" + (i + 1) + ". "
|
gameResults.add(playOptimal(neuralNetwork).getResult());
|
||||||
+ playOptimal(neuralNetwork).getResult());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
boolean suboptimalPlay = false;
|
||||||
* File output = new File("ttt.net");
|
System.out.println("Optimal game summary: ");
|
||||||
*
|
for (int i = 0; i < gameResults.size(); i++) {
|
||||||
* FileOutputStream fos = new FileOutputStream(output);
|
RESULT result = gameResults.get(i);
|
||||||
*
|
System.out.println("" + (i + 1) + ". " + result);
|
||||||
* neuralNetwork.save(fos);
|
if (result != RESULT.X_WINS) {
|
||||||
*/
|
suboptimalPlay = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
File output = new File("ttt.net");
|
||||||
|
|
||||||
|
FileOutputStream fos = new FileOutputStream(output);
|
||||||
|
|
||||||
|
neuralNetwork.save(fos);
|
||||||
|
|
||||||
|
System.out.println("Playing optimal vs random games.");
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
System.out.println("" + (i + 1) + ". "
|
||||||
|
+ playOptimalVsRandom(neuralNetwork).getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (suboptimalPlay) {
|
||||||
|
System.out.println("Suboptimal play detected!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
|
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
|
||||||
@@ -115,6 +151,32 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
|||||||
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork) {
|
||||||
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
|
||||||
|
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||||
|
Policy randomPolicy = new RandomPolicy();
|
||||||
|
|
||||||
|
State state = gameRecord.getState();
|
||||||
|
|
||||||
|
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
|
||||||
|
int turnNo = 0;
|
||||||
|
do {
|
||||||
|
Action action;
|
||||||
|
State nextState;
|
||||||
|
|
||||||
|
action = policies[turnNo % 2].getAction(gameRecord.getState());
|
||||||
|
|
||||||
|
nextState = gameRecord.apply(action);
|
||||||
|
System.out.println("Action " + action + " selected by policy "
|
||||||
|
+ policies[turnNo % 2].getName());
|
||||||
|
System.out.println("Next board state: " + nextState);
|
||||||
|
state = nextState;
|
||||||
|
turnNo++;
|
||||||
|
} while (!state.isTerminal());
|
||||||
|
return gameRecord;
|
||||||
|
}
|
||||||
|
|
||||||
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
||||||
GameRecord gameRecord = new GameRecord();
|
GameRecord gameRecord = new GameRecord();
|
||||||
|
|
||||||
@@ -122,8 +184,6 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
|||||||
|
|
||||||
State state = gameRecord.getState();
|
State state = gameRecord.getState();
|
||||||
|
|
||||||
System.out.println("Playing optimal game:");
|
|
||||||
|
|
||||||
do {
|
do {
|
||||||
Action action;
|
Action action;
|
||||||
State nextState;
|
State nextState;
|
||||||
@@ -131,14 +191,12 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
|||||||
action = neuralNetPolicy.getAction(gameRecord.getState());
|
action = neuralNetPolicy.getAction(gameRecord.getState());
|
||||||
|
|
||||||
nextState = gameRecord.apply(action);
|
nextState = gameRecord.apply(action);
|
||||||
System.out.println("Action " + action + " selected by policy " +
|
System.out.println("Action " + action + " selected by policy "
|
||||||
neuralNetPolicy.getName());
|
+ neuralNetPolicy.getName());
|
||||||
System.out.println("Next board state: " + nextState);
|
System.out.println("Next board state: " + nextState);
|
||||||
state = nextState;
|
state = nextState;
|
||||||
} while (!state.isTerminal());
|
} while (!state.isTerminal());
|
||||||
|
|
||||||
// finally, reinforce the actual reward
|
|
||||||
|
|
||||||
return gameRecord;
|
return gameRecord;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,7 +254,7 @@ public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
|||||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||||
validationSet[valIndex]), new NNData(outputNames,
|
validationSet[valIndex]), new NNData(outputNames,
|
||||||
new double[] {0.0}));
|
new double[] { 0.0 }));
|
||||||
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import java.util.List;
|
|||||||
|
|
||||||
public class TemporalDifference extends TrainingMethod {
|
public class TemporalDifference extends TrainingMethod {
|
||||||
private final double alpha;
|
private final double alpha;
|
||||||
private final double gamma = 1.0;
|
// private final double gamma = 1.0;
|
||||||
private final double lambda;
|
private final double lambda;
|
||||||
|
|
||||||
public TemporalDifference(double alpha, double lambda) {
|
public TemporalDifference(double alpha, double lambda) {
|
||||||
@@ -81,23 +81,27 @@ public class TemporalDifference extends TrainingMethod {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateWeights(FeedforwardNetwork neuralNetwork, double predictionError) {
|
private void updateWeights(FeedforwardNetwork neuralNetwork,
|
||||||
|
double predictionError) {
|
||||||
for (Connection connection : neuralNetwork.getConnections()) {
|
for (Connection connection : neuralNetwork.getConnections()) {
|
||||||
/*Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
/*
|
||||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
* Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
|
* Neuron destNeuron =
|
||||||
double delta = alpha * srcNeuron.getOutput()
|
* neuralNetwork.getNeuron(connection.getDest());
|
||||||
* destNeuron.getGradient() * predictionError + connection.getTrace() * lambda;
|
*
|
||||||
|
* double delta = alpha * srcNeuron.getOutput()
|
||||||
// TODO allow for momentum
|
* destNeuron.getGradient() * predictionError +
|
||||||
// double lastDelta = connection.getLastDelta();
|
* connection.getTrace() * lambda;
|
||||||
connection.addDelta(delta);*/
|
*
|
||||||
|
* // TODO allow for momentum // double lastDelta =
|
||||||
|
* connection.getLastDelta(); connection.addDelta(delta);
|
||||||
|
*/
|
||||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||||
double delta = alpha * srcNeuron.getOutput()
|
double delta = alpha * srcNeuron.getOutput()
|
||||||
* destNeuron.getGradient() + connection.getTrace() * lambda;
|
* destNeuron.getGradient() + connection.getTrace() * lambda;
|
||||||
//TODO allow for momentum
|
// TODO allow for momentum
|
||||||
//double lastDelta = connection.getLastDelta();
|
// double lastDelta = connection.getLastDelta();
|
||||||
connection.addDelta(delta);
|
connection.addDelta(delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,23 +121,23 @@ public class TemporalDifference extends TrainingMethod {
|
|||||||
@Override
|
@Override
|
||||||
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||||
NNDataPair statePair, NNData nextReward) {
|
NNDataPair statePair, NNData nextReward) {
|
||||||
//System.out.println("Learningrate: " + alpha);
|
// System.out.println("Learningrate: " + alpha);
|
||||||
|
|
||||||
zeroGradients(neuralNetwork);
|
zeroGradients(neuralNetwork);
|
||||||
|
|
||||||
//System.out.println("Training with: " + statePair.getInput());
|
// System.out.println("Training with: " + statePair.getInput());
|
||||||
|
|
||||||
NNData ideal = nextReward;
|
NNData ideal = nextReward;
|
||||||
NNData actual = neuralNetwork.compute(statePair);
|
NNData actual = neuralNetwork.compute(statePair);
|
||||||
|
|
||||||
//System.out.println("Updating weights. Ideal Output: " + ideal);
|
// System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||||
//System.out.println("Actual Output: " + actual);
|
// System.out.println("Actual Output: " + actual);
|
||||||
|
|
||||||
// backpropagate the gradients w.r.t. output error
|
// backpropagate the gradients w.r.t. output error
|
||||||
backPropagate(neuralNetwork, ideal);
|
backPropagate(neuralNetwork, ideal);
|
||||||
|
|
||||||
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
|
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
|
||||||
+ actual.getValues()[0] - nextReward.getValues()[0];
|
+ actual.getValues()[0] - nextReward.getValues()[0];
|
||||||
|
|
||||||
updateWeights(neuralNetwork, predictionError);
|
updateWeights(neuralNetwork, predictionError);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,15 @@ public class AlphaBeta implements Policy {
|
|||||||
|
|
||||||
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
||||||
|
|
||||||
|
private boolean logging = false;
|
||||||
|
public boolean isLogging() {
|
||||||
|
return logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogging(boolean logging) {
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
private int lookAhead;
|
private int lookAhead;
|
||||||
private int numStateEvaluations = 0;
|
private int numStateEvaluations = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.Player;
|
|||||||
import net.woodyfolsom.msproj.gui.Goban;
|
import net.woodyfolsom.msproj.gui.Goban;
|
||||||
|
|
||||||
public class HumanGuiInput implements Policy {
|
public class HumanGuiInput implements Policy {
|
||||||
|
private boolean logging;
|
||||||
|
public boolean isLogging() {
|
||||||
|
return logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogging(boolean logging) {
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
private Goban goban;
|
private Goban goban;
|
||||||
|
|
||||||
public HumanGuiInput(Goban goban) {
|
public HumanGuiInput(Goban goban) {
|
||||||
|
|||||||
@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.GameState;
|
|||||||
import net.woodyfolsom.msproj.Player;
|
import net.woodyfolsom.msproj.Player;
|
||||||
|
|
||||||
public class HumanKeyboardInput implements Policy {
|
public class HumanKeyboardInput implements Policy {
|
||||||
|
private boolean logging = false;
|
||||||
|
|
||||||
|
public boolean isLogging() {
|
||||||
|
return logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogging(boolean logging) {
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||||
|
|||||||
@@ -16,6 +16,15 @@ public class Minimax implements Policy {
|
|||||||
|
|
||||||
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
||||||
|
|
||||||
|
private boolean logging = false;
|
||||||
|
public boolean isLogging() {
|
||||||
|
return logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogging(boolean logging) {
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
private int lookAhead;
|
private int lookAhead;
|
||||||
private int numStateEvaluations = 0;
|
private int numStateEvaluations = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,15 @@ import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
|||||||
public abstract class MonteCarlo implements Policy {
|
public abstract class MonteCarlo implements Policy {
|
||||||
protected static final int ROLLOUT_DEPTH_LIMIT = 250;
|
protected static final int ROLLOUT_DEPTH_LIMIT = 250;
|
||||||
|
|
||||||
|
private boolean logging = false;
|
||||||
|
public boolean isLogging() {
|
||||||
|
return logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogging(boolean logging) {
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
protected int numStateEvaluations = 0;
|
protected int numStateEvaluations = 0;
|
||||||
protected Policy movePolicy;
|
protected Policy movePolicy;
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,43 @@ public class MonteCarloAMAF extends MonteCarloUCT {
|
|||||||
rootGameState, new AMAFProperties());
|
rootGameState, new AMAFProperties());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Action getBestAction(GameTreeNode<MonteCarloProperties> node) {
|
||||||
|
Action bestAction = Action.NONE;
|
||||||
|
double bestScore = Double.NEGATIVE_INFINITY;
|
||||||
|
GameTreeNode<MonteCarloProperties> bestChild = null;
|
||||||
|
|
||||||
|
for (Action action : node.getActions()) {
|
||||||
|
GameTreeNode<MonteCarloProperties> childNode = node
|
||||||
|
.getChild(action);
|
||||||
|
|
||||||
|
AMAFProperties childProps = (AMAFProperties)childNode.getProperties();
|
||||||
|
double childScore = childProps.getAmafWins() / (double)childProps.getAmafVisits();
|
||||||
|
|
||||||
|
if (childScore >= bestScore) {
|
||||||
|
bestScore = childScore;
|
||||||
|
bestAction = action;
|
||||||
|
bestChild = childNode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bestAction == Action.NONE) {
|
||||||
|
System.out
|
||||||
|
.println("MonteCarloUCT failed - no actions were found for the current game state (not even PASS).");
|
||||||
|
} else {
|
||||||
|
System.out.println("Action " + bestAction + " selected for "
|
||||||
|
+ node.getGameState().getPlayerToMove()
|
||||||
|
+ " with simulated win ratio of "
|
||||||
|
+ (bestScore * 100.0 + "%"));
|
||||||
|
System.out.println("It was visited "
|
||||||
|
+ bestChild.getProperties().getVisits() + " times out of "
|
||||||
|
+ node.getProperties().getVisits() + " rollouts among "
|
||||||
|
+ node.getNumChildren()
|
||||||
|
+ " valid actions from the current state.");
|
||||||
|
}
|
||||||
|
return bestAction;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) {
|
protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) {
|
||||||
//double nodeVisits = gameTreeNode.getParent().getProperties().getVisits();
|
//double nodeVisits = gameTreeNode.getParent().getProperties().getVisits();
|
||||||
@@ -72,16 +109,8 @@ public class MonteCarloAMAF extends MonteCarloUCT {
|
|||||||
if (gameTreeNode.getGameState().isTerminal()) {
|
if (gameTreeNode.getGameState().isTerminal()) {
|
||||||
nodeScore = 0.0;
|
nodeScore = 0.0;
|
||||||
} else {
|
} else {
|
||||||
/*
|
|
||||||
MonteCarloProperties properties = gameTreeNode.getProperties();
|
|
||||||
nodeScore = (double) (properties.getWins() / properties
|
|
||||||
.getVisits())
|
|
||||||
+ (TUNING_CONSTANT * Math.sqrt(Math.log(nodeVisits)
|
|
||||||
/ gameTreeNode.getProperties().getVisits()));
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties();
|
AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties();
|
||||||
nodeScore = (double) (properties.getAmafWins() / properties
|
nodeScore = (properties.getAmafWins() / (double) properties
|
||||||
.getAmafVisits())
|
.getAmafVisits())
|
||||||
+ (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits)
|
+ (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits)
|
||||||
/ properties.getAmafVisits()));
|
/ properties.getAmafVisits()));
|
||||||
|
|||||||
59
src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java
Normal file
59
src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package net.woodyfolsom.msproj.policy;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import net.woodyfolsom.msproj.Action;
|
||||||
|
import net.woodyfolsom.msproj.Player;
|
||||||
|
import net.woodyfolsom.msproj.tree.AMAFProperties;
|
||||||
|
import net.woodyfolsom.msproj.tree.GameTreeNode;
|
||||||
|
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||||
|
|
||||||
|
public class MonteCarloSMAF extends MonteCarloAMAF {
|
||||||
|
private int horizon;
|
||||||
|
|
||||||
|
public MonteCarloSMAF(Policy movePolicy, long searchTimeLimit, int horizon) {
|
||||||
|
super(movePolicy, searchTimeLimit);
|
||||||
|
this.horizon = horizon;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void update(GameTreeNode<MonteCarloProperties> node, Rollout rollout) {
|
||||||
|
GameTreeNode<MonteCarloProperties> currentNode = node;
|
||||||
|
//List<Action> subTreeActions = new ArrayList<Action>(rollout.getPlayout());
|
||||||
|
|
||||||
|
List<Action> playout = rollout.getPlayout();
|
||||||
|
int reward = rollout.getReward();
|
||||||
|
while (currentNode != null) {
|
||||||
|
AMAFProperties nodeProperties = (AMAFProperties)currentNode.getProperties();
|
||||||
|
|
||||||
|
//Always update props for the current node
|
||||||
|
nodeProperties.setWins(nodeProperties.getWins() + reward);
|
||||||
|
nodeProperties.setVisits(nodeProperties.getVisits() + 1);
|
||||||
|
nodeProperties.setAmafWins(nodeProperties.getAmafWins() + reward);
|
||||||
|
nodeProperties.setAmafVisits(nodeProperties.getAmafVisits() + 1);
|
||||||
|
|
||||||
|
GameTreeNode<MonteCarloProperties> parentNode = currentNode.getParent();
|
||||||
|
if (parentNode != null) {
|
||||||
|
Player playerToMove = parentNode.getGameState().getPlayerToMove();
|
||||||
|
for (Action actionFromParent : parentNode.getActions()) {
|
||||||
|
if (playout.subList(0, Math.max(horizon,playout.size())).contains(actionFromParent)) {
|
||||||
|
GameTreeNode<MonteCarloProperties> subTreeChild = parentNode.getChild(actionFromParent);
|
||||||
|
//Don't count AMAF properties for the current node twice
|
||||||
|
if (subTreeChild == currentNode) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
AMAFProperties siblingProperties = (AMAFProperties)subTreeChild.getProperties();
|
||||||
|
//Only update AMAF properties if the sibling is reached by the same action with the same player to move
|
||||||
|
if (rollout.hasPlay(playerToMove,actionFromParent)) {
|
||||||
|
siblingProperties.setAmafWins(siblingProperties.getAmafWins() + reward);
|
||||||
|
siblingProperties.setAmafVisits(siblingProperties.getAmafVisits() + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentNode = currentNode.getParent();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -90,11 +90,8 @@ public class MonteCarloUCT extends MonteCarlo {
|
|||||||
GameTreeNode<MonteCarloProperties> childNode = node
|
GameTreeNode<MonteCarloProperties> childNode = node
|
||||||
.getChild(action);
|
.getChild(action);
|
||||||
|
|
||||||
//MonteCarloProperties properties = childNode.getProperties();
|
MonteCarloProperties childProps = childNode.getProperties();
|
||||||
//double childScore = (double) properties.getWins()
|
double childScore = childProps.getWins() / (double)childProps.getVisits();
|
||||||
// / properties.getVisits();
|
|
||||||
|
|
||||||
double childScore = getNodeScore(childNode);
|
|
||||||
|
|
||||||
if (childScore >= bestScore) {
|
if (childScore >= bestScore) {
|
||||||
bestScore = childScore;
|
bestScore = childScore;
|
||||||
|
|||||||
@@ -17,4 +17,8 @@ public interface Policy {
|
|||||||
public int getNumStateEvaluations();
|
public int getNumStateEvaluations();
|
||||||
|
|
||||||
public void setState(GameState gameState);
|
public void setState(GameState gameState);
|
||||||
|
|
||||||
|
boolean isLogging();
|
||||||
|
|
||||||
|
void setLogging(boolean logging);
|
||||||
}
|
}
|
||||||
@@ -110,6 +110,7 @@ public class RandomMovePolicy implements Policy, ActionGenerator {
|
|||||||
return randomAction;
|
return randomAction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public boolean isLogging() {
|
public boolean isLogging() {
|
||||||
return logging;
|
return logging;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,16 @@ import net.woodyfolsom.msproj.Player;
|
|||||||
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||||
|
|
||||||
public class RootParallelization implements Policy {
|
public class RootParallelization implements Policy {
|
||||||
|
private boolean logging = false;
|
||||||
private int numTrees = 1;
|
private int numTrees = 1;
|
||||||
|
public boolean isLogging() {
|
||||||
|
return logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogging(boolean logging) {
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
private long timeLimit = 1000L;
|
private long timeLimit = 1000L;
|
||||||
|
|
||||||
public RootParallelization(int numTrees, long timeLimit) {
|
public RootParallelization(int numTrees, long timeLimit) {
|
||||||
|
|||||||
Reference in New Issue
Block a user