diff --git a/src/aima/core/agent/Action.java b/src/aima/core/agent/Action.java new file mode 100644 index 0000000..e21e090 --- /dev/null +++ b/src/aima/core/agent/Action.java @@ -0,0 +1,19 @@ +package aima.core.agent; + +/** + * Describes an Action that can or has been taken by an Agent via one of its + * Actuators. + * + * @author Ciaran O'Reilly + */ +public interface Action { + + /** + * Indicates whether or not this Action is a 'No Operation'.
+ * Note: AIMA3e - NoOp, or no operation, is the name of an assembly language + * instruction that does nothing. + * + * @return true if this is a NoOp Action. + */ + //boolean isNoOp(); +} diff --git a/src/aima/core/environment/cellworld/Cell.java b/src/aima/core/environment/cellworld/Cell.java new file mode 100644 index 0000000..fa6c4ea --- /dev/null +++ b/src/aima/core/environment/cellworld/Cell.java @@ -0,0 +1,87 @@ +package aima.core.environment.cellworld; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
+ *
+ * A representation of a Cell in the environment detailed in Figure 17.1. + * + * @param + * the content type of the cell. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public class Cell { + private int x = 1; + private int y = 1; + private C content = null; + + /** + * Construct a Cell. + * + * @param x + * the x position of the cell. + * @param y + * the y position of the cell. + * @param content + * the initial content of the cell. + */ + public Cell(int x, int y, C content) { + this.x = x; + this.y = y; + this.content = content; + } + + /** + * + * @return the x position of the cell. + */ + public int getX() { + return x; + } + + /** + * + * @return the y position of the cell. + */ + public int getY() { + return y; + } + + /** + * + * @return the content of the cell. + */ + public C getContent() { + return content; + } + + /** + * Set the cell's content. + * + * @param content + * the content to be placed in the cell. + */ + public void setContent(C content) { + this.content = content; + } + + @Override + public String toString() { + return ""; + } + + @Override + public boolean equals(Object o) { + if (o instanceof Cell) { + Cell c = (Cell) o; + return x == c.x && y == c.y && content.equals(c.content); + } + return false; + } + + @Override + public int hashCode() { + return x + 23 + y + 31 * content.hashCode(); + } +} diff --git a/src/aima/core/environment/cellworld/CellWorld.java b/src/aima/core/environment/cellworld/CellWorld.java new file mode 100644 index 0000000..20d8a78 --- /dev/null +++ b/src/aima/core/environment/cellworld/CellWorld.java @@ -0,0 +1,123 @@ +package aima.core.environment.cellworld; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
+ *
+ * + * A representation for the environment depicted in figure 17.1.
+ *
+ * Note: the x and y coordinates are always positive integers starting at + * 1.
+ * Note: If looking at a rectangle - the coordinate (x=1, y=1) will be the + * bottom left hand corner.
+ * + * + * @param + * the type of content for the Cells in the world. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public class CellWorld { + private Set> cells = new LinkedHashSet>(); + private Map>> cellLookup = new HashMap>>(); + + /** + * Construct a Cell World with size xDimension * y Dimension cells, all with + * their values set to a default content value. + * + * @param xDimension + * the size of the x dimension. + * @param yDimension + * the size of the y dimension. + * + * @param defaultCellContent + * the default content to assign to each cell created. + */ + public CellWorld(int xDimension, int yDimension, C defaultCellContent) { + for (int x = 1; x <= xDimension; x++) { + Map> xCol = new HashMap>(); + for (int y = 1; y <= yDimension; y++) { + Cell c = new Cell(x, y, defaultCellContent); + cells.add(c); + xCol.put(y, c); + } + cellLookup.put(x, xCol); + } + } + + /** + * + * @return all the cells in this world. + */ + public Set> getCells() { + return cells; + } + + /** + * Determine what cell would be moved into if the specified action is + * performed in the specified cell. Normally, this will be the cell adjacent + * in the appropriate direction. However, if there is no cell in the + * adjacent direction of the action then the outcome of the action is to + * stay in the same cell as the action was performed in. + * + * @param s + * the cell location from which the action is to be performed. + * @param a + * the action to perform (Up, Down, Left, or Right). + * @return the Cell an agent would end up in if they performed the specified + * action from the specified cell location. + */ + public Cell result(Cell s, CellWorldAction a) { + Cell sDelta = getCellAt(a.getXResult(s.getX()), a.getYResult(s + .getY())); + if (null == sDelta) { + // Default to no effect + // (i.e. bumps back in place as no adjoining cell). + sDelta = s; + } + + return sDelta; + } + + /** + * Remove the cell at the specified location from this Cell World. This + * allows you to introduce barriers into different location. + * + * @param x + * the x dimension of the cell to be removed. + * @param y + * the y dimension of the cell to be removed. + */ + public void removeCell(int x, int y) { + Map> xCol = cellLookup.get(x); + if (null != xCol) { + cells.remove(xCol.remove(y)); + } + } + + /** + * Get the cell at the specified x and y locations. + * + * @param x + * the x dimension of the cell to be retrieved. + * @param y + * the y dimension of the cell to be retrieved. + * @return the cell at the specified x,y location, null if no cell exists at + * this location. + */ + public Cell getCellAt(int x, int y) { + Cell c = null; + Map> xCol = cellLookup.get(x); + if (null != xCol) { + c = xCol.get(y); + } + + return c; + } +} diff --git a/src/aima/core/environment/cellworld/CellWorldAction.java b/src/aima/core/environment/cellworld/CellWorldAction.java new file mode 100644 index 0000000..ae14bd8 --- /dev/null +++ b/src/aima/core/environment/cellworld/CellWorldAction.java @@ -0,0 +1,142 @@ +package aima.core.environment.cellworld; + +import java.util.LinkedHashSet; +import java.util.Set; + +import aima.core.agent.Action; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
+ *
+ * + * The actions in every state are Up, Down, Left, and Right.
+ *
+ * Note: Moving 'North' causes y to increase by 1, 'Down' y to decrease by + * 1, 'Left' x to decrease by 1, and 'Right' x to increase by 1 within a Cell + * World. + * + * @author Ciaran O'Reilly + * + */ +public enum CellWorldAction implements Action { + Up, Down, Left, Right, None; + + private static final Set _actions = new LinkedHashSet(); + static { + _actions.add(Up); + _actions.add(Down); + _actions.add(Left); + _actions.add(Right); + _actions.add(None); + } + + /** + * + * @return a set of the actual actions. + */ + public static final Set actions() { + return _actions; + } + + // + // START-Action + //@Override + //public boolean isNoOp() { + // if (None == this) { + // return true; + // } + // return false; + //} + // END-Action + // + + /** + * + * @param curX + * the current x position. + * @return the result on the x position of applying this action. + */ + public int getXResult(int curX) { + int newX = curX; + + switch (this) { + case Left: + newX--; + break; + case Right: + newX++; + break; + } + + return newX; + } + + /** + * + * @param curY + * the current y position. + * @return the result on the y position of applying this action. + */ + public int getYResult(int curY) { + int newY = curY; + + switch (this) { + case Up: + newY++; + break; + case Down: + newY--; + break; + } + + return newY; + } + + /** + * + * @return the first right angled action related to this action. + */ + public CellWorldAction getFirstRightAngledAction() { + CellWorldAction a = null; + + switch (this) { + case Up: + case Down: + a = Left; + break; + case Left: + case Right: + a = Down; + break; + case None: + a = None; + break; + } + + return a; + } + + /** + * + * @return the second right angled action related to this action. + */ + public CellWorldAction getSecondRightAngledAction() { + CellWorldAction a = null; + + switch (this) { + case Up: + case Down: + a = Right; + break; + case Left: + case Right: + a = Up; + break; + case None: + a = None; + break; + } + + return a; + } +} diff --git a/src/aima/core/environment/cellworld/CellWorldFactory.java b/src/aima/core/environment/cellworld/CellWorldFactory.java new file mode 100644 index 0000000..16ad6ac --- /dev/null +++ b/src/aima/core/environment/cellworld/CellWorldFactory.java @@ -0,0 +1,27 @@ +package aima.core.environment.cellworld; + +/** + * + * @author Ciaran O'Reilly + * + */ +public class CellWorldFactory { + + /** + * Create the cell world as defined in Figure 17.1 in AIMA3e. (a) A simple 4 + * x 3 environment that presents the agent with a sequential decision + * problem. + * + * @return a cell world representation of Fig 17.1 in AIMA3e. + */ + public static CellWorld createCellWorldForFig17_1() { + CellWorld cw = new CellWorld(4, 3, -0.04); + + cw.removeCell(2, 2); + + cw.getCellAt(4, 3).setContent(1.0); + cw.getCellAt(4, 2).setContent(-1.0); + + return cw; + } +} \ No newline at end of file diff --git a/src/aima/core/environment/gridworld/GridCell.java b/src/aima/core/environment/gridworld/GridCell.java new file mode 100644 index 0000000..3390949 --- /dev/null +++ b/src/aima/core/environment/gridworld/GridCell.java @@ -0,0 +1,87 @@ +package aima.core.environment.gridworld; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
+ *
+ * A representation of a Cell in the environment detailed in Figure 17.1. + * + * @param + * the content type of the cell. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public class GridCell { + private int x = 1; + private int y = 1; + private C content = null; + + /** + * Construct a Cell. + * + * @param x + * the x position of the cell. + * @param y + * the y position of the cell. + * @param content + * the initial content of the cell. + */ + public GridCell(int x, int y, C content) { + this.x = x; + this.y = y; + this.content = content; + } + + /** + * + * @return the x position of the cell. + */ + public int getX() { + return x; + } + + /** + * + * @return the y position of the cell. + */ + public int getY() { + return y; + } + + /** + * + * @return the content of the cell. + */ + public C getContent() { + return content; + } + + /** + * Set the cell's content. + * + * @param content + * the content to be placed in the cell. + */ + public void setContent(C content) { + this.content = content; + } + + @Override + public String toString() { + return ""; + } + + @Override + public boolean equals(Object o) { + if (o instanceof GridCell) { + GridCell c = (GridCell) o; + return x == c.x && y == c.y && content.equals(c.content); + } + return false; + } + + @Override + public int hashCode() { + return x + 23 + y + 31 * content.hashCode(); + } +} diff --git a/src/aima/core/environment/gridworld/GridWorld.java b/src/aima/core/environment/gridworld/GridWorld.java new file mode 100644 index 0000000..7faa7e1 --- /dev/null +++ b/src/aima/core/environment/gridworld/GridWorld.java @@ -0,0 +1,56 @@ +package aima.core.environment.gridworld; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +public class GridWorld { + private Set> cells = new LinkedHashSet>(); + private Map>> cellLookup = new HashMap>>(); + + public GridWorld(int xDimension, int yDimension, C defaultCellContent) { + for (int x = 1; x <= xDimension; x++) { + Map> xCol = new HashMap>(); + for (int y = 1; y <= yDimension; y++) { + GridCell c = new GridCell(x, y, defaultCellContent); + cells.add(c); + xCol.put(y, c); + } + cellLookup.put(x, xCol); + } + } + + public Set> getCells() { + return cells; + } + + public GridCell result(GridCell s, GridWorldAction a) { + GridCell sDelta = getCellAt(a.getXResult(s.getX()), a.getYResult(s + .getY())); + if (null == sDelta) { + // Default to no effect + // (i.e. bumps back in place as no adjoining cell). + sDelta = s; + } + + return sDelta; + } + + public void removeCell(int x, int y) { + Map> xCol = cellLookup.get(x); + if (null != xCol) { + cells.remove(xCol.remove(y)); + } + } + + public GridCell getCellAt(int x, int y) { + GridCell c = null; + Map> xCol = cellLookup.get(x); + if (null != xCol) { + c = xCol.get(y); + } + + return c; + } +} \ No newline at end of file diff --git a/src/aima/core/environment/gridworld/GridWorldAction.java b/src/aima/core/environment/gridworld/GridWorldAction.java new file mode 100644 index 0000000..4ee86ab --- /dev/null +++ b/src/aima/core/environment/gridworld/GridWorldAction.java @@ -0,0 +1,56 @@ +package aima.core.environment.gridworld; + +import java.util.LinkedHashSet; +import java.util.Set; + +import aima.core.agent.Action; + +public enum GridWorldAction implements Action { + AddTile,CaptureThree,RandomMove; + + private static final Set _actions = new LinkedHashSet(); + static { + _actions.add(AddTile); // try to add a tile, turn (low chance of capture) + _actions.add(CaptureThree); // try to subtract two tiles, add a turn (high chance of capture) + _actions.add(RandomMove); // try add a tile, add a turn (even chance of add/capture) + } + + public static final Set actions() { + return _actions; + } + + // + // START-Action + //@Override + //public boolean isNoOp() { + // if (None == this) { + // return true; + // } + // return false; + //} + // END-Action + // + + public int getXResult(int curX) { + int newX = curX; + + switch (this) { + case AddTile: + newX++; + break; + case CaptureThree: + newX-=2; + break; + case RandomMove: + newX--; + break; + } + + return newX; + } + + public int getYResult(int curY) { + //the score increments by 1 at every action, regardless + return curY+1; + } +} \ No newline at end of file diff --git a/src/aima/core/environment/gridworld/GridWorldFactory.java b/src/aima/core/environment/gridworld/GridWorldFactory.java new file mode 100644 index 0000000..35afccf --- /dev/null +++ b/src/aima/core/environment/gridworld/GridWorldFactory.java @@ -0,0 +1,23 @@ +package aima.core.environment.gridworld; + +/** + * + * @author Woody Folsom + * + */ +public class GridWorldFactory { + + /** + * Create a CellWorld modeling a TileGame where the objective is to reach the maximum Number of tiles without + * exceeding targetScore. + * + * @return a cell world representation of Fig 17.1 in AIMA3e. + */ + public static GridWorld createGridWorldForTileGame(int maxTiles, int maxScore, double nonTerminalReward) { + GridWorld cw = new GridWorld(maxTiles, maxScore, nonTerminalReward); + + cw.getCellAt(maxTiles, maxScore).setContent(1.0); + + return cw; + } +} \ No newline at end of file diff --git a/src/aima/core/probability/example/MDPFactory.java b/src/aima/core/probability/example/MDPFactory.java new file mode 100644 index 0000000..5494a73 --- /dev/null +++ b/src/aima/core/probability/example/MDPFactory.java @@ -0,0 +1,251 @@ +package aima.core.probability.example; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import aima.core.environment.cellworld.Cell; +import aima.core.environment.cellworld.CellWorld; +import aima.core.environment.cellworld.CellWorldAction; +import aima.core.environment.gridworld.GridCell; +import aima.core.environment.gridworld.GridWorld; +import aima.core.environment.gridworld.GridWorldAction; +import aima.core.probability.mdp.ActionsFunction; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.RewardFunction; +import aima.core.probability.mdp.TransitionProbabilityFunction; +import aima.core.probability.mdp.impl.MDP; + +/** + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public class MDPFactory { + + /** + * Constructs an MDP that can be used to generate the utility values + * detailed in Fig 17.3. + * + * @param cw + * the cell world from figure 17.1. + * @return an MDP that can be used to generate the utility values detailed + * in Fig 17.3. + */ + public static MarkovDecisionProcess, CellWorldAction> createMDPForFigure17_3( + final CellWorld cw) { + + return new MDP, CellWorldAction>(cw.getCells(), + cw.getCellAt(1, 1), createActionsFunctionForFigure17_1(cw), + createTransitionProbabilityFunctionForFigure17_1(cw), + createRewardFunctionForFigure17_1()); + } + + public static MarkovDecisionProcess, GridWorldAction> createMDPForTileGame( + final GridWorld cw, int maxTiles, int maxScore) { + + return new MDP, GridWorldAction>(cw.getCells(), + cw.getCellAt(1, 1), createActionsFunctionForTileGame(cw,maxTiles,maxScore), + createTransitionProbabilityFunctionForTileGame(cw), + createRewardFunctionForTileGame()); + } + + /** + * Returns the allowed actions from a specified cell within the cell world + * described in Fig 17.1. + * + * @param cw + * the cell world from figure 17.1. + * @return the set of actions allowed at a particular cell. This set will be + * empty if at a terminal state. + */ + public static ActionsFunction, CellWorldAction> createActionsFunctionForFigure17_1( + final CellWorld cw) { + final Set> terminals = new HashSet>(); + terminals.add(cw.getCellAt(4, 3)); + terminals.add(cw.getCellAt(4, 2)); + + ActionsFunction, CellWorldAction> af = new ActionsFunction, CellWorldAction>() { + + @Override + public Set actions(Cell s) { + // All actions can be performed in each cell + // (except terminal states) + if (terminals.contains(s)) { + return Collections.emptySet(); + } + return CellWorldAction.actions(); + } + }; + return af; + } + + public static ActionsFunction, GridWorldAction> createActionsFunctionForTileGame( + final GridWorld cw, int maxTiles, int maxScore) { + final Set> terminals = new HashSet>(); + terminals.add(cw.getCellAt(maxTiles,maxScore)); + + ActionsFunction, GridWorldAction> af = new ActionsFunction, GridWorldAction>() { + + @Override + public Set actions(GridCell s) { + // All actions can be performed in each cell + // (except terminal states) + if (terminals.contains(s)) { + return Collections.emptySet(); + } + return GridWorldAction.actions(); + } + }; + + return af; + } + + /** + * Figure 17.1 (b) Illustration of the transition model of the environment: + * the 'intended' outcome occurs with probability 0.8, but with probability + * 0.2 the agent moves at right angles to the intended direction. A + * collision with a wall results in no movement. + * + * @param cw + * the cell world from figure 17.1. + * @return the transition probability function as described in figure 17.1. + */ + public static TransitionProbabilityFunction, CellWorldAction> createTransitionProbabilityFunctionForFigure17_1( + final CellWorld cw) { + TransitionProbabilityFunction, CellWorldAction> tf = new TransitionProbabilityFunction, CellWorldAction>() { + private double[] distribution = new double[] { 0.8, 0.1, 0.1 }; + + @Override + public double probability(Cell sDelta, Cell s, + CellWorldAction a) { + double prob = 0; + + List> outcomes = possibleOutcomes(s, a); + for (int i = 0; i < outcomes.size(); i++) { + if (sDelta.equals(outcomes.get(i))) { + // Note: You have to sum the matches to + // sDelta as the different actions + // could have the same effect (i.e. + // staying in place due to there being + // no adjacent cells), which increases + // the probability of the transition for + // that state. + prob += distribution[i]; + } + } + + return prob; + } + + private List> possibleOutcomes(Cell c, + CellWorldAction a) { + // There can be three possible outcomes for the planned action + List> outcomes = new ArrayList>(); + + outcomes.add(cw.result(c, a)); + outcomes.add(cw.result(c, a.getFirstRightAngledAction())); + outcomes.add(cw.result(c, a.getSecondRightAngledAction())); + + return outcomes; + } + }; + + return tf; + } + + public static TransitionProbabilityFunction, GridWorldAction> createTransitionProbabilityFunctionForTileGame( + final GridWorld cw) { + + TransitionProbabilityFunction, GridWorldAction> tf = new TransitionProbabilityFunction, GridWorldAction>() { + + @Override + public double probability(GridCell sDelta, GridCell s, + GridWorldAction a) { + double prob = 0; + + double[] distribution = getDistribution(a); + List> outcomes = possibleOutcomes(s, a); + for (int i = 0; i < outcomes.size(); i++) { + if (sDelta.equals(outcomes.get(i))) { + // Note: You have to sum the matches to + // sDelta as the different actions + // could have the same effect (i.e. + // staying in place due to there being + // no adjacent cells), which increases + // the probability of the transition for + // that state. + prob += distribution[i]; + } + } + + return prob; + } + + private double[] getDistribution(GridWorldAction a) { + switch (a) { + case AddTile : + return new double[] { 0.66, 0.34 }; + case CaptureThree : + return new double[] { 0.34, 0.66 }; + case RandomMove : + return new double[] { 0.50, 0.50 }; + default : + throw new RuntimeException("Unrecognized action: " + a); + } + } + private List> possibleOutcomes(GridCell c, + GridWorldAction a) { + // There can be three possible outcomes for the planned action + List> outcomes = new ArrayList>(); + + switch (a) { + case AddTile : + outcomes.add(cw.result(c, GridWorldAction.AddTile)); + outcomes.add(cw.result(c, GridWorldAction.CaptureThree)); + break; + case CaptureThree : + outcomes.add(cw.result(c, GridWorldAction.AddTile)); + outcomes.add(cw.result(c, GridWorldAction.CaptureThree)); + break; + case RandomMove : + outcomes.add(cw.result(c, GridWorldAction.AddTile)); + outcomes.add(cw.result(c, GridWorldAction.CaptureThree)); + default : + //no possible outcomes for unrecognized actions + } + + return outcomes; + } + }; + + return tf; + } + + /** + * + * @return the reward function which takes the content of the cell as being + * the reward value. + */ + public static RewardFunction> createRewardFunctionForFigure17_1() { + RewardFunction> rf = new RewardFunction>() { + @Override + public double reward(Cell s) { + return s.getContent(); + } + }; + return rf; + } + + public static RewardFunction> createRewardFunctionForTileGame() { + RewardFunction> rf = new RewardFunction>() { + @Override + public double reward(GridCell s) { + return s.getContent(); + } + }; + return rf; + } +} \ No newline at end of file diff --git a/src/aima/core/probability/mdp/ActionsFunction.java b/src/aima/core/probability/mdp/ActionsFunction.java new file mode 100644 index 0000000..cf7af83 --- /dev/null +++ b/src/aima/core/probability/mdp/ActionsFunction.java @@ -0,0 +1,27 @@ +package aima.core.probability.mdp; + +import java.util.Set; + +import aima.core.agent.Action; + +/** + * An interface for MDP action functions. + * + * @param + * the state type. + * @param + * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public interface ActionsFunction { + /** + * Get the set of actions for state s. + * + * @param s + * the state. + * @return the set of actions for state s. + */ + Set actions(S s); +} diff --git a/src/aima/core/probability/mdp/MarkovDecisionProcess.java b/src/aima/core/probability/mdp/MarkovDecisionProcess.java new file mode 100644 index 0000000..e729e55 --- /dev/null +++ b/src/aima/core/probability/mdp/MarkovDecisionProcess.java @@ -0,0 +1,79 @@ +package aima.core.probability.mdp; + +import java.util.Set; + +import aima.core.agent.Action; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 647.
+ *
+ * + * A sequential decision problem for a fully observable, stochastic environment + * with a Markovian transition model and additive rewards is called a Markov + * decision process, or MDP, and consists of a set of states (with an + * initial state s0; a set ACTIONS(s) of actions in each state; a + * transition model P(s' | s, a); and a reward function R(s).
+ *
+ * Note: Some definitions of MDPs allow the reward to depend on the + * action and outcome too, so the reward function is R(s, a, s'). This + * simplifies the description of some environments but does not change the + * problem in any fundamental way. + * + * @param + * the state type. + * @param
+ * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + * + */ +public interface MarkovDecisionProcess { + + /** + * Get the set of states associated with the Markov decision process. + * + * @return the set of states associated with the Markov decision process. + */ + Set states(); + + /** + * Get the initial state s0 for this instance of a Markov + * decision process. + * + * @return the initial state s0. + */ + S getInitialState(); + + /** + * Get the set of actions for state s. + * + * @param s + * the state. + * @return the set of actions for state s. + */ + Set actions(S s); + + /** + * Return the probability of going from state s using action a to s' based + * on the underlying transition model P(s' | s, a). + * + * @param sDelta + * the state s' being transitioned to. + * @param s + * the state s being transitions from. + * @param a + * the action used to move from state s to s'. + * @return the probability of going from state s using action a to s'. + */ + double transitionProbability(S sDelta, S s, A a); + + /** + * Get the reward associated with being in state s. + * + * @param s + * the state whose award is sought. + * @return the reward associated with being in state s. + */ + double reward(S s); +} diff --git a/src/aima/core/probability/mdp/Policy.java b/src/aima/core/probability/mdp/Policy.java new file mode 100644 index 0000000..59ca984 --- /dev/null +++ b/src/aima/core/probability/mdp/Policy.java @@ -0,0 +1,34 @@ +package aima.core.probability.mdp; + +import aima.core.agent.Action; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 647.
+ *
+ * + * A solution to a Markov decision process is called a policy. It + * specifies what the agent should do for any state that the agent might reach. + * It is traditional to denote a policy by π, and π(s) is the action + * recommended by the policy π for state s. If the agent has a complete + * policy, then no matter what the outcome of any action, the agent will always + * know what to do next. + * + * @param + * the state type. + * @param
+ * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + * + */ +public interface Policy { + /** + * π(s) is the action recommended by the policy π for state s. + * + * @param s + * the state s + * @return the action recommended by the policy π for state s. + */ + A action(S s); +} diff --git a/src/aima/core/probability/mdp/PolicyEvaluation.java b/src/aima/core/probability/mdp/PolicyEvaluation.java new file mode 100644 index 0000000..0ae5b48 --- /dev/null +++ b/src/aima/core/probability/mdp/PolicyEvaluation.java @@ -0,0 +1,39 @@ +package aima.core.probability.mdp; + +import java.util.Map; + +import aima.core.agent.Action; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 656.
+ *
+ * Given a policy πi, calculate + * Ui=Uπi, the utility of each state if + * πi were to be executed. + * + * @param + * the state type. + * @param
+ * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public interface PolicyEvaluation { + /** + * Policy evaluation: given a policy πi, calculate + * Ui=Uπi, the utility of each state if + * πi were to be executed. + * + * @param pi_i + * a policy vector indexed by state + * @param U + * a vector of utilities for states in S + * @param mdp + * an MDP with states S, actions A(s), transition model P(s'|s,a) + * @return Ui=Uπi, the utility of each + * state if πi were to be executed. + */ + Map evaluate(Map pi_i, Map U, + MarkovDecisionProcess mdp); +} diff --git a/src/aima/core/probability/mdp/RewardFunction.java b/src/aima/core/probability/mdp/RewardFunction.java new file mode 100644 index 0000000..e0085db --- /dev/null +++ b/src/aima/core/probability/mdp/RewardFunction.java @@ -0,0 +1,21 @@ +package aima.core.probability.mdp; + +/** + * An interface for MDP reward functions. + * + * @param + * the state type. + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public interface RewardFunction { + + /** + * Get the reward associated with being in state s. + * + * @param s + * the state whose award is sought. + * @return the reward associated with being in state s. + */ + double reward(S s); +} \ No newline at end of file diff --git a/src/aima/core/probability/mdp/TransitionProbabilityFunction.java b/src/aima/core/probability/mdp/TransitionProbabilityFunction.java new file mode 100644 index 0000000..c3dd710 --- /dev/null +++ b/src/aima/core/probability/mdp/TransitionProbabilityFunction.java @@ -0,0 +1,31 @@ +package aima.core.probability.mdp; + +import aima.core.agent.Action; + +/** + * An interface for MDP transition probability functions. + * + * @param + * the state type. + * @param + * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public interface TransitionProbabilityFunction { + + /** + * Return the probability of going from state s using action a to s' based + * on the underlying transition model P(s' | s, a). + * + * @param sDelta + * the state s' being transitioned to. + * @param s + * the state s being transitions from. + * @param a + * the action used to move from state s to s'. + * @return the probability of going from state s using action a to s'. + */ + double probability(S sDelta, S s, A a); +} diff --git a/src/aima/core/probability/mdp/impl/LookupPolicy.java b/src/aima/core/probability/mdp/impl/LookupPolicy.java new file mode 100644 index 0000000..2aa688e --- /dev/null +++ b/src/aima/core/probability/mdp/impl/LookupPolicy.java @@ -0,0 +1,36 @@ +package aima.core.probability.mdp.impl; + +import java.util.HashMap; +import java.util.Map; + +import aima.core.agent.Action; +import aima.core.probability.mdp.Policy; + +/** + * Default implementation of the Policy interface using an underlying Map to + * look up an action associated with a state. + * + * @param + * the state type. + * @param + * the action type. + * + * @author Ciaran O'Reilly + */ +public class LookupPolicy implements Policy { + private Map policy = new HashMap(); + + public LookupPolicy(Map aPolicy) { + policy.putAll(aPolicy); + } + + // + // START-Policy + @Override + public A action(S s) { + return policy.get(s); + } + + // END-Policy + // +} diff --git a/src/aima/core/probability/mdp/impl/MDP.java b/src/aima/core/probability/mdp/impl/MDP.java new file mode 100644 index 0000000..69113ba --- /dev/null +++ b/src/aima/core/probability/mdp/impl/MDP.java @@ -0,0 +1,69 @@ +package aima.core.probability.mdp.impl; + +import java.util.Set; + +import aima.core.agent.Action; +import aima.core.probability.mdp.ActionsFunction; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.RewardFunction; +import aima.core.probability.mdp.TransitionProbabilityFunction; + +/** + * Default implementation of the MarkovDecisionProcess interface. + * + * @param + * the state type. + * @param + * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + */ +public class MDP implements MarkovDecisionProcess { + private Set states = null; + private S initialState = null; + private ActionsFunction actionsFunction = null; + private TransitionProbabilityFunction transitionProbabilityFunction = null; + private RewardFunction rewardFunction = null; + + public MDP(Set states, S initialState, + ActionsFunction actionsFunction, + TransitionProbabilityFunction transitionProbabilityFunction, + RewardFunction rewardFunction) { + this.states = states; + this.initialState = initialState; + this.actionsFunction = actionsFunction; + this.transitionProbabilityFunction = transitionProbabilityFunction; + this.rewardFunction = rewardFunction; + } + + // + // START-MarkovDecisionProcess + @Override + public Set states() { + return states; + } + + @Override + public S getInitialState() { + return initialState; + } + + @Override + public Set actions(S s) { + return actionsFunction.actions(s); + } + + @Override + public double transitionProbability(S sDelta, S s, A a) { + return transitionProbabilityFunction.probability(sDelta, s, a); + } + + @Override + public double reward(S s) { + return rewardFunction.reward(s); + } + + // END-MarkovDecisionProcess + // +} diff --git a/src/aima/core/probability/mdp/impl/ModifiedPolicyEvaluation.java b/src/aima/core/probability/mdp/impl/ModifiedPolicyEvaluation.java new file mode 100644 index 0000000..a910c94 --- /dev/null +++ b/src/aima/core/probability/mdp/impl/ModifiedPolicyEvaluation.java @@ -0,0 +1,93 @@ +package aima.core.probability.mdp.impl; + +import java.util.HashMap; +import java.util.Map; + +import aima.core.agent.Action; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.PolicyEvaluation; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 657.
+ *
+ * For small state spaces, policy evaluation using exact solution methods is + * often the most efficient approach. For large state spaces, O(n3) + * time might be prohibitive. Fortunately, it is not necessary to do exact + * policy evaluation. Instead, we can perform some number of simplified value + * iteration steps (simplified because the policy is fixed) to give a reasonably + * good approximation of utilities. The simplified Bellman update for this + * process is:
+ *
+ * + *
+ * Ui+1(s) <- R(s) + γΣs'P(s'|s,πi(s))Ui(s')
+ * 
+ * + * and this is repeated k times to produce the next utility estimate. The + * resulting algorithm is called modified policy iteration. It is often + * much more efficient than standard policy iteration or value iteration. + * + * + * @param + * the state type. + * @param
+ * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + * + */ +public class ModifiedPolicyEvaluation implements PolicyEvaluation { + // # iterations to use to produce the next utility estimate + private int k; + // discount γ to be used. + private double gamma; + + /** + * Constructor. + * + * @param k + * number iterations to use to produce the next utility estimate + * @param gamma + * discount γ to be used + */ + public ModifiedPolicyEvaluation(int k, double gamma) { + if (gamma > 1.0 || gamma <= 0.0) { + throw new IllegalArgumentException("Gamma must be > 0 and <= 1.0"); + } + this.k = k; + this.gamma = gamma; + } + + // + // START-PolicyEvaluation + @Override + public Map evaluate(Map pi_i, Map U, + MarkovDecisionProcess mdp) { + Map U_i = new HashMap(U); + Map U_ip1 = new HashMap(U); + // repeat k times to produce the next utility estimate + for (int i = 0; i < k; i++) { + // Ui+1(s) <- R(s) + + // γΣs'P(s'|s,πi(s))Ui(s') + for (S s : U.keySet()) { + A ap_i = pi_i.get(s); + double aSum = 0; + // Handle terminal states (i.e. no actions) + if (null != ap_i) { + for (S sDelta : U.keySet()) { + aSum += mdp.transitionProbability(sDelta, s, ap_i) + * U_i.get(sDelta); + } + } + U_ip1.put(s, mdp.reward(s) + gamma * aSum); + } + + U_i.putAll(U_ip1); + } + return U_ip1; + } + + // END-PolicyEvaluation + // +} diff --git a/src/aima/core/probability/mdp/search/PolicyIteration.java b/src/aima/core/probability/mdp/search/PolicyIteration.java new file mode 100644 index 0000000..8d2692d --- /dev/null +++ b/src/aima/core/probability/mdp/search/PolicyIteration.java @@ -0,0 +1,144 @@ +package aima.core.probability.mdp.search; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import aima.core.agent.Action; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.Policy; +import aima.core.probability.mdp.PolicyEvaluation; +import aima.core.probability.mdp.impl.LookupPolicy; +import aima.core.util.Util; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 657.
+ *
+ * + *
+ * function POLICY-ITERATION(mdp) returns a policy
+ *   inputs: mdp, an MDP with states S, actions A(s), transition model P(s' | s, a)
+ *   local variables: U, a vector of utilities for states in S, initially zero
+ *                    π, a policy vector indexed by state, initially random
+ *                    
+ *   repeat
+ *      U <- POLICY-EVALUATION(π, U, mdp)
+ *      unchanged? <- true
+ *      for each state s in S do
+ *          if maxa ∈ A(s) Σs'P(s'|s,a)U[s'] > Σs'P(s'|s,π[s])U[s'] then do
+ *             π[s] <- argmaxa ∈ A(s) Σs'P(s'|s,a)U[s']
+ *             unchanged? <- false
+ *   until unchanged?
+ *   return π
+ * 
+ * + * Figure 17.7 The policy iteration algorithm for calculating an optimal policy. + * + * @param + * the state type. + * @param
+ * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + * + */ +public class PolicyIteration { + + private PolicyEvaluation policyEvaluation = null; + + /** + * Constructor. + * + * @param policyEvaluation + * the policy evaluation function to use. + */ + public PolicyIteration(PolicyEvaluation policyEvaluation) { + this.policyEvaluation = policyEvaluation; + } + + // function POLICY-ITERATION(mdp) returns a policy + /** + * The policy iteration algorithm for calculating an optimal policy. + * + * @param mdp + * an MDP with states S, actions A(s), transition model P(s'|s,a) + * @return an optimal policy + */ + public Policy policyIteration(MarkovDecisionProcess mdp) { + // local variables: U, a vector of utilities for states in S, initially + // zero + Map U = Util.create(mdp.states(), new Double(0)); + // π, a policy vector indexed by state, initially random + Map pi = initialPolicyVector(mdp); + boolean unchanged; + // repeat + do { + // U <- POLICY-EVALUATION(π, U, mdp) + U = policyEvaluation.evaluate(pi, U, mdp); + // unchanged? <- true + unchanged = true; + // for each state s in S do + for (S s : mdp.states()) { + // calculate: + // maxa ∈ A(s) + // Σs'P(s'|s,a)U[s'] + double aMax = Double.NEGATIVE_INFINITY, piVal = 0; + A aArgmax = pi.get(s); + for (A a : mdp.actions(s)) { + double aSum = 0; + for (S sDelta : mdp.states()) { + aSum += mdp.transitionProbability(sDelta, s, a) + * U.get(sDelta); + } + if (aSum > aMax) { + aMax = aSum; + aArgmax = a; + } + // track: + // Σs'P(s'|s,π[s])U[s'] + if (a.equals(pi.get(s))) { + piVal = aSum; + } + } + // if maxa ∈ A(s) + // Σs'P(s'|s,a)U[s'] + // > Σs'P(s'|s,π[s])U[s'] then do + if (aMax > piVal) { + // π[s] <- argmaxa ∈A(s) + // Σs'P(s'|s,a)U[s'] + pi.put(s, aArgmax); + // unchanged? <- false + unchanged = false; + } + } + // until unchanged? + } while (!unchanged); + + // return π + return new LookupPolicy(pi); + } + + /** + * Create a policy vector indexed by state, initially random. + * + * @param mdp + * an MDP with states S, actions A(s), transition model P(s'|s,a) + * @return a policy vector indexed by state, initially random. + */ + public static Map initialPolicyVector( + MarkovDecisionProcess mdp) { + Map pi = new LinkedHashMap(); + List actions = new ArrayList(); + for (S s : mdp.states()) { + actions.clear(); + actions.addAll(mdp.actions(s)); + // Handle terminal states (i.e. no actions). + if (actions.size() > 0) { + pi.put(s, Util.selectRandomlyFromList(actions)); + } + } + return pi; + } +} diff --git a/src/aima/core/probability/mdp/search/ValueIteration.java b/src/aima/core/probability/mdp/search/ValueIteration.java new file mode 100644 index 0000000..3c577c9 --- /dev/null +++ b/src/aima/core/probability/mdp/search/ValueIteration.java @@ -0,0 +1,129 @@ +package aima.core.probability.mdp.search; + +import java.util.Map; +import java.util.Set; + +import aima.core.agent.Action; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.util.Util; + +/** + * Artificial Intelligence A Modern Approach (3rd Edition): page 653.
+ *
+ * + *
+ * function VALUE-ITERATION(mdp, ε) returns a utility function
+ *   inputs: mdp, an MDP with states S, actions A(s), transition model P(s' | s, a),
+ *             rewards R(s), discount γ
+ *           ε the maximum error allowed in the utility of any state
+ *   local variables: U, U', vectors of utilities for states in S, initially zero
+ *                    δ the maximum change in the utility of any state in an iteration
+ *                    
+ *   repeat
+ *       U <- U'; δ <- 0
+ *       for each state s in S do
+ *           U'[s] <- R(s) + γ  maxa ∈ A(s) Σs'P(s' | s, a) U[s']
+ *           if |U'[s] - U[s]| > δ then δ <- |U'[s] - U[s]|
+ *   until δ < ε(1 - γ)/γ
+ *   return U
+ * 
+ * + * Figure 17.4 The value iteration algorithm for calculating utilities of + * states. The termination condition is from Equation (17.8):
+ * + *
+ * if ||Ui+1 - Ui|| < ε(1 - γ)/γ then ||Ui+1 - U|| < ε
+ * 
+ * + * @param + * the state type. + * @param
+ * the action type. + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + * + */ +public class ValueIteration { + // discount γ to be used. + private double gamma = 0; + + /** + * Constructor. + * + * @param gamma + * discount γ to be used. + */ + public ValueIteration(double gamma) { + if (gamma > 1.0 || gamma <= 0.0) { + throw new IllegalArgumentException("Gamma must be > 0 and <= 1.0"); + } + this.gamma = gamma; + } + + // function VALUE-ITERATION(mdp, ε) returns a utility function + /** + * The value iteration algorithm for calculating the utility of states. + * + * @param mdp + * an MDP with states S, actions A(s),
+ * transition model P(s' | s, a), rewards R(s) + * @param epsilon + * the maximum error allowed in the utility of any state + * @return a vector of utilities for states in S + */ + public Map valueIteration(MarkovDecisionProcess mdp, + double epsilon) { + // + // local variables: U, U', vectors of utilities for states in S, + // initially zero + Map U = Util.create(mdp.states(), new Double(0)); + Map Udelta = Util.create(mdp.states(), new Double(0)); + // δ the maximum change in the utility of any state in an + // iteration + double delta = 0; + // Note: Just calculate this once for efficiency purposes: + // ε(1 - γ)/γ + double minDelta = epsilon * (1 - gamma) / gamma; + + // repeat + do { + // U <- U'; δ <- 0 + U.putAll(Udelta); + delta = 0; + // for each state s in S do + for (S s : mdp.states()) { + // maxa ∈ A(s) + Set
actions = mdp.actions(s); + // Handle terminal states (i.e. no actions). + double aMax = 0; + if (actions.size() > 0) { + aMax = Double.NEGATIVE_INFINITY; + } + for (A a : actions) { + // Σs'P(s' | s, a) U[s'] + double aSum = 0; + for (S sDelta : mdp.states()) { + aSum += mdp.transitionProbability(sDelta, s, a) + * U.get(sDelta); + } + if (aSum > aMax) { + aMax = aSum; + } + } + // U'[s] <- R(s) + γ + // maxa ∈ A(s) + Udelta.put(s, mdp.reward(s) + gamma * aMax); + // if |U'[s] - U[s]| > δ then δ <- |U'[s] - U[s]| + double aDiff = Math.abs(Udelta.get(s) - U.get(s)); + if (aDiff > delta) { + delta = aDiff; + } + } + // until δ < ε(1 - γ)/γ + } while (delta > minDelta); + + // return U + return U; + } +} diff --git a/src/aima/core/util/Util.java b/src/aima/core/util/Util.java new file mode 100644 index 0000000..ef04fdd --- /dev/null +++ b/src/aima/core/util/Util.java @@ -0,0 +1,240 @@ +package aima.core.util; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Hashtable; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * @author Ravi Mohan + * + */ +public class Util { + public static final String NO = "No"; + public static final String YES = "Yes"; + // + private static Random _r = new Random(); + + /** + * Get the first element from a list. + * + * @param l + * the list the first element is to be extracted from. + * @return the first element of the passed in list. + */ + public static T first(List l) { + return l.get(0); + } + + /** + * Get a sublist of all of the elements in the list except for first. + * + * @param l + * the list the rest of the elements are to be extracted from. + * @return a list of all of the elements in the passed in list except for + * the first element. + */ + public static List rest(List l) { + return l.subList(1, l.size()); + } + + /** + * Create a Map with the passed in keys having their values + * initialized to the passed in value. + * + * @param keys + * the keys for the newly constructed map. + * @param value + * the value to be associated with each of the maps keys. + * @return a map with the passed in keys initialized to value. + */ + public static Map create(Collection keys, V value) { + Map map = new LinkedHashMap(); + + for (K k : keys) { + map.put(k, value); + } + + return map; + } + + /** + * Randomly select an element from a list. + * + * @param + * the type of element to be returned from the list l. + * @param l + * a list of type T from which an element is to be selected + * randomly. + * @return a randomly selected element from l. + */ + public static T selectRandomlyFromList(List l) { + return l.get(_r.nextInt(l.size())); + } + + public static boolean randomBoolean() { + int trueOrFalse = _r.nextInt(2); + return (!(trueOrFalse == 0)); + } + + public static double[] normalize(double[] probDist) { + int len = probDist.length; + double total = 0.0; + for (double d : probDist) { + total = total + d; + } + + double[] normalized = new double[len]; + if (total != 0) { + for (int i = 0; i < len; i++) { + normalized[i] = probDist[i] / total; + } + } + + return normalized; + } + + public static List normalize(List values) { + double[] valuesAsArray = new double[values.size()]; + for (int i = 0; i < valuesAsArray.length; i++) { + valuesAsArray[i] = values.get(i); + } + double[] normalized = normalize(valuesAsArray); + List results = new ArrayList(); + for (int i = 0; i < normalized.length; i++) { + results.add(normalized[i]); + } + return results; + } + + public static int min(int i, int j) { + return (i > j ? j : i); + } + + public static int max(int i, int j) { + return (i < j ? j : i); + } + + public static int max(int i, int j, int k) { + return max(max(i, j), k); + } + + public static int min(int i, int j, int k) { + return min(min(i, j), k); + } + + public static T mode(List l) { + Hashtable hash = new Hashtable(); + for (T obj : l) { + if (hash.containsKey(obj)) { + hash.put(obj, hash.get(obj).intValue() + 1); + } else { + hash.put(obj, 1); + } + } + + T maxkey = hash.keySet().iterator().next(); + for (T key : hash.keySet()) { + if (hash.get(key) > hash.get(maxkey)) { + maxkey = key; + } + } + return maxkey; + } + + public static String[] yesno() { + return new String[] { YES, NO }; + } + + public static double log2(double d) { + return Math.log(d) / Math.log(2); + } + + public static double information(double[] probabilities) { + double total = 0.0; + for (double d : probabilities) { + total += (-1.0 * log2(d) * d); + } + return total; + } + + public static List removeFrom(List list, T member) { + List newList = new ArrayList(list); + newList.remove(member); + return newList; + } + + public static double sumOfSquares(List list) { + double accum = 0; + for (T item : list) { + accum = accum + (item.doubleValue() * item.doubleValue()); + } + return accum; + } + + public static String ntimes(String s, int n) { + StringBuffer buf = new StringBuffer(); + for (int i = 0; i < n; i++) { + buf.append(s); + } + return buf.toString(); + } + + public static void checkForNanOrInfinity(double d) { + if (Double.isNaN(d)) { + throw new RuntimeException("Not a Number"); + } + if (Double.isInfinite(d)) { + throw new RuntimeException("Infinite Number"); + } + } + + public static int randomNumberBetween(int i, int j) { + /* i,j bothinclusive */ + return _r.nextInt(j - i + 1) + i; + } + + public static double calculateMean(List lst) { + Double sum = 0.0; + for (Double d : lst) { + sum = sum + d.doubleValue(); + } + return sum / lst.size(); + } + + public static double calculateStDev(List values, double mean) { + + int listSize = values.size(); + + Double sumOfDiffSquared = 0.0; + for (Double value : values) { + double diffFromMean = value - mean; + sumOfDiffSquared += ((diffFromMean * diffFromMean) / (listSize - 1)); + // division moved here to avoid sum becoming too big if this + // doesn't work use incremental formulation + + } + double variance = sumOfDiffSquared; + // (listSize - 1); + // assumes at least 2 members in list. + return Math.sqrt(variance); + } + + public static List normalizeFromMeanAndStdev(List values, + double mean, double stdev) { + List normalized = new ArrayList(); + for (Double d : values) { + normalized.add((d - mean) / stdev); + } + return normalized; + } + + public static double generateRandomDoubleBetween(double lowerLimit, + double upperLimit) { + + return lowerLimit + ((upperLimit - lowerLimit) * _r.nextDouble()); + } +} \ No newline at end of file diff --git a/src/model/comPlayer/AdaptiveComPlayer.java b/src/model/comPlayer/AdaptiveComPlayer.java new file mode 100644 index 0000000..efd2995 --- /dev/null +++ b/src/model/comPlayer/AdaptiveComPlayer.java @@ -0,0 +1,111 @@ +package model.comPlayer; + +import model.Board; +import model.BoardScorer; +import model.Move; +import model.comPlayer.generator.AlphaBetaMoveGenerator; +import model.comPlayer.generator.MonteCarloMoveGenerator; +import model.comPlayer.generator.MoveGenerator; +import model.playerModel.PlayerModel; +import aima.core.environment.gridworld.GridCell; +import aima.core.environment.gridworld.GridWorld; +import aima.core.environment.gridworld.GridWorldAction; +import aima.core.environment.gridworld.GridWorldFactory; +import aima.core.probability.example.MDPFactory; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.Policy; +import aima.core.probability.mdp.PolicyEvaluation; +import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation; +import aima.core.probability.mdp.search.PolicyIteration; + +public class AdaptiveComPlayer implements Player { + private final MoveGenerator abMoveGenerator = new AlphaBetaMoveGenerator(); + private final MoveGenerator mcMoveGenerator = new MonteCarloMoveGenerator(); + + private BoardScorer boardScorer = new BoardScorer(); + private boolean calculatePolicy = true; + private GridWorld gw = null; + private MarkovDecisionProcess, GridWorldAction> mdp = null; + private Policy, GridWorldAction> policy = null; + private PolicyIteration, GridWorldAction> pi = null; + + @Override + public void denyMove() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Move getMove(Board board, PlayerModel player) { + if (calculatePolicy) { + System.out.println("Calculating policy for PlayerModel: " + player); + + // take 10 turns to place 6 tiles + double defaultPenalty = -0.25; + + int maxScore = player.getTargetScore().getTargetScore(); + int maxTiles = Board.NUM_COLS * Board.NUM_ROWS; + + gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, + maxScore, defaultPenalty); + mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); + + // gamma = 1.0 + PolicyEvaluation, GridWorldAction> pe = new ModifiedPolicyEvaluation, GridWorldAction>( + 50, 0.9); + pi = new PolicyIteration, GridWorldAction>(pe); + policy = pi.policyIteration(mdp); + + System.out.println("Optimum policy calculated."); + + for (int j = maxScore; j >= 1; j--) { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= maxTiles; i++) { + sb.append(policy.action(gw.getCellAt(i, j))); + sb.append(" "); + } + System.out.println(sb.toString()); + } + + calculatePolicy = false; + } else { + System.out.println("Using pre-calculated policy"); + } + + GridCell state = getState(board); + GridWorldAction action = policy.action(state); + + if (action == null || state == null) { + System.out.println("Board state outside of parameters of MDP. Reverting to failsafe behavior."); + action = GridWorldAction.RandomMove; + } + System.out.println("Performing action " + action + " at state " + state + " per policy."); + switch (action) { + case AddTile: + //System.out.println("Performing action #" + GridWorldAction.AddTile.ordinal()); + return abMoveGenerator.genMove(board, false); + case CaptureThree: + //System.out.println("Performing action #" + GridWorldAction.CaptureThree.ordinal()); + return mcMoveGenerator.genMove(board, false); + case RandomMove: + //System.out.println("Performing action #" + GridWorldAction.None.ordinal()); + return mcMoveGenerator.genMove(board, false); + default: + //System.out.println("Performing failsafe action"); + return mcMoveGenerator.genMove(board, false); + } + } + + private GridCell getState(Board board) { + return gw.getCellAt(board.getTurn(), boardScorer.getScore(board)); + } + + @Override + public boolean isReady() { + return true; // always ready to play a random valid move + } + + @Override + public String toString() { + return "Alpha-Beta ComPlayer"; + } +} \ No newline at end of file diff --git a/src/model/mdp/Action.java b/src/model/mdp/Action.java new file mode 100644 index 0000000..cda7b6d --- /dev/null +++ b/src/model/mdp/Action.java @@ -0,0 +1,18 @@ +package model.mdp; + +public class Action { + public static Action playToWin = new Action("PlayToWin"); + public static Action playToLose = new Action("PlayToLose"); + //public static Action maintainScore = new Action(); + + private final String name; + + public Action(String name) { + this.name = name; + } + + @Override + public String toString() { + return name; + } +} diff --git a/src/model/mdp/MDP.java b/src/model/mdp/MDP.java new file mode 100644 index 0000000..fe534b7 --- /dev/null +++ b/src/model/mdp/MDP.java @@ -0,0 +1,51 @@ +package model.mdp; + +public class MDP { + public static final double nonTerminalReward = -0.25; + + public enum MODE { + CEIL, FLOOR + } + + private final int maxScore; + private final int maxTiles; + private final MODE mode; + + public MDP(int maxScore, int maxTiles, MODE mode) { + this.maxScore = maxScore; + this.maxTiles = maxTiles; + this.mode = mode; + } + + public Action[] getActions(int i, int j) { + if (i == maxScore) { + return new Action[0]; + } + if (j == maxTiles) { + return new Action[0]; + } + return new Action[]{Action.playToLose,Action.playToWin}; + } + + public int getMaxScore() { + return maxScore; + } + + public int getMaxTiles() { + return maxTiles; + } + + public double getReward(int score, int tiles) { + if (score == maxScore && tiles == maxTiles) { + return 10.0; + } + // TODO scale linearly? + if (score == maxScore) { + return -1.0; + } + if (tiles == maxTiles) { + return -5.0; + } + return nonTerminalReward; + } +} \ No newline at end of file diff --git a/src/model/mdp/MDPSolver.java b/src/model/mdp/MDPSolver.java new file mode 100644 index 0000000..812fed2 --- /dev/null +++ b/src/model/mdp/MDPSolver.java @@ -0,0 +1,5 @@ +package model.mdp; + +public interface MDPSolver { + Policy solve(MDP mdp); +} diff --git a/src/model/mdp/Policy.java b/src/model/mdp/Policy.java new file mode 100644 index 0000000..66b9b0c --- /dev/null +++ b/src/model/mdp/Policy.java @@ -0,0 +1,7 @@ +package model.mdp; + +import java.util.ArrayList; + +public class Policy extends ArrayList{ + +} diff --git a/src/model/mdp/Transition.java b/src/model/mdp/Transition.java new file mode 100644 index 0000000..5148b8f --- /dev/null +++ b/src/model/mdp/Transition.java @@ -0,0 +1,34 @@ +package model.mdp; + +public class Transition { + private double prob; + private int scoreChange; + private int tileCountChange; + + public Transition(double prob, int scoreChange, int tileCountChange) { + super(); + this.prob = prob; + this.scoreChange = scoreChange; + this.tileCountChange = tileCountChange; + } + + public double getProb() { + return prob; + } + public void setProb(double prob) { + this.prob = prob; + } + public int getScoreChange() { + return scoreChange; + } + public void setScoreChange(int scoreChange) { + this.scoreChange = scoreChange; + } + public int getTileCountChange() { + return tileCountChange; + } + public void setTileCountChange(int tileCountChange) { + this.tileCountChange = tileCountChange; + } + +} \ No newline at end of file diff --git a/src/model/mdp/ValueIterationSolver.java b/src/model/mdp/ValueIterationSolver.java new file mode 100644 index 0000000..35e9d87 --- /dev/null +++ b/src/model/mdp/ValueIterationSolver.java @@ -0,0 +1,110 @@ +package model.mdp; + +import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.List; + +public class ValueIterationSolver implements MDPSolver { + public int maxIterations = 10; + public final double DEFAULT_EPS = 0.1; + public final double GAMMA = 0.9; //discount + + private DecimalFormat fmt = new DecimalFormat("##.00"); + public Policy solve(MDP mdp) { + Policy policy = new Policy(); + + double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1]; + double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1]; + + for (int i = 0; i <= mdp.getMaxScore(); i++) { + //StringBuilder sb = new StringBuilder(); + for (int j = 0; j <= mdp.getMaxTiles(); j++) { + utilityPrime[i][j] = mdp.getReward(i, j); + //sb.append(fmt.format(utility[i][j])); + //sb.append(" "); + } + //System.out.println(sb); + } + + converged: + for (int iteration = 0; iteration < maxIterations; iteration++) { + for (int i = 0; i <= mdp.getMaxScore(); i++) { + for (int j = 0; j <= mdp.getMaxTiles(); j++) { + utility[i][j] = utilityPrime[i][j]; + } + } + for (int i = 0; i <= mdp.getMaxScore(); i++) { + for (int j = 0; j <= mdp.getMaxTiles(); j++) { + Action[] actions = mdp.getActions(i,j); + + double aMax; + if (actions.length > 0) { + aMax = Double.NEGATIVE_INFINITY; + } else { + aMax = 0; + } + + for (Action action : actions){ + List transitions = getTransitions(action,mdp,i,j); + double aSum = 0.0; + for (Transition transition : transitions) { + int transI = transition.getScoreChange(); + int transJ = transition.getTileCountChange(); + if (i+transI >= 0 && i+transI <= mdp.getMaxScore() + && j+transJ >= 0 && j+transJ <= mdp.getMaxTiles()) + aSum += utility[i+transI][j+transJ]; + } + if (aSum > aMax) { + aMax = aSum; + } + } + utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax; + } + } + double maxDiff = getMaxDiff(utility,utilityPrime); + System.out.println("Max diff |U - U'| = " + maxDiff); + if (maxDiff < DEFAULT_EPS) { + System.out.println("Solution to MDP converged: " + maxDiff); + break converged; + } + } + + for (int i = 0; i < utility.length; i++) { + StringBuilder sb = new StringBuilder(); + for (int j = 0; j < utility[i].length; j++) { + sb.append(fmt.format(utility[i][j])); + sb.append(" "); + } + System.out.println(sb); + } + + //utility is now the utility Matrix + //get the policy + return policy; + } + + double getMaxDiff(double[][]u, double[][]uPrime) { + double maxDiff = 0; + for (int i = 0; i < u.length; i++) { + for (int j = 0; j < u[i].length; j++) { + maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j])); + } + } + return maxDiff; + } + + private List getTransitions(Action action, MDP mdp, int score, int tiles) { + List transitions = new ArrayList(); + if (Action.playToWin == action) { + transitions.add(new Transition(0.9,1,1)); + transitions.add(new Transition(0.1,1,-3)); + } else if (Action.playToLose == action) { + transitions.add(new Transition(0.9,1,1)); + transitions.add(new Transition(0.1,1,-3)); + } /*else if (Action.maintainScore == action) { + transitions.add(new Transition(0.5,1,1)); + transitions.add(new Transition(0.5,1,-3)); + }*/ + return transitions; + } +} \ No newline at end of file diff --git a/src/view/ParsedArgs.java b/src/view/ParsedArgs.java index 17c547d..210895a 100644 --- a/src/view/ParsedArgs.java +++ b/src/view/ParsedArgs.java @@ -1,5 +1,6 @@ package view; +import model.comPlayer.AdaptiveComPlayer; import model.comPlayer.AlphaBetaComPlayer; import model.comPlayer.MinimaxComPlayer; import model.comPlayer.MonteCarloComPlayer; @@ -7,16 +8,19 @@ import model.comPlayer.Player; import model.comPlayer.RandomComPlayer; public class ParsedArgs { - public static final String COM_RANDOM = "RANDOM"; - public static final String COM_MINIMAX = "MINIMAX"; + public static final String COM_ADAPTIVE = "ADAPTIVE"; public static final String COM_ALPHABETA = "ALPHABETA"; + public static final String COM_MINIMAX = "MINIMAX"; public static final String COM_MONTECARLO = "MONTECARLO"; + public static final String COM_RANDOM = "RANDOM"; public static final String COM_DEFAULT = COM_ALPHABETA; private String comPlayer = COM_DEFAULT; public Player getComPlayer() { - if (COM_RANDOM.equalsIgnoreCase(comPlayer)) { + if (COM_ADAPTIVE.equalsIgnoreCase(comPlayer)) { + return new AdaptiveComPlayer(); + } else if (COM_RANDOM.equalsIgnoreCase(comPlayer)) { return new RandomComPlayer(); } else if (COM_MINIMAX.equalsIgnoreCase(comPlayer)) { return new MinimaxComPlayer(); diff --git a/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java b/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java new file mode 100644 index 0000000..e266e92 --- /dev/null +++ b/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java @@ -0,0 +1,98 @@ +package aima.core.probability.mdp; + +import junit.framework.Assert; + +import org.junit.Before; +import org.junit.Test; + +import aima.core.environment.cellworld.Cell; +import aima.core.environment.cellworld.CellWorld; +import aima.core.environment.cellworld.CellWorldAction; +import aima.core.environment.cellworld.CellWorldFactory; +import aima.core.probability.example.MDPFactory; +import aima.core.probability.mdp.MarkovDecisionProcess; + +/** + * + * @author Ciaran O'Reilly + * @author Ravi Mohan + * + */ +public class MarkovDecisionProcessTest { + public static final double DELTA_THRESHOLD = 1e-3; + + private CellWorld cw = null; + private MarkovDecisionProcess, CellWorldAction> mdp = null; + + @Before + public void setUp() { + cw = CellWorldFactory.createCellWorldForFig17_1(); + mdp = MDPFactory.createMDPForFigure17_3(cw); + } + + @Test + public void testActions() { + // Ensure all actions can be performed in each cell + // except for the terminal states. + for (Cell s : cw.getCells()) { + if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) { + Assert.assertEquals(0, mdp.actions(s).size()); + } else { + Assert.assertEquals(5, mdp.actions(s).size()); + } + } + } + + @Test + public void testMDPTransitionModel() { + Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2), + cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); + Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1), + cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); + Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1), + cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); + Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3), + cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); + + Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1), + cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); + Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1), + cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); + Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1), + cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); + Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2), + cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); + + Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1), + cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); + Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1), + cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); + Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1), + cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); + Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2), + cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); + + Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1), + cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); + Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1), + cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); + Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2), + cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); + Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3), + cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); + } + + @Test + public void testRewardFunction() { + // Ensure all actions can be performed in each cell. + for (Cell s : cw.getCells()) { + if (4 == s.getX() && 3 == s.getY()) { + Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD); + } else if (4 == s.getX() && 2 == s.getY()) { + Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD); + } else { + Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD); + } + } + } +} diff --git a/test/aima/core/probability/mdp/PolicyIterationTest.java b/test/aima/core/probability/mdp/PolicyIterationTest.java new file mode 100644 index 0000000..255f403 --- /dev/null +++ b/test/aima/core/probability/mdp/PolicyIterationTest.java @@ -0,0 +1,80 @@ +package aima.core.probability.mdp; + +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import aima.core.environment.cellworld.Cell; +import aima.core.environment.cellworld.CellWorld; +import aima.core.environment.cellworld.CellWorldAction; +import aima.core.environment.cellworld.CellWorldFactory; +import aima.core.environment.gridworld.GridCell; +import aima.core.environment.gridworld.GridWorld; +import aima.core.environment.gridworld.GridWorldAction; +import aima.core.environment.gridworld.GridWorldFactory; +import aima.core.probability.example.MDPFactory; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation; +import aima.core.probability.mdp.search.PolicyIteration; +import aima.core.probability.mdp.search.ValueIteration; + +/** + * @author Ravi Mohan + * @author Ciaran O'Reilly + * + */ +public class PolicyIterationTest { + public static final double DELTA_THRESHOLD = 1e-3; + + private GridWorld gw = null; + private MarkovDecisionProcess, GridWorldAction> mdp = null; + private PolicyIteration, GridWorldAction> pi = null; + + final int maxTiles = 6; + final int maxScore = 10; + + @Before + public void setUp() { + //take 10 turns to place 6 tiles + double defaultPenalty = -0.04; + + gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty); + mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); + + //gamma = 1.0 + PolicyEvaluation,GridWorldAction> pe = new ModifiedPolicyEvaluation, GridWorldAction>(100,0.9); + pi = new PolicyIteration, GridWorldAction>(pe); + } + + @Test + public void testPolicyIterationForTileGame() { + Policy, GridWorldAction> policy = pi.policyIteration(mdp); + + for (int j = maxScore; j >= 1; j--) { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= maxTiles; i++) { + sb.append(policy.action(gw.getCellAt(i, j))); + sb.append(" "); + } + System.out.println(sb.toString()); + } + + //Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD); + /* + Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD); + Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD); + Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD); + Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD); + Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/ + } +} diff --git a/test/aima/core/probability/mdp/ValueIterationTest.java b/test/aima/core/probability/mdp/ValueIterationTest.java new file mode 100644 index 0000000..9d1215e --- /dev/null +++ b/test/aima/core/probability/mdp/ValueIterationTest.java @@ -0,0 +1,64 @@ +package aima.core.probability.mdp; + +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import aima.core.environment.cellworld.Cell; +import aima.core.environment.cellworld.CellWorld; +import aima.core.environment.cellworld.CellWorldAction; +import aima.core.environment.cellworld.CellWorldFactory; +import aima.core.probability.example.MDPFactory; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.search.ValueIteration; + +/** + * @author Ravi Mohan + * @author Ciaran O'Reilly + * + */ +public class ValueIterationTest { + public static final double DELTA_THRESHOLD = 1e-3; + + private CellWorld cw = null; + private MarkovDecisionProcess, CellWorldAction> mdp = null; + private ValueIteration, CellWorldAction> vi = null; + + @Before + public void setUp() { + cw = CellWorldFactory.createCellWorldForFig17_1(); + mdp = MDPFactory.createMDPForFigure17_3(cw); + vi = new ValueIteration, CellWorldAction>(1.0); + } + + @Test + public void testValueIterationForFig17_3() { + Map, Double> U = vi.valueIteration(mdp, 0.0001); + + Assert.assertEquals(0.705, U.get(cw.getCellAt(1, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.762, U.get(cw.getCellAt(1, 2)), DELTA_THRESHOLD); + Assert.assertEquals(0.812, U.get(cw.getCellAt(1, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.655, U.get(cw.getCellAt(2, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.868, U.get(cw.getCellAt(2, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.611, U.get(cw.getCellAt(3, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.660, U.get(cw.getCellAt(3, 2)), DELTA_THRESHOLD); + Assert.assertEquals(0.918, U.get(cw.getCellAt(3, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.388, U.get(cw.getCellAt(4, 1)), DELTA_THRESHOLD); + Assert.assertEquals(-1.0, U.get(cw.getCellAt(4, 2)), DELTA_THRESHOLD); + Assert.assertEquals(1.0, U.get(cw.getCellAt(4, 3)), DELTA_THRESHOLD); + + for (int j = 3; j >= 1; j--) { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= 4; i++) { + sb.append(U.get(cw.getCellAt(i, j))); + sb.append(" "); + } + System.out.println(sb.toString()); + } + } +} diff --git a/test/aima/core/probability/mdp/ValueIterationTest2.java b/test/aima/core/probability/mdp/ValueIterationTest2.java new file mode 100644 index 0000000..a0c6ce1 --- /dev/null +++ b/test/aima/core/probability/mdp/ValueIterationTest2.java @@ -0,0 +1,76 @@ +package aima.core.probability.mdp; + +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import aima.core.environment.cellworld.Cell; +import aima.core.environment.cellworld.CellWorld; +import aima.core.environment.cellworld.CellWorldAction; +import aima.core.environment.cellworld.CellWorldFactory; +import aima.core.environment.gridworld.GridCell; +import aima.core.environment.gridworld.GridWorld; +import aima.core.environment.gridworld.GridWorldAction; +import aima.core.environment.gridworld.GridWorldFactory; +import aima.core.probability.example.MDPFactory; +import aima.core.probability.mdp.MarkovDecisionProcess; +import aima.core.probability.mdp.search.ValueIteration; + +/** + * @author Ravi Mohan + * @author Ciaran O'Reilly + * + */ +public class ValueIterationTest2 { + public static final double DELTA_THRESHOLD = 1e-3; + + private GridWorld gw = null; + private MarkovDecisionProcess, GridWorldAction> mdp = null; + private ValueIteration, GridWorldAction> vi = null; + + final int maxTiles = 6; + final int maxScore = 10; + + @Before + public void setUp() { + //take 10 turns to place 6 tiles + double defaultPenalty = -0.04; + + gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty); + mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); + + //gamma = 1.0 + vi = new ValueIteration, GridWorldAction>(0.9); + } + + @Test + public void testValueIterationForTileGame() { + Map, Double> U = vi.valueIteration(mdp, 1.0); + + for (int j = maxScore; j >= 1; j--) { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= maxTiles; i++) { + sb.append(U.get(gw.getCellAt(i, j))); + sb.append(" "); + } + System.out.println(sb.toString()); + } + + Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/* + Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD); + Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD); + Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD); + Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD); + + Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD); + Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD); + Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/ + } +} diff --git a/test/model/mdp/ValueIterationSolverTest.java b/test/model/mdp/ValueIterationSolverTest.java new file mode 100644 index 0000000..dac3656 --- /dev/null +++ b/test/model/mdp/ValueIterationSolverTest.java @@ -0,0 +1,26 @@ +package model.mdp; + +import static org.junit.Assert.assertTrue; +import model.mdp.MDP.MODE; + +import org.junit.Test; + +public class ValueIterationSolverTest { + + @Test + public void testSolve() { + MDPSolver solver = new ValueIterationSolver(); + + //solve for a score of 25 in at most 35 turns + int maxScore = 6; + int maxTurns = 10; + + MDP mdp = new MDP(maxScore,maxTurns,MODE.CEIL); + Policy policy = solver.solve(mdp); + + assertTrue(policy.size() >= maxScore); + assertTrue(policy.size() <= maxTurns); + + System.out.println("Policy: " + policy); + } +}