Fixed Minimax search to use the new GameTreeNode, MinimaxProperty classes.

The previous implementation was overly complicated and may have been buggy except when searching only 2 plies ahead.
This commit is contained in:
cs6601
2012-08-30 10:51:04 -04:00
parent 2e40440838
commit 4a1c64843d
12 changed files with 249 additions and 153 deletions

View File

@@ -7,6 +7,10 @@ public class StateEvaluator {
this.gameConfig = gameConfig; this.gameConfig = gameConfig;
} }
public GameConfig getGameConfig() {
return gameConfig;
}
public GameScore scoreGame(GameState gameState) { public GameScore scoreGame(GameState gameState) {
GameBoard gameBoard; GameBoard gameBoard;
if (gameState.getGameBoard().isTerritoryMarked()) { if (gameState.getGameBoard().isTerritoryMarked()) {

View File

@@ -1,69 +0,0 @@
package net.woodyfolsom.msproj.policy;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameState;
public class GameTreeNode {
private GameState gameState;
private GameTreeNode parent;
private int numVisits;
private int numWins;
private Map<Action, GameTreeNode> children = new HashMap<Action, GameTreeNode>();
public GameTreeNode(GameState gameState) {
this.gameState = gameState;
}
public void addChild(Action action, GameTreeNode child) {
children.put(action, child);
child.parent = this;
}
public Set<Action> getActions() {
return children.keySet();
}
public GameTreeNode getChild(Action action) {
return children.get(action);
}
public int getNumChildren() {
return children.size();
}
public GameState getGameState() {
return gameState;
}
public int getNumVisits() {
return numVisits;
}
public int getNumWins() {
return numWins;
}
public GameTreeNode getParent() {
return parent;
}
public boolean isRoot() {
return parent == null;
}
public boolean isTerminal() {
return children.size() == 0;
}
public void incrementVisits() {
numVisits++;
}
public void incrementWins() {
numWins++;
}
}

View File

@@ -1,77 +1,131 @@
package net.woodyfolsom.msproj.policy; package net.woodyfolsom.msproj.policy;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.Action; import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig; import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameScore;
import net.woodyfolsom.msproj.GameState; import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.GoGame; import net.woodyfolsom.msproj.GoGame;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.StateEvaluator; import net.woodyfolsom.msproj.StateEvaluator;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MinimaxProperties;
public class Minimax implements Policy { public class Minimax implements Policy {
private static final int DEFAULT_RECURSIVE_PLAYS = 1; private static final int DEFAULT_LOOKAHEAD = 1;
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private int lookAhead;
public Minimax() {
this(DEFAULT_LOOKAHEAD);
}
public Minimax(int lookAhead) {
this.lookAhead = lookAhead;
}
@Override @Override
public Action getAction(GameConfig gameConfig, GameState gameState, public Action getAction(GameConfig gameConfig, GameState gameState,
Player color) { Player player) {
MoveCandidate moveCandidate = findBestMinimaxResult(
DEFAULT_RECURSIVE_PLAYS * 2,
gameConfig, gameState, color, false, Action.PASS);
return moveCandidate.move;
}
private MoveCandidate findBestMinimaxResult(int recursionLevels,
GameConfig gameConfig, GameState gameState,
Player initialColor, boolean playAsOpponent, Action bestPrevMove) {
StateEvaluator stateEvaluator = new StateEvaluator(gameConfig); StateEvaluator stateEvaluator = new StateEvaluator(gameConfig);
List<MoveCandidate> randomMoveCandidates = new ArrayList<MoveCandidate>();
Player colorPlaying = GoGame.getColorToPlay(initialColor, playAsOpponent); GameTreeNode<MinimaxProperties> rootNode = new GameTreeNode<MinimaxProperties>(gameState, new MinimaxProperties());
List<Action> validMoves = validMoveGenerator.getActions(gameConfig, if (player == Player.BLACK) {
gameState, colorPlaying, ActionGenerator.ALL_ACTIONS); return getMax(
lookAhead * 2,
for (Action randomMove : validMoves) { stateEvaluator,
GameState stateCopy = new GameState(gameState); rootNode,
stateCopy.playStone(colorPlaying, randomMove); player);
if (recursionLevels > 1) { } else {
randomMoveCandidates.add(findBestMinimaxResult(recursionLevels - 1, return getMin(
gameConfig, stateCopy, initialColor, lookAhead * 2,
!playAsOpponent, randomMove)); stateEvaluator,
} else { rootNode,
GameScore score = stateEvaluator.scoreGame(stateCopy); player);
randomMoveCandidates.add(new MoveCandidate(randomMove, score));
}
}
// TODO use a sorted list and just return the last element
MoveCandidate bestMove = randomMoveCandidates.get(0);
double bestScoreSoFar = bestMove.score.getScore(colorPlaying);
for (MoveCandidate moveCandidate : randomMoveCandidates) {
if (moveCandidate.score.getScore(colorPlaying) > bestScoreSoFar) {
bestMove = moveCandidate;
bestScoreSoFar = moveCandidate.score.getScore(colorPlaying);
}
}
// Fix to prevent thinking that the _opponent's_ best move is the move
// to make.
// If evaluating an opponent's move, the best move (for my opponent) is
// my previous move which gives the opponent the highest score.
// This should only happen if recursionLevels is initially odd.
if (playAsOpponent) {
return new MoveCandidate(bestPrevMove, bestMove.score);
} else { // if evaluating my own move, the move which gives me the
// highest score is the best.
return bestMove;
} }
} }
private Action getMax(int recursionLevels,
StateEvaluator stateEvaluator,
GameTreeNode<MinimaxProperties> node,
Player player) {
GameState gameState = new GameState(node.getGameState());
List<Action> validMoves = validMoveGenerator.getActions(stateEvaluator.getGameConfig(),
node.getGameState(), player, ActionGenerator.ALL_ACTIONS);
for (Action nextMove : validMoves) {
GameState nextState = new GameState(gameState);
nextState.playStone(player, nextMove);
GameTreeNode<MinimaxProperties> childNode = new GameTreeNode<MinimaxProperties>(nextState, new MinimaxProperties());
node.addChild(nextMove, childNode);
if (recursionLevels > 1) {
getMin(recursionLevels - 1, stateEvaluator, childNode, GoGame.getColorToPlay(player, true));
} else {
//tail condition - set reward of this leaf node
childNode.getProperties().setReward(stateEvaluator.scoreGame(nextState).getAggregateScore());
}
}
double maxScore = Double.NEGATIVE_INFINITY;
Action bestAction = Action.NONE;
for (Action nextMove : validMoves) {
GameTreeNode<MinimaxProperties> childNode = node.getChild(nextMove);
double gameScore = childNode.getProperties().getReward();
if (gameScore > maxScore) {
maxScore = gameScore;
bestAction = nextMove;
}
}
node.getProperties().setReward(maxScore);
return bestAction;
}
private Action getMin(int recursionLevels,
StateEvaluator stateEvaluator,
GameTreeNode<MinimaxProperties> node,
Player player) {
GameState gameState = new GameState(node.getGameState());
List<Action> validMoves = validMoveGenerator.getActions(stateEvaluator.getGameConfig(),
node.getGameState(), player, ActionGenerator.ALL_ACTIONS);
for (Action nextMove : validMoves) {
GameState nextState = new GameState(gameState);
nextState.playStone(player, nextMove);
GameTreeNode<MinimaxProperties> childNode = new GameTreeNode<MinimaxProperties>(nextState, new MinimaxProperties());
node.addChild(nextMove, childNode);
if (recursionLevels > 1) {
getMax(recursionLevels - 1, stateEvaluator, childNode, GoGame.getColorToPlay(player, true));
} else {
//tail condition - set reward of this leaf node
childNode.getProperties().setReward(stateEvaluator.scoreGame(nextState).getAggregateScore());
}
}
double minScore = Double.POSITIVE_INFINITY;
Action bestAction = Action.NONE;
for (Action nextMove : validMoves) {
GameTreeNode<MinimaxProperties> childNode = node.getChild(nextMove);
double gameScore = childNode.getProperties().getReward();
if (gameScore < minScore) {
minScore = gameScore;
bestAction = nextMove;
}
}
node.getProperties().setReward(minScore);
return bestAction;
}
} }

View File

@@ -7,6 +7,8 @@ import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig; import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameState; import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public abstract class MonteCarlo implements Policy { public abstract class MonteCarlo implements Policy {
protected Policy movePolicy; protected Policy movePolicy;
@@ -24,7 +26,7 @@ public abstract class MonteCarlo implements Policy {
* @param node * @param node
* @return * @return
*/ */
public abstract List<GameTreeNode> descend(GameTreeNode node); public abstract List<GameTreeNode<MonteCarloProperties>> descend(GameTreeNode<MonteCarloProperties> node);
@Override @Override
public Action getAction(GameConfig gameConfig, GameState gameState, public Action getAction(GameConfig gameConfig, GameState gameState,
@@ -35,20 +37,20 @@ public abstract class MonteCarlo implements Policy {
//Note that this may lose the game by forfeit even when picking any random move could //Note that this may lose the game by forfeit even when picking any random move could
//result in a win. //result in a win.
GameTreeNode rootNode = new GameTreeNode(gameState); GameTreeNode<MonteCarloProperties> rootNode = new GameTreeNode<MonteCarloProperties>(gameState, new MonteCarloProperties());
do { do {
//TODO these return types may need to be lists for some MC methods //TODO these return types may need to be lists for some MC methods
List<GameTreeNode> selectedNodes = descend(rootNode); List<GameTreeNode<MonteCarloProperties>> selectedNodes = descend(rootNode);
List<GameTreeNode> newLeaves = new ArrayList<GameTreeNode>(); List<GameTreeNode<MonteCarloProperties>> newLeaves = new ArrayList<GameTreeNode<MonteCarloProperties>>();
for (GameTreeNode selectedNode: selectedNodes) { for (GameTreeNode<MonteCarloProperties> selectedNode: selectedNodes) {
for (GameTreeNode newLeaf : grow(selectedNode)) { for (GameTreeNode<MonteCarloProperties> newLeaf : grow(selectedNode)) {
newLeaves.add(newLeaf); newLeaves.add(newLeaf);
} }
} }
for (GameTreeNode newLeaf : newLeaves) { for (GameTreeNode<MonteCarloProperties> newLeaf : newLeaves) {
int reward = rollout(newLeaf); int reward = rollout(newLeaf);
update(newLeaf, reward); update(newLeaf, reward);
} }
@@ -63,13 +65,13 @@ public abstract class MonteCarlo implements Policy {
return elapsedTime; return elapsedTime;
} }
public abstract Action getBestAction(GameTreeNode node); public abstract Action getBestAction(GameTreeNode<MonteCarloProperties> node);
public abstract List<GameTreeNode> grow(GameTreeNode node); public abstract List<GameTreeNode<MonteCarloProperties>> grow(GameTreeNode<MonteCarloProperties> node);
public abstract int rollout(GameTreeNode node); public abstract int rollout(GameTreeNode<MonteCarloProperties> node);
public abstract void update(GameTreeNode node, int reward); public abstract void update(GameTreeNode<MonteCarloProperties> node, int reward);
public long getSearchTimeLimit() { public long getSearchTimeLimit() {
return searchTimeLimit; return searchTimeLimit;

View File

@@ -4,6 +4,8 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import net.woodyfolsom.msproj.Action; import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class MonteCarloUCT extends MonteCarlo { public class MonteCarloUCT extends MonteCarlo {
@@ -12,17 +14,20 @@ public class MonteCarloUCT extends MonteCarlo {
} }
@Override @Override
public List<GameTreeNode> descend(GameTreeNode node) { public List<GameTreeNode<MonteCarloProperties>> descend(GameTreeNode<MonteCarloProperties> node) {
double bestScore = (double) node.getNumWins() / node.getNumVisits(); double bestScore = Double.NEGATIVE_INFINITY;
GameTreeNode bestNode = node; GameTreeNode<MonteCarloProperties> bestNode = node;
//This appears slightly redundant with getBestAction() but it is not - //This appears slightly redundant with getBestAction() but it is not -
//descend() may pick the current node rather than a child to expand (if a child has a good score but high/low uncertainty) //descend() may pick the current node rather than a child to expand (if a child has a good score but high/low uncertainty)
//but getBestAction specifically asks for the optimum action to take from the current node, //but getBestAction specifically asks for the optimum action to take from the current node,
//even if it results in a worse next state. //even if it results in a worse next state.
for (Action action : node.getActions()) { for (Action action : node.getActions()) {
GameTreeNode childNode = node.getChild(action); GameTreeNode<MonteCarloProperties> childNode = node.getChild(action);
double childScore = (double) childNode.getNumWins() / childNode.getNumVisits();
MonteCarloProperties properties = childNode.getProperties();
double childScore = (double) properties.getWins() / properties.getVisits();
if (childScore >= bestScore) { if (childScore >= bestScore) {
bestScore = childScore; bestScore = childScore;
bestNode = childNode; bestNode = childNode;
@@ -30,7 +35,7 @@ public class MonteCarloUCT extends MonteCarlo {
} }
if (bestNode == node) { if (bestNode == node) {
List<GameTreeNode> bestNodeList = new ArrayList<GameTreeNode>(); List<GameTreeNode<MonteCarloProperties>> bestNodeList = new ArrayList<GameTreeNode<MonteCarloProperties>>();
bestNodeList.add(bestNode); bestNodeList.add(bestNode);
return bestNodeList; return bestNodeList;
} else { } else {
@@ -39,13 +44,16 @@ public class MonteCarloUCT extends MonteCarlo {
} }
@Override @Override
public Action getBestAction(GameTreeNode node) { public Action getBestAction(GameTreeNode<MonteCarloProperties> node) {
Action bestAction = Action.NONE; Action bestAction = Action.NONE;
double bestScore = Double.NEGATIVE_INFINITY; double bestScore = Double.NEGATIVE_INFINITY;
for (Action action : node.getActions()) { for (Action action : node.getActions()) {
GameTreeNode childNode = node.getChild(action); GameTreeNode<MonteCarloProperties> childNode = node.getChild(action);
double childScore = (double) childNode.getNumWins() / childNode.getNumVisits();
MonteCarloProperties properties = childNode.getProperties();
double childScore = (double) properties.getWins() / properties.getVisits();
if (childScore >= bestScore) { if (childScore >= bestScore) {
bestScore = childScore; bestScore = childScore;
bestAction = action; bestAction = action;
@@ -56,19 +64,19 @@ public class MonteCarloUCT extends MonteCarlo {
} }
@Override @Override
public List<GameTreeNode> grow(GameTreeNode node) { public List<GameTreeNode<MonteCarloProperties>> grow(GameTreeNode<MonteCarloProperties> node) {
// TODO Auto-generated method stub // TODO Auto-generated method stub
return null; return null;
} }
@Override @Override
public int rollout(GameTreeNode node) { public int rollout(GameTreeNode<MonteCarloProperties> node) {
// TODO Auto-generated method stub // TODO Auto-generated method stub
return 0; return 0;
} }
@Override @Override
public void update(GameTreeNode node, int reward) { public void update(GameTreeNode<MonteCarloProperties> node, int reward) {
// TODO Auto-generated method stub // TODO Auto-generated method stub
} }

View File

@@ -1,14 +0,0 @@
package net.woodyfolsom.msproj.policy;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameScore;
public class MoveCandidate {
public final Action move;
public final GameScore score;
public MoveCandidate(Action move, GameScore score) {
this.move = move;
this.score = score;
}
}

View File

@@ -0,0 +1,18 @@
package net.woodyfolsom.msproj.tree;
public class AlphaBetaPropeties extends GameTreeNodeProperties{
int alpha = 0;
int beta = 0;
public int getAlpha() {
return alpha;
}
public void setAlpha(int alpha) {
this.alpha = alpha;
}
public int getBeta() {
return beta;
}
public void setBeta(int beta) {
this.beta = beta;
}
}

View File

@@ -0,0 +1,57 @@
package net.woodyfolsom.msproj.tree;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameState;
public class GameTreeNode<T extends GameTreeNodeProperties> {
private GameState gameState;
private GameTreeNode<T> parent;
private Map<Action, GameTreeNode<T>> children = new HashMap<Action, GameTreeNode<T>>();
private T properties;
public GameTreeNode(GameState gameState, T properties) {
this.gameState = gameState;
this.properties = properties;
}
public void addChild(Action action, GameTreeNode<T> child) {
children.put(action, child);
child.parent = this;
}
public Set<Action> getActions() {
return children.keySet();
}
public GameTreeNode<T> getChild(Action action) {
return children.get(action);
}
public int getNumChildren() {
return children.size();
}
public GameState getGameState() {
return gameState;
}
public GameTreeNode<T> getParent() {
return parent;
}
public T getProperties() {
return properties;
}
public boolean isRoot() {
return parent == null;
}
public boolean isTerminal() {
return children.size() == 0;
}
}

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.tree;
public abstract class GameTreeNodeProperties {
}

View File

@@ -0,0 +1,13 @@
package net.woodyfolsom.msproj.tree;
public class MinimaxProperties extends GameTreeNodeProperties {
private double reward = 0.0;
public double getReward() {
return reward;
}
public void setReward(double reward) {
this.reward = reward;
}
}

View File

@@ -0,0 +1,18 @@
package net.woodyfolsom.msproj.tree;
public class MonteCarloProperties extends GameTreeNodeProperties {
int visits = 0;
int wins = 0;
public int getVisits() {
return visits;
}
public void setVisits(int visits) {
this.visits = visits;
}
public int getWins() {
return wins;
}
public void setWins(int wins) {
this.wins = wins;
}
}

View File

@@ -24,7 +24,7 @@ public class MinimaxTest {
System.out.println(gameState); System.out.println(gameState);
System.out.println("Generated move: " + move); System.out.println("Generated move: " + move);
assertEquals("Expected B3 but was: " + move, "B3", move); assertEquals("Expected B3 but was: " + move, Action.getInstance("B3"), move);
gameState.playStone(Player.WHITE, move); gameState.playStone(Player.WHITE, move);
System.out.println(gameState); System.out.println(gameState);
@@ -45,7 +45,7 @@ public class MinimaxTest {
System.out.println(gameState); System.out.println(gameState);
System.out.println("Generated move: " + move); System.out.println("Generated move: " + move);
assertEquals("Expected B3 but was: " + move, "B3", move); assertEquals("Expected B3 but was: " + move, Action.getInstance("B3"), move);
gameState.playStone(Player.BLACK, move); gameState.playStone(Player.BLACK, move);
System.out.println(gameState); System.out.println(gameState);