Implemented agent which chooses to play winning, losing or random moves by solving a simplified MDP model of the game using policy iteration.
Portions of MDP/solver code by Ciaran O'Reilly and Ravi Mohan used under MIT license.
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
package aima.core.probability.mdp;
|
||||
|
||||
import junit.framework.Assert;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import aima.core.environment.cellworld.Cell;
|
||||
import aima.core.environment.cellworld.CellWorld;
|
||||
import aima.core.environment.cellworld.CellWorldAction;
|
||||
import aima.core.environment.cellworld.CellWorldFactory;
|
||||
import aima.core.probability.example.MDPFactory;
|
||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author Ciaran O'Reilly
|
||||
* @author Ravi Mohan
|
||||
*
|
||||
*/
|
||||
public class MarkovDecisionProcessTest {
|
||||
public static final double DELTA_THRESHOLD = 1e-3;
|
||||
|
||||
private CellWorld<Double> cw = null;
|
||||
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
cw = CellWorldFactory.createCellWorldForFig17_1();
|
||||
mdp = MDPFactory.createMDPForFigure17_3(cw);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActions() {
|
||||
// Ensure all actions can be performed in each cell
|
||||
// except for the terminal states.
|
||||
for (Cell<Double> s : cw.getCells()) {
|
||||
if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) {
|
||||
Assert.assertEquals(0, mdp.actions(s).size());
|
||||
} else {
|
||||
Assert.assertEquals(5, mdp.actions(s).size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMDPTransitionModel() {
|
||||
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
|
||||
cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRewardFunction() {
|
||||
// Ensure all actions can be performed in each cell.
|
||||
for (Cell<Double> s : cw.getCells()) {
|
||||
if (4 == s.getX() && 3 == s.getY()) {
|
||||
Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD);
|
||||
} else if (4 == s.getX() && 2 == s.getY()) {
|
||||
Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD);
|
||||
} else {
|
||||
Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
80
test/aima/core/probability/mdp/PolicyIterationTest.java
Normal file
80
test/aima/core/probability/mdp/PolicyIterationTest.java
Normal file
@@ -0,0 +1,80 @@
|
||||
package aima.core.probability.mdp;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import aima.core.environment.cellworld.Cell;
|
||||
import aima.core.environment.cellworld.CellWorld;
|
||||
import aima.core.environment.cellworld.CellWorldAction;
|
||||
import aima.core.environment.cellworld.CellWorldFactory;
|
||||
import aima.core.environment.gridworld.GridCell;
|
||||
import aima.core.environment.gridworld.GridWorld;
|
||||
import aima.core.environment.gridworld.GridWorldAction;
|
||||
import aima.core.environment.gridworld.GridWorldFactory;
|
||||
import aima.core.probability.example.MDPFactory;
|
||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
||||
import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
|
||||
import aima.core.probability.mdp.search.PolicyIteration;
|
||||
import aima.core.probability.mdp.search.ValueIteration;
|
||||
|
||||
/**
|
||||
* @author Ravi Mohan
|
||||
* @author Ciaran O'Reilly
|
||||
*
|
||||
*/
|
||||
public class PolicyIterationTest {
|
||||
public static final double DELTA_THRESHOLD = 1e-3;
|
||||
|
||||
private GridWorld<Double> gw = null;
|
||||
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
||||
private PolicyIteration<GridCell<Double>, GridWorldAction> pi = null;
|
||||
|
||||
final int maxTiles = 6;
|
||||
final int maxScore = 10;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
//take 10 turns to place 6 tiles
|
||||
double defaultPenalty = -0.04;
|
||||
|
||||
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
|
||||
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
||||
|
||||
//gamma = 1.0
|
||||
PolicyEvaluation<GridCell<Double>,GridWorldAction> pe = new ModifiedPolicyEvaluation<GridCell<Double>, GridWorldAction>(100,0.9);
|
||||
pi = new PolicyIteration<GridCell<Double>, GridWorldAction>(pe);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPolicyIterationForTileGame() {
|
||||
Policy<GridCell<Double>, GridWorldAction> policy = pi.policyIteration(mdp);
|
||||
|
||||
for (int j = maxScore; j >= 1; j--) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 1; i <= maxTiles; i++) {
|
||||
sb.append(policy.action(gw.getCellAt(i, j)));
|
||||
sb.append(" ");
|
||||
}
|
||||
System.out.println(sb.toString());
|
||||
}
|
||||
|
||||
//Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
|
||||
/*
|
||||
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
|
||||
}
|
||||
}
|
||||
64
test/aima/core/probability/mdp/ValueIterationTest.java
Normal file
64
test/aima/core/probability/mdp/ValueIterationTest.java
Normal file
@@ -0,0 +1,64 @@
|
||||
package aima.core.probability.mdp;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import aima.core.environment.cellworld.Cell;
|
||||
import aima.core.environment.cellworld.CellWorld;
|
||||
import aima.core.environment.cellworld.CellWorldAction;
|
||||
import aima.core.environment.cellworld.CellWorldFactory;
|
||||
import aima.core.probability.example.MDPFactory;
|
||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
||||
import aima.core.probability.mdp.search.ValueIteration;
|
||||
|
||||
/**
|
||||
* @author Ravi Mohan
|
||||
* @author Ciaran O'Reilly
|
||||
*
|
||||
*/
|
||||
public class ValueIterationTest {
|
||||
public static final double DELTA_THRESHOLD = 1e-3;
|
||||
|
||||
private CellWorld<Double> cw = null;
|
||||
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
|
||||
private ValueIteration<Cell<Double>, CellWorldAction> vi = null;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
cw = CellWorldFactory.createCellWorldForFig17_1();
|
||||
mdp = MDPFactory.createMDPForFigure17_3(cw);
|
||||
vi = new ValueIteration<Cell<Double>, CellWorldAction>(1.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testValueIterationForFig17_3() {
|
||||
Map<Cell<Double>, Double> U = vi.valueIteration(mdp, 0.0001);
|
||||
|
||||
Assert.assertEquals(0.705, U.get(cw.getCellAt(1, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.762, U.get(cw.getCellAt(1, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.812, U.get(cw.getCellAt(1, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.655, U.get(cw.getCellAt(2, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.868, U.get(cw.getCellAt(2, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.611, U.get(cw.getCellAt(3, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.660, U.get(cw.getCellAt(3, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.918, U.get(cw.getCellAt(3, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.388, U.get(cw.getCellAt(4, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(-1.0, U.get(cw.getCellAt(4, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(1.0, U.get(cw.getCellAt(4, 3)), DELTA_THRESHOLD);
|
||||
|
||||
for (int j = 3; j >= 1; j--) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
sb.append(U.get(cw.getCellAt(i, j)));
|
||||
sb.append(" ");
|
||||
}
|
||||
System.out.println(sb.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
76
test/aima/core/probability/mdp/ValueIterationTest2.java
Normal file
76
test/aima/core/probability/mdp/ValueIterationTest2.java
Normal file
@@ -0,0 +1,76 @@
|
||||
package aima.core.probability.mdp;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import aima.core.environment.cellworld.Cell;
|
||||
import aima.core.environment.cellworld.CellWorld;
|
||||
import aima.core.environment.cellworld.CellWorldAction;
|
||||
import aima.core.environment.cellworld.CellWorldFactory;
|
||||
import aima.core.environment.gridworld.GridCell;
|
||||
import aima.core.environment.gridworld.GridWorld;
|
||||
import aima.core.environment.gridworld.GridWorldAction;
|
||||
import aima.core.environment.gridworld.GridWorldFactory;
|
||||
import aima.core.probability.example.MDPFactory;
|
||||
import aima.core.probability.mdp.MarkovDecisionProcess;
|
||||
import aima.core.probability.mdp.search.ValueIteration;
|
||||
|
||||
/**
|
||||
* @author Ravi Mohan
|
||||
* @author Ciaran O'Reilly
|
||||
*
|
||||
*/
|
||||
public class ValueIterationTest2 {
|
||||
public static final double DELTA_THRESHOLD = 1e-3;
|
||||
|
||||
private GridWorld<Double> gw = null;
|
||||
private MarkovDecisionProcess<GridCell<Double>, GridWorldAction> mdp = null;
|
||||
private ValueIteration<GridCell<Double>, GridWorldAction> vi = null;
|
||||
|
||||
final int maxTiles = 6;
|
||||
final int maxScore = 10;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
//take 10 turns to place 6 tiles
|
||||
double defaultPenalty = -0.04;
|
||||
|
||||
gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
|
||||
mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
|
||||
|
||||
//gamma = 1.0
|
||||
vi = new ValueIteration<GridCell<Double>, GridWorldAction>(0.9);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testValueIterationForTileGame() {
|
||||
Map<GridCell<Double>, Double> U = vi.valueIteration(mdp, 1.0);
|
||||
|
||||
for (int j = maxScore; j >= 1; j--) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 1; i <= maxTiles; i++) {
|
||||
sb.append(U.get(gw.getCellAt(i, j)));
|
||||
sb.append(" ");
|
||||
}
|
||||
System.out.println(sb.toString());
|
||||
}
|
||||
|
||||
Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/*
|
||||
Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
|
||||
|
||||
Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
|
||||
Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
|
||||
}
|
||||
}
|
||||
26
test/model/mdp/ValueIterationSolverTest.java
Normal file
26
test/model/mdp/ValueIterationSolverTest.java
Normal file
@@ -0,0 +1,26 @@
|
||||
package model.mdp;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import model.mdp.MDP.MODE;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class ValueIterationSolverTest {
|
||||
|
||||
@Test
|
||||
public void testSolve() {
|
||||
MDPSolver solver = new ValueIterationSolver();
|
||||
|
||||
//solve for a score of 25 in at most 35 turns
|
||||
int maxScore = 6;
|
||||
int maxTurns = 10;
|
||||
|
||||
MDP mdp = new MDP(maxScore,maxTurns,MODE.CEIL);
|
||||
Policy policy = solver.solve(mdp);
|
||||
|
||||
assertTrue(policy.size() >= maxScore);
|
||||
assertTrue(policy.size() <= maxTurns);
|
||||
|
||||
System.out.println("Policy: " + policy);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user