First working implementation of ANN which is trained using GameResults.
The PassFilter simply outputs BlackWins and WhiteWins (Range 0 - 1 but not presently clamped). In principle, this type of feedforward ANN can be used to decide whether a PASS will result in blackwins or whitewins at any stage. The goal is for the network to learn that passing while losing when valid moves exist is bad, but passing while winning is relatively harmless later in the game.
This commit is contained in:
29
test/net/woodyfolsom/msproj/GameBoardTest.java
Normal file
29
test/net/woodyfolsom/msproj/GameBoardTest.java
Normal file
@@ -0,0 +1,29 @@
|
||||
package net.woodyfolsom.msproj;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class GameBoardTest {
|
||||
@Test
|
||||
public void testCapture() {
|
||||
GameConfig gameConfig = new GameConfig(5);
|
||||
GameState gameState = new GameState(gameConfig);
|
||||
|
||||
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("A2")));
|
||||
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("B3")));
|
||||
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("B1")));
|
||||
assertTrue(gameState.placeStone(Player.BLACK, Action.getInstance("C2")));
|
||||
|
||||
assertTrue(gameState.isSelfFill(Action.getInstance("B2"), Player.BLACK));
|
||||
assertFalse(gameState
|
||||
.isSelfFill(Action.getInstance("B2"), Player.WHITE));
|
||||
assertFalse(gameState
|
||||
.isSelfFill(Action.getInstance("B4"), Player.BLACK));
|
||||
assertFalse(gameState
|
||||
.isSelfFill(Action.getInstance("B4"), Player.BLACK));
|
||||
|
||||
System.out.println(gameState);
|
||||
}
|
||||
}
|
||||
@@ -1,28 +1,76 @@
|
||||
package net.woodyfolsom.msproj;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
import net.woodyfolsom.msproj.GameResult;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
|
||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||
import net.woodyfolsom.msproj.policy.RootParallelization;
|
||||
import net.woodyfolsom.msproj.sgf.SGFLexer;
|
||||
import net.woodyfolsom.msproj.sgf.SGFNodeCollection;
|
||||
import net.woodyfolsom.msproj.sgf.SGFParser;
|
||||
|
||||
import org.antlr.runtime.ANTLRInputStream;
|
||||
import org.antlr.runtime.ANTLRStringStream;
|
||||
import org.antlr.runtime.CommonTokenStream;
|
||||
import org.antlr.runtime.RecognitionException;
|
||||
import org.junit.Test;
|
||||
|
||||
public class GameScoreTest {
|
||||
|
||||
// public static final String endGameSGF =
|
||||
// "(;FF[4]GM[1]SZ[9]KM[5.5];B[ef];W[ff];B[dg];W[aa];B[fc];W[da];B[cg];W[ei];B[gf]"
|
||||
// +
|
||||
// ";W[fi];B[ag];W[ii];B[bi];W[if];B[db];W[ci];B[cf];W[ih];B[bc];W[hb];B[eb];W[fh];B[ig];W[hc];B[be];W[he];B[gc];"
|
||||
// +
|
||||
// "W[id];B[cd];W[df];B[hf];W[ah];B[bh];W[fa];B[bg];W[fe];B[ec];W[eh];B[ee];W[bd];B[hg];W[ie];B[fg];W[ca];B[eg];"
|
||||
// +
|
||||
// "W[cb];B[ad];W[ba];B[ch];W[dh];B[gd];W[ic];B[ha];W[ab];B[gh];W[gb];B[ed];W[];B[])";
|
||||
|
||||
//public static final String endGameSGF = "(;FF[4]GM[1]SZ[9]KM[5.5]RE[W+0.5];B[ef];W[cb];B[fe];W[da];B[cd];W[hh];B[ed];W[cc];B[ci];W[bc];B[cg];W[fi];B[be];W[ea];B[hi];W[df];B[fd];W[bg];B[cf];W[aa];B[gd];W[ch];B[ad];W[dg];B[de];W[ge];B[bh];W[fa];B[ag];W[hd];B[if];W[bi];B[gf];W[bd];B[ah];W[gc];B[ff];W[ca];B[hf];W[dd];B[ce];W[ae];B[ga];W[hc];B[ac];W[gg];B[fg];W[fb];B[ie];W[dh];B[af];W[ec];B[dc];W[id];B[dd];W[eh];B[eb];W[gb];B[ae];W[ic];B[di];W[fh];B[ig];W[ab];B[ha];W[hg];B[hb];W[gi];B[ii];W[ia];B[fc];W[ba];B[eg];W[];B[db];W[];B[])";
|
||||
|
||||
public static final String endGameSGF = "(;FF[4]GM[1]SZ[6]KM[1.5]RE[B+0.5];B[bb];W[];B[ec];W[ef];B[ac];W[ed];B[ba];W[dc];B[cf];W[];B[])";
|
||||
@Test
|
||||
public void testGetAggregateScoreZero() {
|
||||
GameResult gameScore = new GameResult(0,0,19,0, true);
|
||||
assertEquals(gameScore.getNormalizedZeroScore(), gameScore.getNormalizedScore());
|
||||
GameResult gameScore = new GameResult(0, 0, 19, 0, true);
|
||||
assertEquals(gameScore.getNormalizedZeroScore(),
|
||||
gameScore.getNormalizedScore());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGetAggregateScoreBlackWinsNoKomi() {
|
||||
GameResult gameScore = new GameResult(25,2,19,0, true);
|
||||
GameResult gameScore = new GameResult(25, 2, 19, 0, true);
|
||||
assertEquals(407, gameScore.getNormalizedScore());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGetAggregateScoreWhiteWinsWithKomi() {
|
||||
GameResult gameScore = new GameResult(10,12,19,6.5, true);
|
||||
GameResult gameScore = new GameResult(10, 12, 19, 6.5, true);
|
||||
assertEquals(357, gameScore.getNormalizedScore());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScoreEndGame() throws IOException, RecognitionException {
|
||||
InputStream is = new ByteArrayInputStream(endGameSGF.getBytes());
|
||||
GameRecord gameRecord = Referee.replay(is);
|
||||
assertEquals(11, gameRecord.getNumTurns());
|
||||
|
||||
GameState gameState9 = gameRecord.getGameState(9);
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
//Action action = new RootParallelization(4, 1000L).getAction(gameRecord.getGameConfig(), gameState9, Player.WHITE);
|
||||
Action action = new MonteCarloUCT(new RandomMovePolicy(),1000L).getAction(gameRecord.getGameConfig(), gameState9, Player.WHITE);
|
||||
System.out.println("Suggested action for "+Player.WHITE+": " + action);
|
||||
}
|
||||
|
||||
gameState9.playStone(Player.WHITE, Action.PASS);
|
||||
gameState9.playStone(Player.BLACK, Action.PASS);
|
||||
assertTrue(gameState9.isTerminal());
|
||||
System.out.println(gameState9.getResult());
|
||||
}
|
||||
}
|
||||
30
test/net/woodyfolsom/msproj/ann/PassNetworkTest.java
Normal file
30
test/net/woodyfolsom/msproj/ann/PassNetworkTest.java
Normal file
@@ -0,0 +1,30 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.neuroph.core.NeuralNetwork;
|
||||
|
||||
public class PassNetworkTest {
|
||||
|
||||
@Test
|
||||
public void testSavedNetwork() {
|
||||
NeuralNetwork passFilter = NeuralNetwork.load("data/networks/Pass1.nn");
|
||||
passFilter.setInput(0.75,0.25);
|
||||
passFilter.calculate();
|
||||
|
||||
PassData passData = new PassData();
|
||||
double[] output = passFilter.getOutput();
|
||||
System.out.println("Output: " + passData.getOutput(output));
|
||||
|
||||
assertTrue(output[0] > 0.50);
|
||||
assertTrue(output[1] < 0.50);
|
||||
|
||||
passFilter.setInput(0.25,0.50);
|
||||
passFilter.calculate();
|
||||
output = passFilter.getOutput();
|
||||
System.out.println("Output: " + passData.getOutput(output));
|
||||
assertTrue(output[0] < 0.50);
|
||||
assertTrue(output[1] > 0.50);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user