Fixed unit tests, changed MDP generation to more reasonably seek the goal state, avoiding premature end of game.

Removed unused google-code classes.
Regenerate policy when AdaptiveComPlayer.setTarget() is called.
This commit is contained in:
Woody Folsom
2012-04-30 17:37:37 -04:00
parent 3800436cd9
commit 8f92ae65d8
19 changed files with 53 additions and 939 deletions

View File

@@ -1,87 +0,0 @@
package aima.core.environment.cellworld;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): page 645.<br>
* <br>
* A representation of a Cell in the environment detailed in Figure 17.1.
*
* @param <C>
* the content type of the cell.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*/
public class Cell<C> {
private int x = 1;
private int y = 1;
private C content = null;
/**
* Construct a Cell.
*
* @param x
* the x position of the cell.
* @param y
* the y position of the cell.
* @param content
* the initial content of the cell.
*/
public Cell(int x, int y, C content) {
this.x = x;
this.y = y;
this.content = content;
}
/**
*
* @return the x position of the cell.
*/
public int getX() {
return x;
}
/**
*
* @return the y position of the cell.
*/
public int getY() {
return y;
}
/**
*
* @return the content of the cell.
*/
public C getContent() {
return content;
}
/**
* Set the cell's content.
*
* @param content
* the content to be placed in the cell.
*/
public void setContent(C content) {
this.content = content;
}
@Override
public String toString() {
return "<x=" + x + ", y=" + y + ", content=" + content + ">";
}
@Override
public boolean equals(Object o) {
if (o instanceof Cell<?>) {
Cell<?> c = (Cell<?>) o;
return x == c.x && y == c.y && content.equals(c.content);
}
return false;
}
@Override
public int hashCode() {
return x + 23 + y + 31 * content.hashCode();
}
}

View File

@@ -1,123 +0,0 @@
package aima.core.environment.cellworld;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): page 645.<br>
* <br>
*
* A representation for the environment depicted in figure 17.1.<br>
* <br>
* <b>Note:<b> the x and y coordinates are always positive integers starting at
* 1.<br>
* <b>Note:<b> If looking at a rectangle - the coordinate (x=1, y=1) will be the
* bottom left hand corner.<br>
*
*
* @param <C>
* the type of content for the Cells in the world.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*/
public class CellWorld<C> {
private Set<Cell<C>> cells = new LinkedHashSet<Cell<C>>();
private Map<Integer, Map<Integer, Cell<C>>> cellLookup = new HashMap<Integer, Map<Integer, Cell<C>>>();
/**
* Construct a Cell World with size xDimension * y Dimension cells, all with
* their values set to a default content value.
*
* @param xDimension
* the size of the x dimension.
* @param yDimension
* the size of the y dimension.
*
* @param defaultCellContent
* the default content to assign to each cell created.
*/
public CellWorld(int xDimension, int yDimension, C defaultCellContent) {
for (int x = 1; x <= xDimension; x++) {
Map<Integer, Cell<C>> xCol = new HashMap<Integer, Cell<C>>();
for (int y = 1; y <= yDimension; y++) {
Cell<C> c = new Cell<C>(x, y, defaultCellContent);
cells.add(c);
xCol.put(y, c);
}
cellLookup.put(x, xCol);
}
}
/**
*
* @return all the cells in this world.
*/
public Set<Cell<C>> getCells() {
return cells;
}
/**
* Determine what cell would be moved into if the specified action is
* performed in the specified cell. Normally, this will be the cell adjacent
* in the appropriate direction. However, if there is no cell in the
* adjacent direction of the action then the outcome of the action is to
* stay in the same cell as the action was performed in.
*
* @param s
* the cell location from which the action is to be performed.
* @param a
* the action to perform (Up, Down, Left, or Right).
* @return the Cell an agent would end up in if they performed the specified
* action from the specified cell location.
*/
public Cell<C> result(Cell<C> s, CellWorldAction a) {
Cell<C> sDelta = getCellAt(a.getXResult(s.getX()), a.getYResult(s
.getY()));
if (null == sDelta) {
// Default to no effect
// (i.e. bumps back in place as no adjoining cell).
sDelta = s;
}
return sDelta;
}
/**
* Remove the cell at the specified location from this Cell World. This
* allows you to introduce barriers into different location.
*
* @param x
* the x dimension of the cell to be removed.
* @param y
* the y dimension of the cell to be removed.
*/
public void removeCell(int x, int y) {
Map<Integer, Cell<C>> xCol = cellLookup.get(x);
if (null != xCol) {
cells.remove(xCol.remove(y));
}
}
/**
* Get the cell at the specified x and y locations.
*
* @param x
* the x dimension of the cell to be retrieved.
* @param y
* the y dimension of the cell to be retrieved.
* @return the cell at the specified x,y location, null if no cell exists at
* this location.
*/
public Cell<C> getCellAt(int x, int y) {
Cell<C> c = null;
Map<Integer, Cell<C>> xCol = cellLookup.get(x);
if (null != xCol) {
c = xCol.get(y);
}
return c;
}
}

View File

@@ -1,142 +0,0 @@
package aima.core.environment.cellworld;
import java.util.LinkedHashSet;
import java.util.Set;
import aima.core.agent.Action;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): page 645.<br>
* <br>
*
* The actions in every state are Up, Down, Left, and Right.<br>
* <br>
* <b>Note:<b> Moving 'North' causes y to increase by 1, 'Down' y to decrease by
* 1, 'Left' x to decrease by 1, and 'Right' x to increase by 1 within a Cell
* World.
*
* @author Ciaran O'Reilly
*
*/
public enum CellWorldAction implements Action {
Up, Down, Left, Right, None;
private static final Set<CellWorldAction> _actions = new LinkedHashSet<CellWorldAction>();
static {
_actions.add(Up);
_actions.add(Down);
_actions.add(Left);
_actions.add(Right);
_actions.add(None);
}
/**
*
* @return a set of the actual actions.
*/
public static final Set<CellWorldAction> actions() {
return _actions;
}
//
// START-Action
//@Override
//public boolean isNoOp() {
// if (None == this) {
// return true;
// }
// return false;
//}
// END-Action
//
/**
*
* @param curX
* the current x position.
* @return the result on the x position of applying this action.
*/
public int getXResult(int curX) {
int newX = curX;
switch (this) {
case Left:
newX--;
break;
case Right:
newX++;
break;
}
return newX;
}
/**
*
* @param curY
* the current y position.
* @return the result on the y position of applying this action.
*/
public int getYResult(int curY) {
int newY = curY;
switch (this) {
case Up:
newY++;
break;
case Down:
newY--;
break;
}
return newY;
}
/**
*
* @return the first right angled action related to this action.
*/
public CellWorldAction getFirstRightAngledAction() {
CellWorldAction a = null;
switch (this) {
case Up:
case Down:
a = Left;
break;
case Left:
case Right:
a = Down;
break;
case None:
a = None;
break;
}
return a;
}
/**
*
* @return the second right angled action related to this action.
*/
public CellWorldAction getSecondRightAngledAction() {
CellWorldAction a = null;
switch (this) {
case Up:
case Down:
a = Right;
break;
case Left:
case Right:
a = Up;
break;
case None:
a = None;
break;
}
return a;
}
}

View File

@@ -1,27 +0,0 @@
package aima.core.environment.cellworld;
/**
*
* @author Ciaran O'Reilly
*
*/
public class CellWorldFactory {
/**
* Create the cell world as defined in Figure 17.1 in AIMA3e. (a) A simple 4
* x 3 environment that presents the agent with a sequential decision
* problem.
*
* @return a cell world representation of Fig 17.1 in AIMA3e.
*/
public static CellWorld<Double> createCellWorldForFig17_1() {
CellWorld<Double> cw = new CellWorld<Double>(4, 3, -0.04);
cw.removeCell(2, 2);
cw.getCellAt(4, 3).setContent(1.0);
cw.getCellAt(4, 2).setContent(-1.0);
return cw;
}
}

View File

@@ -17,7 +17,12 @@ public class GridWorldFactory {
GridWorld<Double> cw = new GridWorld<Double>(maxTiles, maxScore, nonTerminalReward);
cw.getCellAt(maxTiles, maxScore).setContent(1.0);
for (int score = 1; score < maxScore; score++) {
cw.getCellAt(maxTiles, score).setContent(-0.2);
}
for (int tiles = 1; tiles < maxTiles; tiles++) {
cw.getCellAt(tiles, maxScore).setContent(-0.2);
}
return cw;
}
}

View File

@@ -6,9 +6,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
@@ -19,30 +16,11 @@ import aima.core.probability.mdp.TransitionProbabilityFunction;
import aima.core.probability.mdp.impl.MDP;
/**
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
* Based on MDPFactory by Ciaran O'Reilly and Ravi Mohan.
* @author Woody
*/
public class MDPFactory {
/**
* Constructs an MDP that can be used to generate the utility values
* detailed in Fig 17.3.
*
* @param cw
* the cell world from figure 17.1.
* @return an MDP that can be used to generate the utility values detailed
* in Fig 17.3.
*/
public static MarkovDecisionProcess<Cell<Double>, CellWorldAction> createMDPForFigure17_3(
final CellWorld<Double> cw) {
return new MDP<Cell<Double>, CellWorldAction>(cw.getCells(),
cw.getCellAt(1, 1), createActionsFunctionForFigure17_1(cw),
createTransitionProbabilityFunctionForFigure17_1(cw),
createRewardFunctionForFigure17_1());
}
public static MarkovDecisionProcess<GridCell<Double>, GridWorldAction> createMDPForTileGame(
final GridWorld<Double> cw, int maxTiles, int maxScore) {
@@ -52,36 +30,6 @@ public class MDPFactory {
createRewardFunctionForTileGame());
}
/**
* Returns the allowed actions from a specified cell within the cell world
* described in Fig 17.1.
*
* @param cw
* the cell world from figure 17.1.
* @return the set of actions allowed at a particular cell. This set will be
* empty if at a terminal state.
*/
public static ActionsFunction<Cell<Double>, CellWorldAction> createActionsFunctionForFigure17_1(
final CellWorld<Double> cw) {
final Set<Cell<Double>> terminals = new HashSet<Cell<Double>>();
terminals.add(cw.getCellAt(4, 3));
terminals.add(cw.getCellAt(4, 2));
ActionsFunction<Cell<Double>, CellWorldAction> af = new ActionsFunction<Cell<Double>, CellWorldAction>() {
@Override
public Set<CellWorldAction> actions(Cell<Double> s) {
// All actions can be performed in each cell
// (except terminal states)
if (terminals.contains(s)) {
return Collections.emptySet();
}
return CellWorldAction.actions();
}
};
return af;
}
public static ActionsFunction<GridCell<Double>, GridWorldAction> createActionsFunctionForTileGame(
final GridWorld<Double> cw, int maxTiles, int maxScore) {
final Set<GridCell<Double>> terminals = new HashSet<GridCell<Double>>();
@@ -102,59 +50,6 @@ public class MDPFactory {
return af;
}
/**
* Figure 17.1 (b) Illustration of the transition model of the environment:
* the 'intended' outcome occurs with probability 0.8, but with probability
* 0.2 the agent moves at right angles to the intended direction. A
* collision with a wall results in no movement.
*
* @param cw
* the cell world from figure 17.1.
* @return the transition probability function as described in figure 17.1.
*/
public static TransitionProbabilityFunction<Cell<Double>, CellWorldAction> createTransitionProbabilityFunctionForFigure17_1(
final CellWorld<Double> cw) {
TransitionProbabilityFunction<Cell<Double>, CellWorldAction> tf = new TransitionProbabilityFunction<Cell<Double>, CellWorldAction>() {
private double[] distribution = new double[] { 0.8, 0.1, 0.1 };
@Override
public double probability(Cell<Double> sDelta, Cell<Double> s,
CellWorldAction a) {
double prob = 0;
List<Cell<Double>> outcomes = possibleOutcomes(s, a);
for (int i = 0; i < outcomes.size(); i++) {
if (sDelta.equals(outcomes.get(i))) {
// Note: You have to sum the matches to
// sDelta as the different actions
// could have the same effect (i.e.
// staying in place due to there being
// no adjacent cells), which increases
// the probability of the transition for
// that state.
prob += distribution[i];
}
}
return prob;
}
private List<Cell<Double>> possibleOutcomes(Cell<Double> c,
CellWorldAction a) {
// There can be three possible outcomes for the planned action
List<Cell<Double>> outcomes = new ArrayList<Cell<Double>>();
outcomes.add(cw.result(c, a));
outcomes.add(cw.result(c, a.getFirstRightAngledAction()));
outcomes.add(cw.result(c, a.getSecondRightAngledAction()));
return outcomes;
}
};
return tf;
}
public static TransitionProbabilityFunction<GridCell<Double>, GridWorldAction> createTransitionProbabilityFunctionForTileGame(
final GridWorld<Double> cw) {
@@ -170,13 +65,6 @@ public class MDPFactory {
List<GridCell<Double>> outcomes = possibleOutcomes(s, a);
for (int i = 0; i < outcomes.size(); i++) {
if (sDelta.equals(outcomes.get(i))) {
// Note: You have to sum the matches to
// sDelta as the different actions
// could have the same effect (i.e.
// staying in place due to there being
// no adjacent cells), which increases
// the probability of the transition for
// that state.
prob += distribution[i];
}
}
@@ -198,7 +86,7 @@ public class MDPFactory {
}
private List<GridCell<Double>> possibleOutcomes(GridCell<Double> c,
GridWorldAction a) {
// There can be three possible outcomes for the planned action
List<GridCell<Double>> outcomes = new ArrayList<GridCell<Double>>();
switch (a) {
@@ -224,21 +112,6 @@ public class MDPFactory {
return tf;
}
/**
*
* @return the reward function which takes the content of the cell as being
* the reward value.
*/
public static RewardFunction<Cell<Double>> createRewardFunctionForFigure17_1() {
RewardFunction<Cell<Double>> rf = new RewardFunction<Cell<Double>>() {
@Override
public double reward(Cell<Double> s) {
return s.getContent();
}
};
return rf;
}
public static RewardFunction<GridCell<Double>> createRewardFunctionForTileGame() {
RewardFunction<GridCell<Double>> rf = new RewardFunction<GridCell<Double>>() {
@Override

View File

@@ -119,5 +119,6 @@ public class AdaptiveComPlayer implements Player {
@Override
public void setGameGoal(GameGoal target) {
this.target = target;
this.calculatePolicy = true;
}
}

View File

@@ -1,18 +0,0 @@
package model.mdp;
public class Action {
public static Action playToWin = new Action("PlayToWin");
public static Action playToLose = new Action("PlayToLose");
//public static Action maintainScore = new Action();
private final String name;
public Action(String name) {
this.name = name;
}
@Override
public String toString() {
return name;
}
}

View File

@@ -1,51 +0,0 @@
package model.mdp;
public class MDP {
public static final double nonTerminalReward = -0.25;
public enum MODE {
CEIL, FLOOR
}
private final int maxScore;
private final int maxTiles;
private final MODE mode;
public MDP(int maxScore, int maxTiles, MODE mode) {
this.maxScore = maxScore;
this.maxTiles = maxTiles;
this.mode = mode;
}
public Action[] getActions(int i, int j) {
if (i == maxScore) {
return new Action[0];
}
if (j == maxTiles) {
return new Action[0];
}
return new Action[]{Action.playToLose,Action.playToWin};
}
public int getMaxScore() {
return maxScore;
}
public int getMaxTiles() {
return maxTiles;
}
public double getReward(int score, int tiles) {
if (score == maxScore && tiles == maxTiles) {
return 10.0;
}
// TODO scale linearly?
if (score == maxScore) {
return -1.0;
}
if (tiles == maxTiles) {
return -5.0;
}
return nonTerminalReward;
}
}

View File

@@ -1,5 +0,0 @@
package model.mdp;
public interface MDPSolver {
Policy solve(MDP mdp);
}

View File

@@ -1,7 +0,0 @@
package model.mdp;
import java.util.ArrayList;
public class Policy extends ArrayList<Action>{
}

View File

@@ -1,34 +0,0 @@
package model.mdp;
public class Transition {
private double prob;
private int scoreChange;
private int tileCountChange;
public Transition(double prob, int scoreChange, int tileCountChange) {
super();
this.prob = prob;
this.scoreChange = scoreChange;
this.tileCountChange = tileCountChange;
}
public double getProb() {
return prob;
}
public void setProb(double prob) {
this.prob = prob;
}
public int getScoreChange() {
return scoreChange;
}
public void setScoreChange(int scoreChange) {
this.scoreChange = scoreChange;
}
public int getTileCountChange() {
return tileCountChange;
}
public void setTileCountChange(int tileCountChange) {
this.tileCountChange = tileCountChange;
}
}

View File

@@ -1,110 +0,0 @@
package model.mdp;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
public class ValueIterationSolver implements MDPSolver {
public int maxIterations = 10;
public final double DEFAULT_EPS = 0.1;
public final double GAMMA = 0.9; //discount
private DecimalFormat fmt = new DecimalFormat("##.00");
public Policy solve(MDP mdp) {
Policy policy = new Policy();
double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
for (int i = 0; i <= mdp.getMaxScore(); i++) {
//StringBuilder sb = new StringBuilder();
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
utilityPrime[i][j] = mdp.getReward(i, j);
//sb.append(fmt.format(utility[i][j]));
//sb.append(" ");
}
//System.out.println(sb);
}
converged:
for (int iteration = 0; iteration < maxIterations; iteration++) {
for (int i = 0; i <= mdp.getMaxScore(); i++) {
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
utility[i][j] = utilityPrime[i][j];
}
}
for (int i = 0; i <= mdp.getMaxScore(); i++) {
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
Action[] actions = mdp.getActions(i,j);
double aMax;
if (actions.length > 0) {
aMax = Double.NEGATIVE_INFINITY;
} else {
aMax = 0;
}
for (Action action : actions){
List<Transition> transitions = getTransitions(action,mdp,i,j);
double aSum = 0.0;
for (Transition transition : transitions) {
int transI = transition.getScoreChange();
int transJ = transition.getTileCountChange();
if (i+transI >= 0 && i+transI <= mdp.getMaxScore()
&& j+transJ >= 0 && j+transJ <= mdp.getMaxTiles())
aSum += utility[i+transI][j+transJ];
}
if (aSum > aMax) {
aMax = aSum;
}
}
utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax;
}
}
double maxDiff = getMaxDiff(utility,utilityPrime);
System.out.println("Max diff |U - U'| = " + maxDiff);
if (maxDiff < DEFAULT_EPS) {
System.out.println("Solution to MDP converged: " + maxDiff);
break converged;
}
}
for (int i = 0; i < utility.length; i++) {
StringBuilder sb = new StringBuilder();
for (int j = 0; j < utility[i].length; j++) {
sb.append(fmt.format(utility[i][j]));
sb.append(" ");
}
System.out.println(sb);
}
//utility is now the utility Matrix
//get the policy
return policy;
}
double getMaxDiff(double[][]u, double[][]uPrime) {
double maxDiff = 0;
for (int i = 0; i < u.length; i++) {
for (int j = 0; j < u[i].length; j++) {
maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j]));
}
}
return maxDiff;
}
private List<Transition> getTransitions(Action action, MDP mdp, int score, int tiles) {
List<Transition> transitions = new ArrayList<Transition>();
if (Action.playToWin == action) {
transitions.add(new Transition(0.9,1,1));
transitions.add(new Transition(0.1,1,-3));
} else if (Action.playToLose == action) {
transitions.add(new Transition(0.9,1,1));
transitions.add(new Transition(0.1,1,-3));
} /*else if (Action.maintainScore == action) {
transitions.add(new Transition(0.5,1,1));
transitions.add(new Transition(0.5,1,-3));
}*/
return transitions;
}
}

BIN
test/PlayerModel.dat Normal file

Binary file not shown.

View File

@@ -5,91 +5,57 @@ import junit.framework.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
import aima.core.environment.gridworld.GridWorldFactory;
import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
/**
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*
* Based on MarkovDecisionProcessTest by Ciaran O'Reilly and Ravi Mohan. Used under MIT license.
*/
public class MarkovDecisionProcessTest {
public static final double DELTA_THRESHOLD = 1e-3;
private CellWorld<Double> cw = null;
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
private double nonTerminalReward = -0.04;
private GridWorld<Double> gw = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
@Before
public void setUp() {
cw = CellWorldFactory.createCellWorldForFig17_1();
mdp = MDPFactory.createMDPForFigure17_3(cw);
int maxTiles = 6;
int maxScore = 10;
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore, nonTerminalReward);
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
}
@Test
public void testActions() {
// Ensure all actions can be performed in each cell
// except for the terminal states.
for (Cell<Double> s : cw.getCells()) {
if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) {
for (GridCell<Double> s : gw.getCells()) {
if (6 == s.getX() && 10 == s.getY()) {
Assert.assertEquals(0, mdp.actions(s).size());
} else {
Assert.assertEquals(5, mdp.actions(s).size());
Assert.assertEquals(3, mdp.actions(s).size());
}
}
}
@Test
public void testMDPTransitionModel() {
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.66, mdp.transitionProbability(gw.getCellAt(2, 2),
gw.getCellAt(1, 1), GridWorldAction.AddTile), DELTA_THRESHOLD);
}
@Test
public void testRewardFunction() {
// Ensure all actions can be performed in each cell.
for (Cell<Double> s : cw.getCells()) {
if (4 == s.getX() && 3 == s.getY()) {
for (GridCell<Double> s : gw.getCells()) {
if (6 == s.getX() && 10 == s.getY()) {
Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD);
} else if (4 == s.getX() && 2 == s.getY()) {
Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD);
} else {
Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD);
}

View File

@@ -1,15 +1,8 @@
package aima.core.probability.mdp;
import java.util.Map;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
@@ -18,7 +11,6 @@ import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
import aima.core.probability.mdp.search.PolicyIteration;
import aima.core.probability.mdp.search.ValueIteration;
/**
* @author Ravi Mohan
@@ -29,28 +21,31 @@ public class PolicyIterationTest {
public static final double DELTA_THRESHOLD = 1e-3;
private GridWorld<Double> gw = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private PolicyIteration<GridCell<Double>, GridWorldAction> pi = null;
final int maxTiles = 6;
final int maxScore = 10;
@Before
public void setUp() {
//take 10 turns to place 6 tiles
// take 10 turns to place 6 tiles
double defaultPenalty = -0.04;
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore,
defaultPenalty);
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
//gamma = 1.0
PolicyEvaluation<GridCell<Double>,GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(100,0.9);
// gamma = 1.0
PolicyEvaluation<GridCell<Double>, GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(
100, 0.9);
pi = new PolicyIteration<GridCell<Double>, GridWorldAction>(pe);
}
@Test
public void testPolicyIterationForTileGame() {
Policy<GridCell<Double>, GridWorldAction> policy = pi.policyIteration(mdp);
Policy<GridCell<Double>, GridWorldAction> policy = pi
.policyIteration(mdp);
for (int j = maxScore; j >= 1; j--) {
StringBuilder sb = new StringBuilder();
@@ -60,21 +55,5 @@ public class PolicyIterationTest {
}
System.out.println(sb.toString());
}
//Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
/*
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
}
}

View File

@@ -1,64 +0,0 @@
package aima.core.probability.mdp;
import java.util.Map;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.search.ValueIteration;
/**
* @author Ravi Mohan
* @author Ciaran O'Reilly
*
*/
public class ValueIterationTest {
public static final double DELTA_THRESHOLD = 1e-3;
private CellWorld<Double> cw = null;
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
private ValueIteration<Cell<Double>, CellWorldAction> vi = null;
@Before
public void setUp() {
cw = CellWorldFactory.createCellWorldForFig17_1();
mdp = MDPFactory.createMDPForFigure17_3(cw);
vi = new ValueIteration<Cell<Double>, CellWorldAction>(1.0);
}
@Test
public void testValueIterationForFig17_3() {
Map<Cell<Double>, Double> U = vi.valueIteration(mdp, 0.0001);
Assert.assertEquals(0.705, U.get(cw.getCellAt(1, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.762, U.get(cw.getCellAt(1, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.812, U.get(cw.getCellAt(1, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.655, U.get(cw.getCellAt(2, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.868, U.get(cw.getCellAt(2, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.611, U.get(cw.getCellAt(3, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.660, U.get(cw.getCellAt(3, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.918, U.get(cw.getCellAt(3, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.388, U.get(cw.getCellAt(4, 1)), DELTA_THRESHOLD);
Assert.assertEquals(-1.0, U.get(cw.getCellAt(4, 2)), DELTA_THRESHOLD);
Assert.assertEquals(1.0, U.get(cw.getCellAt(4, 3)), DELTA_THRESHOLD);
for (int j = 3; j >= 1; j--) {
StringBuilder sb = new StringBuilder();
for (int i = 1; i <= 4; i++) {
sb.append(U.get(cw.getCellAt(i, j)));
sb.append(" ");
}
System.out.println(sb.toString());
}
}
}

View File

@@ -6,10 +6,6 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
@@ -19,29 +15,30 @@ import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.search.ValueIteration;
/**
* @author Ravi Mohan
* @author Ciaran O'Reilly
*
* @author Woody
*
*/
public class ValueIterationTest2 {
public static final double DELTA_THRESHOLD = 1e-3;
private GridWorld<Double> gw = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private ValueIteration<GridCell<Double>, GridWorldAction> vi = null;
final int maxTiles = 6;
final int maxScore = 10;
@Before
public void setUp() {
//take 10 turns to place 6 tiles
// take 10 turns to place 6 tiles
double defaultPenalty = -0.04;
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore,
defaultPenalty);
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
//gamma = 1.0
// gamma = 1.0
vi = new ValueIteration<GridCell<Double>, GridWorldAction>(0.9);
}
@@ -57,20 +54,7 @@ public class ValueIterationTest2 {
}
System.out.println(sb.toString());
}
Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/*
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
Assert.assertEquals(-0.1874236, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
}
}

View File

@@ -1,26 +0,0 @@
package model.mdp;
import static org.junit.Assert.assertTrue;
import model.mdp.MDP.MODE;
import org.junit.Test;
public class ValueIterationSolverTest {
@Test
public void testSolve() {
MDPSolver solver = new ValueIterationSolver();
//solve for a score of 25 in at most 35 turns
int maxScore = 6;
int maxTurns = 10;
MDP mdp = new MDP(maxScore,maxTurns,MODE.CEIL);
Policy policy = solver.solve(mdp);
assertTrue(policy.size() >= maxScore);
assertTrue(policy.size() <= maxTurns);
System.out.println("Policy: " + policy);
}
}