states() {
+ return states;
+ }
+
+ @Override
+ public S getInitialState() {
+ return initialState;
+ }
+
+ @Override
+ public Set actions(S s) {
+ return actionsFunction.actions(s);
+ }
+
+ @Override
+ public double transitionProbability(S sDelta, S s, A a) {
+ return transitionProbabilityFunction.probability(sDelta, s, a);
+ }
+
+ @Override
+ public double reward(S s) {
+ return rewardFunction.reward(s);
+ }
+
+ // END-MarkovDecisionProcess
+ //
+}
diff --git a/src/aima/core/probability/mdp/impl/ModifiedPolicyEvaluation.java b/src/aima/core/probability/mdp/impl/ModifiedPolicyEvaluation.java
new file mode 100644
index 0000000..a910c94
--- /dev/null
+++ b/src/aima/core/probability/mdp/impl/ModifiedPolicyEvaluation.java
@@ -0,0 +1,93 @@
+package aima.core.probability.mdp.impl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import aima.core.agent.Action;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+import aima.core.probability.mdp.PolicyEvaluation;
+
+/**
+ * Artificial Intelligence A Modern Approach (3rd Edition): page 657.
+ *
+ * For small state spaces, policy evaluation using exact solution methods is
+ * often the most efficient approach. For large state spaces, O(n3)
+ * time might be prohibitive. Fortunately, it is not necessary to do exact
+ * policy evaluation. Instead, we can perform some number of simplified value
+ * iteration steps (simplified because the policy is fixed) to give a reasonably
+ * good approximation of utilities. The simplified Bellman update for this
+ * process is:
+ *
+ *
+ *
+ * Ui+1(s) <- R(s) + γΣs'P(s'|s,πi(s))Ui(s')
+ *
+ *
+ * and this is repeated k times to produce the next utility estimate. The
+ * resulting algorithm is called modified policy iteration. It is often
+ * much more efficient than standard policy iteration or value iteration.
+ *
+ *
+ * @param
+ * the state type.
+ * @param
+ * the action type.
+ *
+ * @author Ciaran O'Reilly
+ * @author Ravi Mohan
+ *
+ */
+public class ModifiedPolicyEvaluation implements PolicyEvaluation {
+ // # iterations to use to produce the next utility estimate
+ private int k;
+ // discount γ to be used.
+ private double gamma;
+
+ /**
+ * Constructor.
+ *
+ * @param k
+ * number iterations to use to produce the next utility estimate
+ * @param gamma
+ * discount γ to be used
+ */
+ public ModifiedPolicyEvaluation(int k, double gamma) {
+ if (gamma > 1.0 || gamma <= 0.0) {
+ throw new IllegalArgumentException("Gamma must be > 0 and <= 1.0");
+ }
+ this.k = k;
+ this.gamma = gamma;
+ }
+
+ //
+ // START-PolicyEvaluation
+ @Override
+ public Map evaluate(Map pi_i, Map U,
+ MarkovDecisionProcess mdp) {
+ Map U_i = new HashMap(U);
+ Map U_ip1 = new HashMap(U);
+ // repeat k times to produce the next utility estimate
+ for (int i = 0; i < k; i++) {
+ // Ui+1(s) <- R(s) +
+ // γΣs'P(s'|s,πi(s))Ui(s')
+ for (S s : U.keySet()) {
+ A ap_i = pi_i.get(s);
+ double aSum = 0;
+ // Handle terminal states (i.e. no actions)
+ if (null != ap_i) {
+ for (S sDelta : U.keySet()) {
+ aSum += mdp.transitionProbability(sDelta, s, ap_i)
+ * U_i.get(sDelta);
+ }
+ }
+ U_ip1.put(s, mdp.reward(s) + gamma * aSum);
+ }
+
+ U_i.putAll(U_ip1);
+ }
+ return U_ip1;
+ }
+
+ // END-PolicyEvaluation
+ //
+}
diff --git a/src/aima/core/probability/mdp/search/PolicyIteration.java b/src/aima/core/probability/mdp/search/PolicyIteration.java
new file mode 100644
index 0000000..8d2692d
--- /dev/null
+++ b/src/aima/core/probability/mdp/search/PolicyIteration.java
@@ -0,0 +1,144 @@
+package aima.core.probability.mdp.search;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import aima.core.agent.Action;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+import aima.core.probability.mdp.Policy;
+import aima.core.probability.mdp.PolicyEvaluation;
+import aima.core.probability.mdp.impl.LookupPolicy;
+import aima.core.util.Util;
+
+/**
+ * Artificial Intelligence A Modern Approach (3rd Edition): page 657.
+ *
+ *
+ *
+ * function POLICY-ITERATION(mdp) returns a policy
+ * inputs: mdp, an MDP with states S, actions A(s), transition model P(s' | s, a)
+ * local variables: U, a vector of utilities for states in S, initially zero
+ * π, a policy vector indexed by state, initially random
+ *
+ * repeat
+ * U <- POLICY-EVALUATION(π, U, mdp)
+ * unchanged? <- true
+ * for each state s in S do
+ * if maxa ∈ A(s) Σs'P(s'|s,a)U[s'] > Σs'P(s'|s,π[s])U[s'] then do
+ * π[s] <- argmaxa ∈ A(s) Σs'P(s'|s,a)U[s']
+ * unchanged? <- false
+ * until unchanged?
+ * return π
+ *
+ *
+ * Figure 17.7 The policy iteration algorithm for calculating an optimal policy.
+ *
+ * @param
+ * the state type.
+ * @param
+ * the action type.
+ *
+ * @author Ciaran O'Reilly
+ * @author Ravi Mohan
+ *
+ */
+public class PolicyIteration {
+
+ private PolicyEvaluation policyEvaluation = null;
+
+ /**
+ * Constructor.
+ *
+ * @param policyEvaluation
+ * the policy evaluation function to use.
+ */
+ public PolicyIteration(PolicyEvaluation policyEvaluation) {
+ this.policyEvaluation = policyEvaluation;
+ }
+
+ // function POLICY-ITERATION(mdp) returns a policy
+ /**
+ * The policy iteration algorithm for calculating an optimal policy.
+ *
+ * @param mdp
+ * an MDP with states S, actions A(s), transition model P(s'|s,a)
+ * @return an optimal policy
+ */
+ public Policy policyIteration(MarkovDecisionProcess mdp) {
+ // local variables: U, a vector of utilities for states in S, initially
+ // zero
+ Map U = Util.create(mdp.states(), new Double(0));
+ // π, a policy vector indexed by state, initially random
+ Map pi = initialPolicyVector(mdp);
+ boolean unchanged;
+ // repeat
+ do {
+ // U <- POLICY-EVALUATION(π, U, mdp)
+ U = policyEvaluation.evaluate(pi, U, mdp);
+ // unchanged? <- true
+ unchanged = true;
+ // for each state s in S do
+ for (S s : mdp.states()) {
+ // calculate:
+ // maxa ∈ A(s)
+ // Σs'P(s'|s,a)U[s']
+ double aMax = Double.NEGATIVE_INFINITY, piVal = 0;
+ A aArgmax = pi.get(s);
+ for (A a : mdp.actions(s)) {
+ double aSum = 0;
+ for (S sDelta : mdp.states()) {
+ aSum += mdp.transitionProbability(sDelta, s, a)
+ * U.get(sDelta);
+ }
+ if (aSum > aMax) {
+ aMax = aSum;
+ aArgmax = a;
+ }
+ // track:
+ // Σs'P(s'|s,π[s])U[s']
+ if (a.equals(pi.get(s))) {
+ piVal = aSum;
+ }
+ }
+ // if maxa ∈ A(s)
+ // Σs'P(s'|s,a)U[s']
+ // > Σs'P(s'|s,π[s])U[s'] then do
+ if (aMax > piVal) {
+ // π[s] <- argmaxa ∈A(s)
+ // Σs'P(s'|s,a)U[s']
+ pi.put(s, aArgmax);
+ // unchanged? <- false
+ unchanged = false;
+ }
+ }
+ // until unchanged?
+ } while (!unchanged);
+
+ // return π
+ return new LookupPolicy(pi);
+ }
+
+ /**
+ * Create a policy vector indexed by state, initially random.
+ *
+ * @param mdp
+ * an MDP with states S, actions A(s), transition model P(s'|s,a)
+ * @return a policy vector indexed by state, initially random.
+ */
+ public static Map initialPolicyVector(
+ MarkovDecisionProcess mdp) {
+ Map pi = new LinkedHashMap();
+ List actions = new ArrayList();
+ for (S s : mdp.states()) {
+ actions.clear();
+ actions.addAll(mdp.actions(s));
+ // Handle terminal states (i.e. no actions).
+ if (actions.size() > 0) {
+ pi.put(s, Util.selectRandomlyFromList(actions));
+ }
+ }
+ return pi;
+ }
+}
diff --git a/src/aima/core/probability/mdp/search/ValueIteration.java b/src/aima/core/probability/mdp/search/ValueIteration.java
new file mode 100644
index 0000000..3c577c9
--- /dev/null
+++ b/src/aima/core/probability/mdp/search/ValueIteration.java
@@ -0,0 +1,129 @@
+package aima.core.probability.mdp.search;
+
+import java.util.Map;
+import java.util.Set;
+
+import aima.core.agent.Action;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+import aima.core.util.Util;
+
+/**
+ * Artificial Intelligence A Modern Approach (3rd Edition): page 653.
+ *
+ *
+ *
+ * function VALUE-ITERATION(mdp, ε) returns a utility function
+ * inputs: mdp, an MDP with states S, actions A(s), transition model P(s' | s, a),
+ * rewards R(s), discount γ
+ * ε the maximum error allowed in the utility of any state
+ * local variables: U, U', vectors of utilities for states in S, initially zero
+ * δ the maximum change in the utility of any state in an iteration
+ *
+ * repeat
+ * U <- U'; δ <- 0
+ * for each state s in S do
+ * U'[s] <- R(s) + γ maxa ∈ A(s) Σs'P(s' | s, a) U[s']
+ * if |U'[s] - U[s]| > δ then δ <- |U'[s] - U[s]|
+ * until δ < ε(1 - γ)/γ
+ * return U
+ *
+ *
+ * Figure 17.4 The value iteration algorithm for calculating utilities of
+ * states. The termination condition is from Equation (17.8):
+ *
+ *
+ * if ||Ui+1 - Ui|| < ε(1 - γ)/γ then ||Ui+1 - U|| < ε
+ *
+ *
+ * @param
+ * the state type.
+ * @param
+ * the action type.
+ *
+ * @author Ciaran O'Reilly
+ * @author Ravi Mohan
+ *
+ */
+public class ValueIteration {
+ // discount γ to be used.
+ private double gamma = 0;
+
+ /**
+ * Constructor.
+ *
+ * @param gamma
+ * discount γ to be used.
+ */
+ public ValueIteration(double gamma) {
+ if (gamma > 1.0 || gamma <= 0.0) {
+ throw new IllegalArgumentException("Gamma must be > 0 and <= 1.0");
+ }
+ this.gamma = gamma;
+ }
+
+ // function VALUE-ITERATION(mdp, ε) returns a utility function
+ /**
+ * The value iteration algorithm for calculating the utility of states.
+ *
+ * @param mdp
+ * an MDP with states S, actions A(s),
+ * transition model P(s' | s, a), rewards R(s)
+ * @param epsilon
+ * the maximum error allowed in the utility of any state
+ * @return a vector of utilities for states in S
+ */
+ public Map valueIteration(MarkovDecisionProcess mdp,
+ double epsilon) {
+ //
+ // local variables: U, U', vectors of utilities for states in S,
+ // initially zero
+ Map U = Util.create(mdp.states(), new Double(0));
+ Map Udelta = Util.create(mdp.states(), new Double(0));
+ // δ the maximum change in the utility of any state in an
+ // iteration
+ double delta = 0;
+ // Note: Just calculate this once for efficiency purposes:
+ // ε(1 - γ)/γ
+ double minDelta = epsilon * (1 - gamma) / gamma;
+
+ // repeat
+ do {
+ // U <- U'; δ <- 0
+ U.putAll(Udelta);
+ delta = 0;
+ // for each state s in S do
+ for (S s : mdp.states()) {
+ // maxa ∈ A(s)
+ Set actions = mdp.actions(s);
+ // Handle terminal states (i.e. no actions).
+ double aMax = 0;
+ if (actions.size() > 0) {
+ aMax = Double.NEGATIVE_INFINITY;
+ }
+ for (A a : actions) {
+ // Σs'P(s' | s, a) U[s']
+ double aSum = 0;
+ for (S sDelta : mdp.states()) {
+ aSum += mdp.transitionProbability(sDelta, s, a)
+ * U.get(sDelta);
+ }
+ if (aSum > aMax) {
+ aMax = aSum;
+ }
+ }
+ // U'[s] <- R(s) + γ
+ // maxa ∈ A(s)
+ Udelta.put(s, mdp.reward(s) + gamma * aMax);
+ // if |U'[s] - U[s]| > δ then δ <- |U'[s] - U[s]|
+ double aDiff = Math.abs(Udelta.get(s) - U.get(s));
+ if (aDiff > delta) {
+ delta = aDiff;
+ }
+ }
+ // until δ < ε(1 - γ)/γ
+ } while (delta > minDelta);
+
+ // return U
+ return U;
+ }
+}
diff --git a/src/aima/core/util/Util.java b/src/aima/core/util/Util.java
new file mode 100644
index 0000000..ef04fdd
--- /dev/null
+++ b/src/aima/core/util/Util.java
@@ -0,0 +1,240 @@
+package aima.core.util;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Hashtable;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * @author Ravi Mohan
+ *
+ */
+public class Util {
+ public static final String NO = "No";
+ public static final String YES = "Yes";
+ //
+ private static Random _r = new Random();
+
+ /**
+ * Get the first element from a list.
+ *
+ * @param l
+ * the list the first element is to be extracted from.
+ * @return the first element of the passed in list.
+ */
+ public static T first(List l) {
+ return l.get(0);
+ }
+
+ /**
+ * Get a sublist of all of the elements in the list except for first.
+ *
+ * @param l
+ * the list the rest of the elements are to be extracted from.
+ * @return a list of all of the elements in the passed in list except for
+ * the first element.
+ */
+ public static List rest(List l) {
+ return l.subList(1, l.size());
+ }
+
+ /**
+ * Create a Map with the passed in keys having their values
+ * initialized to the passed in value.
+ *
+ * @param keys
+ * the keys for the newly constructed map.
+ * @param value
+ * the value to be associated with each of the maps keys.
+ * @return a map with the passed in keys initialized to value.
+ */
+ public static Map create(Collection keys, V value) {
+ Map map = new LinkedHashMap();
+
+ for (K k : keys) {
+ map.put(k, value);
+ }
+
+ return map;
+ }
+
+ /**
+ * Randomly select an element from a list.
+ *
+ * @param
+ * the type of element to be returned from the list l.
+ * @param l
+ * a list of type T from which an element is to be selected
+ * randomly.
+ * @return a randomly selected element from l.
+ */
+ public static T selectRandomlyFromList(List l) {
+ return l.get(_r.nextInt(l.size()));
+ }
+
+ public static boolean randomBoolean() {
+ int trueOrFalse = _r.nextInt(2);
+ return (!(trueOrFalse == 0));
+ }
+
+ public static double[] normalize(double[] probDist) {
+ int len = probDist.length;
+ double total = 0.0;
+ for (double d : probDist) {
+ total = total + d;
+ }
+
+ double[] normalized = new double[len];
+ if (total != 0) {
+ for (int i = 0; i < len; i++) {
+ normalized[i] = probDist[i] / total;
+ }
+ }
+
+ return normalized;
+ }
+
+ public static List normalize(List values) {
+ double[] valuesAsArray = new double[values.size()];
+ for (int i = 0; i < valuesAsArray.length; i++) {
+ valuesAsArray[i] = values.get(i);
+ }
+ double[] normalized = normalize(valuesAsArray);
+ List results = new ArrayList();
+ for (int i = 0; i < normalized.length; i++) {
+ results.add(normalized[i]);
+ }
+ return results;
+ }
+
+ public static int min(int i, int j) {
+ return (i > j ? j : i);
+ }
+
+ public static int max(int i, int j) {
+ return (i < j ? j : i);
+ }
+
+ public static int max(int i, int j, int k) {
+ return max(max(i, j), k);
+ }
+
+ public static int min(int i, int j, int k) {
+ return min(min(i, j), k);
+ }
+
+ public static T mode(List l) {
+ Hashtable hash = new Hashtable();
+ for (T obj : l) {
+ if (hash.containsKey(obj)) {
+ hash.put(obj, hash.get(obj).intValue() + 1);
+ } else {
+ hash.put(obj, 1);
+ }
+ }
+
+ T maxkey = hash.keySet().iterator().next();
+ for (T key : hash.keySet()) {
+ if (hash.get(key) > hash.get(maxkey)) {
+ maxkey = key;
+ }
+ }
+ return maxkey;
+ }
+
+ public static String[] yesno() {
+ return new String[] { YES, NO };
+ }
+
+ public static double log2(double d) {
+ return Math.log(d) / Math.log(2);
+ }
+
+ public static double information(double[] probabilities) {
+ double total = 0.0;
+ for (double d : probabilities) {
+ total += (-1.0 * log2(d) * d);
+ }
+ return total;
+ }
+
+ public static List removeFrom(List list, T member) {
+ List newList = new ArrayList(list);
+ newList.remove(member);
+ return newList;
+ }
+
+ public static double sumOfSquares(List list) {
+ double accum = 0;
+ for (T item : list) {
+ accum = accum + (item.doubleValue() * item.doubleValue());
+ }
+ return accum;
+ }
+
+ public static String ntimes(String s, int n) {
+ StringBuffer buf = new StringBuffer();
+ for (int i = 0; i < n; i++) {
+ buf.append(s);
+ }
+ return buf.toString();
+ }
+
+ public static void checkForNanOrInfinity(double d) {
+ if (Double.isNaN(d)) {
+ throw new RuntimeException("Not a Number");
+ }
+ if (Double.isInfinite(d)) {
+ throw new RuntimeException("Infinite Number");
+ }
+ }
+
+ public static int randomNumberBetween(int i, int j) {
+ /* i,j bothinclusive */
+ return _r.nextInt(j - i + 1) + i;
+ }
+
+ public static double calculateMean(List lst) {
+ Double sum = 0.0;
+ for (Double d : lst) {
+ sum = sum + d.doubleValue();
+ }
+ return sum / lst.size();
+ }
+
+ public static double calculateStDev(List values, double mean) {
+
+ int listSize = values.size();
+
+ Double sumOfDiffSquared = 0.0;
+ for (Double value : values) {
+ double diffFromMean = value - mean;
+ sumOfDiffSquared += ((diffFromMean * diffFromMean) / (listSize - 1));
+ // division moved here to avoid sum becoming too big if this
+ // doesn't work use incremental formulation
+
+ }
+ double variance = sumOfDiffSquared;
+ // (listSize - 1);
+ // assumes at least 2 members in list.
+ return Math.sqrt(variance);
+ }
+
+ public static List normalizeFromMeanAndStdev(List values,
+ double mean, double stdev) {
+ List normalized = new ArrayList();
+ for (Double d : values) {
+ normalized.add((d - mean) / stdev);
+ }
+ return normalized;
+ }
+
+ public static double generateRandomDoubleBetween(double lowerLimit,
+ double upperLimit) {
+
+ return lowerLimit + ((upperLimit - lowerLimit) * _r.nextDouble());
+ }
+}
\ No newline at end of file
diff --git a/src/model/comPlayer/AdaptiveComPlayer.java b/src/model/comPlayer/AdaptiveComPlayer.java
new file mode 100644
index 0000000..5413f1b
--- /dev/null
+++ b/src/model/comPlayer/AdaptiveComPlayer.java
@@ -0,0 +1,123 @@
+package model.comPlayer;
+
+import model.Board;
+import model.BoardScorer;
+import model.Move;
+import model.comPlayer.generator.AlphaBetaMoveGenerator;
+import model.comPlayer.generator.MonteCarloMoveGenerator;
+import model.comPlayer.generator.MoveGenerator;
+import model.playerModel.GameGoal;
+import model.playerModel.PlayerModel;
+import aima.core.environment.gridworld.GridCell;
+import aima.core.environment.gridworld.GridWorld;
+import aima.core.environment.gridworld.GridWorldAction;
+import aima.core.environment.gridworld.GridWorldFactory;
+import aima.core.probability.example.MDPFactory;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+import aima.core.probability.mdp.Policy;
+import aima.core.probability.mdp.PolicyEvaluation;
+import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
+import aima.core.probability.mdp.search.PolicyIteration;
+
+public class AdaptiveComPlayer implements Player {
+ private final MoveGenerator abMoveGenerator = new AlphaBetaMoveGenerator();
+ private final MoveGenerator mcMoveGenerator = new MonteCarloMoveGenerator();
+
+ private BoardScorer boardScorer = new BoardScorer();
+ private boolean calculatePolicy = true;
+ private GameGoal target = null;
+ private GridWorld gw = null;
+ private MarkovDecisionProcess, GridWorldAction> mdp = null;
+ private Policy, GridWorldAction> policy = null;
+ private PolicyIteration, GridWorldAction> pi = null;
+
+ @Override
+ public void denyMove() {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ @Override
+ public Move getMove(Board board, PlayerModel player) {
+ if (calculatePolicy) {
+ System.out.println("Calculating policy for PlayerModel: " + player);
+
+ // take 10 turns to place 6 tiles
+ double defaultPenalty = -0.25;
+
+ int maxScore = target.getTargetScore();
+ int maxTiles = Board.NUM_COLS * Board.NUM_ROWS;
+
+ gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,
+ maxScore, defaultPenalty);
+ mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
+
+ // gamma = 1.0
+ PolicyEvaluation, GridWorldAction> pe = new ModifiedPolicyEvaluation, GridWorldAction>(
+ 50, 0.9);
+ pi = new PolicyIteration, GridWorldAction>(pe);
+ policy = pi.policyIteration(mdp);
+
+ System.out.println("Optimum policy calculated.");
+
+ for (int j = maxScore; j >= 1; j--) {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 1; i <= maxTiles; i++) {
+ sb.append(policy.action(gw.getCellAt(i, j)));
+ sb.append(" ");
+ }
+ System.out.println(sb.toString());
+ }
+
+ calculatePolicy = false;
+ } else {
+ System.out.println("Using pre-calculated policy");
+ }
+
+ GridCell state = getState(board);
+ GridWorldAction action = policy.action(state);
+
+ if (action == null || state == null) {
+ System.out
+ .println("Board state outside of parameters of MDP. Reverting to failsafe behavior.");
+ action = GridWorldAction.RandomMove;
+ }
+ System.out.println("Performing action " + action + " at state " + state
+ + " per policy.");
+ switch (action) {
+ case AddTile:
+ // System.out.println("Performing action #" +
+ // GridWorldAction.AddTile.ordinal());
+ return abMoveGenerator.genMove(board, false);
+ case CaptureThree:
+ // System.out.println("Performing action #" +
+ // GridWorldAction.CaptureThree.ordinal());
+ return mcMoveGenerator.genMove(board, false);
+ case RandomMove:
+ // System.out.println("Performing action #" +
+ // GridWorldAction.None.ordinal());
+ return mcMoveGenerator.genMove(board, false);
+ default:
+ // System.out.println("Performing failsafe action");
+ return mcMoveGenerator.genMove(board, false);
+ }
+ }
+
+ private GridCell getState(Board board) {
+ return gw.getCellAt(board.getTurn(), boardScorer.getScore(board));
+ }
+
+ @Override
+ public boolean isReady() {
+ return true; // always ready to play a random valid move
+ }
+
+ @Override
+ public String toString() {
+ return "Adaptive ComPlayer";
+ }
+
+ @Override
+ public void setGameGoal(GameGoal target) {
+ this.target = target;
+ }
+}
\ No newline at end of file
diff --git a/src/model/mdp/Action.java b/src/model/mdp/Action.java
new file mode 100644
index 0000000..cda7b6d
--- /dev/null
+++ b/src/model/mdp/Action.java
@@ -0,0 +1,18 @@
+package model.mdp;
+
+public class Action {
+ public static Action playToWin = new Action("PlayToWin");
+ public static Action playToLose = new Action("PlayToLose");
+ //public static Action maintainScore = new Action();
+
+ private final String name;
+
+ public Action(String name) {
+ this.name = name;
+ }
+
+ @Override
+ public String toString() {
+ return name;
+ }
+}
diff --git a/src/model/mdp/MDP.java b/src/model/mdp/MDP.java
new file mode 100644
index 0000000..fe534b7
--- /dev/null
+++ b/src/model/mdp/MDP.java
@@ -0,0 +1,51 @@
+package model.mdp;
+
+public class MDP {
+ public static final double nonTerminalReward = -0.25;
+
+ public enum MODE {
+ CEIL, FLOOR
+ }
+
+ private final int maxScore;
+ private final int maxTiles;
+ private final MODE mode;
+
+ public MDP(int maxScore, int maxTiles, MODE mode) {
+ this.maxScore = maxScore;
+ this.maxTiles = maxTiles;
+ this.mode = mode;
+ }
+
+ public Action[] getActions(int i, int j) {
+ if (i == maxScore) {
+ return new Action[0];
+ }
+ if (j == maxTiles) {
+ return new Action[0];
+ }
+ return new Action[]{Action.playToLose,Action.playToWin};
+ }
+
+ public int getMaxScore() {
+ return maxScore;
+ }
+
+ public int getMaxTiles() {
+ return maxTiles;
+ }
+
+ public double getReward(int score, int tiles) {
+ if (score == maxScore && tiles == maxTiles) {
+ return 10.0;
+ }
+ // TODO scale linearly?
+ if (score == maxScore) {
+ return -1.0;
+ }
+ if (tiles == maxTiles) {
+ return -5.0;
+ }
+ return nonTerminalReward;
+ }
+}
\ No newline at end of file
diff --git a/src/model/mdp/MDPSolver.java b/src/model/mdp/MDPSolver.java
new file mode 100644
index 0000000..812fed2
--- /dev/null
+++ b/src/model/mdp/MDPSolver.java
@@ -0,0 +1,5 @@
+package model.mdp;
+
+public interface MDPSolver {
+ Policy solve(MDP mdp);
+}
diff --git a/src/model/mdp/Policy.java b/src/model/mdp/Policy.java
new file mode 100644
index 0000000..66b9b0c
--- /dev/null
+++ b/src/model/mdp/Policy.java
@@ -0,0 +1,7 @@
+package model.mdp;
+
+import java.util.ArrayList;
+
+public class Policy extends ArrayList{
+
+}
diff --git a/src/model/mdp/Transition.java b/src/model/mdp/Transition.java
new file mode 100644
index 0000000..5148b8f
--- /dev/null
+++ b/src/model/mdp/Transition.java
@@ -0,0 +1,34 @@
+package model.mdp;
+
+public class Transition {
+ private double prob;
+ private int scoreChange;
+ private int tileCountChange;
+
+ public Transition(double prob, int scoreChange, int tileCountChange) {
+ super();
+ this.prob = prob;
+ this.scoreChange = scoreChange;
+ this.tileCountChange = tileCountChange;
+ }
+
+ public double getProb() {
+ return prob;
+ }
+ public void setProb(double prob) {
+ this.prob = prob;
+ }
+ public int getScoreChange() {
+ return scoreChange;
+ }
+ public void setScoreChange(int scoreChange) {
+ this.scoreChange = scoreChange;
+ }
+ public int getTileCountChange() {
+ return tileCountChange;
+ }
+ public void setTileCountChange(int tileCountChange) {
+ this.tileCountChange = tileCountChange;
+ }
+
+}
\ No newline at end of file
diff --git a/src/model/mdp/ValueIterationSolver.java b/src/model/mdp/ValueIterationSolver.java
new file mode 100644
index 0000000..35e9d87
--- /dev/null
+++ b/src/model/mdp/ValueIterationSolver.java
@@ -0,0 +1,110 @@
+package model.mdp;
+
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ValueIterationSolver implements MDPSolver {
+ public int maxIterations = 10;
+ public final double DEFAULT_EPS = 0.1;
+ public final double GAMMA = 0.9; //discount
+
+ private DecimalFormat fmt = new DecimalFormat("##.00");
+ public Policy solve(MDP mdp) {
+ Policy policy = new Policy();
+
+ double[][] utility = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
+ double[][] utilityPrime = new double[mdp.getMaxScore()+1][mdp.getMaxTiles()+1];
+
+ for (int i = 0; i <= mdp.getMaxScore(); i++) {
+ //StringBuilder sb = new StringBuilder();
+ for (int j = 0; j <= mdp.getMaxTiles(); j++) {
+ utilityPrime[i][j] = mdp.getReward(i, j);
+ //sb.append(fmt.format(utility[i][j]));
+ //sb.append(" ");
+ }
+ //System.out.println(sb);
+ }
+
+ converged:
+ for (int iteration = 0; iteration < maxIterations; iteration++) {
+ for (int i = 0; i <= mdp.getMaxScore(); i++) {
+ for (int j = 0; j <= mdp.getMaxTiles(); j++) {
+ utility[i][j] = utilityPrime[i][j];
+ }
+ }
+ for (int i = 0; i <= mdp.getMaxScore(); i++) {
+ for (int j = 0; j <= mdp.getMaxTiles(); j++) {
+ Action[] actions = mdp.getActions(i,j);
+
+ double aMax;
+ if (actions.length > 0) {
+ aMax = Double.NEGATIVE_INFINITY;
+ } else {
+ aMax = 0;
+ }
+
+ for (Action action : actions){
+ List transitions = getTransitions(action,mdp,i,j);
+ double aSum = 0.0;
+ for (Transition transition : transitions) {
+ int transI = transition.getScoreChange();
+ int transJ = transition.getTileCountChange();
+ if (i+transI >= 0 && i+transI <= mdp.getMaxScore()
+ && j+transJ >= 0 && j+transJ <= mdp.getMaxTiles())
+ aSum += utility[i+transI][j+transJ];
+ }
+ if (aSum > aMax) {
+ aMax = aSum;
+ }
+ }
+ utilityPrime[i][j] = mdp.getReward(i,j) + GAMMA * aMax;
+ }
+ }
+ double maxDiff = getMaxDiff(utility,utilityPrime);
+ System.out.println("Max diff |U - U'| = " + maxDiff);
+ if (maxDiff < DEFAULT_EPS) {
+ System.out.println("Solution to MDP converged: " + maxDiff);
+ break converged;
+ }
+ }
+
+ for (int i = 0; i < utility.length; i++) {
+ StringBuilder sb = new StringBuilder();
+ for (int j = 0; j < utility[i].length; j++) {
+ sb.append(fmt.format(utility[i][j]));
+ sb.append(" ");
+ }
+ System.out.println(sb);
+ }
+
+ //utility is now the utility Matrix
+ //get the policy
+ return policy;
+ }
+
+ double getMaxDiff(double[][]u, double[][]uPrime) {
+ double maxDiff = 0;
+ for (int i = 0; i < u.length; i++) {
+ for (int j = 0; j < u[i].length; j++) {
+ maxDiff = Math.max(maxDiff,Math.abs(u[i][j] - uPrime[i][j]));
+ }
+ }
+ return maxDiff;
+ }
+
+ private List getTransitions(Action action, MDP mdp, int score, int tiles) {
+ List transitions = new ArrayList();
+ if (Action.playToWin == action) {
+ transitions.add(new Transition(0.9,1,1));
+ transitions.add(new Transition(0.1,1,-3));
+ } else if (Action.playToLose == action) {
+ transitions.add(new Transition(0.9,1,1));
+ transitions.add(new Transition(0.1,1,-3));
+ } /*else if (Action.maintainScore == action) {
+ transitions.add(new Transition(0.5,1,1));
+ transitions.add(new Transition(0.5,1,-3));
+ }*/
+ return transitions;
+ }
+}
\ No newline at end of file
diff --git a/src/view/ParsedArgs.java b/src/view/ParsedArgs.java
index dcdeac5..3dfe4e1 100644
--- a/src/view/ParsedArgs.java
+++ b/src/view/ParsedArgs.java
@@ -1,5 +1,6 @@
package view;
+import model.comPlayer.AdaptiveComPlayer;
import model.comPlayer.AlphaBetaComPlayer;
import model.comPlayer.ComboPlayer;
import model.comPlayer.CountingPlayer;
@@ -10,6 +11,7 @@ import model.comPlayer.Player;
import model.comPlayer.RandomComPlayer;
public class ParsedArgs {
+ public static final String COM_ADAPTIVE = "ADAPTIVE";
public static final String COM_ALPHABETA = "ALPHABETA";
public static final String COM_ANN = "NEURALNET";
public static final String COM_COMBO = "COMBO";
@@ -22,7 +24,9 @@ public class ParsedArgs {
private String comPlayer = COM_DEFAULT;
public Player getComPlayer() {
- if (COM_RANDOM.equalsIgnoreCase(comPlayer)) {
+ if (COM_ADAPTIVE.equalsIgnoreCase(comPlayer)) {
+ return new AdaptiveComPlayer();
+ } else if (COM_RANDOM.equalsIgnoreCase(comPlayer)) {
return new RandomComPlayer();
} else if (COM_MINIMAX.equalsIgnoreCase(comPlayer)) {
return new MinimaxComPlayer();
diff --git a/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java b/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java
new file mode 100644
index 0000000..e266e92
--- /dev/null
+++ b/test/aima/core/probability/mdp/MarkovDecisionProcessTest.java
@@ -0,0 +1,98 @@
+package aima.core.probability.mdp;
+
+import junit.framework.Assert;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import aima.core.environment.cellworld.Cell;
+import aima.core.environment.cellworld.CellWorld;
+import aima.core.environment.cellworld.CellWorldAction;
+import aima.core.environment.cellworld.CellWorldFactory;
+import aima.core.probability.example.MDPFactory;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+
+/**
+ *
+ * @author Ciaran O'Reilly
+ * @author Ravi Mohan
+ *
+ */
+public class MarkovDecisionProcessTest {
+ public static final double DELTA_THRESHOLD = 1e-3;
+
+ private CellWorld cw = null;
+ private MarkovDecisionProcess| , CellWorldAction> mdp = null;
+
+ @Before
+ public void setUp() {
+ cw = CellWorldFactory.createCellWorldForFig17_1();
+ mdp = MDPFactory.createMDPForFigure17_3(cw);
+ }
+
+ @Test
+ public void testActions() {
+ // Ensure all actions can be performed in each cell
+ // except for the terminal states.
+ for (Cell s : cw.getCells()) {
+ if (4 == s.getX() && (3 == s.getY() || 2 == s.getY())) {
+ Assert.assertEquals(0, mdp.actions(s).size());
+ } else {
+ Assert.assertEquals(5, mdp.actions(s).size());
+ }
+ }
+ }
+
+ @Test
+ public void testMDPTransitionModel() {
+ Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(1, 2),
+ cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
+ Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
+ Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
+ Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
+ cw.getCellAt(1, 1), CellWorldAction.Up), DELTA_THRESHOLD);
+
+ Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
+ Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(2, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
+ Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
+ Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 2),
+ cw.getCellAt(1, 1), CellWorldAction.Down), DELTA_THRESHOLD);
+
+ Assert.assertEquals(0.9, mdp.transitionProbability(cw.getCellAt(1, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
+ Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(2, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
+ Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(3, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
+ Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
+ cw.getCellAt(1, 1), CellWorldAction.Left), DELTA_THRESHOLD);
+
+ Assert.assertEquals(0.8, mdp.transitionProbability(cw.getCellAt(2, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
+ Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 1),
+ cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
+ Assert.assertEquals(0.1, mdp.transitionProbability(cw.getCellAt(1, 2),
+ cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
+ Assert.assertEquals(0.0, mdp.transitionProbability(cw.getCellAt(1, 3),
+ cw.getCellAt(1, 1), CellWorldAction.Right), DELTA_THRESHOLD);
+ }
+
+ @Test
+ public void testRewardFunction() {
+ // Ensure all actions can be performed in each cell.
+ for (Cell s : cw.getCells()) {
+ if (4 == s.getX() && 3 == s.getY()) {
+ Assert.assertEquals(1.0, mdp.reward(s), DELTA_THRESHOLD);
+ } else if (4 == s.getX() && 2 == s.getY()) {
+ Assert.assertEquals(-1.0, mdp.reward(s), DELTA_THRESHOLD);
+ } else {
+ Assert.assertEquals(-0.04, mdp.reward(s), DELTA_THRESHOLD);
+ }
+ }
+ }
+}
diff --git a/test/aima/core/probability/mdp/PolicyIterationTest.java b/test/aima/core/probability/mdp/PolicyIterationTest.java
new file mode 100644
index 0000000..255f403
--- /dev/null
+++ b/test/aima/core/probability/mdp/PolicyIterationTest.java
@@ -0,0 +1,80 @@
+package aima.core.probability.mdp;
+
+import java.util.Map;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import aima.core.environment.cellworld.Cell;
+import aima.core.environment.cellworld.CellWorld;
+import aima.core.environment.cellworld.CellWorldAction;
+import aima.core.environment.cellworld.CellWorldFactory;
+import aima.core.environment.gridworld.GridCell;
+import aima.core.environment.gridworld.GridWorld;
+import aima.core.environment.gridworld.GridWorldAction;
+import aima.core.environment.gridworld.GridWorldFactory;
+import aima.core.probability.example.MDPFactory;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
+import aima.core.probability.mdp.search.PolicyIteration;
+import aima.core.probability.mdp.search.ValueIteration;
+
+/**
+ * @author Ravi Mohan
+ * @author Ciaran O'Reilly
+ *
+ */
+public class PolicyIterationTest {
+ public static final double DELTA_THRESHOLD = 1e-3;
+
+ private GridWorld gw = null;
+ private MarkovDecisionProcess, GridWorldAction> mdp = null;
+ private PolicyIteration, GridWorldAction> pi = null;
+
+ final int maxTiles = 6;
+ final int maxScore = 10;
+
+ @Before
+ public void setUp() {
+ //take 10 turns to place 6 tiles
+ double defaultPenalty = -0.04;
+
+ gw = GridWorldFactory.createGridWorldForTileGame(maxTiles,maxScore,defaultPenalty);
+ mdp = MDPFactory.createMDPForTileGame(gw, maxTiles, maxScore);
+
+ //gamma = 1.0
+ PolicyEvaluation,GridWorldAction> pe = new ModifiedPolicyEvaluation, GridWorldAction>(100,0.9);
+ pi = new PolicyIteration, GridWorldAction>(pe);
+ }
+
+ @Test
+ public void testPolicyIterationForTileGame() {
+ Policy, GridWorldAction> policy = pi.policyIteration(mdp);
+
+ for (int j = maxScore; j >= 1; j--) {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 1; i <= maxTiles; i++) {
+ sb.append(policy.action(gw.getCellAt(i, j)));
+ sb.append(" ");
+ }
+ System.out.println(sb.toString());
+ }
+
+ //Assert.assertEquals(0.705, U.get(gw.getCellAt(1, 1)), DELTA_THRESHOLD);
+ /*
+ Assert.assertEquals(0.762, U.get(cw1.getCellAt(1, 2)), DELTA_THRESHOLD);
+ Assert.assertEquals(0.812, U.get(cw1.getCellAt(1, 3)), DELTA_THRESHOLD);
+
+ Assert.assertEquals(0.655, U.get(cw1.getCellAt(2, 1)), DELTA_THRESHOLD);
+ Assert.assertEquals(0.868, U.get(cw1.getCellAt(2, 3)), DELTA_THRESHOLD);
+
+ Assert.assertEquals(0.611, U.get(cw1.getCellAt(3, 1)), DELTA_THRESHOLD);
+ Assert.assertEquals(0.660, U.get(cw1.getCellAt(3, 2)), DELTA_THRESHOLD);
+ Assert.assertEquals(0.918, U.get(cw1.getCellAt(3, 3)), DELTA_THRESHOLD);
+
+ Assert.assertEquals(0.388, U.get(cw1.getCellAt(4, 1)), DELTA_THRESHOLD);
+ Assert.assertEquals(-1.0, U.get(cw1.getCellAt(4, 2)), DELTA_THRESHOLD);
+ Assert.assertEquals(1.0, U.get(cw1.getCellAt(4, 3)), DELTA_THRESHOLD);*/
+ }
+}
diff --git a/test/aima/core/probability/mdp/ValueIterationTest.java b/test/aima/core/probability/mdp/ValueIterationTest.java
new file mode 100644
index 0000000..9d1215e
--- /dev/null
+++ b/test/aima/core/probability/mdp/ValueIterationTest.java
@@ -0,0 +1,64 @@
+package aima.core.probability.mdp;
+
+import java.util.Map;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import aima.core.environment.cellworld.Cell;
+import aima.core.environment.cellworld.CellWorld;
+import aima.core.environment.cellworld.CellWorldAction;
+import aima.core.environment.cellworld.CellWorldFactory;
+import aima.core.probability.example.MDPFactory;
+import aima.core.probability.mdp.MarkovDecisionProcess;
+import aima.core.probability.mdp.search.ValueIteration;
+
+/**
+ * @author Ravi Mohan
+ * @author Ciaran O'Reilly
+ *
+ */
+public class ValueIterationTest {
+ public static final double DELTA_THRESHOLD = 1e-3;
+
+ private CellWorld cw = null;
+ private MarkovDecisionProcess |