First working implementation of ANN which is trained using GameResults.

The PassFilter simply outputs BlackWins and WhiteWins (Range 0 - 1 but not presently clamped).
In principle, this type of feedforward ANN can be used to decide whether a PASS will result in blackwins or whitewins at any stage.
The goal is for the network to learn that passing while losing when valid moves exist is bad, but passing while winning is relatively harmless later in the game.
This commit is contained in:
2012-11-17 18:40:31 -05:00
parent d9d6ecda80
commit aca8320600
37 changed files with 1040 additions and 544 deletions

View File

@@ -0,0 +1,29 @@
package net.woodyfolsom.msproj;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
public class GameBoardTest {
@Test
public void testCapture() {
GameConfig gameConfig = new GameConfig(5);
GameState gameState = new GameState(gameConfig);
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("A2")));
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("B3")));
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("B1")));
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("C2")));
assertTrue(gameState.isSelfFill(Action.getInstance("B2"), Player.BLACK));
assertFalse(gameState
.isSelfFill(Action.getInstance("B2"), Player.WHITE));
assertFalse(gameState
.isSelfFill(Action.getInstance("B4"), Player.BLACK));
assertFalse(gameState
.isSelfFill(Action.getInstance("B4"), Player.BLACK));
System.out.println(gameState);
}
}

View File

@@ -1,28 +1,76 @@
package net.woodyfolsom.msproj;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParallelization;
import net.woodyfolsom.msproj.sgf.SGFLexer;
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
import net.woodyfolsom.msproj.sgf.SGFParser;
import org.antlr.runtime.ANTLRInputStream;
import org.antlr.runtime.ANTLRStringStream;
import org.antlr.runtime.CommonTokenStream;
import org.antlr.runtime.RecognitionException;
import org.junit.Test;
public class GameScoreTest {
// public static final String endGameSGF =
// "(;FF[4]GM[1]SZ[9]KM[5.5];B[ef];W[ff];B[dg];W[aa];B[fc];W[da];B[cg];W[ei];B[gf]"
// +
// ";W[fi];B[ag];W[ii];B[bi];W[if];B[db];W[ci];B[cf];W[ih];B[bc];W[hb];B[eb];W[fh];B[ig];W[hc];B[be];W[he];B[gc];"
// +
// "W[id];B[cd];W[df];B[hf];W[ah];B[bh];W[fa];B[bg];W[fe];B[ec];W[eh];B[ee];W[bd];B[hg];W[ie];B[fg];W[ca];B[eg];"
// +
// "W[cb];B[ad];W[ba];B[ch];W[dh];B[gd];W[ic];B[ha];W[ab];B[gh];W[gb];B[ed];W[];B[])";
//public static final String endGameSGF = "(;FF[4]GM[1]SZ[9]KM[5.5]RE[W+0.5];B[ef];W[cb];B[fe];W[da];B[cd];W[hh];B[ed];W[cc];B[ci];W[bc];B[cg];W[fi];B[be];W[ea];B[hi];W[df];B[fd];W[bg];B[cf];W[aa];B[gd];W[ch];B[ad];W[dg];B[de];W[ge];B[bh];W[fa];B[ag];W[hd];B[if];W[bi];B[gf];W[bd];B[ah];W[gc];B[ff];W[ca];B[hf];W[dd];B[ce];W[ae];B[ga];W[hc];B[ac];W[gg];B[fg];W[fb];B[ie];W[dh];B[af];W[ec];B[dc];W[id];B[dd];W[eh];B[eb];W[gb];B[ae];W[ic];B[di];W[fh];B[ig];W[ab];B[ha];W[hg];B[hb];W[gi];B[ii];W[ia];B[fc];W[ba];B[eg];W[];B[db];W[];B[])";
public static final String endGameSGF = "(;FF[4]GM[1]SZ[6]KM[1.5]RE[B+0.5];B[bb];W[];B[ec];W[ef];B[ac];W[ed];B[ba];W[dc];B[cf];W[];B[])";
@Test
public void testGetAggregateScoreZero() {
GameResult gameScore = new GameResult(0,0,19,0, true);
assertEquals(gameScore.getNormalizedZeroScore(), gameScore.getNormalizedScore());
GameResult gameScore = new GameResult(0, 0, 19, 0, true);
assertEquals(gameScore.getNormalizedZeroScore(),
gameScore.getNormalizedScore());
}
@Test
public void testGetAggregateScoreBlackWinsNoKomi() {
GameResult gameScore = new GameResult(25,2,19,0, true);
GameResult gameScore = new GameResult(25, 2, 19, 0, true);
assertEquals(407, gameScore.getNormalizedScore());
}
@Test
public void testGetAggregateScoreWhiteWinsWithKomi() {
GameResult gameScore = new GameResult(10,12,19,6.5, true);
GameResult gameScore = new GameResult(10, 12, 19, 6.5, true);
assertEquals(357, gameScore.getNormalizedScore());
}
@Test
public void testScoreEndGame() throws IOException, RecognitionException {
InputStream is = new ByteArrayInputStream(endGameSGF.getBytes());
GameRecord gameRecord = Referee.replay(is);
assertEquals(11, gameRecord.getNumTurns());
GameState gameState9 = gameRecord.getGameState(9);
for (int i = 0; i < 5; i++) {
//Action action = new RootParallelization(4, 1000L).getAction(gameRecord.getGameConfig(), gameState9, Player.WHITE);
Action action = new MonteCarloUCT(new RandomMovePolicy(),1000L).getAction(gameRecord.getGameConfig(), gameState9, Player.WHITE);
System.out.println("Suggested action for "+Player.WHITE+": " + action);
}
gameState9.playStone(Player.WHITE, Action.PASS);
gameState9.playStone(Player.BLACK, Action.PASS);
assertTrue(gameState9.isTerminal());
System.out.println(gameState9.getResult());
}
}

View File

@@ -0,0 +1,30 @@
package net.woodyfolsom.msproj.ann;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
public class PassNetworkTest {
@Test
public void testSavedNetwork() {
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn");
passFilter.setInput(0.75,0.25);
passFilter.calculate();
PassData passData = new PassData();
double[] output = passFilter.getOutput();
System.out.println("Output: " + passData.getOutput(output));
assertTrue(output[0] > 0.50);
assertTrue(output[1] < 0.50);
passFilter.setInput(0.25,0.50);
passFilter.calculate();
output = passFilter.getOutput();
System.out.println("Output: " + passData.getOutput(output));
assertTrue(output[0] < 0.50);
assertTrue(output[1] > 0.50);
}
}