Fixed use of Zobrist hash for positional superko detection.

This commit is contained in:
cs6601
2012-09-04 16:02:49 -04:00
parent 0bbcb1054d
commit d4acc5beda
14 changed files with 507 additions and 152 deletions

View File

@@ -1,6 +1,8 @@
package net.woodyfolsom.msproj; package net.woodyfolsom.msproj;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
public class GameBoard { public class GameBoard {
public static final char BLACK_STONE = 'X'; public static final char BLACK_STONE = 'X';
@@ -15,20 +17,35 @@ public class GameBoard {
private boolean territoryMarked = false; private boolean territoryMarked = false;
private int size; private int size;
private char[] board; private char[] board;
private List<Integer> captureList;
private List<Long> boardHashHistory = new ArrayList<Long>();
private ZobristHashGenerator zobristHashGenerator;
public GameBoard(int size) { public GameBoard(int size) {
this.size = size; this.size = size;
board = new char[size * size]; board = new char[size * size];
Arrays.fill(board, '.'); Arrays.fill(board, '.');
zobristHashGenerator = ZobristHashGenerator.getInstance(size);
boardHashHistory.add(zobristHashGenerator.getEmptyBoardHash());
captureList = new ArrayList<Integer>();
} }
public GameBoard(GameBoard that) { public GameBoard(GameBoard that) {
this.size = that.size; this.size = that.size;
this.board = Arrays.copyOf(that.board, that.board.length); this.board = Arrays.copyOf(that.board, that.board.length);
zobristHashGenerator = ZobristHashGenerator.getInstance(size);
boardHashHistory = new ArrayList<Long>(that.boardHashHistory);
captureList = new ArrayList<Integer>(that.captureList);
}
public long getZobristHash() {
return boardHashHistory.get(boardHashHistory.size() - 1);
} }
public void clear() { public void clear() {
territoryMarked = false; territoryMarked = false;
captureList.clear();
Arrays.fill(board, EMPTY_INTERSECTION); Arrays.fill(board, EMPTY_INTERSECTION);
} }
@@ -44,6 +61,10 @@ public class GameBoard {
return stoneCount; return stoneCount;
} }
public List<Integer> getCaptureList() {
return captureList;
}
@Override @Override
// TODO: implement as Zobrist hash. // TODO: implement as Zobrist hash.
public int hashCode() { public int hashCode() {
@@ -102,7 +123,8 @@ public class GameBoard {
} else if (stoneSymbol == GameBoard.WHITE_STONE) { } else if (stoneSymbol == GameBoard.WHITE_STONE) {
return GameBoard.BLACK_STONE; return GameBoard.BLACK_STONE;
} else { } else {
throw new IllegalArgumentException("StoneSymbol must be BLACK_STONE or WHITE_STONE"); throw new IllegalArgumentException(
"StoneSymbol must be BLACK_STONE or WHITE_STONE");
} }
} }
@@ -111,16 +133,23 @@ public class GameBoard {
} }
/** /**
* @param colLabel [A..T] (skipping I * @param colLabel
* @param rowNumber 1-based * [A..T] (skipping I
* @param rowNumber
* 1-based
* @return * @return
*/ */
public char getSymbolAt(char colLabel, int rowNumber) { public char getSymbolAt(char colLabel, int rowNumber) {
return getSymbolAt(getColumnIndex(colLabel), rowNumber - 1); return getSymbolAt(getColumnIndex(colLabel), rowNumber - 1);
} }
public char getSymbolAt(int index) {
return board[index];
}
/** /**
* 0-based. * 0-based.
*
* @param col * @param col
* @param row * @param row
* @return * @return
@@ -138,6 +167,16 @@ public class GameBoard {
return getSymbolAt(col, row) == EMPTY_INTERSECTION; return getSymbolAt(col, row) == EMPTY_INTERSECTION;
} }
public boolean isKoViolation() {
long lastHash = boardHashHistory.get(boardHashHistory.size() - 1);
for (int i = 0; i < boardHashHistory.size() - 1; i++) {
if (boardHashHistory.get(i) == lastHash) {
return true;
}
}
return false;
}
public boolean isTerritoryMarked() { public boolean isTerritoryMarked() {
return territoryMarked; return territoryMarked;
} }
@@ -170,26 +209,85 @@ public class GameBoard {
if (getSymbolAt(colLabel, rowNum) == EMPTY_INTERSECTION) { if (getSymbolAt(colLabel, rowNum) == EMPTY_INTERSECTION) {
return false; return false;
} }
setSymbolAt(colLabel, rowNum, EMPTY_INTERSECTION); setSymbolAt(colLabel, rowNum, EMPTY_INTERSECTION);
return true; return true;
} }
public int replaceSymbol(char symbol, char replacement) { public int captureMarkedGroup(char opponentSymbol) {
int numReplaced = 0; int numReplaced = 0;
for (int i = 0; i < board.length; i++) { for (int i = 0; i < board.length; i++) {
if (board[i] == symbol) { if (board[i] == MARKED_GROUP) {
board[i] = replacement; board[i] = opponentSymbol;
setSymbolAt(i,EMPTY_INTERSECTION);
captureList.add(i);
numReplaced++; numReplaced++;
} }
} }
return numReplaced; return numReplaced;
} }
public void clearCaptureList() {
captureList.clear();
}
public void markTerritory(char ownerSymbol) {
for (int i = 0; i < board.length; i++) {
if (board[i] == MARKED_TERRITORY) {
board[i] = ownerSymbol;
}
}
}
public void unmarkGroup(char opponentSymbol) {
for (int i = 0; i < board.length; i++) {
if (board[i] == MARKED_GROUP) {
board[i] = opponentSymbol;
}
}
}
//TODO change boardHashHistory to stack
public void popHashHistory() {
boardHashHistory.remove(boardHashHistory.get(boardHashHistory.size() - 1));
}
public void pushHashHistory() {
boardHashHistory.add(boardHashHistory.get(boardHashHistory.size() - 1));
}
public void setSymbolAt(char colLabel, int rowNumber, char symbol) { public void setSymbolAt(char colLabel, int rowNumber, char symbol) {
setSymbolAt(getColumnIndex(colLabel), rowNumber - 1, symbol); setSymbolAt(getColumnIndex(colLabel), rowNumber - 1, symbol);
} }
public void setSymbolAt(int col, int row, char symbol) {
board[(size - row - 1) * size + col] = symbol; public void setSymbolAt(int index, char newSymbol) {
char oldSymbol = board[index];
//TODO marked intersections should really be stored in
//a separate array to ensure that the hash code is always in sync with
//an actual or transitional board position
if (oldSymbol == MARKED_GROUP || newSymbol == MARKED_GROUP || oldSymbol == MARKED_TERRITORY || newSymbol == MARKED_TERRITORY) {
board[index] = newSymbol;
} else {
int hashIndex = boardHashHistory.size() - 1;
long currentHashCode = boardHashHistory.get(hashIndex);
board[index] = newSymbol;
currentHashCode ^= zobristHashGenerator.getHashCode(index,
oldSymbol);
currentHashCode ^= zobristHashGenerator.getHashCode(index,
newSymbol);
boardHashHistory.set(hashIndex, currentHashCode);
}
}
public void setSymbolAt(int col, int row, char newSymbol) {
setSymbolAt((size - row - 1) * size + col, newSymbol);
} }
public void setTerritoryMarked(boolean territoryMarked) { public void setTerritoryMarked(boolean territoryMarked) {
@@ -201,10 +299,14 @@ public class GameBoard {
return; return;
} }
replaceSymbol(BLACK_TERRITORY,EMPTY_INTERSECTION); for (int i = 0; i < board.length; i++) {
replaceSymbol(WHITE_TERRITORY,EMPTY_INTERSECTION); char currentSymbol = board[i];
replaceSymbol(UNOWNED_TERRITORY,EMPTY_INTERSECTION); if (currentSymbol == BLACK_TERRITORY
|| currentSymbol == WHITE_TERRITORY
|| currentSymbol == UNOWNED_TERRITORY) {
board[i] = EMPTY_INTERSECTION;
}
}
territoryMarked = false; territoryMarked = false;
} }
} }

View File

@@ -3,12 +3,7 @@ package net.woodyfolsom.msproj;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.log4j.Logger;
public class GameState { public class GameState {
private static final Logger LOGGER = Logger.getLogger(GameState.class.getName());
private int blackPrisoners = 0; private int blackPrisoners = 0;
private int whitePrisoners = 0; private int whitePrisoners = 0;
private GameBoard gameBoard; private GameBoard gameBoard;
@@ -18,7 +13,6 @@ public class GameState {
throw new IllegalArgumentException("Invalid board size: " + size); throw new IllegalArgumentException("Invalid board size: " + size);
} }
gameBoard = new GameBoard(size); gameBoard = new GameBoard(size);
LOGGER.info("Created new GameBoard of size " + size);
} }
public GameState(GameState that) { public GameState(GameState that) {
@@ -42,7 +36,8 @@ public class GameState {
for (int colIndex = 0; colIndex < gameBoard.getSize(); colIndex++) { for (int colIndex = 0; colIndex < gameBoard.getSize(); colIndex++) {
for (int rowIndex = 0; rowIndex < gameBoard.getSize(); rowIndex++) { for (int rowIndex = 0; rowIndex < gameBoard.getSize(); rowIndex++) {
if (GameBoard.EMPTY_INTERSECTION == gameBoard.getSymbolAt(colIndex, rowIndex)); if (GameBoard.EMPTY_INTERSECTION == gameBoard.getSymbolAt(
colIndex, rowIndex))
emptyCoords.add(GameBoard.getCoordinate(colIndex, rowIndex)); emptyCoords.add(GameBoard.getCoordinate(colIndex, rowIndex));
} }
} }
@@ -59,19 +54,9 @@ public class GameState {
} }
public boolean playStone(Player player, String move) { public boolean playStone(Player player, String move) {
//Opponent passes? Just ignore it. return playStone(player, Action.getInstance(move));
Action action = Action.getInstance(move);
if (action.isPass()) {
return true;
} }
if (action.isNone()) {
return false;
}
return playStone(player, action);
}
/** /**
* Places a stone at the requested coordinate. Placement is legal if the * Places a stone at the requested coordinate. Placement is legal if the
* coordinate is currently empty, has at least one liberty (empty neighbor), * coordinate is currently empty, has at least one liberty (empty neighbor),
@@ -87,57 +72,82 @@ public class GameState {
* @return * @return
*/ */
public boolean playStone(Player player, Action action) { public boolean playStone(Player player, Action action) {
if (action == Action.PASS) {
if (player == Player.NONE) {
throw new IllegalArgumentException("Cannot play as " + player);
}
if (action.isPass()) {
return true; return true;
} }
char currentStone = gameBoard.getSymbolAt(action.getColumn(), action.getRow()); if (action.isNone()) {
return false;
}
char currentStone = gameBoard.getSymbolAt(action.getColumn(),
action.getRow());
if (currentStone != GameBoard.EMPTY_INTERSECTION) { if (currentStone != GameBoard.EMPTY_INTERSECTION) {
return false; return false;
} }
//Place stone as requested, then check for (1) captured neighbors and (2) illegal move due to 0 liberties. assertCorrectHash();
gameBoard.pushHashHistory();
gameBoard.clearCaptureList();
// Place stone as requested, then check for (1) captured neighbors and
// (2) illegal move due to 0 liberties.
char stoneSymbol = player.getStoneSymbol(); char stoneSymbol = player.getStoneSymbol();
//Player opponent = GoGame.getNextPlayer(player);
gameBoard.setSymbolAt(action.getColumn(), action.getRow(), stoneSymbol); gameBoard.setSymbolAt(action.getColumn(), action.getRow(), stoneSymbol);
// look for captured adjacent groups and increment the prisoner counter // look for captured adjacent groups and increment the prisoner counter
char opponentSymbol = GameBoard.getOpponentSymbol(player.getStoneSymbol()); char opponentSymbol = GoGame.getNextPlayer(player).getStoneSymbol();
int col = GameBoard.getColumnIndex(action.getColumn()); int col = GameBoard.getColumnIndex(action.getColumn());
int row = action.getRow() - 1; int row = action.getRow() - 1;
int prisonerCount = 0; int prisonerCount = 0;
if (col > 0 && gameBoard.getSymbolAt(col - 1, row) == opponentSymbol) { if (col > 0 && gameBoard.getSymbolAt(col - 1, row) == opponentSymbol) {
int liberties = LibertyCounter.countLiberties(gameBoard, col-1, row, opponentSymbol,true); int liberties = LibertyCounter.countLiberties(gameBoard, col - 1,
row, opponentSymbol, true);
if (liberties == 0) { if (liberties == 0) {
prisonerCount += gameBoard.replaceSymbol(GameBoard.MARKED_GROUP,GameBoard.EMPTY_INTERSECTION); prisonerCount += gameBoard.captureMarkedGroup(opponentSymbol);
} else { } else {
gameBoard.replaceSymbol(GameBoard.MARKED_GROUP, opponentSymbol); gameBoard.unmarkGroup(opponentSymbol);
} }
} }
if (col < gameBoard.getSize() - 1 && gameBoard.getSymbolAt(col+1, row) == opponentSymbol) { if (col < gameBoard.getSize() - 1
int liberties = LibertyCounter.countLiberties(gameBoard, col+1, row, opponentSymbol,true); && gameBoard.getSymbolAt(col + 1, row) == opponentSymbol) {
int liberties = LibertyCounter.countLiberties(gameBoard, col + 1,
row, opponentSymbol, true);
if (liberties == 0) { if (liberties == 0) {
prisonerCount += gameBoard.replaceSymbol(GameBoard.MARKED_GROUP,GameBoard.EMPTY_INTERSECTION); prisonerCount += gameBoard.captureMarkedGroup(opponentSymbol);
} else { } else {
gameBoard.replaceSymbol(GameBoard.MARKED_GROUP, opponentSymbol); gameBoard.unmarkGroup(opponentSymbol);
} }
} }
if (row > 0 && gameBoard.getSymbolAt(col, row - 1) == opponentSymbol) { if (row > 0 && gameBoard.getSymbolAt(col, row - 1) == opponentSymbol) {
int liberties = LibertyCounter.countLiberties(gameBoard, col, row-1, opponentSymbol,true); int liberties = LibertyCounter.countLiberties(gameBoard, col,
row - 1, opponentSymbol, true);
if (liberties == 0) { if (liberties == 0) {
prisonerCount += gameBoard.replaceSymbol(GameBoard.MARKED_GROUP,GameBoard.EMPTY_INTERSECTION); prisonerCount += gameBoard.captureMarkedGroup(opponentSymbol);
} else { } else {
gameBoard.replaceSymbol(GameBoard.MARKED_GROUP, opponentSymbol); gameBoard.unmarkGroup(opponentSymbol);
} }
} }
if (row < gameBoard.getSize() - 1 && gameBoard.getSymbolAt(col, row+1) == opponentSymbol) { if (row < gameBoard.getSize() - 1
int liberties = LibertyCounter.countLiberties(gameBoard, col, row+1, opponentSymbol,true); && gameBoard.getSymbolAt(col, row + 1) == opponentSymbol) {
int liberties = LibertyCounter.countLiberties(gameBoard, col,
row + 1, opponentSymbol, true);
if (liberties == 0) { if (liberties == 0) {
prisonerCount += gameBoard.replaceSymbol(GameBoard.MARKED_GROUP,GameBoard.EMPTY_INTERSECTION); prisonerCount += gameBoard.captureMarkedGroup(opponentSymbol);
} else { } else {
gameBoard.replaceSymbol(GameBoard.MARKED_GROUP, opponentSymbol); gameBoard.unmarkGroup(opponentSymbol);
} }
} }
@@ -147,13 +157,59 @@ public class GameState {
whitePrisoners += prisonerCount; whitePrisoners += prisonerCount;
} }
//Moved test for 0 liberties until after attempting to capture neighboring groups. // Moved test for 0 liberties until after attempting to capture
if (0 == LibertyCounter.countLiberties(gameBoard, action.getColumn(), action.getRow(), stoneSymbol)) { // neighboring groups.
// This will only happen if no neighboring groups were capture, hence there is nothing to undo.
// So return now.
if (0 == LibertyCounter.countLiberties(gameBoard, action.getColumn(),
action.getRow(), stoneSymbol)) {
gameBoard.removeStone(action.getColumn(), action.getRow()); gameBoard.removeStone(action.getColumn(), action.getRow());
gameBoard.popHashHistory();
return false; return false;
} }
// If this hashcode has already appeared, then probably Ko violation.
// TODO change this to a Map<Long, List<GameBoard>> and check for
// complete board equality
if (gameBoard.isKoViolation()) {
List<Integer> captureList = gameBoard.getCaptureList();
for (int i : captureList) {
gameBoard.setSymbolAt(i, opponentSymbol);
}
//And finally, remove the originally played stone, which was never valid due to ko.
gameBoard.removeStone(action.getColumn(), action.getRow());
gameBoard.clearCaptureList();
gameBoard.popHashHistory();
//assertCorrectHash();
return false;
} else {
//assertCorrectHash();
return true; return true;
} }
}
@Deprecated
private void assertCorrectHash() {
long hashFromHistory = gameBoard.getZobristHash();
int boardSize = gameBoard.getSize();
ZobristHashGenerator zhg = ZobristHashGenerator.getInstance(boardSize);
long recalculatedHash = zhg.getEmptyBoardHash();
for (int i = 0; i < boardSize * boardSize; i ++) {
recalculatedHash ^= zhg.getHashCode(i, GameBoard.EMPTY_INTERSECTION);
recalculatedHash ^= zhg.getHashCode(i, gameBoard.getSymbolAt(i));
}
if (hashFromHistory != recalculatedHash) {
throw new RuntimeException("Zobrist hash code mismatch");
}
}
public String toString() { public String toString() {
int boardSize = gameBoard.getSize(); int boardSize = gameBoard.getSize();

View File

@@ -158,8 +158,7 @@ public class GoGame {
DOMConfigurator.configure("log4j.xml"); DOMConfigurator.configure("log4j.xml");
} }
public static Player getColorToPlay(Player player, boolean playAsOpponent) { public static Player getNextPlayer(Player player) {
if (playAsOpponent) {
if (player == Player.WHITE) { if (player == Player.WHITE) {
return Player.BLACK; return Player.BLACK;
} else if (player == Player.BLACK) { } else if (player == Player.BLACK) {
@@ -167,8 +166,5 @@ public class GoGame {
} else { } else {
return Player.NONE; return Player.NONE;
} }
} else {
return player;
}
} }
} }

View File

@@ -7,9 +7,11 @@ public class LibertyCounter {
public static int countLiberties(GameBoard gameBoard, int col, int row, char groupColor, boolean markGroup) { public static int countLiberties(GameBoard gameBoard, int col, int row, char groupColor, boolean markGroup) {
int liberties = markGroup(gameBoard, col, row, groupColor); int liberties = markGroup(gameBoard, col, row, groupColor);
if (!markGroup) { if (!markGroup) {
gameBoard.replaceSymbol(GameBoard.MARKED_GROUP,groupColor); gameBoard.unmarkGroup(groupColor);
} }
return liberties; return liberties;
} }

View File

@@ -19,11 +19,11 @@ public class TerritoryMarker {
} }
int ownedBy = findTerritory(gameBoard,col,row); int ownedBy = findTerritory(gameBoard,col,row);
if (ownedBy == BLACK) { if (ownedBy == BLACK) {
gameBoard.replaceSymbol(TERRITORY_MARKER, BLACK_TERRITORY); gameBoard.markTerritory(BLACK_TERRITORY);
} else if (ownedBy == WHITE) { } else if (ownedBy == WHITE) {
gameBoard.replaceSymbol(TERRITORY_MARKER, WHITE_TERRITORY); gameBoard.markTerritory(WHITE_TERRITORY);
} else { } else {
gameBoard.replaceSymbol(TERRITORY_MARKER, UNOWNED_TERRITORY); gameBoard.markTerritory(UNOWNED_TERRITORY);
} }
} }
} }

View File

@@ -2,26 +2,61 @@ package net.woodyfolsom.msproj;
import java.math.BigInteger; import java.math.BigInteger;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.HashMap;
import java.util.Map;
public class ZobristHashGenerator { public class ZobristHashGenerator {
private static final Map<Integer, ZobristHashGenerator> zhgMap = new HashMap<Integer, ZobristHashGenerator>();
private long emptyBoardHash;
private long[] randomBitFields; private long[] randomBitFields;
private ZobristHashGenerator() { private ZobristHashGenerator(int boardSize) {
} // Fields are 0, BLACK, WHITE
public static ZobristHashGenerator getInstance(int boardSize) {
int nRandomFields = 3 * boardSize * boardSize; int nRandomFields = 3 * boardSize * boardSize;
ZobristHashGenerator zobHashGen = new ZobristHashGenerator();
SecureRandom secureRandom = new SecureRandom(); SecureRandom secureRandom = new SecureRandom();
zobHashGen.randomBitFields = new long[nRandomFields]; randomBitFields = new long[nRandomFields];
byte[] nextBytes = new byte[8]; byte[] nextBytes = new byte[8];
for (int i = 0; i < nRandomFields; i++) { for (int i = 0; i < nRandomFields; i++) {
secureRandom.nextBytes(nextBytes); secureRandom.nextBytes(nextBytes);
zobHashGen.randomBitFields[i] = new BigInteger(nextBytes) randomBitFields[i] = new BigInteger(nextBytes).longValue();
.longValue();
} }
emptyBoardHash = 0L;
for (int i = 0; i < randomBitFields.length / 3; i++) {
emptyBoardHash ^= randomBitFields[i * 3];
}
}
public static ZobristHashGenerator getInstance(int boardSize) {
if (!zhgMap.containsKey(boardSize)) {
ZobristHashGenerator zobHashGen = new ZobristHashGenerator(
boardSize);
// TODO add check for minimum hamming distance/colinearity check // TODO add check for minimum hamming distance/colinearity check
return zobHashGen; zhgMap.put(boardSize, zobHashGen);
}
return zhgMap.get(boardSize);
}
public long getEmptyBoardHash() {
return emptyBoardHash;
}
public long getHashCode(int index, char stoneType) {
switch (stoneType) {
case GameBoard.EMPTY_INTERSECTION:
return randomBitFields[index * 3];
case GameBoard.BLACK_STONE:
return randomBitFields[index * 3 + 1];
case GameBoard.WHITE_STONE:
return randomBitFields[index * 3 + 2];
default:
throw new IllegalArgumentException("No hash code for stone type: "
+ stoneType);
}
} }
} }

View File

@@ -86,7 +86,7 @@ public class AlphaBeta implements Policy {
node.addChild(nextMove, childNode); node.addChild(nextMove, childNode);
getMin(recursionLevels - 1, stateEvaluator, childNode, getMin(recursionLevels - 1, stateEvaluator, childNode,
GoGame.getColorToPlay(player, true)); GoGame.getNextPlayer(player));
double gameScore = childNode.getProperties().getReward(); double gameScore = childNode.getProperties().getReward();
@@ -145,7 +145,7 @@ public class AlphaBeta implements Policy {
node.addChild(nextMove, childNode); node.addChild(nextMove, childNode);
getMax(recursionLevels - 1, stateEvaluator, childNode, getMax(recursionLevels - 1, stateEvaluator, childNode,
GoGame.getColorToPlay(player, true)); GoGame.getNextPlayer(player));
double gameScore = childNode.getProperties().getReward(); double gameScore = childNode.getProperties().getReward();

View File

@@ -80,7 +80,7 @@ public class Minimax implements Policy {
node.addChild(nextMove, childNode); node.addChild(nextMove, childNode);
getMin(recursionLevels - 1, stateEvaluator, childNode, getMin(recursionLevels - 1, stateEvaluator, childNode,
GoGame.getColorToPlay(player, true)); GoGame.getNextPlayer(player));
double gameScore = childNode.getProperties().getReward(); double gameScore = childNode.getProperties().getReward();
@@ -126,7 +126,7 @@ public class Minimax implements Policy {
node.addChild(nextMove, childNode); node.addChild(nextMove, childNode);
getMax(recursionLevels - 1, stateEvaluator, childNode, getMax(recursionLevels - 1, stateEvaluator, childNode,
GoGame.getColorToPlay(player, true)); GoGame.getNextPlayer(player));
double gameScore = childNode.getProperties().getReward(); double gameScore = childNode.getProperties().getReward();

View File

@@ -51,7 +51,7 @@ public abstract class MonteCarlo implements Policy {
List<GameTreeNode<MonteCarloProperties>> selectedNodes = descend(rootNode); List<GameTreeNode<MonteCarloProperties>> selectedNodes = descend(rootNode);
List<GameTreeNode<MonteCarloProperties>> newLeaves = new ArrayList<GameTreeNode<MonteCarloProperties>>(); List<GameTreeNode<MonteCarloProperties>> newLeaves = new ArrayList<GameTreeNode<MonteCarloProperties>>();
Player nextPlayer = GoGame.getColorToPlay(player, true); Player nextPlayer = GoGame.getNextPlayer(player);
for (GameTreeNode<MonteCarloProperties> selectedNode: selectedNodes) { for (GameTreeNode<MonteCarloProperties> selectedNode: selectedNodes) {
for (GameTreeNode<MonteCarloProperties> newLeaf : grow(gameConfig, selectedNode, nextPlayer)) { for (GameTreeNode<MonteCarloProperties> newLeaf : grow(gameConfig, selectedNode, nextPlayer)) {
@@ -65,7 +65,9 @@ public abstract class MonteCarlo implements Policy {
} }
elapsedTime = System.currentTimeMillis() - startTime; elapsedTime = System.currentTimeMillis() - startTime;
} while (elapsedTime < searchTimeLimit); //} while (elapsedTime < searchTimeLimit);
//TODO: for debugging, temporarily specify the number of state evaluations rather than time limit
} while (numStateEvaluations < searchTimeLimit);
return getBestAction(rootNode); return getBestAction(rootNode);
} }

View File

@@ -105,10 +105,12 @@ public class MonteCarloUCT extends MonteCarlo {
Player currentPlayer = player; Player currentPlayer = player;
do { do {
rolloutDepth++; rolloutDepth++;
action = randomMovePolicy.getAction(gameConfig, node.getGameState(), player); action = randomMovePolicy.getAction(gameConfig, finalGameState, currentPlayer);
if (action != Action.NONE) { if (action != Action.NONE) {
finalGameState.playStone(currentPlayer, action); if (!finalGameState.playStone(currentPlayer, action)) {
currentPlayer = GoGame.getColorToPlay(currentPlayer, true); throw new RuntimeException("Failed to play move selected by RandomMovePolicy");
}
currentPlayer = GoGame.getNextPlayer(currentPlayer);
} }
} while (action != Action.NONE && rolloutDepth < ROLLOUT_DEPTH_LIMIT); } while (action != Action.NONE && rolloutDepth < ROLLOUT_DEPTH_LIMIT);

View File

@@ -0,0 +1,33 @@
package net.woodyfolsom.msproj;
import static org.junit.Assert.*;
import java.util.List;
import org.junit.Test;
public class GameStateTest {
@Test
public void testGetEmptyCoords() {
GameState gameState = new GameState(3);
gameState.playStone(Player.BLACK, "A1");
gameState.playStone(Player.WHITE, "A2");
gameState.playStone(Player.BLACK, "A3");
List<String> validMoves = gameState.getEmptyCoords();
assertFalse(validMoves.contains("A1"));
assertFalse(validMoves.contains("A2"));
assertFalse(validMoves.contains("A3"));
assertTrue(validMoves.contains("B1"));
assertTrue(validMoves.contains("B2"));
assertTrue(validMoves.contains("B3"));
assertTrue(validMoves.contains("C1"));
assertTrue(validMoves.contains("C2"));
assertTrue(validMoves.contains("C3"));
}
}

View File

@@ -1,8 +1,14 @@
package net.woodyfolsom.msproj; package net.woodyfolsom.msproj;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import java.util.List;
import net.woodyfolsom.msproj.policy.ActionGenerator;
import net.woodyfolsom.msproj.policy.ValidMoveGenerator;
import org.junit.Test; import org.junit.Test;
public class IllegalMoveTest { public class IllegalMoveTest {
@@ -63,4 +69,27 @@ public class IllegalMoveTest {
System.out.println(gameState); System.out.println(gameState);
assertFalse("Play by WHITE at J1 should have failed.",gameState.playStone(Player.WHITE, Action.getInstance("J1"))); assertFalse("Play by WHITE at J1 should have failed.",gameState.playStone(Player.WHITE, Action.getInstance("J1")));
} }
@Test
public void testIllegalMoveSuicide() {
GameState gameState = new GameState(3);
gameState.playStone(Player.WHITE, Action.getInstance("A1"));
gameState.playStone(Player.WHITE, Action.getInstance("B1"));
gameState.playStone(Player.WHITE, Action.getInstance("B2"));
gameState.playStone(Player.WHITE, Action.getInstance("A3"));
gameState.playStone(Player.WHITE, Action.getInstance("B3"));
System.out.println("State before move: ");
System.out.println(gameState);
assertFalse("Play by BLACK at A2 should have failed.",gameState.playStone(Player.BLACK, Action.getInstance("A2")));
List<Action> validMoves = new ValidMoveGenerator().getActions(new GameConfig(), gameState, Player.BLACK, ActionGenerator.ALL_ACTIONS);
assertEquals(4, validMoves.size());
assertTrue(validMoves.contains(Action.PASS));
assertTrue(validMoves.contains(Action.getInstance("C1")));
assertTrue(validMoves.contains(Action.getInstance("C2")));
assertTrue(validMoves.contains(Action.getInstance("C3")));
assertFalse(validMoves.contains(Action.getInstance("A2")));
}
} }

View File

@@ -14,4 +14,56 @@ public class LegalMoveTest {
assertTrue(gameState.playStone(Player.WHITE, Action.getInstance("B2"))); assertTrue(gameState.playStone(Player.WHITE, Action.getInstance("B2")));
System.out.println(gameState); System.out.println(gameState);
} }
@Test
public void testLegalMove2Liberties() {
//Unit test based on illegal move from 9x9 game using MonteCarloUCT
//Illegal move detected by gokgs.com server
GameState gameState = new GameState(9);
gameState.playStone(Player.BLACK, Action.getInstance("G5"));
gameState.playStone(Player.BLACK, Action.getInstance("G7"));
gameState.playStone(Player.BLACK, Action.getInstance("F6"));
gameState.playStone(Player.BLACK, Action.getInstance("H6"));
gameState.playStone(Player.BLACK, Action.getInstance("C7"));
gameState.playStone(Player.BLACK, Action.getInstance("D7"));
gameState.playStone(Player.BLACK, Action.getInstance("E7"));
gameState.playStone(Player.BLACK, Action.getInstance("F7"));
gameState.playStone(Player.BLACK, Action.getInstance("G8"));
gameState.playStone(Player.BLACK, Action.getInstance("H9"));
gameState.playStone(Player.BLACK, Action.getInstance("J7"));
gameState.playStone(Player.BLACK, Action.getInstance("E5"));
gameState.playStone(Player.BLACK, Action.getInstance("F4"));
gameState.playStone(Player.BLACK, Action.getInstance("G3"));
gameState.playStone(Player.BLACK, Action.getInstance("D4"));
gameState.playStone(Player.BLACK, Action.getInstance("E3"));
gameState.playStone(Player.BLACK, Action.getInstance("B4"));
gameState.playStone(Player.BLACK, Action.getInstance("C3"));
gameState.playStone(Player.BLACK, Action.getInstance("D2"));
gameState.playStone(Player.BLACK, Action.getInstance("E1"));
gameState.playStone(Player.WHITE, Action.getInstance("H8"));
gameState.playStone(Player.WHITE, Action.getInstance("H7"));
gameState.playStone(Player.WHITE, Action.getInstance("D9"));
gameState.playStone(Player.WHITE, Action.getInstance("D8"));
gameState.playStone(Player.WHITE, Action.getInstance("E8"));
gameState.playStone(Player.WHITE, Action.getInstance("A7"));
gameState.playStone(Player.WHITE, Action.getInstance("A6"));
gameState.playStone(Player.WHITE, Action.getInstance("B8"));
gameState.playStone(Player.WHITE, Action.getInstance("B7"));
gameState.playStone(Player.WHITE, Action.getInstance("B6"));
gameState.playStone(Player.WHITE, Action.getInstance("C5"));
gameState.playStone(Player.WHITE, Action.getInstance("D5"));
gameState.playStone(Player.WHITE, Action.getInstance("D6"));
gameState.playStone(Player.WHITE, Action.getInstance("A3"));
gameState.playStone(Player.WHITE, Action.getInstance("B3"));
gameState.playStone(Player.WHITE, Action.getInstance("B1"));
gameState.playStone(Player.WHITE, Action.getInstance("F1"));
System.out.println("State before move: ");
System.out.println(gameState);
assertTrue("Play by WHITE at H5 should not have failed.",gameState.playStone(Player.WHITE, Action.getInstance("H5")));
}
} }

View File

@@ -1,6 +1,8 @@
package net.woodyfolsom.msproj.policy; package net.woodyfolsom.msproj.policy;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@@ -35,21 +37,14 @@ public class RandomTest {
gameState.playStone(Player.BLACK, Action.getInstance("A4")); gameState.playStone(Player.BLACK, Action.getInstance("A4"));
gameState.playStone(Player.BLACK, Action.getInstance("B1"));; gameState.playStone(Player.BLACK, Action.getInstance("B1"));;
gameState.playStone(Player.BLACK, Action.getInstance("B2")); gameState.playStone(Player.BLACK, Action.getInstance("B2"));
//gameState.playStone('B', 3, GameBoard.BLACK_STONE);
gameState.playStone(Player.BLACK, Action.getInstance("B4")); gameState.playStone(Player.BLACK, Action.getInstance("B4"));
gameState.playStone(Player.BLACK, Action.getInstance("C2")); gameState.playStone(Player.BLACK, Action.getInstance("C2"));
gameState.playStone(Player.BLACK, Action.getInstance("C3")); gameState.playStone(Player.BLACK, Action.getInstance("C3"));
gameState.playStone(Player.BLACK, Action.getInstance("C4")); gameState.playStone(Player.BLACK, Action.getInstance("C4"));
gameState.playStone(Player.BLACK, Action.getInstance("D4")); gameState.playStone(Player.BLACK, Action.getInstance("D4"));
gameState.playStone(Player.WHITE, Action.getInstance("C1")); gameState.playStone(Player.WHITE, Action.getInstance("C1"));
gameState.playStone(Player.WHITE, Action.getInstance("D2")); gameState.playStone(Player.WHITE, Action.getInstance("D2"));
gameState.playStone(Player.WHITE, Action.getInstance("D3")); gameState.playStone(Player.WHITE, Action.getInstance("D3"));
System.out.println("State before random WHITE move selection:"); System.out.println("State before random WHITE move selection:");
System.out.println(gameState); System.out.println(gameState);
//This is correct - checked vs. MFOG //This is correct - checked vs. MFOG
@@ -61,4 +56,55 @@ public class RandomTest {
System.out.println(gameState); System.out.println(gameState);
} }
@Test
public void testIllegalMoveSuicide() {
GameState gameState = new GameState(3);
gameState.playStone(Player.WHITE, Action.getInstance("A1"));
gameState.playStone(Player.WHITE, Action.getInstance("B1"));
gameState.playStone(Player.WHITE, Action.getInstance("B2"));
gameState.playStone(Player.WHITE, Action.getInstance("A3"));
gameState.playStone(Player.WHITE, Action.getInstance("B3"));
System.out.println("State before move: ");
System.out.println(gameState);
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
//There is only a minute chance (5E-7) that RandomMoveGenerator fails to return an invalid move with probability 1/4
//after 50 calls, if this bug recurs.
for (int i = 0; i < 50; i++) {
Action action = randomMovePolicy.getAction(new GameConfig(),gameState,Player.BLACK);
//System.out.println(action);
assertFalse("RandomMovePolicy returned illegal suicide move A2",action.equals(Action.getInstance("A2")));
}
}
@Test
public void testIllegalMoveKo() {
GameState gameState = new GameState(4);
gameState.playStone(Player.WHITE, Action.getInstance("B1"));
gameState.playStone(Player.WHITE, Action.getInstance("A2"));
gameState.playStone(Player.WHITE, Action.getInstance("C2"));
gameState.playStone(Player.WHITE, Action.getInstance("B3"));
gameState.playStone(Player.BLACK, Action.getInstance("A3"));
gameState.playStone(Player.BLACK, Action.getInstance("C3"));
gameState.playStone(Player.BLACK, Action.getInstance("B4"));
System.out.println("State before move: ");
System.out.println(gameState);
assertTrue(gameState.playStone(Player.BLACK, Action.getInstance("B2")));
System.out.println("State after move: ");
System.out.println(gameState);
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
//Test that after 50 moves, the policy never returns B3, which would be a Ko violation
for (int i = 0; i < 50; i++) {
Action action = randomMovePolicy.getAction(new GameConfig(),gameState,Player.WHITE);
//System.out.println(action);
assertFalse(action.equals(Action.NONE));
assertFalse("RandomMovePolicy returned Ko violation move B3",action.equals(Action.getInstance("B3")));
}
}
} }