Fixed unit tests, changed MDP generation to more reasonably seek the goal state, avoiding premature end of game.
Removed unused google-code classes. Regenerate policy when AdaptiveComPlayer.setTarget() is called.
This commit is contained in:
@@ -1,87 +0,0 @@
|
|||||||
package aima.core.environment.cellworld;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Artificial Intelligence A Modern Approach (3rd Edition): page 645.<br>
|
|
||||||
* <br>
|
|
||||||
* A representation of a Cell in the environment detailed in Figure 17.1.
|
|
||||||
*
|
|
||||||
* @param <C>
|
|
||||||
* the content type of the cell.
|
|
||||||
*
|
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
* @author Ravi Mohan
|
|
||||||
*/
|
|
||||||
public class Cell<C> {
|
|
||||||
private int x = 1;
|
|
||||||
private int y = 1;
|
|
||||||
private C content = null;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Construct a Cell.
|
|
||||||
*
|
|
||||||
* @param x
|
|
||||||
* the x position of the cell.
|
|
||||||
* @param y
|
|
||||||
* the y position of the cell.
|
|
||||||
* @param content
|
|
||||||
* the initial content of the cell.
|
|
||||||
*/
|
|
||||||
public Cell(int x, int y, C content) {
|
|
||||||
this.x = x;
|
|
||||||
this.y = y;
|
|
||||||
this.content = content;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return the x position of the cell.
|
|
||||||
*/
|
|
||||||
public int getX() {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return the y position of the cell.
|
|
||||||
*/
|
|
||||||
public int getY() {
|
|
||||||
return y;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return the content of the cell.
|
|
||||||
*/
|
|
||||||
public C getContent() {
|
|
||||||
return content;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the cell's content.
|
|
||||||
*
|
|
||||||
* @param content
|
|
||||||
* the content to be placed in the cell.
|
|
||||||
*/
|
|
||||||
public void setContent(C content) {
|
|
||||||
this.content = content;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "<x=" + x + ", y=" + y + ", content=" + content + ">";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (o instanceof Cell<?>) {
|
|
||||||
Cell<?> c = (Cell<?>) o;
|
|
||||||
return x == c.x && y == c.y && content.equals(c.content);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return x + 23 + y + 31 * content.hashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
package aima.core.environment.cellworld;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedHashSet;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Artificial Intelligence A Modern Approach (3rd Edition): page 645.<br>
|
|
||||||
* <br>
|
|
||||||
*
|
|
||||||
* A representation for the environment depicted in figure 17.1.<br>
|
|
||||||
* <br>
|
|
||||||
* <b>Note:<b> the x and y coordinates are always positive integers starting at
|
|
||||||
* 1.<br>
|
|
||||||
* <b>Note:<b> If looking at a rectangle - the coordinate (x=1, y=1) will be the
|
|
||||||
* bottom left hand corner.<br>
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @param <C>
|
|
||||||
* the type of content for the Cells in the world.
|
|
||||||
*
|
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
* @author Ravi Mohan
|
|
||||||
*/
|
|
||||||
public class CellWorld<C> {
|
|
||||||
private Set<Cell<C>> cells = new LinkedHashSet<Cell<C>>();
|
|
||||||
private Map<Integer, Map<Integer, Cell<C>>> cellLookup = new HashMap<Integer, Map<Integer, Cell<C>>>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Construct a Cell World with size xDimension * y Dimension cells, all with
|
|
||||||
* their values set to a default content value.
|
|
||||||
*
|
|
||||||
* @param xDimension
|
|
||||||
* the size of the x dimension.
|
|
||||||
* @param yDimension
|
|
||||||
* the size of the y dimension.
|
|
||||||
*
|
|
||||||
* @param defaultCellContent
|
|
||||||
* the default content to assign to each cell created.
|
|
||||||
*/
|
|
||||||
public CellWorld(int xDimension, int yDimension, C defaultCellContent) {
|
|
||||||
for (int x = 1; x <= xDimension; x++) {
|
|
||||||
Map<Integer, Cell<C>> xCol = new HashMap<Integer, Cell<C>>();
|
|
||||||
for (int y = 1; y <= yDimension; y++) {
|
|
||||||
Cell<C> c = new Cell<C>(x, y, defaultCellContent);
|
|
||||||
cells.add(c);
|
|
||||||
xCol.put(y, c);
|
|
||||||
}
|
|
||||||
cellLookup.put(x, xCol);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return all the cells in this world.
|
|
||||||
*/
|
|
||||||
public Set<Cell<C>> getCells() {
|
|
||||||
return cells;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Determine what cell would be moved into if the specified action is
|
|
||||||
* performed in the specified cell. Normally, this will be the cell adjacent
|
|
||||||
* in the appropriate direction. However, if there is no cell in the
|
|
||||||
* adjacent direction of the action then the outcome of the action is to
|
|
||||||
* stay in the same cell as the action was performed in.
|
|
||||||
*
|
|
||||||
* @param s
|
|
||||||
* the cell location from which the action is to be performed.
|
|
||||||
* @param a
|
|
||||||
* the action to perform (Up, Down, Left, or Right).
|
|
||||||
* @return the Cell an agent would end up in if they performed the specified
|
|
||||||
* action from the specified cell location.
|
|
||||||
*/
|
|
||||||
public Cell<C> result(Cell<C> s, CellWorldAction a) {
|
|
||||||
Cell<C> sDelta = getCellAt(a.getXResult(s.getX()), a.getYResult(s
|
|
||||||
.getY()));
|
|
||||||
if (null == sDelta) {
|
|
||||||
// Default to no effect
|
|
||||||
// (i.e. bumps back in place as no adjoining cell).
|
|
||||||
sDelta = s;
|
|
||||||
}
|
|
||||||
|
|
||||||
return sDelta;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove the cell at the specified location from this Cell World. This
|
|
||||||
* allows you to introduce barriers into different location.
|
|
||||||
*
|
|
||||||
* @param x
|
|
||||||
* the x dimension of the cell to be removed.
|
|
||||||
* @param y
|
|
||||||
* the y dimension of the cell to be removed.
|
|
||||||
*/
|
|
||||||
public void removeCell(int x, int y) {
|
|
||||||
Map<Integer, Cell<C>> xCol = cellLookup.get(x);
|
|
||||||
if (null != xCol) {
|
|
||||||
cells.remove(xCol.remove(y));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the cell at the specified x and y locations.
|
|
||||||
*
|
|
||||||
* @param x
|
|
||||||
* the x dimension of the cell to be retrieved.
|
|
||||||
* @param y
|
|
||||||
* the y dimension of the cell to be retrieved.
|
|
||||||
* @return the cell at the specified x,y location, null if no cell exists at
|
|
||||||
* this location.
|
|
||||||
*/
|
|
||||||
public Cell<C> getCellAt(int x, int y) {
|
|
||||||
Cell<C> c = null;
|
|
||||||
Map<Integer, Cell<C>> xCol = cellLookup.get(x);
|
|
||||||
if (null != xCol) {
|
|
||||||
c = xCol.get(y);
|
|
||||||
}
|
|
||||||
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
package aima.core.environment.cellworld;
|
|
||||||
|
|
||||||
import java.util.LinkedHashSet;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import aima.core.agent.Action;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Artificial Intelligence A Modern Approach (3rd Edition): page 645.<br>
|
|
||||||
* <br>
|
|
||||||
*
|
|
||||||
* The actions in every state are Up, Down, Left, and Right.<br>
|
|
||||||
* <br>
|
|
||||||
* <b>Note:<b> Moving 'North' causes y to increase by 1, 'Down' y to decrease by
|
|
||||||
* 1, 'Left' x to decrease by 1, and 'Right' x to increase by 1 within a Cell
|
|
||||||
* World.
|
|
||||||
*
|
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public enum CellWorldAction implements Action {
|
|
||||||
Up, Down, Left, Right, None;
|
|
||||||
|
|
||||||
private static final Set<CellWorldAction> _actions = new LinkedHashSet<CellWorldAction>();
|
|
||||||
static {
|
|
||||||
_actions.add(Up);
|
|
||||||
_actions.add(Down);
|
|
||||||
_actions.add(Left);
|
|
||||||
_actions.add(Right);
|
|
||||||
_actions.add(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return a set of the actual actions.
|
|
||||||
*/
|
|
||||||
public static final Set<CellWorldAction> actions() {
|
|
||||||
return _actions;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// START-Action
|
|
||||||
//@Override
|
|
||||||
//public boolean isNoOp() {
|
|
||||||
// if (None == this) {
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
// return false;
|
|
||||||
//}
|
|
||||||
// END-Action
|
|
||||||
//
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param curX
|
|
||||||
* the current x position.
|
|
||||||
* @return the result on the x position of applying this action.
|
|
||||||
*/
|
|
||||||
public int getXResult(int curX) {
|
|
||||||
int newX = curX;
|
|
||||||
|
|
||||||
switch (this) {
|
|
||||||
case Left:
|
|
||||||
newX--;
|
|
||||||
break;
|
|
||||||
case Right:
|
|
||||||
newX++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return newX;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param curY
|
|
||||||
* the current y position.
|
|
||||||
* @return the result on the y position of applying this action.
|
|
||||||
*/
|
|
||||||
public int getYResult(int curY) {
|
|
||||||
int newY = curY;
|
|
||||||
|
|
||||||
switch (this) {
|
|
||||||
case Up:
|
|
||||||
newY++;
|
|
||||||
break;
|
|
||||||
case Down:
|
|
||||||
newY--;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return newY;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return the first right angled action related to this action.
|
|
||||||
*/
|
|
||||||
public CellWorldAction getFirstRightAngledAction() {
|
|
||||||
CellWorldAction a = null;
|
|
||||||
|
|
||||||
switch (this) {
|
|
||||||
case Up:
|
|
||||||
case Down:
|
|
||||||
a = Left;
|
|
||||||
break;
|
|
||||||
case Left:
|
|
||||||
case Right:
|
|
||||||
a = Down;
|
|
||||||
break;
|
|
||||||
case None:
|
|
||||||
a = None;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return the second right angled action related to this action.
|
|
||||||
*/
|
|
||||||
public CellWorldAction getSecondRightAngledAction() {
|
|
||||||
CellWorldAction a = null;
|
|
||||||
|
|
||||||
switch (this) {
|
|
||||||
case Up:
|
|
||||||
case Down:
|
|
||||||
a = Right;
|
|
||||||
break;
|
|
||||||
case Left:
|
|
||||||
case Right:
|
|
||||||
a = Up;
|
|
||||||
break;
|
|
||||||
case None:
|
|
||||||
a = None;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
package aima.core.environment.cellworld;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class CellWorldFactory {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create the cell world as defined in Figure 17.1 in AIMA3e. (a) A simple 4
|
|
||||||
* x 3 environment that presents the agent with a sequential decision
|
|
||||||
* problem.
|
|
||||||
*
|
|
||||||
* @return a cell world representation of Fig 17.1 in AIMA3e.
|
|
||||||
*/
|
|
||||||
public static CellWorld<Double> createCellWorldForFig17_1() {
|
|
||||||
CellWorld<Double> cw = new CellWorld<Double>(4, 3, -0.04);
|
|
||||||
|
|
||||||
cw.removeCell(2, 2);
|
|
||||||
|
|
||||||
cw.getCellAt(4, 3).setContent(1.0);
|
|
||||||
cw.getCellAt(4, 2).setContent(-1.0);
|
|
||||||
|
|
||||||
return cw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -17,7 +17,12 @@ public class GridWorldFactory {
|
|||||||
GridWorld<Double> cw = new GridWorld<Double>(maxTiles, maxScore, nonTerminalReward);
|
GridWorld<Double> cw = new GridWorld<Double>(maxTiles, maxScore, nonTerminalReward);
|
||||||
|
|
||||||
cw.getCellAt(maxTiles, maxScore).setContent(1.0);
|
cw.getCellAt(maxTiles, maxScore).setContent(1.0);
|
||||||
|
for (int score = 1; score < maxScore; score++) {
|
||||||
|
cw.getCellAt(maxTiles, score).setContent(-0.2);
|
||||||
|
}
|
||||||
|
for (int tiles = 1; tiles < maxTiles; tiles++) {
|
||||||
|
cw.getCellAt(tiles, maxScore).setContent(-0.2);
|
||||||
|
}
|
||||||
return cw;
|
return cw;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -6,9 +6,6 @@ import java.util.HashSet;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import aima.core.environment.cellworld.Cell;
|
|
||||||
import aima.core.environment.cellworld.CellWorld;
|
|
||||||
import aima.core.environment.cellworld.CellWorldAction;
|
|
||||||
import aima.core.environment.gridworld.GridCell;
|
import aima.core.environment.gridworld.GridCell;
|
||||||
import aima.core.environment.gridworld.GridWorld;
|
import aima.core.environment.gridworld.GridWorld;
|
||||||
import aima.core.environment.gridworld.GridWorldAction;
|
import aima.core.environment.gridworld.GridWorldAction;
|
||||||
@@ -19,30 +16,11 @@ import aima.core.probability.mdp.TransitionProbabilityFunction;
|
|||||||
import aima.core.probability.mdp.impl.MDP;
|
import aima.core.probability.mdp.impl.MDP;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* Based on MDPFactory by Ciaran O'Reilly and Ravi Mohan.
|
||||||
* @author Ciaran O'Reilly
|
* @author Woody
|
||||||
* @author Ravi Mohan
|
|
||||||
*/
|
*/
|
||||||
public class MDPFactory {
|
public class MDPFactory {
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs an MDP that can be used to generate the utility values
|
|
||||||
* detailed in Fig 17.3.
|
|
||||||
*
|
|
||||||
* @param cw
|
|
||||||
* the cell world from figure 17.1.
|
|
||||||
* @return an MDP that can be used to generate the utility values detailed
|
|
||||||
* in Fig 17.3.
|
|
||||||
*/
|
|
||||||
public static MarkovDecisionProcess<Cell<Double>, CellWorldAction> createMDPForFigure17_3(
|
|
||||||
final CellWorld<Double> cw) {
|
|
||||||
|
|
||||||
return new MDP<Cell<Double>, CellWorldAction>(cw.getCells(),
|
|
||||||
cw.getCellAt(1, 1), createActionsFunctionForFigure17_1(cw),
|
|
||||||
createTransitionProbabilityFunctionForFigure17_1(cw),
|
|
||||||
createRewardFunctionForFigure17_1());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static MarkovDecisionProcess<GridCell<Double>, GridWorldAction> createMDPForTileGame(
|
public static MarkovDecisionProcess<GridCell<Double>, GridWorldAction> createMDPForTileGame(
|
||||||
final GridWorld<Double> cw, int maxTiles, int maxScore) {
|
final GridWorld<Double> cw, int maxTiles, int maxScore) {
|
||||||
|
|
||||||
@@ -52,36 +30,6 @@ public class MDPFactory {
|
|||||||
createRewardFunctionForTileGame());
|
createRewardFunctionForTileGame());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the allowed actions from a specified cell within the cell world
|
|
||||||
* described in Fig 17.1.
|
|
||||||
*
|
|
||||||
* @param cw
|
|
||||||
* the cell world from figure 17.1.
|
|
||||||
* @return the set of actions allowed at a particular cell. This set will be
|
|
||||||
* empty if at a terminal state.
|
|
||||||
*/
|
|
||||||
public static ActionsFunction<Cell<Double>, CellWorldAction> createActionsFunctionForFigure17_1(
|
|
||||||
final CellWorld<Double> cw) {
|
|
||||||
final Set<Cell<Double>> terminals = new HashSet<Cell<Double>>();
|
|
||||||
terminals.add(cw.getCellAt(4, 3));
|
|
||||||
terminals.add(cw.getCellAt(4, 2));
|
|
||||||
|
|
||||||
ActionsFunction<Cell<Double>, CellWorldAction> af = new ActionsFunction<Cell<Double>, CellWorldAction>() {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Set<CellWorldAction> actions(Cell<Double> s) {
|
|
||||||
// All actions can be performed in each cell
|
|
||||||
// (except terminal states)
|
|
||||||
if (terminals.contains(s)) {
|
|
||||||
return Collections.emptySet();
|
|
||||||
}
|
|
||||||
return CellWorldAction.actions();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return af;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ActionsFunction<GridCell<Double>, GridWorldAction> createActionsFunctionForTileGame(
|
public static ActionsFunction<GridCell<Double>, GridWorldAction> createActionsFunctionForTileGame(
|
||||||
final GridWorld<Double> cw, int maxTiles, int maxScore) {
|
final GridWorld<Double> cw, int maxTiles, int maxScore) {
|
||||||
final Set<GridCell<Double>> terminals = new HashSet<GridCell<Double>>();
|
final Set<GridCell<Double>> terminals = new HashSet<GridCell<Double>>();
|
||||||
@@ -102,59 +50,6 @@ public class MDPFactory {
|
|||||||
|
|
||||||
return af;
|
return af;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Figure 17.1 (b) Illustration of the transition model of the environment:
|
|
||||||
* the 'intended' outcome occurs with probability 0.8, but with probability
|
|
||||||
* 0.2 the agent moves at right angles to the intended direction. A
|
|
||||||
* collision with a wall results in no movement.
|
|
||||||
*
|
|
||||||
* @param cw
|
|
||||||
* the cell world from figure 17.1.
|
|
||||||
* @return the transition probability function as described in figure 17.1.
|
|
||||||
*/
|
|
||||||
public static TransitionProbabilityFunction<Cell<Double>, CellWorldAction> createTransitionProbabilityFunctionForFigure17_1(
|
|
||||||
final CellWorld<Double> cw) {
|
|
||||||
TransitionProbabilityFunction<Cell<Double>, CellWorldAction> tf = new TransitionProbabilityFunction<Cell<Double>, CellWorldAction>() {
|
|
||||||
private double[] distribution = new double[] { 0.8, 0.1, 0.1 };
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double probability(Cell<Double> sDelta, Cell<Double> s,
|
|
||||||
CellWorldAction a) {
|
|
||||||
double prob = 0;
|
|
||||||
|
|
||||||
List<Cell<Double>> outcomes = possibleOutcomes(s, a);
|
|
||||||
for (int i = 0; i < outcomes.size(); i++) {
|
|
||||||
if (sDelta.equals(outcomes.get(i))) {
|
|
||||||
// Note: You have to sum the matches to
|
|
||||||
// sDelta as the different actions
|
|
||||||
// could have the same effect (i.e.
|
|
||||||
// staying in place due to there being
|
|
||||||
// no adjacent cells), which increases
|
|
||||||
// the probability of the transition for
|
|
||||||
// that state.
|
|
||||||
prob += distribution[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return prob;
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<Cell<Double>> possibleOutcomes(Cell<Double> c,
|
|
||||||
CellWorldAction a) {
|
|
||||||
// There can be three possible outcomes for the planned action
|
|
||||||
List<Cell<Double>> outcomes = new ArrayList<Cell<Double>>();
|
|
||||||
|
|
||||||
outcomes.add(cw.result(c, a));
|
|
||||||
outcomes.add(cw.result(c, a.getFirstRightAngledAction()));
|
|
||||||
outcomes.add(cw.result(c, a.getSecondRightAngledAction()));
|
|
||||||
|
|
||||||
return outcomes;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return tf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static TransitionProbabilityFunction<GridCell<Double>, GridWorldAction> createTransitionProbabilityFunctionForTileGame(
|
public static TransitionProbabilityFunction<GridCell<Double>, GridWorldAction> createTransitionProbabilityFunctionForTileGame(
|
||||||
final GridWorld<Double> cw) {
|
final GridWorld<Double> cw) {
|
||||||
@@ -170,13 +65,6 @@ public class MDPFactory {
|
|||||||
List<GridCell<Double>> outcomes = possibleOutcomes(s, a);
|
List<GridCell<Double>> outcomes = possibleOutcomes(s, a);
|
||||||
for (int i = 0; i < outcomes.size(); i++) {
|
for (int i = 0; i < outcomes.size(); i++) {
|
||||||
if (sDelta.equals(outcomes.get(i))) {
|
if (sDelta.equals(outcomes.get(i))) {
|
||||||
// Note: You have to sum the matches to
|
|
||||||
// sDelta as the different actions
|
|
||||||
// could have the same effect (i.e.
|
|
||||||
// staying in place due to there being
|
|
||||||
// no adjacent cells), which increases
|
|
||||||
// the probability of the transition for
|
|
||||||
// that state.
|
|
||||||
prob += distribution[i];
|
prob += distribution[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,7 +86,7 @@ public class MDPFactory {
|
|||||||
}
|
}
|
||||||
private List<GridCell<Double>> possibleOutcomes(GridCell<Double> c,
|
private List<GridCell<Double>> possibleOutcomes(GridCell<Double> c,
|
||||||
GridWorldAction a) {
|
GridWorldAction a) {
|
||||||
// There can be three possible outcomes for the planned action
|
|
||||||
List<GridCell<Double>> outcomes = new ArrayList<GridCell<Double>>();
|
List<GridCell<Double>> outcomes = new ArrayList<GridCell<Double>>();
|
||||||
|
|
||||||
switch (a) {
|
switch (a) {
|
||||||
@@ -224,21 +112,6 @@ public class MDPFactory {
|
|||||||
return tf;
|
return tf;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return the reward function which takes the content of the cell as being
|
|
||||||
* the reward value.
|
|
||||||
*/
|
|
||||||
public static RewardFunction<Cell<Double>> createRewardFunctionForFigure17_1() {
|
|
||||||
RewardFunction<Cell<Double>> rf = new RewardFunction<Cell<Double>>() {
|
|
||||||
@Override
|
|
||||||
public double reward(Cell<Double> s) {
|
|
||||||
return s.getContent();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return rf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static RewardFunction<GridCell<Double>> createRewardFunctionForTileGame() {
|
public static RewardFunction<GridCell<Double>> createRewardFunctionForTileGame() {
|
||||||
RewardFunction<GridCell<Double>> rf = new RewardFunction<GridCell<Double>>() {
|
RewardFunction<GridCell<Double>> rf = new RewardFunction<GridCell<Double>>() {
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -119,5 +119,6 @@ public class AdaptiveComPlayer implements Player {
|
|||||||
@Override
|
@Override
|
||||||
public void setGameGoal(GameGoal target) {
|
public void setGameGoal(GameGoal target) {
|
||||||
this.target = target;
|
this.target = target;
|
||||||
|
this.calculatePolicy = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
public class Action {
|
|
||||||
public static Action playToWin = new Action("PlayToWin");
|
|
||||||
public static Action playToLose = new Action("PlayToLose");
|
|
||||||
//public static Action maintainScore = new Action();
|
|
||||||
|
|
||||||
private final String name;
|
|
||||||
|
|
||||||
public Action(String name) {
|
|
||||||
this.name = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
public class MDP {
|
|
||||||
public static final double nonTerminalReward = -0.25;
|
|
||||||
|
|
||||||
public enum MODE {
|
|
||||||
CEIL, FLOOR
|
|
||||||
}
|
|
||||||
|
|
||||||
private final int maxScore;
|
|
||||||
private final int maxTiles;
|
|
||||||
private final MODE mode;
|
|
||||||
|
|
||||||
public MDP(int maxScore, int maxTiles, MODE mode) {
|
|
||||||
this.maxScore = maxScore;
|
|
||||||
this.maxTiles = maxTiles;
|
|
||||||
this.mode = mode;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Action[] getActions(int i, int j) {
|
|
||||||
if (i == maxScore) {
|
|
||||||
return new Action[0];
|
|
||||||
}
|
|
||||||
if (j == maxTiles) {
|
|
||||||
return new Action[0];
|
|
||||||
}
|
|
||||||
return new Action[]{Action.playToLose,Action.playToWin};
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMaxScore() {
|
|
||||||
return maxScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMaxTiles() {
|
|
||||||
return maxTiles;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getReward(int score, int tiles) {
|
|
||||||
if (score == maxScore && tiles == maxTiles) {
|
|
||||||
return 10.0;
|
|
||||||
}
|
|
||||||
// TODO scale linearly?
|
|
||||||
if (score == maxScore) {
|
|
||||||
return -1.0;
|
|
||||||
}
|
|
||||||
if (tiles == maxTiles) {
|
|
||||||
return -5.0;
|
|
||||||
}
|
|
||||||
return nonTerminalReward;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
public interface MDPSolver {
|
|
||||||
Policy solve(MDP mdp);
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
|
|
||||||
public class Policy extends ArrayList<Action>{
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
public class Transition {
|
|
||||||
private double prob;
|
|
||||||
private int scoreChange;
|
|
||||||
private int tileCountChange;
|
|
||||||
|
|
||||||
public Transition(double prob, int scoreChange, int tileCountChange) {
|
|
||||||
super();
|
|
||||||
this.prob = prob;
|
|
||||||
this.scoreChange = scoreChange;
|
|
||||||
this.tileCountChange = tileCountChange;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getProb() {
|
|
||||||
return prob;
|
|
||||||
}
|
|
||||||
public void setProb(double prob) {
|
|
||||||
this.prob = prob;
|
|
||||||
}
|
|
||||||
public int getScoreChange() {
|
|
||||||
return scoreChange;
|
|
||||||
}
|
|
||||||
public void setScoreChange(int scoreChange) {
|
|
||||||
this.scoreChange = scoreChange;
|
|
||||||
}
|
|
||||||
public int getTileCountChange() {
|
|
||||||
return tileCountChange;
|
|
||||||
}
|
|
||||||
public void setTileCountChange(int tileCountChange) {
|
|
||||||
this.tileCountChange = tileCountChange;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
import java.text.DecimalFormat;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ValueIterationSolver implements MDPSolver {
|
|
||||||
public int maxIterations = 10;
|
|
||||||
public final double DEFAULT_EPS = 0.1;
|
|
||||||
public final double GAMMA = 0.9; //discount
|
|
||||||
|
|
||||||
private DecimalFormat fmt = new DecimalFormat("##.00");
|
|
||||||
public Policy solve(MDP mdp) {
|
|
||||||
Policy policy = new Policy();
|
|
||||||
|
|
||||||
double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
|
|
||||||
double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
|
|
||||||
|
|
||||||
for (int i = 0; i <= mdp.getMaxScore(); i++) {
|
|
||||||
//StringBuilder sb = new StringBuilder();
|
|
||||||
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
|
|
||||||
utilityPrime[i][j] = mdp.getReward(i, j);
|
|
||||||
//sb.append(fmt.format(utility[i][j]));
|
|
||||||
//sb.append(" ");
|
|
||||||
}
|
|
||||||
//System.out.println(sb);
|
|
||||||
}
|
|
||||||
|
|
||||||
converged:
|
|
||||||
for (int iteration = 0; iteration < maxIterations; iteration++) {
|
|
||||||
for (int i = 0; i <= mdp.getMaxScore(); i++) {
|
|
||||||
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
|
|
||||||
utility[i][j] = utilityPrime[i][j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i <= mdp.getMaxScore(); i++) {
|
|
||||||
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
|
|
||||||
Action[] actions = mdp.getActions(i,j);
|
|
||||||
|
|
||||||
double aMax;
|
|
||||||
if (actions.length > 0) {
|
|
||||||
aMax = Double.NEGATIVE_INFINITY;
|
|
||||||
} else {
|
|
||||||
aMax = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (Action action : actions){
|
|
||||||
List<Transition> transitions = getTransitions(action,mdp,i,j);
|
|
||||||
double aSum = 0.0;
|
|
||||||
for (Transition transition : transitions) {
|
|
||||||
int transI = transition.getScoreChange();
|
|
||||||
int transJ = transition.getTileCountChange();
|
|
||||||
if (i+transI >= 0 && i+transI <= mdp.getMaxScore()
|
|
||||||
&& j+transJ >= 0 && j+transJ <= mdp.getMaxTiles())
|
|
||||||
aSum += utility[i+transI][j+transJ];
|
|
||||||
}
|
|
||||||
if (aSum > aMax) {
|
|
||||||
aMax = aSum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
double maxDiff = getMaxDiff(utility,utilityPrime);
|
|
||||||
System.out.println("Max diff |U - U'| = " + maxDiff);
|
|
||||||
if (maxDiff < DEFAULT_EPS) {
|
|
||||||
System.out.println("Solution to MDP converged: " + maxDiff);
|
|
||||||
break converged;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < utility.length; i++) {
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
for (int j = 0; j < utility[i].length; j++) {
|
|
||||||
sb.append(fmt.format(utility[i][j]));
|
|
||||||
sb.append(" ");
|
|
||||||
}
|
|
||||||
System.out.println(sb);
|
|
||||||
}
|
|
||||||
|
|
||||||
//utility is now the utility Matrix
|
|
||||||
//get the policy
|
|
||||||
return policy;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getMaxDiff(double[][]u, double[][]uPrime) {
|
|
||||||
double maxDiff = 0;
|
|
||||||
for (int i = 0; i < u.length; i++) {
|
|
||||||
for (int j = 0; j < u[i].length; j++) {
|
|
||||||
maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return maxDiff;
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<Transition> getTransitions(Action action, MDP mdp, int score, int tiles) {
|
|
||||||
List<Transition> transitions = new ArrayList<Transition>();
|
|
||||||
if (Action.playToWin == action) {
|
|
||||||
transitions.add(new Transition(0.9,1,1));
|
|
||||||
transitions.add(new Transition(0.1,1,-3));
|
|
||||||
} else if (Action.playToLose == action) {
|
|
||||||
transitions.add(new Transition(0.9,1,1));
|
|
||||||
transitions.add(new Transition(0.1,1,-3));
|
|
||||||
} /*else if (Action.maintainScore == action) {
|
|
||||||
transitions.add(new Transition(0.5,1,1));
|
|
||||||
transitions.add(new Transition(0.5,1,-3));
|
|
||||||
}*/
|
|
||||||
return transitions;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BIN
test/PlayerModel.dat
Normal file
BIN
test/PlayerModel.dat
Normal file
Binary file not shown.
@@ -5,91 +5,57 @@ import junit.framework.Assert;
|
|||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import aima.core.environment.cellworld.Cell;
|
import aima.core.environment.gridworld.GridCell;
|
||||||
import aima.core.environment.cellworld.CellWorld;
|
import aima.core.environment.gridworld.GridWorld;
|
||||||
import aima.core.environment.cellworld.CellWorldAction;
|
import aima.core.environment.gridworld.GridWorldAction;
|
||||||
import aima.core.environment.cellworld.CellWorldFactory;
|
import aima.core.environment.gridworld.GridWorldFactory;
|
||||||
import aima.core.probability.example.MDPFactory;
|
import aima.core.probability.example.MDPFactory;
|
||||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
import aima.core.probability.mdp.MarkovDecisionProcess;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* Based on MarkovDecisionProcessTest by Ciaran O'Reilly and Ravi Mohan. Used under MIT license.
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
* @author Ravi Mohan
|
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public class MarkovDecisionProcessTest {
|
public class MarkovDecisionProcessTest {
|
||||||
public static final double DELTA_THRESHOLD = 1e-3;
|
public static final double DELTA_THRESHOLD = 1e-3;
|
||||||
|
|
||||||
private CellWorld<Double> cw = null;
|
private double nonTerminalReward = -0.04;
|
||||||
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
|
private GridWorld<Double> gw = null;
|
||||||
|
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
cw = CellWorldFactory.createCellWorldForFig17_1();
|
int maxTiles = 6;
|
||||||
mdp = MDPFactory.createMDPForFigure17_3(cw);
|
int maxScore = 10;
|
||||||
|
|
||||||
|
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore, nonTerminalReward);
|
||||||
|
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testActions() {
|
public void testActions() {
|
||||||
// Ensure all actions can be performed in each cell
|
// Ensure all actions can be performed in each cell
|
||||||
// except for the terminal states.
|
// except for the terminal states.
|
||||||
for (Cell<Double> s : cw.getCells()) {
|
for (GridCell<Double> s : gw.getCells()) {
|
||||||
if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) {
|
if (6 == s.getX() && 10 == s.getY()) {
|
||||||
Assert.assertEquals(0, mdp.actions(s).size());
|
Assert.assertEquals(0, mdp.actions(s).size());
|
||||||
} else {
|
} else {
|
||||||
Assert.assertEquals(5, mdp.actions(s).size());
|
Assert.assertEquals(3, mdp.actions(s).size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMDPTransitionModel() {
|
public void testMDPTransitionModel() {
|
||||||
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2),
|
Assert.assertEquals(0.66, mdp.transitionProbability(gw.getCellAt(2, 2),
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
gw.getCellAt(1, 1), GridWorldAction.AddTile), DELTA_THRESHOLD);
|
||||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
|
|
||||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRewardFunction() {
|
public void testRewardFunction() {
|
||||||
// Ensure all actions can be performed in each cell.
|
// Ensure all actions can be performed in each cell.
|
||||||
for (Cell<Double> s : cw.getCells()) {
|
for (GridCell<Double> s : gw.getCells()) {
|
||||||
if (4 == s.getX() && 3 == s.getY()) {
|
if (6 == s.getX() && 10 == s.getY()) {
|
||||||
Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD);
|
Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD);
|
||||||
} else if (4 == s.getX() && 2 == s.getY()) {
|
|
||||||
Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD);
|
|
||||||
} else {
|
} else {
|
||||||
Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD);
|
Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
package aima.core.probability.mdp;
|
package aima.core.probability.mdp;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import aima.core.environment.cellworld.Cell;
|
|
||||||
import aima.core.environment.cellworld.CellWorld;
|
|
||||||
import aima.core.environment.cellworld.CellWorldAction;
|
|
||||||
import aima.core.environment.cellworld.CellWorldFactory;
|
|
||||||
import aima.core.environment.gridworld.GridCell;
|
import aima.core.environment.gridworld.GridCell;
|
||||||
import aima.core.environment.gridworld.GridWorld;
|
import aima.core.environment.gridworld.GridWorld;
|
||||||
import aima.core.environment.gridworld.GridWorldAction;
|
import aima.core.environment.gridworld.GridWorldAction;
|
||||||
@@ -18,7 +11,6 @@ import aima.core.probability.example.MDPFactory;
|
|||||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
import aima.core.probability.mdp.MarkovDecisionProcess;
|
||||||
import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
|
import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
|
||||||
import aima.core.probability.mdp.search.PolicyIteration;
|
import aima.core.probability.mdp.search.PolicyIteration;
|
||||||
import aima.core.probability.mdp.search.ValueIteration;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Ravi Mohan
|
* @author Ravi Mohan
|
||||||
@@ -29,28 +21,31 @@ public class PolicyIterationTest {
|
|||||||
public static final double DELTA_THRESHOLD = 1e-3;
|
public static final double DELTA_THRESHOLD = 1e-3;
|
||||||
|
|
||||||
private GridWorld<Double> gw = null;
|
private GridWorld<Double> gw = null;
|
||||||
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
||||||
private PolicyIteration<GridCell<Double>, GridWorldAction> pi = null;
|
private PolicyIteration<GridCell<Double>, GridWorldAction> pi = null;
|
||||||
|
|
||||||
final int maxTiles = 6;
|
final int maxTiles = 6;
|
||||||
final int maxScore = 10;
|
final int maxScore = 10;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
//take 10 turns to place 6 tiles
|
// take 10 turns to place 6 tiles
|
||||||
double defaultPenalty = -0.04;
|
double defaultPenalty = -0.04;
|
||||||
|
|
||||||
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
|
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore,
|
||||||
|
defaultPenalty);
|
||||||
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
||||||
|
|
||||||
//gamma = 1.0
|
// gamma = 1.0
|
||||||
PolicyEvaluation<GridCell<Double>,GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(100,0.9);
|
PolicyEvaluation<GridCell<Double>, GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(
|
||||||
|
100, 0.9);
|
||||||
pi = new PolicyIteration<GridCell<Double>, GridWorldAction>(pe);
|
pi = new PolicyIteration<GridCell<Double>, GridWorldAction>(pe);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPolicyIterationForTileGame() {
|
public void testPolicyIterationForTileGame() {
|
||||||
Policy<GridCell<Double>, GridWorldAction> policy = pi.policyIteration(mdp);
|
Policy<GridCell<Double>, GridWorldAction> policy = pi
|
||||||
|
.policyIteration(mdp);
|
||||||
|
|
||||||
for (int j = maxScore; j >= 1; j--) {
|
for (int j = maxScore; j >= 1; j--) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
@@ -60,21 +55,5 @@ public class PolicyIterationTest {
|
|||||||
}
|
}
|
||||||
System.out.println(sb.toString());
|
System.out.println(sb.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
|
|
||||||
/*
|
|
||||||
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
package aima.core.probability.mdp;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import aima.core.environment.cellworld.Cell;
|
|
||||||
import aima.core.environment.cellworld.CellWorld;
|
|
||||||
import aima.core.environment.cellworld.CellWorldAction;
|
|
||||||
import aima.core.environment.cellworld.CellWorldFactory;
|
|
||||||
import aima.core.probability.example.MDPFactory;
|
|
||||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
|
||||||
import aima.core.probability.mdp.search.ValueIteration;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Ravi Mohan
|
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class ValueIterationTest {
|
|
||||||
public static final double DELTA_THRESHOLD = 1e-3;
|
|
||||||
|
|
||||||
private CellWorld<Double> cw = null;
|
|
||||||
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
|
|
||||||
private ValueIteration<Cell<Double>, CellWorldAction> vi = null;
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void setUp() {
|
|
||||||
cw = CellWorldFactory.createCellWorldForFig17_1();
|
|
||||||
mdp = MDPFactory.createMDPForFigure17_3(cw);
|
|
||||||
vi = new ValueIteration<Cell<Double>, CellWorldAction>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testValueIterationForFig17_3() {
|
|
||||||
Map<Cell<Double>, Double> U = vi.valueIteration(mdp, 0.0001);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.705, U.get(cw.getCellAt(1, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.762, U.get(cw.getCellAt(1, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.812, U.get(cw.getCellAt(1, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.655, U.get(cw.getCellAt(2, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.868, U.get(cw.getCellAt(2, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.611, U.get(cw.getCellAt(3, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.660, U.get(cw.getCellAt(3, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.918, U.get(cw.getCellAt(3, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.388, U.get(cw.getCellAt(4, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(-1.0, U.get(cw.getCellAt(4, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(1.0, U.get(cw.getCellAt(4, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
for (int j = 3; j >= 1; j--) {
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
for (int i = 1; i <= 4; i++) {
|
|
||||||
sb.append(U.get(cw.getCellAt(i, j)));
|
|
||||||
sb.append(" ");
|
|
||||||
}
|
|
||||||
System.out.println(sb.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,10 +6,6 @@ import org.junit.Assert;
|
|||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import aima.core.environment.cellworld.Cell;
|
|
||||||
import aima.core.environment.cellworld.CellWorld;
|
|
||||||
import aima.core.environment.cellworld.CellWorldAction;
|
|
||||||
import aima.core.environment.cellworld.CellWorldFactory;
|
|
||||||
import aima.core.environment.gridworld.GridCell;
|
import aima.core.environment.gridworld.GridCell;
|
||||||
import aima.core.environment.gridworld.GridWorld;
|
import aima.core.environment.gridworld.GridWorld;
|
||||||
import aima.core.environment.gridworld.GridWorldAction;
|
import aima.core.environment.gridworld.GridWorldAction;
|
||||||
@@ -19,29 +15,30 @@ import aima.core.probability.mdp.MarkovDecisionProcess;
|
|||||||
import aima.core.probability.mdp.search.ValueIteration;
|
import aima.core.probability.mdp.search.ValueIteration;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Ravi Mohan
|
|
||||||
* @author Ciaran O'Reilly
|
|
||||||
*
|
*
|
||||||
|
* @author Woody
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
public class ValueIterationTest2 {
|
public class ValueIterationTest2 {
|
||||||
public static final double DELTA_THRESHOLD = 1e-3;
|
public static final double DELTA_THRESHOLD = 1e-3;
|
||||||
|
|
||||||
private GridWorld<Double> gw = null;
|
private GridWorld<Double> gw = null;
|
||||||
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
||||||
private ValueIteration<GridCell<Double>, GridWorldAction> vi = null;
|
private ValueIteration<GridCell<Double>, GridWorldAction> vi = null;
|
||||||
|
|
||||||
final int maxTiles = 6;
|
final int maxTiles = 6;
|
||||||
final int maxScore = 10;
|
final int maxScore = 10;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
//take 10 turns to place 6 tiles
|
// take 10 turns to place 6 tiles
|
||||||
double defaultPenalty = -0.04;
|
double defaultPenalty = -0.04;
|
||||||
|
|
||||||
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
|
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore,
|
||||||
|
defaultPenalty);
|
||||||
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
||||||
|
|
||||||
//gamma = 1.0
|
// gamma = 1.0
|
||||||
vi = new ValueIteration<GridCell<Double>, GridWorldAction>(0.9);
|
vi = new ValueIteration<GridCell<Double>, GridWorldAction>(0.9);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,20 +54,7 @@ public class ValueIterationTest2 {
|
|||||||
}
|
}
|
||||||
System.out.println(sb.toString());
|
System.out.println(sb.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/*
|
|
||||||
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
|
Assert.assertEquals(-0.1874236, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
|
||||||
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
|
|
||||||
|
|
||||||
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
|
|
||||||
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package model.mdp;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
import model.mdp.MDP.MODE;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
public class ValueIterationSolverTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSolve() {
|
|
||||||
MDPSolver solver = new ValueIterationSolver();
|
|
||||||
|
|
||||||
//solve for a score of 25 in at most 35 turns
|
|
||||||
int maxScore = 6;
|
|
||||||
int maxTurns = 10;
|
|
||||||
|
|
||||||
MDP mdp = new MDP(maxScore,maxTurns,MODE.CEIL);
|
|
||||||
Policy policy = solver.solve(mdp);
|
|
||||||
|
|
||||||
assertTrue(policy.size() >= maxScore);
|
|
||||||
assertTrue(policy.size() <= maxTurns);
|
|
||||||
|
|
||||||
System.out.println("Policy: " + policy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user