Fixed unit tests, changed MDP generation to more reasonably seek the goal state, avoiding premature end of game.

Removed unused google-code classes.
Regenerate policy when AdaptiveComPlayer.setTarget() is called.
This commit is contained in:
Woody Folsom
2012-04-30 17:37:37 -04:00
parent 3800436cd9
commit 8f92ae65d8
19 changed files with 53 additions and 939 deletions

View File

@@ -119,5 +119,6 @@ public class AdaptiveComPlayer implements Player {
@Override
public void setGameGoal(GameGoal target) {
this.target = target;
this.calculatePolicy = true;
}
}

View File

@@ -1,18 +0,0 @@
package model.mdp;
public class Action {
public static Action playToWin = new Action("PlayToWin");
public static Action playToLose = new Action("PlayToLose");
//public static Action maintainScore = new Action();
private final String name;
public Action(String name) {
this.name = name;
}
@Override
public String toString() {
return name;
}
}

View File

@@ -1,51 +0,0 @@
package model.mdp;
public class MDP {
public static final double nonTerminalReward = -0.25;
public enum MODE {
CEIL, FLOOR
}
private final int maxScore;
private final int maxTiles;
private final MODE mode;
public MDP(int maxScore, int maxTiles, MODE mode) {
this.maxScore = maxScore;
this.maxTiles = maxTiles;
this.mode = mode;
}
public Action[] getActions(int i, int j) {
if (i == maxScore) {
return new Action[0];
}
if (j == maxTiles) {
return new Action[0];
}
return new Action[]{Action.playToLose,Action.playToWin};
}
public int getMaxScore() {
return maxScore;
}
public int getMaxTiles() {
return maxTiles;
}
public double getReward(int score, int tiles) {
if (score == maxScore && tiles == maxTiles) {
return 10.0;
}
// TODO scale linearly?
if (score == maxScore) {
return -1.0;
}
if (tiles == maxTiles) {
return -5.0;
}
return nonTerminalReward;
}
}

View File

@@ -1,5 +0,0 @@
package model.mdp;
public interface MDPSolver {
Policy solve(MDP mdp);
}

View File

@@ -1,7 +0,0 @@
package model.mdp;
import java.util.ArrayList;
public class Policy extends ArrayList<Action>{
}

View File

@@ -1,34 +0,0 @@
package model.mdp;
public class Transition {
private double prob;
private int scoreChange;
private int tileCountChange;
public Transition(double prob, int scoreChange, int tileCountChange) {
super();
this.prob = prob;
this.scoreChange = scoreChange;
this.tileCountChange = tileCountChange;
}
public double getProb() {
return prob;
}
public void setProb(double prob) {
this.prob = prob;
}
public int getScoreChange() {
return scoreChange;
}
public void setScoreChange(int scoreChange) {
this.scoreChange = scoreChange;
}
public int getTileCountChange() {
return tileCountChange;
}
public void setTileCountChange(int tileCountChange) {
this.tileCountChange = tileCountChange;
}
}

View File

@@ -1,110 +0,0 @@
package model.mdp;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
public class ValueIterationSolver implements MDPSolver {
public int maxIterations = 10;
public final double DEFAULT_EPS = 0.1;
public final double GAMMA = 0.9; //discount
private DecimalFormat fmt = new DecimalFormat("##.00");
public Policy solve(MDP mdp) {
Policy policy = new Policy();
double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
for (int i = 0; i <= mdp.getMaxScore(); i++) {
//StringBuilder sb = new StringBuilder();
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
utilityPrime[i][j] = mdp.getReward(i, j);
//sb.append(fmt.format(utility[i][j]));
//sb.append(" ");
}
//System.out.println(sb);
}
converged:
for (int iteration = 0; iteration < maxIterations; iteration++) {
for (int i = 0; i <= mdp.getMaxScore(); i++) {
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
utility[i][j] = utilityPrime[i][j];
}
}
for (int i = 0; i <= mdp.getMaxScore(); i++) {
for (int j = 0; j <= mdp.getMaxTiles(); j++) {
Action[] actions = mdp.getActions(i,j);
double aMax;
if (actions.length > 0) {
aMax = Double.NEGATIVE_INFINITY;
} else {
aMax = 0;
}
for (Action action : actions){
List<Transition> transitions = getTransitions(action,mdp,i,j);
double aSum = 0.0;
for (Transition transition : transitions) {
int transI = transition.getScoreChange();
int transJ = transition.getTileCountChange();
if (i+transI >= 0 && i+transI <= mdp.getMaxScore()
&& j+transJ >= 0 && j+transJ <= mdp.getMaxTiles())
aSum += utility[i+transI][j+transJ];
}
if (aSum > aMax) {
aMax = aSum;
}
}
utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax;
}
}
double maxDiff = getMaxDiff(utility,utilityPrime);
System.out.println("Max diff |U - U'| = " + maxDiff);
if (maxDiff < DEFAULT_EPS) {
System.out.println("Solution to MDP converged: " + maxDiff);
break converged;
}
}
for (int i = 0; i < utility.length; i++) {
StringBuilder sb = new StringBuilder();
for (int j = 0; j < utility[i].length; j++) {
sb.append(fmt.format(utility[i][j]));
sb.append(" ");
}
System.out.println(sb);
}
//utility is now the utility Matrix
//get the policy
return policy;
}
double getMaxDiff(double[][]u, double[][]uPrime) {
double maxDiff = 0;
for (int i = 0; i < u.length; i++) {
for (int j = 0; j < u[i].length; j++) {
maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j]));
}
}
return maxDiff;
}
private List<Transition> getTransitions(Action action, MDP mdp, int score, int tiles) {
List<Transition> transitions = new ArrayList<Transition>();
if (Action.playToWin == action) {
transitions.add(new Transition(0.9,1,1));
transitions.add(new Transition(0.1,1,-3));
} else if (Action.playToLose == action) {
transitions.add(new Transition(0.9,1,1));
transitions.add(new Transition(0.1,1,-3));
} /*else if (Action.maintainScore == action) {
transitions.add(new Transition(0.5,1,1));
transitions.add(new Transition(0.5,1,-3));
}*/
return transitions;
}
}