diff --git a/src/aima/core/environment/cellworld/Cell.java b/src/aima/core/environment/cellworld/Cell.java deleted file mode 100644 index fa6c4ea..0000000 --- a/src/aima/core/environment/cellworld/Cell.java +++ /dev/null @@ -1,87 +0,0 @@ -package aima.core.environment.cellworld; - -/** - * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
- *
- * A representation of a Cell in the environment detailed in Figure 17.1. - * - * @param - * the content type of the cell. - * - * @author Ciaran O'Reilly - * @author Ravi Mohan - */ -public class Cell { - private int x = 1; - private int y = 1; - private C content = null; - - /** - * Construct a Cell. - * - * @param x - * the x position of the cell. - * @param y - * the y position of the cell. - * @param content - * the initial content of the cell. - */ - public Cell(int x, int y, C content) { - this.x = x; - this.y = y; - this.content = content; - } - - /** - * - * @return the x position of the cell. - */ - public int getX() { - return x; - } - - /** - * - * @return the y position of the cell. - */ - public int getY() { - return y; - } - - /** - * - * @return the content of the cell. - */ - public C getContent() { - return content; - } - - /** - * Set the cell's content. - * - * @param content - * the content to be placed in the cell. - */ - public void setContent(C content) { - this.content = content; - } - - @Override - public String toString() { - return ""; - } - - @Override - public boolean equals(Object o) { - if (o instanceof Cell) { - Cell c = (Cell) o; - return x == c.x && y == c.y && content.equals(c.content); - } - return false; - } - - @Override - public int hashCode() { - return x + 23 + y + 31 * content.hashCode(); - } -} diff --git a/src/aima/core/environment/cellworld/CellWorld.java b/src/aima/core/environment/cellworld/CellWorld.java deleted file mode 100644 index 20d8a78..0000000 --- a/src/aima/core/environment/cellworld/CellWorld.java +++ /dev/null @@ -1,123 +0,0 @@ -package aima.core.environment.cellworld; - -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; - -/** - * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
- *
- * - * A representation for the environment depicted in figure 17.1.
- *
- * Note: the x and y coordinates are always positive integers starting at - * 1.
- * Note: If looking at a rectangle - the coordinate (x=1, y=1) will be the - * bottom left hand corner.
- * - * - * @param - * the type of content for the Cells in the world. - * - * @author Ciaran O'Reilly - * @author Ravi Mohan - */ -public class CellWorld { - private Set> cells = new LinkedHashSet>(); - private Map>> cellLookup = new HashMap>>(); - - /** - * Construct a Cell World with size xDimension * y Dimension cells, all with - * their values set to a default content value. - * - * @param xDimension - * the size of the x dimension. - * @param yDimension - * the size of the y dimension. - * - * @param defaultCellContent - * the default content to assign to each cell created. - */ - public CellWorld(int xDimension, int yDimension, C defaultCellContent) { - for (int x = 1; x <= xDimension; x++) { - Map> xCol = new HashMap>(); - for (int y = 1; y <= yDimension; y++) { - Cell c = new Cell(x, y, defaultCellContent); - cells.add(c); - xCol.put(y, c); - } - cellLookup.put(x, xCol); - } - } - - /** - * - * @return all the cells in this world. - */ - public Set> getCells() { - return cells; - } - - /** - * Determine what cell would be moved into if the specified action is - * performed in the specified cell. Normally, this will be the cell adjacent - * in the appropriate direction. However, if there is no cell in the - * adjacent direction of the action then the outcome of the action is to - * stay in the same cell as the action was performed in. - * - * @param s - * the cell location from which the action is to be performed. - * @param a - * the action to perform (Up, Down, Left, or Right). - * @return the Cell an agent would end up in if they performed the specified - * action from the specified cell location. - */ - public Cell result(Cell s, CellWorldAction a) { - Cell sDelta = getCellAt(a.getXResult(s.getX()), a.getYResult(s - .getY())); - if (null == sDelta) { - // Default to no effect - // (i.e. bumps back in place as no adjoining cell). - sDelta = s; - } - - return sDelta; - } - - /** - * Remove the cell at the specified location from this Cell World. This - * allows you to introduce barriers into different location. - * - * @param x - * the x dimension of the cell to be removed. - * @param y - * the y dimension of the cell to be removed. - */ - public void removeCell(int x, int y) { - Map> xCol = cellLookup.get(x); - if (null != xCol) { - cells.remove(xCol.remove(y)); - } - } - - /** - * Get the cell at the specified x and y locations. - * - * @param x - * the x dimension of the cell to be retrieved. - * @param y - * the y dimension of the cell to be retrieved. - * @return the cell at the specified x,y location, null if no cell exists at - * this location. - */ - public Cell getCellAt(int x, int y) { - Cell c = null; - Map> xCol = cellLookup.get(x); - if (null != xCol) { - c = xCol.get(y); - } - - return c; - } -} diff --git a/src/aima/core/environment/cellworld/CellWorldAction.java b/src/aima/core/environment/cellworld/CellWorldAction.java deleted file mode 100644 index ae14bd8..0000000 --- a/src/aima/core/environment/cellworld/CellWorldAction.java +++ /dev/null @@ -1,142 +0,0 @@ -package aima.core.environment.cellworld; - -import java.util.LinkedHashSet; -import java.util.Set; - -import aima.core.agent.Action; - -/** - * Artificial Intelligence A Modern Approach (3rd Edition): page 645.
- *
- * - * The actions in every state are Up, Down, Left, and Right.
- *
- * Note: Moving 'North' causes y to increase by 1, 'Down' y to decrease by - * 1, 'Left' x to decrease by 1, and 'Right' x to increase by 1 within a Cell - * World. - * - * @author Ciaran O'Reilly - * - */ -public enum CellWorldAction implements Action { - Up, Down, Left, Right, None; - - private static final Set _actions = new LinkedHashSet(); - static { - _actions.add(Up); - _actions.add(Down); - _actions.add(Left); - _actions.add(Right); - _actions.add(None); - } - - /** - * - * @return a set of the actual actions. - */ - public static final Set actions() { - return _actions; - } - - // - // START-Action - //@Override - //public boolean isNoOp() { - // if (None == this) { - // return true; - // } - // return false; - //} - // END-Action - // - - /** - * - * @param curX - * the current x position. - * @return the result on the x position of applying this action. - */ - public int getXResult(int curX) { - int newX = curX; - - switch (this) { - case Left: - newX--; - break; - case Right: - newX++; - break; - } - - return newX; - } - - /** - * - * @param curY - * the current y position. - * @return the result on the y position of applying this action. - */ - public int getYResult(int curY) { - int newY = curY; - - switch (this) { - case Up: - newY++; - break; - case Down: - newY--; - break; - } - - return newY; - } - - /** - * - * @return the first right angled action related to this action. - */ - public CellWorldAction getFirstRightAngledAction() { - CellWorldAction a = null; - - switch (this) { - case Up: - case Down: - a = Left; - break; - case Left: - case Right: - a = Down; - break; - case None: - a = None; - break; - } - - return a; - } - - /** - * - * @return the second right angled action related to this action. - */ - public CellWorldAction getSecondRightAngledAction() { - CellWorldAction a = null; - - switch (this) { - case Up: - case Down: - a = Right; - break; - case Left: - case Right: - a = Up; - break; - case None: - a = None; - break; - } - - return a; - } -} diff --git a/src/aima/core/environment/cellworld/CellWorldFactory.java b/src/aima/core/environment/cellworld/CellWorldFactory.java deleted file mode 100644 index 16ad6ac..0000000 --- a/src/aima/core/environment/cellworld/CellWorldFactory.java +++ /dev/null @@ -1,27 +0,0 @@ -package aima.core.environment.cellworld; - -/** - * - * @author Ciaran O'Reilly - * - */ -public class CellWorldFactory { - - /** - * Create the cell world as defined in Figure 17.1 in AIMA3e. (a) A simple 4 - * x 3 environment that presents the agent with a sequential decision - * problem. - * - * @return a cell world representation of Fig 17.1 in AIMA3e. - */ - public static CellWorld createCellWorldForFig17_1() { - CellWorld cw = new CellWorld(4, 3, -0.04); - - cw.removeCell(2, 2); - - cw.getCellAt(4, 3).setContent(1.0); - cw.getCellAt(4, 2).setContent(-1.0); - - return cw; - } -} \ No newline at end of file diff --git a/src/aima/core/environment/gridworld/GridWorldFactory.java b/src/aima/core/environment/gridworld/GridWorldFactory.java index 35afccf..0d5f767 100644 --- a/src/aima/core/environment/gridworld/GridWorldFactory.java +++ b/src/aima/core/environment/gridworld/GridWorldFactory.java @@ -17,7 +17,12 @@ public class GridWorldFactory { GridWorld cw = new GridWorld(maxTiles, maxScore, nonTerminalReward); cw.getCellAt(maxTiles, maxScore).setContent(1.0); - + for (int score = 1; score < maxScore; score++) { + cw.getCellAt(maxTiles, score).setContent(-0.2); + } + for (int tiles = 1; tiles < maxTiles; tiles++) { + cw.getCellAt(tiles, maxScore).setContent(-0.2); + } return cw; } } \ No newline at end of file diff --git a/src/aima/core/probability/example/MDPFactory.java b/src/aima/core/probability/example/MDPFactory.java index 5494a73..ab21fb0 100644 --- a/src/aima/core/probability/example/MDPFactory.java +++ b/src/aima/core/probability/example/MDPFactory.java @@ -6,9 +6,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import aima.core.environment.cellworld.Cell; -import aima.core.environment.cellworld.CellWorld; -import aima.core.environment.cellworld.CellWorldAction; import aima.core.environment.gridworld.GridCell; import aima.core.environment.gridworld.GridWorld; import aima.core.environment.gridworld.GridWorldAction; @@ -19,30 +16,11 @@ import aima.core.probability.mdp.TransitionProbabilityFunction; import aima.core.probability.mdp.impl.MDP; /** - * - * @author Ciaran O'Reilly - * @author Ravi Mohan + * Based on MDPFactory by Ciaran O'Reilly and Ravi Mohan. + * @author Woody */ public class MDPFactory { - /** - * Constructs an MDP that can be used to generate the utility values - * detailed in Fig 17.3. - * - * @param cw - * the cell world from figure 17.1. - * @return an MDP that can be used to generate the utility values detailed - * in Fig 17.3. - */ - public static MarkovDecisionProcess, CellWorldAction> createMDPForFigure17_3( - final CellWorld cw) { - - return new MDP, CellWorldAction>(cw.getCells(), - cw.getCellAt(1, 1), createActionsFunctionForFigure17_1(cw), - createTransitionProbabilityFunctionForFigure17_1(cw), - createRewardFunctionForFigure17_1()); - } - public static MarkovDecisionProcess, GridWorldAction> createMDPForTileGame( final GridWorld cw, int maxTiles, int maxScore) { @@ -52,36 +30,6 @@ public class MDPFactory { createRewardFunctionForTileGame()); } - /** - * Returns the allowed actions from a specified cell within the cell world - * described in Fig 17.1. - * - * @param cw - * the cell world from figure 17.1. - * @return the set of actions allowed at a particular cell. This set will be - * empty if at a terminal state. - */ - public static ActionsFunction, CellWorldAction> createActionsFunctionForFigure17_1( - final CellWorld cw) { - final Set> terminals = new HashSet>(); - terminals.add(cw.getCellAt(4, 3)); - terminals.add(cw.getCellAt(4, 2)); - - ActionsFunction, CellWorldAction> af = new ActionsFunction, CellWorldAction>() { - - @Override - public Set actions(Cell s) { - // All actions can be performed in each cell - // (except terminal states) - if (terminals.contains(s)) { - return Collections.emptySet(); - } - return CellWorldAction.actions(); - } - }; - return af; - } - public static ActionsFunction, GridWorldAction> createActionsFunctionForTileGame( final GridWorld cw, int maxTiles, int maxScore) { final Set> terminals = new HashSet>(); @@ -102,59 +50,6 @@ public class MDPFactory { return af; } - - /** - * Figure 17.1 (b) Illustration of the transition model of the environment: - * the 'intended' outcome occurs with probability 0.8, but with probability - * 0.2 the agent moves at right angles to the intended direction. A - * collision with a wall results in no movement. - * - * @param cw - * the cell world from figure 17.1. - * @return the transition probability function as described in figure 17.1. - */ - public static TransitionProbabilityFunction, CellWorldAction> createTransitionProbabilityFunctionForFigure17_1( - final CellWorld cw) { - TransitionProbabilityFunction, CellWorldAction> tf = new TransitionProbabilityFunction, CellWorldAction>() { - private double[] distribution = new double[] { 0.8, 0.1, 0.1 }; - - @Override - public double probability(Cell sDelta, Cell s, - CellWorldAction a) { - double prob = 0; - - List> outcomes = possibleOutcomes(s, a); - for (int i = 0; i < outcomes.size(); i++) { - if (sDelta.equals(outcomes.get(i))) { - // Note: You have to sum the matches to - // sDelta as the different actions - // could have the same effect (i.e. - // staying in place due to there being - // no adjacent cells), which increases - // the probability of the transition for - // that state. - prob += distribution[i]; - } - } - - return prob; - } - - private List> possibleOutcomes(Cell c, - CellWorldAction a) { - // There can be three possible outcomes for the planned action - List> outcomes = new ArrayList>(); - - outcomes.add(cw.result(c, a)); - outcomes.add(cw.result(c, a.getFirstRightAngledAction())); - outcomes.add(cw.result(c, a.getSecondRightAngledAction())); - - return outcomes; - } - }; - - return tf; - } public static TransitionProbabilityFunction, GridWorldAction> createTransitionProbabilityFunctionForTileGame( final GridWorld cw) { @@ -170,13 +65,6 @@ public class MDPFactory { List> outcomes = possibleOutcomes(s, a); for (int i = 0; i < outcomes.size(); i++) { if (sDelta.equals(outcomes.get(i))) { - // Note: You have to sum the matches to - // sDelta as the different actions - // could have the same effect (i.e. - // staying in place due to there being - // no adjacent cells), which increases - // the probability of the transition for - // that state. prob += distribution[i]; } } @@ -198,7 +86,7 @@ public class MDPFactory { } private List> possibleOutcomes(GridCell c, GridWorldAction a) { - // There can be three possible outcomes for the planned action + List> outcomes = new ArrayList>(); switch (a) { @@ -224,21 +112,6 @@ public class MDPFactory { return tf; } - /** - * - * @return the reward function which takes the content of the cell as being - * the reward value. - */ - public static RewardFunction> createRewardFunctionForFigure17_1() { - RewardFunction> rf = new RewardFunction>() { - @Override - public double reward(Cell s) { - return s.getContent(); - } - }; - return rf; - } - public static RewardFunction> createRewardFunctionForTileGame() { RewardFunction> rf = new RewardFunction>() { @Override diff --git a/src/model/comPlayer/AdaptiveComPlayer.java b/src/model/comPlayer/AdaptiveComPlayer.java index 5413f1b..03725c5 100644 --- a/src/model/comPlayer/AdaptiveComPlayer.java +++ b/src/model/comPlayer/AdaptiveComPlayer.java @@ -119,5 +119,6 @@ public class AdaptiveComPlayer implements Player { @Override public void setGameGoal(GameGoal target) { this.target = target; + this.calculatePolicy = true; } } \ No newline at end of file diff --git a/src/model/mdp/Action.java b/src/model/mdp/Action.java deleted file mode 100644 index cda7b6d..0000000 --- a/src/model/mdp/Action.java +++ /dev/null @@ -1,18 +0,0 @@ -package model.mdp; - -public class Action { - public static Action playToWin = new Action("PlayToWin"); - public static Action playToLose = new Action("PlayToLose"); - //public static Action maintainScore = new Action(); - - private final String name; - - public Action(String name) { - this.name = name; - } - - @Override - public String toString() { - return name; - } -} diff --git a/src/model/mdp/MDP.java b/src/model/mdp/MDP.java deleted file mode 100644 index fe534b7..0000000 --- a/src/model/mdp/MDP.java +++ /dev/null @@ -1,51 +0,0 @@ -package model.mdp; - -public class MDP { - public static final double nonTerminalReward = -0.25; - - public enum MODE { - CEIL, FLOOR - } - - private final int maxScore; - private final int maxTiles; - private final MODE mode; - - public MDP(int maxScore, int maxTiles, MODE mode) { - this.maxScore = maxScore; - this.maxTiles = maxTiles; - this.mode = mode; - } - - public Action[] getActions(int i, int j) { - if (i == maxScore) { - return new Action[0]; - } - if (j == maxTiles) { - return new Action[0]; - } - return new Action[]{Action.playToLose,Action.playToWin}; - } - - public int getMaxScore() { - return maxScore; - } - - public int getMaxTiles() { - return maxTiles; - } - - public double getReward(int score, int tiles) { - if (score == maxScore && tiles == maxTiles) { - return 10.0; - } - // TODO scale linearly? - if (score == maxScore) { - return -1.0; - } - if (tiles == maxTiles) { - return -5.0; - } - return nonTerminalReward; - } -} \ No newline at end of file diff --git a/src/model/mdp/MDPSolver.java b/src/model/mdp/MDPSolver.java deleted file mode 100644 index 812fed2..0000000 --- a/src/model/mdp/MDPSolver.java +++ /dev/null @@ -1,5 +0,0 @@ -package model.mdp; - -public interface MDPSolver { - Policy solve(MDP mdp); -} diff --git a/src/model/mdp/Policy.java b/src/model/mdp/Policy.java deleted file mode 100644 index 66b9b0c..0000000 --- a/src/model/mdp/Policy.java +++ /dev/null @@ -1,7 +0,0 @@ -package model.mdp; - -import java.util.ArrayList; - -public class Policy extends ArrayList{ - -} diff --git a/src/model/mdp/Transition.java b/src/model/mdp/Transition.java deleted file mode 100644 index 5148b8f..0000000 --- a/src/model/mdp/Transition.java +++ /dev/null @@ -1,34 +0,0 @@ -package model.mdp; - -public class Transition { - private double prob; - private int scoreChange; - private int tileCountChange; - - public Transition(double prob, int scoreChange, int tileCountChange) { - super(); - this.prob = prob; - this.scoreChange = scoreChange; - this.tileCountChange = tileCountChange; - } - - public double getProb() { - return prob; - } - public void setProb(double prob) { - this.prob = prob; - } - public int getScoreChange() { - return scoreChange; - } - public void setScoreChange(int scoreChange) { - this.scoreChange = scoreChange; - } - public int getTileCountChange() { - return tileCountChange; - } - public void setTileCountChange(int tileCountChange) { - this.tileCountChange = tileCountChange; - } - -} \ No newline at end of file diff --git a/src/model/mdp/ValueIterationSolver.java b/src/model/mdp/ValueIterationSolver.java deleted file mode 100644 index 35e9d87..0000000 --- a/src/model/mdp/ValueIterationSolver.java +++ /dev/null @@ -1,110 +0,0 @@ -package model.mdp; - -import java.text.DecimalFormat; -import java.util.ArrayList; -import java.util.List; - -public class ValueIterationSolver implements MDPSolver { - public int maxIterations = 10; - public final double DEFAULT_EPS = 0.1; - public final double GAMMA = 0.9; //discount - - private DecimalFormat fmt = new DecimalFormat("##.00"); - public Policy solve(MDP mdp) { - Policy policy = new Policy(); - - double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1]; - double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1]; - - for (int i = 0; i <= mdp.getMaxScore(); i++) { - //StringBuilder sb = new StringBuilder(); - for (int j = 0; j <= mdp.getMaxTiles(); j++) { - utilityPrime[i][j] = mdp.getReward(i, j); - //sb.append(fmt.format(utility[i][j])); - //sb.append(" "); - } - //System.out.println(sb); - } - - converged: - for (int iteration = 0; iteration < maxIterations; iteration++) { - for (int i = 0; i <= mdp.getMaxScore(); i++) { - for (int j = 0; j <= mdp.getMaxTiles(); j++) { - utility[i][j] = utilityPrime[i][j]; - } - } - for (int i = 0; i <= mdp.getMaxScore(); i++) { - for (int j = 0; j <= mdp.getMaxTiles(); j++) { - Action[] actions = mdp.getActions(i,j); - - double aMax; - if (actions.length > 0) { - aMax = Double.NEGATIVE_INFINITY; - } else { - aMax = 0; - } - - for (Action action : actions){ - List transitions = getTransitions(action,mdp,i,j); - double aSum = 0.0; - for (Transition transition : transitions) { - int transI = transition.getScoreChange(); - int transJ = transition.getTileCountChange(); - if (i+transI >= 0 && i+transI <= mdp.getMaxScore() - && j+transJ >= 0 && j+transJ <= mdp.getMaxTiles()) - aSum += utility[i+transI][j+transJ]; - } - if (aSum > aMax) { - aMax = aSum; - } - } - utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax; - } - } - double maxDiff = getMaxDiff(utility,utilityPrime); - System.out.println("Max diff |U - U'| = " + maxDiff); - if (maxDiff < DEFAULT_EPS) { - System.out.println("Solution to MDP converged: " + maxDiff); - break converged; - } - } - - for (int i = 0; i < utility.length; i++) { - StringBuilder sb = new StringBuilder(); - for (int j = 0; j < utility[i].length; j++) { - sb.append(fmt.format(utility[i][j])); - sb.append(" "); - } - System.out.println(sb); - } - - //utility is now the utility Matrix - //get the policy - return policy; - } - - double getMaxDiff(double[][]u, double[][]uPrime) { - double maxDiff = 0; - for (int i = 0; i < u.length; i++) { - for (int j = 0; j < u[i].length; j++) { - maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j])); - } - } - return maxDiff; - } - - private List getTransitions(Action action, MDP mdp, int score, int tiles) { - List transitions = new ArrayList(); - if (Action.playToWin == action) { - transitions.add(new Transition(0.9,1,1)); - transitions.add(new Transition(0.1,1,-3)); - } else if (Action.playToLose == action) { - transitions.add(new Transition(0.9,1,1)); - transitions.add(new Transition(0.1,1,-3)); - } /*else if (Action.maintainScore == action) { - transitions.add(new Transition(0.5,1,1)); - transitions.add(new Transition(0.5,1,-3)); - }*/ - return transitions; - } -} \ No newline at end of file diff --git a/test/PlayerModel.dat b/test/PlayerModel.dat new file mode 100644 index 0000000..7160940 Binary files /dev/null and b/test/PlayerModel.dat differ diff --git a/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java b/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java index e266e92..200c644 100644 --- a/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java +++ b/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java @@ -5,91 +5,57 @@ import junit.framework.Assert; import org.junit.Before; import org.junit.Test; -import aima.core.environment.cellworld.Cell; -import aima.core.environment.cellworld.CellWorld; -import aima.core.environment.cellworld.CellWorldAction; -import aima.core.environment.cellworld.CellWorldFactory; +import aima.core.environment.gridworld.GridCell; +import aima.core.environment.gridworld.GridWorld; +import aima.core.environment.gridworld.GridWorldAction; +import aima.core.environment.gridworld.GridWorldFactory; import aima.core.probability.example.MDPFactory; import aima.core.probability.mdp.MarkovDecisionProcess; /** - * - * @author Ciaran O'Reilly - * @author Ravi Mohan - * + * Based on MarkovDecisionProcessTest by Ciaran O'Reilly and Ravi Mohan. Used under MIT license. */ public class MarkovDecisionProcessTest { public static final double DELTA_THRESHOLD = 1e-3; - private CellWorld cw = null; - private MarkovDecisionProcess, CellWorldAction> mdp = null; + private double nonTerminalReward = -0.04; + private GridWorld gw = null; + private MarkovDecisionProcess, GridWorldAction> mdp = null; @Before public void setUp() { - cw = CellWorldFactory.createCellWorldForFig17_1(); - mdp = MDPFactory.createMDPForFigure17_3(cw); + int maxTiles = 6; + int maxScore = 10; + + gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore, nonTerminalReward); + mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); } @Test public void testActions() { // Ensure all actions can be performed in each cell // except for the terminal states. - for (Cell s : cw.getCells()) { - if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) { + for (GridCell s : gw.getCells()) { + if (6 == s.getX() && 10 == s.getY()) { Assert.assertEquals(0, mdp.actions(s).size()); } else { - Assert.assertEquals(5, mdp.actions(s).size()); + Assert.assertEquals(3, mdp.actions(s).size()); } } } @Test public void testMDPTransitionModel() { - Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2), - cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); - Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1), - cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); - Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1), - cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); - Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3), - cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD); - - Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1), - cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); - Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1), - cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); - Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1), - cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); - Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2), - cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD); - - Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1), - cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); - Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1), - cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); - Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1), - cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); - Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2), - cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD); - - Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1), - cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); - Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1), - cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); - Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2), - cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); - Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3), - cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD); + Assert.assertEquals(0.66, mdp.transitionProbability(gw.getCellAt(2, 2), + gw.getCellAt(1, 1), GridWorldAction.AddTile), DELTA_THRESHOLD); } @Test public void testRewardFunction() { // Ensure all actions can be performed in each cell. - for (Cell s : cw.getCells()) { - if (4 == s.getX() && 3 == s.getY()) { + for (GridCell s : gw.getCells()) { + if (6 == s.getX() && 10 == s.getY()) { Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD); - } else if (4 == s.getX() && 2 == s.getY()) { - Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD); } else { Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD); } diff --git a/test/aima/core/probability/mdp/PolicyIterationTest.java b/test/aima/core/probability/mdp/PolicyIterationTest.java index 255f403..f9cbe22 100644 --- a/test/aima/core/probability/mdp/PolicyIterationTest.java +++ b/test/aima/core/probability/mdp/PolicyIterationTest.java @@ -1,15 +1,8 @@ package aima.core.probability.mdp; -import java.util.Map; - -import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import aima.core.environment.cellworld.Cell; -import aima.core.environment.cellworld.CellWorld; -import aima.core.environment.cellworld.CellWorldAction; -import aima.core.environment.cellworld.CellWorldFactory; import aima.core.environment.gridworld.GridCell; import aima.core.environment.gridworld.GridWorld; import aima.core.environment.gridworld.GridWorldAction; @@ -18,7 +11,6 @@ import aima.core.probability.example.MDPFactory; import aima.core.probability.mdp.MarkovDecisionProcess; import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation; import aima.core.probability.mdp.search.PolicyIteration; -import aima.core.probability.mdp.search.ValueIteration; /** * @author Ravi Mohan @@ -29,28 +21,31 @@ public class PolicyIterationTest { public static final double DELTA_THRESHOLD = 1e-3; private GridWorld gw = null; - private MarkovDecisionProcess, GridWorldAction> mdp = null; + private MarkovDecisionProcess, GridWorldAction> mdp = null; private PolicyIteration, GridWorldAction> pi = null; final int maxTiles = 6; final int maxScore = 10; - + @Before public void setUp() { - //take 10 turns to place 6 tiles + // take 10 turns to place 6 tiles double defaultPenalty = -0.04; - - gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty); + + gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore, + defaultPenalty); mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); - - //gamma = 1.0 - PolicyEvaluation,GridWorldAction> pe = new ModifiedPolicyEvaluation, GridWorldAction>(100,0.9); + + // gamma = 1.0 + PolicyEvaluation, GridWorldAction> pe = new ModifiedPolicyEvaluation, GridWorldAction>( + 100, 0.9); pi = new PolicyIteration, GridWorldAction>(pe); } @Test public void testPolicyIterationForTileGame() { - Policy, GridWorldAction> policy = pi.policyIteration(mdp); + Policy, GridWorldAction> policy = pi + .policyIteration(mdp); for (int j = maxScore; j >= 1; j--) { StringBuilder sb = new StringBuilder(); @@ -60,21 +55,5 @@ public class PolicyIterationTest { } System.out.println(sb.toString()); } - - //Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD); - /* - Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD); - Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD); - Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD); - Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD); - Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/ } } diff --git a/test/aima/core/probability/mdp/ValueIterationTest.java b/test/aima/core/probability/mdp/ValueIterationTest.java deleted file mode 100644 index 9d1215e..0000000 --- a/test/aima/core/probability/mdp/ValueIterationTest.java +++ /dev/null @@ -1,64 +0,0 @@ -package aima.core.probability.mdp; - -import java.util.Map; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import aima.core.environment.cellworld.Cell; -import aima.core.environment.cellworld.CellWorld; -import aima.core.environment.cellworld.CellWorldAction; -import aima.core.environment.cellworld.CellWorldFactory; -import aima.core.probability.example.MDPFactory; -import aima.core.probability.mdp.MarkovDecisionProcess; -import aima.core.probability.mdp.search.ValueIteration; - -/** - * @author Ravi Mohan - * @author Ciaran O'Reilly - * - */ -public class ValueIterationTest { - public static final double DELTA_THRESHOLD = 1e-3; - - private CellWorld cw = null; - private MarkovDecisionProcess, CellWorldAction> mdp = null; - private ValueIteration, CellWorldAction> vi = null; - - @Before - public void setUp() { - cw = CellWorldFactory.createCellWorldForFig17_1(); - mdp = MDPFactory.createMDPForFigure17_3(cw); - vi = new ValueIteration, CellWorldAction>(1.0); - } - - @Test - public void testValueIterationForFig17_3() { - Map, Double> U = vi.valueIteration(mdp, 0.0001); - - Assert.assertEquals(0.705, U.get(cw.getCellAt(1, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.762, U.get(cw.getCellAt(1, 2)), DELTA_THRESHOLD); - Assert.assertEquals(0.812, U.get(cw.getCellAt(1, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.655, U.get(cw.getCellAt(2, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.868, U.get(cw.getCellAt(2, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.611, U.get(cw.getCellAt(3, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.660, U.get(cw.getCellAt(3, 2)), DELTA_THRESHOLD); - Assert.assertEquals(0.918, U.get(cw.getCellAt(3, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.388, U.get(cw.getCellAt(4, 1)), DELTA_THRESHOLD); - Assert.assertEquals(-1.0, U.get(cw.getCellAt(4, 2)), DELTA_THRESHOLD); - Assert.assertEquals(1.0, U.get(cw.getCellAt(4, 3)), DELTA_THRESHOLD); - - for (int j = 3; j >= 1; j--) { - StringBuilder sb = new StringBuilder(); - for (int i = 1; i <= 4; i++) { - sb.append(U.get(cw.getCellAt(i, j))); - sb.append(" "); - } - System.out.println(sb.toString()); - } - } -} diff --git a/test/aima/core/probability/mdp/ValueIterationTest2.java b/test/aima/core/probability/mdp/ValueIterationTest2.java index a0c6ce1..7b1e0e2 100644 --- a/test/aima/core/probability/mdp/ValueIterationTest2.java +++ b/test/aima/core/probability/mdp/ValueIterationTest2.java @@ -6,10 +6,6 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import aima.core.environment.cellworld.Cell; -import aima.core.environment.cellworld.CellWorld; -import aima.core.environment.cellworld.CellWorldAction; -import aima.core.environment.cellworld.CellWorldFactory; import aima.core.environment.gridworld.GridCell; import aima.core.environment.gridworld.GridWorld; import aima.core.environment.gridworld.GridWorldAction; @@ -19,29 +15,30 @@ import aima.core.probability.mdp.MarkovDecisionProcess; import aima.core.probability.mdp.search.ValueIteration; /** - * @author Ravi Mohan - * @author Ciaran O'Reilly * + * @author Woody + * */ public class ValueIterationTest2 { public static final double DELTA_THRESHOLD = 1e-3; private GridWorld gw = null; - private MarkovDecisionProcess, GridWorldAction> mdp = null; + private MarkovDecisionProcess, GridWorldAction> mdp = null; private ValueIteration, GridWorldAction> vi = null; final int maxTiles = 6; final int maxScore = 10; - + @Before public void setUp() { - //take 10 turns to place 6 tiles + // take 10 turns to place 6 tiles double defaultPenalty = -0.04; - - gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty); + + gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore, + defaultPenalty); mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); - - //gamma = 1.0 + + // gamma = 1.0 vi = new ValueIteration, GridWorldAction>(0.9); } @@ -57,20 +54,7 @@ public class ValueIterationTest2 { } System.out.println(sb.toString()); } - - Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/* - Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD); - Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD); - Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD); - Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD); - Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD); - - Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD); - Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD); - Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/ + Assert.assertEquals(-0.1874236, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD); } } diff --git a/test/model/mdp/ValueIterationSolverTest.java b/test/model/mdp/ValueIterationSolverTest.java deleted file mode 100644 index dac3656..0000000 --- a/test/model/mdp/ValueIterationSolverTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package model.mdp; - -import static org.junit.Assert.assertTrue; -import model.mdp.MDP.MODE; - -import org.junit.Test; - -public class ValueIterationSolverTest { - - @Test - public void testSolve() { - MDPSolver solver = new ValueIterationSolver(); - - //solve for a score of 25 in at most 35 turns - int maxScore = 6; - int maxTurns = 10; - - MDP mdp = new MDP(maxScore,maxTurns,MODE.CEIL); - Policy policy = solver.solve(mdp); - - assertTrue(policy.size() >= maxScore); - assertTrue(policy.size() <= maxTurns); - - System.out.println("Policy: " + policy); - } -}