Fixed Minimax search to use the new GameTreeNode, MinimaxProperty classes.

The previous implementation was overly complicated and may have been buggy except when searching only 2 plies ahead.
This commit is contained in:
cs6601
2012-08-30 10:51:04 -04:00
parent 2e40440838
commit 4a1c64843d
12 changed files with 249 additions and 153 deletions

View File

@@ -7,6 +7,10 @@ public class StateEvaluator {
this.gameConfig = gameConfig;
}
public GameConfig getGameConfig() {
return gameConfig;
}
public GameScore scoreGame(GameState gameState) {
GameBoard gameBoard;
if (gameState.getGameBoard().isTerritoryMarked()) {

View File

@@ -1,69 +0,0 @@
package net.woodyfolsom.msproj.policy;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameState;
public class GameTreeNode {
private GameState gameState;
private GameTreeNode parent;
private int numVisits;
private int numWins;
private Map<Action, GameTreeNode> children = new HashMap<Action, GameTreeNode>();
public GameTreeNode(GameState gameState) {
this.gameState = gameState;
}
public void addChild(Action action, GameTreeNode child) {
children.put(action, child);
child.parent = this;
}
public Set<Action> getActions() {
return children.keySet();
}
public GameTreeNode getChild(Action action) {
return children.get(action);
}
public int getNumChildren() {
return children.size();
}
public GameState getGameState() {
return gameState;
}
public int getNumVisits() {
return numVisits;
}
public int getNumWins() {
return numWins;
}
public GameTreeNode getParent() {
return parent;
}
public boolean isRoot() {
return parent == null;
}
public boolean isTerminal() {
return children.size() == 0;
}
public void incrementVisits() {
numVisits++;
}
public void incrementWins() {
numWins++;
}
}

View File

@@ -1,77 +1,131 @@
package net.woodyfolsom.msproj.policy;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameScore;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.GoGame;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.StateEvaluator;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MinimaxProperties;
public class Minimax implements Policy {
private static final int DEFAULT_RECURSIVE_PLAYS = 1;
private static final int DEFAULT_LOOKAHEAD = 1;
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private int lookAhead;
public Minimax() {
this(DEFAULT_LOOKAHEAD);
}
public Minimax(int lookAhead) {
this.lookAhead = lookAhead;
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Player color) {
MoveCandidate moveCandidate = findBestMinimaxResult(
DEFAULT_RECURSIVE_PLAYS * 2,
gameConfig, gameState, color, false, Action.PASS);
return moveCandidate.move;
}
private MoveCandidate findBestMinimaxResult(int recursionLevels,
GameConfig gameConfig, GameState gameState,
Player initialColor, boolean playAsOpponent, Action bestPrevMove) {
Player player) {
StateEvaluator stateEvaluator = new StateEvaluator(gameConfig);
List<MoveCandidate> randomMoveCandidates = new ArrayList<MoveCandidate>();
Player colorPlaying = GoGame.getColorToPlay(initialColor, playAsOpponent);
GameTreeNode<MinimaxProperties> rootNode = new GameTreeNode<MinimaxProperties>(gameState, new MinimaxProperties());
List<Action> validMoves = validMoveGenerator.getActions(gameConfig,
gameState, colorPlaying, ActionGenerator.ALL_ACTIONS);
for (Action randomMove : validMoves) {
GameState stateCopy = new GameState(gameState);
stateCopy.playStone(colorPlaying, randomMove);
if (recursionLevels > 1) {
randomMoveCandidates.add(findBestMinimaxResult(recursionLevels - 1,
gameConfig, stateCopy, initialColor,
!playAsOpponent, randomMove));
} else {
GameScore score = stateEvaluator.scoreGame(stateCopy);
randomMoveCandidates.add(new MoveCandidate(randomMove, score));
}
}
// TODO use a sorted list and just return the last element
MoveCandidate bestMove = randomMoveCandidates.get(0);
double bestScoreSoFar = bestMove.score.getScore(colorPlaying);
for (MoveCandidate moveCandidate : randomMoveCandidates) {
if (moveCandidate.score.getScore(colorPlaying) > bestScoreSoFar) {
bestMove = moveCandidate;
bestScoreSoFar = moveCandidate.score.getScore(colorPlaying);
}
}
// Fix to prevent thinking that the _opponent's_ best move is the move
// to make.
// If evaluating an opponent's move, the best move (for my opponent) is
// my previous move which gives the opponent the highest score.
// This should only happen if recursionLevels is initially odd.
if (playAsOpponent) {
return new MoveCandidate(bestPrevMove, bestMove.score);
} else { // if evaluating my own move, the move which gives me the
// highest score is the best.
return bestMove;
if (player == Player.BLACK) {
return getMax(
lookAhead * 2,
stateEvaluator,
rootNode,
player);
} else {
return getMin(
lookAhead * 2,
stateEvaluator,
rootNode,
player);
}
}
private Action getMax(int recursionLevels,
StateEvaluator stateEvaluator,
GameTreeNode<MinimaxProperties> node,
Player player) {
GameState gameState = new GameState(node.getGameState());
List<Action> validMoves = validMoveGenerator.getActions(stateEvaluator.getGameConfig(),
node.getGameState(), player, ActionGenerator.ALL_ACTIONS);
for (Action nextMove : validMoves) {
GameState nextState = new GameState(gameState);
nextState.playStone(player, nextMove);
GameTreeNode<MinimaxProperties> childNode = new GameTreeNode<MinimaxProperties>(nextState, new MinimaxProperties());
node.addChild(nextMove, childNode);
if (recursionLevels > 1) {
getMin(recursionLevels - 1, stateEvaluator, childNode, GoGame.getColorToPlay(player, true));
} else {
//tail condition - set reward of this leaf node
childNode.getProperties().setReward(stateEvaluator.scoreGame(nextState).getAggregateScore());
}
}
double maxScore = Double.NEGATIVE_INFINITY;
Action bestAction = Action.NONE;
for (Action nextMove : validMoves) {
GameTreeNode<MinimaxProperties> childNode = node.getChild(nextMove);
double gameScore = childNode.getProperties().getReward();
if (gameScore > maxScore) {
maxScore = gameScore;
bestAction = nextMove;
}
}
node.getProperties().setReward(maxScore);
return bestAction;
}
private Action getMin(int recursionLevels,
StateEvaluator stateEvaluator,
GameTreeNode<MinimaxProperties> node,
Player player) {
GameState gameState = new GameState(node.getGameState());
List<Action> validMoves = validMoveGenerator.getActions(stateEvaluator.getGameConfig(),
node.getGameState(), player, ActionGenerator.ALL_ACTIONS);
for (Action nextMove : validMoves) {
GameState nextState = new GameState(gameState);
nextState.playStone(player, nextMove);
GameTreeNode<MinimaxProperties> childNode = new GameTreeNode<MinimaxProperties>(nextState, new MinimaxProperties());
node.addChild(nextMove, childNode);
if (recursionLevels > 1) {
getMax(recursionLevels - 1, stateEvaluator, childNode, GoGame.getColorToPlay(player, true));
} else {
//tail condition - set reward of this leaf node
childNode.getProperties().setReward(stateEvaluator.scoreGame(nextState).getAggregateScore());
}
}
double minScore = Double.POSITIVE_INFINITY;
Action bestAction = Action.NONE;
for (Action nextMove : validMoves) {
GameTreeNode<MinimaxProperties> childNode = node.getChild(nextMove);
double gameScore = childNode.getProperties().getReward();
if (gameScore < minScore) {
minScore = gameScore;
bestAction = nextMove;
}
}
node.getProperties().setReward(minScore);
return bestAction;
}
}

View File

@@ -7,6 +7,8 @@ import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public abstract class MonteCarlo implements Policy {
protected Policy movePolicy;
@@ -24,7 +26,7 @@ public abstract class MonteCarlo implements Policy {
* @param node
* @return
*/
public abstract List<GameTreeNode> descend(GameTreeNode node);
public abstract List<GameTreeNode<MonteCarloProperties>> descend(GameTreeNode<MonteCarloProperties> node);
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
@@ -35,20 +37,20 @@ public abstract class MonteCarlo implements Policy {
//Note that this may lose the game by forfeit even when picking any random move could
//result in a win.
GameTreeNode rootNode = new GameTreeNode(gameState);
GameTreeNode<MonteCarloProperties> rootNode = new GameTreeNode<MonteCarloProperties>(gameState, new MonteCarloProperties());
do {
//TODO these return types may need to be lists for some MC methods
List<GameTreeNode> selectedNodes = descend(rootNode);
List<GameTreeNode> newLeaves = new ArrayList<GameTreeNode>();
List<GameTreeNode<MonteCarloProperties>> selectedNodes = descend(rootNode);
List<GameTreeNode<MonteCarloProperties>> newLeaves = new ArrayList<GameTreeNode<MonteCarloProperties>>();
for (GameTreeNode selectedNode: selectedNodes) {
for (GameTreeNode newLeaf : grow(selectedNode)) {
for (GameTreeNode<MonteCarloProperties> selectedNode: selectedNodes) {
for (GameTreeNode<MonteCarloProperties> newLeaf : grow(selectedNode)) {
newLeaves.add(newLeaf);
}
}
for (GameTreeNode newLeaf : newLeaves) {
for (GameTreeNode<MonteCarloProperties> newLeaf : newLeaves) {
int reward = rollout(newLeaf);
update(newLeaf, reward);
}
@@ -63,13 +65,13 @@ public abstract class MonteCarlo implements Policy {
return elapsedTime;
}
public abstract Action getBestAction(GameTreeNode node);
public abstract Action getBestAction(GameTreeNode<MonteCarloProperties> node);
public abstract List<GameTreeNode> grow(GameTreeNode node);
public abstract List<GameTreeNode<MonteCarloProperties>> grow(GameTreeNode<MonteCarloProperties> node);
public abstract int rollout(GameTreeNode node);
public abstract int rollout(GameTreeNode<MonteCarloProperties> node);
public abstract void update(GameTreeNode node, int reward);
public abstract void update(GameTreeNode<MonteCarloProperties> node, int reward);
public long getSearchTimeLimit() {
return searchTimeLimit;

View File

@@ -4,6 +4,8 @@ import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class MonteCarloUCT extends MonteCarlo {
@@ -12,17 +14,20 @@ public class MonteCarloUCT extends MonteCarlo {
}
@Override
public List<GameTreeNode> descend(GameTreeNode node) {
double bestScore = (double) node.getNumWins() / node.getNumVisits();
GameTreeNode bestNode = node;
public List<GameTreeNode<MonteCarloProperties>> descend(GameTreeNode<MonteCarloProperties> node) {
double bestScore = Double.NEGATIVE_INFINITY;
GameTreeNode<MonteCarloProperties> bestNode = node;
//This appears slightly redundant with getBestAction() but it is not -
//descend() may pick the current node rather than a child to expand (if a child has a good score but high/low uncertainty)
//but getBestAction specifically asks for the optimum action to take from the current node,
//even if it results in a worse next state.
for (Action action : node.getActions()) {
GameTreeNode childNode = node.getChild(action);
double childScore = (double) childNode.getNumWins() / childNode.getNumVisits();
GameTreeNode<MonteCarloProperties> childNode = node.getChild(action);
MonteCarloProperties properties = childNode.getProperties();
double childScore = (double) properties.getWins() / properties.getVisits();
if (childScore >= bestScore) {
bestScore = childScore;
bestNode = childNode;
@@ -30,7 +35,7 @@ public class MonteCarloUCT extends MonteCarlo {
}
if (bestNode == node) {
List<GameTreeNode> bestNodeList = new ArrayList<GameTreeNode>();
List<GameTreeNode<MonteCarloProperties>> bestNodeList = new ArrayList<GameTreeNode<MonteCarloProperties>>();
bestNodeList.add(bestNode);
return bestNodeList;
} else {
@@ -39,13 +44,16 @@ public class MonteCarloUCT extends MonteCarlo {
}
@Override
public Action getBestAction(GameTreeNode node) {
public Action getBestAction(GameTreeNode<MonteCarloProperties> node) {
Action bestAction = Action.NONE;
double bestScore = Double.NEGATIVE_INFINITY;
for (Action action : node.getActions()) {
GameTreeNode childNode = node.getChild(action);
double childScore = (double) childNode.getNumWins() / childNode.getNumVisits();
GameTreeNode<MonteCarloProperties> childNode = node.getChild(action);
MonteCarloProperties properties = childNode.getProperties();
double childScore = (double) properties.getWins() / properties.getVisits();
if (childScore >= bestScore) {
bestScore = childScore;
bestAction = action;
@@ -56,19 +64,19 @@ public class MonteCarloUCT extends MonteCarlo {
}
@Override
public List<GameTreeNode> grow(GameTreeNode node) {
public List<GameTreeNode<MonteCarloProperties>> grow(GameTreeNode<MonteCarloProperties> node) {
// TODO Auto-generated method stub
return null;
}
@Override
public int rollout(GameTreeNode node) {
public int rollout(GameTreeNode<MonteCarloProperties> node) {
// TODO Auto-generated method stub
return 0;
}
@Override
public void update(GameTreeNode node, int reward) {
public void update(GameTreeNode<MonteCarloProperties> node, int reward) {
// TODO Auto-generated method stub
}

View File

@@ -1,14 +0,0 @@
package net.woodyfolsom.msproj.policy;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameScore;
public class MoveCandidate {
public final Action move;
public final GameScore score;
public MoveCandidate(Action move, GameScore score) {
this.move = move;
this.score = score;
}
}

View File

@@ -0,0 +1,18 @@
package net.woodyfolsom.msproj.tree;
public class AlphaBetaPropeties extends GameTreeNodeProperties{
int alpha = 0;
int beta = 0;
public int getAlpha() {
return alpha;
}
public void setAlpha(int alpha) {
this.alpha = alpha;
}
public int getBeta() {
return beta;
}
public void setBeta(int beta) {
this.beta = beta;
}
}

View File

@@ -0,0 +1,57 @@
package net.woodyfolsom.msproj.tree;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameState;
public class GameTreeNode<T extends GameTreeNodeProperties> {
private GameState gameState;
private GameTreeNode<T> parent;
private Map<Action, GameTreeNode<T>> children = new HashMap<Action, GameTreeNode<T>>();
private T properties;
public GameTreeNode(GameState gameState, T properties) {
this.gameState = gameState;
this.properties = properties;
}
public void addChild(Action action, GameTreeNode<T> child) {
children.put(action, child);
child.parent = this;
}
public Set<Action> getActions() {
return children.keySet();
}
public GameTreeNode<T> getChild(Action action) {
return children.get(action);
}
public int getNumChildren() {
return children.size();
}
public GameState getGameState() {
return gameState;
}
public GameTreeNode<T> getParent() {
return parent;
}
public T getProperties() {
return properties;
}
public boolean isRoot() {
return parent == null;
}
public boolean isTerminal() {
return children.size() == 0;
}
}

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.tree;
public abstract class GameTreeNodeProperties {
}

View File

@@ -0,0 +1,13 @@
package net.woodyfolsom.msproj.tree;
public class MinimaxProperties extends GameTreeNodeProperties {
private double reward = 0.0;
public double getReward() {
return reward;
}
public void setReward(double reward) {
this.reward = reward;
}
}

View File

@@ -0,0 +1,18 @@
package net.woodyfolsom.msproj.tree;
public class MonteCarloProperties extends GameTreeNodeProperties {
int visits = 0;
int wins = 0;
public int getVisits() {
return visits;
}
public void setVisits(int visits) {
this.visits = visits;
}
public int getWins() {
return wins;
}
public void setWins(int wins) {
this.wins = wins;
}
}

View File

@@ -24,7 +24,7 @@ public class MinimaxTest {
System.out.println(gameState);
System.out.println("Generated move: " + move);
assertEquals("Expected B3 but was: " + move, "B3", move);
assertEquals("Expected B3 but was: " + move, Action.getInstance("B3"), move);
gameState.playStone(Player.WHITE, move);
System.out.println(gameState);
@@ -45,7 +45,7 @@ public class MinimaxTest {
System.out.println(gameState);
System.out.println("Generated move: " + move);
assertEquals("Expected B3 but was: " + move, "B3", move);
assertEquals("Expected B3 but was: " + move, Action.getInstance("B3"), move);
gameState.playStone(Player.BLACK, move);
System.out.println(gameState);