package model.mdp; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.List; public class ValueIterationSolver implements MDPSolver { public int maxIterations = 10; public final double DEFAULT_EPS = 0.1; public final double GAMMA = 0.9; //discount private DecimalFormat fmt = new DecimalFormat("##.00"); public Policy solve(MDP mdp) { Policy policy = new Policy(); double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1]; double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1]; for (int i = 0; i <= mdp.getMaxScore(); i++) { //StringBuilder sb = new StringBuilder(); for (int j = 0; j <= mdp.getMaxTiles(); j++) { utilityPrime[i][j] = mdp.getReward(i, j); //sb.append(fmt.format(utility[i][j])); //sb.append(" "); } //System.out.println(sb); } converged: for (int iteration = 0; iteration < maxIterations; iteration++) { for (int i = 0; i <= mdp.getMaxScore(); i++) { for (int j = 0; j <= mdp.getMaxTiles(); j++) { utility[i][j] = utilityPrime[i][j]; } } for (int i = 0; i <= mdp.getMaxScore(); i++) { for (int j = 0; j <= mdp.getMaxTiles(); j++) { Action[] actions = mdp.getActions(i,j); double aMax; if (actions.length > 0) { aMax = Double.NEGATIVE_INFINITY; } else { aMax = 0; } for (Action action : actions){ List transitions = getTransitions(action,mdp,i,j); double aSum = 0.0; for (Transition transition : transitions) { int transI = transition.getScoreChange(); int transJ = transition.getTileCountChange(); if (i+transI >= 0 && i+transI <= mdp.getMaxScore() && j+transJ >= 0 && j+transJ <= mdp.getMaxTiles()) aSum += utility[i+transI][j+transJ]; } if (aSum > aMax) { aMax = aSum; } } utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax; } } double maxDiff = getMaxDiff(utility,utilityPrime); System.out.println("Max diff |U - U'| = " + maxDiff); if (maxDiff < DEFAULT_EPS) { System.out.println("Solution to MDP converged: " + maxDiff); break converged; } } for (int i = 0; i < utility.length; i++) { StringBuilder sb = new StringBuilder(); for (int j = 0; j < utility[i].length; j++) { sb.append(fmt.format(utility[i][j])); sb.append(" "); } System.out.println(sb); } //utility is now the utility Matrix //get the policy return policy; } double getMaxDiff(double[][]u, double[][]uPrime) { double maxDiff = 0; for (int i = 0; i < u.length; i++) { for (int j = 0; j < u[i].length; j++) { maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j])); } } return maxDiff; } private List getTransitions(Action action, MDP mdp, int score, int tiles) { List transitions = new ArrayList(); if (Action.playToWin == action) { transitions.add(new Transition(0.9,1,1)); transitions.add(new Transition(0.1,1,-3)); } else if (Action.playToLose == action) { transitions.add(new Transition(0.9,1,1)); transitions.add(new Transition(0.1,1,-3)); } /*else if (Action.maintainScore == action) { transitions.add(new Transition(0.5,1,1)); transitions.add(new Transition(0.5,1,-3)); }*/ return transitions; } }