Fixed unit tests, changed MDP generation to more reasonably seek the goal state, avoiding premature end of game.

Removed unused google-code classes.
Regenerate policy when AdaptiveComPlayer.setTarget() is called.
This commit is contained in:
Woody Folsom
2012-04-30 17:37:37 -04:00
parent 3800436cd9
commit 8f92ae65d8
19 changed files with 53 additions and 939 deletions

View File

@@ -5,91 +5,57 @@ import junit.framework.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
import aima.core.environment.gridworld.GridWorldFactory;
import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
/**
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*
* Based on MarkovDecisionProcessTest by Ciaran O'Reilly and Ravi Mohan. Used under MIT license.
*/
public class MarkovDecisionProcessTest {
public static final double DELTA_THRESHOLD = 1e-3;
private CellWorld<Double> cw = null;
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
private double nonTerminalReward = -0.04;
private GridWorld<Double> gw = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
@Before
public void setUp() {
cw = CellWorldFactory.createCellWorldForFig17_1();
mdp = MDPFactory.createMDPForFigure17_3(cw);
int maxTiles = 6;
int maxScore = 10;
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore, nonTerminalReward);
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
}
@Test
public void testActions() {
// Ensure all actions can be performed in each cell
// except for the terminal states.
for (Cell<Double> s : cw.getCells()) {
if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) {
for (GridCell<Double> s : gw.getCells()) {
if (6 == s.getX() && 10 == s.getY()) {
Assert.assertEquals(0, mdp.actions(s).size());
} else {
Assert.assertEquals(5, mdp.actions(s).size());
Assert.assertEquals(3, mdp.actions(s).size());
}
}
}
@Test
public void testMDPTransitionModel() {
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
Assert.assertEquals(0.66, mdp.transitionProbability(gw.getCellAt(2, 2),
gw.getCellAt(1, 1), GridWorldAction.AddTile), DELTA_THRESHOLD);
}
@Test
public void testRewardFunction() {
// Ensure all actions can be performed in each cell.
for (Cell<Double> s : cw.getCells()) {
if (4 == s.getX() && 3 == s.getY()) {
for (GridCell<Double> s : gw.getCells()) {
if (6 == s.getX() && 10 == s.getY()) {
Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD);
} else if (4 == s.getX() && 2 == s.getY()) {
Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD);
} else {
Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD);
}

View File

@@ -1,15 +1,8 @@
package aima.core.probability.mdp;
import java.util.Map;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
@@ -18,7 +11,6 @@ import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
import aima.core.probability.mdp.search.PolicyIteration;
import aima.core.probability.mdp.search.ValueIteration;
/**
* @author Ravi Mohan
@@ -29,28 +21,31 @@ public class PolicyIterationTest {
public static final double DELTA_THRESHOLD = 1e-3;
private GridWorld<Double> gw = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private PolicyIteration<GridCell<Double>, GridWorldAction> pi = null;
final int maxTiles = 6;
final int maxScore = 10;
@Before
public void setUp() {
//take 10 turns to place 6 tiles
// take 10 turns to place 6 tiles
double defaultPenalty = -0.04;
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore,
defaultPenalty);
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
//gamma = 1.0
PolicyEvaluation<GridCell<Double>,GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(100,0.9);
// gamma = 1.0
PolicyEvaluation<GridCell<Double>, GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(
100, 0.9);
pi = new PolicyIteration<GridCell<Double>, GridWorldAction>(pe);
}
@Test
public void testPolicyIterationForTileGame() {
Policy<GridCell<Double>, GridWorldAction> policy = pi.policyIteration(mdp);
Policy<GridCell<Double>, GridWorldAction> policy = pi
.policyIteration(mdp);
for (int j = maxScore; j >= 1; j--) {
StringBuilder sb = new StringBuilder();
@@ -60,21 +55,5 @@ public class PolicyIterationTest {
}
System.out.println(sb.toString());
}
//Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
/*
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
}
}

View File

@@ -1,64 +0,0 @@
package aima.core.probability.mdp;
import java.util.Map;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.search.ValueIteration;
/**
* @author Ravi Mohan
* @author Ciaran O'Reilly
*
*/
public class ValueIterationTest {
public static final double DELTA_THRESHOLD = 1e-3;
private CellWorld<Double> cw = null;
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
private ValueIteration<Cell<Double>, CellWorldAction> vi = null;
@Before
public void setUp() {
cw = CellWorldFactory.createCellWorldForFig17_1();
mdp = MDPFactory.createMDPForFigure17_3(cw);
vi = new ValueIteration<Cell<Double>, CellWorldAction>(1.0);
}
@Test
public void testValueIterationForFig17_3() {
Map<Cell<Double>, Double> U = vi.valueIteration(mdp, 0.0001);
Assert.assertEquals(0.705, U.get(cw.getCellAt(1, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.762, U.get(cw.getCellAt(1, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.812, U.get(cw.getCellAt(1, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.655, U.get(cw.getCellAt(2, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.868, U.get(cw.getCellAt(2, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.611, U.get(cw.getCellAt(3, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.660, U.get(cw.getCellAt(3, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.918, U.get(cw.getCellAt(3, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.388, U.get(cw.getCellAt(4, 1)), DELTA_THRESHOLD);
Assert.assertEquals(-1.0, U.get(cw.getCellAt(4, 2)), DELTA_THRESHOLD);
Assert.assertEquals(1.0, U.get(cw.getCellAt(4, 3)), DELTA_THRESHOLD);
for (int j = 3; j >= 1; j--) {
StringBuilder sb = new StringBuilder();
for (int i = 1; i <= 4; i++) {
sb.append(U.get(cw.getCellAt(i, j)));
sb.append(" ");
}
System.out.println(sb.toString());
}
}
}

View File

@@ -6,10 +6,6 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.environment.gridworld.GridCell;
import aima.core.environment.gridworld.GridWorld;
import aima.core.environment.gridworld.GridWorldAction;
@@ -19,29 +15,30 @@ import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.search.ValueIteration;
/**
* @author Ravi Mohan
* @author Ciaran O'Reilly
*
* @author Woody
*
*/
public class ValueIterationTest2 {
public static final double DELTA_THRESHOLD = 1e-3;
private GridWorld<Double> gw = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
private ValueIteration<GridCell<Double>, GridWorldAction> vi = null;
final int maxTiles = 6;
final int maxScore = 10;
@Before
public void setUp() {
//take 10 turns to place 6 tiles
// take 10 turns to place 6 tiles
double defaultPenalty = -0.04;
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles, maxScore,
defaultPenalty);
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
//gamma = 1.0
// gamma = 1.0
vi = new ValueIteration<GridCell<Double>, GridWorldAction>(0.9);
}
@@ -57,20 +54,7 @@ public class ValueIterationTest2 {
}
System.out.println(sb.toString());
}
Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/*
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
Assert.assertEquals(-0.1874236, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
}
}