package aima.core.probability.mdp; import java.util.Map; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import aima.core.environment.cellworld.Cell; import aima.core.environment.cellworld.CellWorld; import aima.core.environment.cellworld.CellWorldAction; import aima.core.environment.cellworld.CellWorldFactory; import aima.core.environment.gridworld.GridCell; import aima.core.environment.gridworld.GridWorld; import aima.core.environment.gridworld.GridWorldAction; import aima.core.environment.gridworld.GridWorldFactory; import aima.core.probability.example.MDPFactory; import aima.core.probability.mdp.MarkovDecisionProcess; import aima.core.probability.mdp.search.ValueIteration; /** * @author Ravi Mohan * @author Ciaran O'Reilly * */ public class ValueIterationTest2 { public static final double DELTA_THRESHOLD = 1e-3; private GridWorld gw = null; private MarkovDecisionProcess, GridWorldAction> mdp = null; private ValueIteration, GridWorldAction> vi = null; final int maxTiles = 6; final int maxScore = 10; @Before public void setUp() { //take 10 turns to place 6 tiles double defaultPenalty = -0.04; gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty); mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore); //gamma = 1.0 vi = new ValueIteration, GridWorldAction>(0.9); } @Test public void testValueIterationForTileGame() { Map, Double> U = vi.valueIteration(mdp, 1.0); for (int j = maxScore; j >= 1; j--) { StringBuilder sb = new StringBuilder(); for (int i = 1; i <= maxTiles; i++) { sb.append(U.get(gw.getCellAt(i, j))); sb.append(" "); } System.out.println(sb.toString()); } Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);/* Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD); Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD); Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD); Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD); Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD); Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD); Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD); Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD); Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD); Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/ } }