diff --git a/lib/guava-r09.jar b/lib/guava-r09.jar new file mode 100644 index 0000000..f8da8b1 Binary files /dev/null and b/lib/guava-r09.jar differ diff --git a/src/dkohl/bayes/bayesnet/BayesNet.java b/src/dkohl/bayes/bayesnet/BayesNet.java new file mode 100644 index 0000000..767220f --- /dev/null +++ b/src/dkohl/bayes/bayesnet/BayesNet.java @@ -0,0 +1,45 @@ +package dkohl.bayes.bayesnet; + +import java.util.HashMap; +import java.util.LinkedList; + +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ProbabilityDistribution; + +/** + * Represents a Bayes net as a graph with a probability table associated with + * each node. + * + * @author Daniel Kohlsdorf + */ +public class BayesNet extends NamedGraph { + + /** + * The probability tables for each node + */ + private HashMap nodes; + private LinkedList variables; + + public BayesNet(String names[]) { + super(names); + this.nodes = new HashMap(); + this.variables = new LinkedList(); + } + + public void setDistribution(Variable node, ProbabilityDistribution dist) { + nodes.put(node.getName(), dist); + variables.add(node); + } + + public void updateDistribution(Variable node, ProbabilityDistribution dist) { + nodes.put(node.getName(), dist); + } + + public HashMap getNodes() { + return nodes; + } + + public LinkedList getVariables() { + return variables; + } +} diff --git a/src/dkohl/bayes/bayesnet/NamedGraph.java b/src/dkohl/bayes/bayesnet/NamedGraph.java new file mode 100644 index 0000000..ef3d435 --- /dev/null +++ b/src/dkohl/bayes/bayesnet/NamedGraph.java @@ -0,0 +1,98 @@ +package dkohl.bayes.bayesnet; + +import java.util.HashMap; +import java.util.LinkedList; + +import com.google.common.base.Preconditions; + +/** + * A Graph: G(V, E) implemented as a |V| x |V| matrix. + * + * Just one node type !!!! + * + * @author Daniel Kohlsdorf + */ +public class NamedGraph { + + /** + * net[i][j]: variable i is connected to j. + */ + private boolean net[][]; + + /** + * Mapping variable names to positions in the graph's matrix. + */ + private HashMap variable2pos; + + /** + * Mapping positions in the graph to variable names. + */ + private HashMap pos2variable; + + /** + * Initializes the graph of size: |VariableNames| x |VariableNames| + * + * @param variableNames + * The names of the variables + */ + public NamedGraph(String variableNames[]) { + variable2pos = new HashMap(); + pos2variable = new HashMap(); + int num_nodes = variableNames.length; + + net = new boolean[num_nodes][num_nodes]; + + for (int i = 0; i < num_nodes; i++) { + variable2pos.put(variableNames[i], i); + pos2variable.put(i, variableNames[i]); + + for (int j = 0; j < num_nodes; j++) { + net[i][j] = false; + } + } + } + + /** + * Connects two existing vertices in the graph. + * + * @param x + * the variable to connect + * @param y + * the parent (or other node for undirected graphs) + */ + public void connect(String x, String y) { + Preconditions.checkArgument(variable2pos.containsKey(x), + "Variable not known: " + x); + Preconditions.checkArgument(variable2pos.containsKey(y), + "Variable not known: " + y); + + int variable_index = variable2pos.get(x); + int bias_index = variable2pos.get(y); + + net[bias_index][variable_index] = true; + } + + /** + * Returns the names of the variable's parents + * + * @param variable + * the target variable + * @return list of variable names + */ + public LinkedList getParents(String variable) { + Preconditions.checkArgument(variable2pos.containsKey(variable), + "Variable not known: " + variable); + + LinkedList parents = new LinkedList(); + int variable_pos = variable2pos.get(variable); + + for (int i = 0; i < net.length; i++) { + if (net[i][variable_pos]) { + String variableName = pos2variable.get(i); + parents.add(variableName); + } + } + return parents; + } + +} diff --git a/src/dkohl/bayes/estimation/MaximumLikelihoodEstimation.java b/src/dkohl/bayes/estimation/MaximumLikelihoodEstimation.java new file mode 100644 index 0000000..4de8762 --- /dev/null +++ b/src/dkohl/bayes/estimation/MaximumLikelihoodEstimation.java @@ -0,0 +1,119 @@ +package dkohl.bayes.estimation; + +import java.util.LinkedList; + +import com.google.common.base.Preconditions; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ContinousDistribution; +import dkohl.bayes.probability.distribution.ProbabilityDistribution; +import dkohl.bayes.probability.distribution.ProbabilityTable; +import dkohl.bayes.probability.distribution.ProbabilityTree; +import dkohl.bayes.statistic.DataSet; + +/** + * Works for fully observable bayes nets + * + * @author Daniel Kohlsdorf + */ +public class MaximumLikelihoodEstimation { + + /** + * Enumerates all possible assignments for a set of variables using depth + * first + back tracking. Estimates the probability for each assignment + * given data and inserts the values in a probability function. + * + * @param target + * The target variable + * @param assignments + * A list built by traversing the tree. Adding an assignment on + * each layer + * @param variables + * All variables to assign + * @param current + * The current variable + * @param data + * the data for estimation + * @param table + * the probability table + */ + private static void enumerate(Assignment target, + LinkedList assignments, LinkedList variables, + int current, DataSet data, ProbabilityDistribution dist) { + + if (assignments.size() == variables.size()) { + if (dist instanceof ProbabilityTable) { + double likelihood = data.prob(target, assignments) + .getProbability(); + assignments.add(target); + ProbabilityTable table = (ProbabilityTable) dist; + table.setProbabilityForAssignment(assignments, new Probability( + likelihood)); + } + if (dist instanceof ProbabilityTree) { + double likelihood = data.prob(target, assignments) + .getProbability(); + assignments.add(target); + + ProbabilityTree tree = (ProbabilityTree) dist; + tree.setProbabilityForAssignment(assignments, new Probability( + likelihood)); + } + if (dist instanceof ContinousDistribution) { + ContinousDistribution pdf = (ContinousDistribution) dist; + for (LinkedList assignment : data + .getAssignmentMatchesForQuery(assignments)) { + pdf.pushAssignment(assignment); + } + } + return; + } + Variable variable = variables.get(assignments.size()); + for (String value : variable.getDomain()) { + LinkedList new_assignments = new LinkedList( + assignments); + new_assignments.add(new Assignment(variable, value)); + enumerate(target, new_assignments, variables, current + 1, data, + dist); + } + } + + public static void estimate(DataSet data, BayesNet net, String targetName) { + Variable target = null; + // Search parent variables + LinkedList parentNames = net.getParents(targetName); + LinkedList allVariables = net.getVariables(); + LinkedList parentVariables = new LinkedList(); + for (Variable variable : allVariables) { + if (parentNames.contains(variable.getName())) { + parentVariables.add(variable); + } + if (variable.getName().equals(targetName)) { + target = variable; + } + } + Preconditions.checkState(target != null, "MLE: variable " + targetName + + "Not in net"); + + /** + * For all assignments of this variable list all assignments of it's + * parents and estimate the probability for each assignment using data. + */ + ProbabilityDistribution dist = net.getNodes().get(targetName); + if (dist instanceof ContinousDistribution) { + enumerate(null, new LinkedList(), parentVariables, 0, + data, dist); + } else { + for (String value : target.getDomain()) { + enumerate(new Assignment(target, value), + new LinkedList(), parentVariables, 0, data, + dist); + } + } + net.updateDistribution(target, dist); + } + +} diff --git a/src/dkohl/bayes/inference/EnumerateAll.java b/src/dkohl/bayes/inference/EnumerateAll.java new file mode 100644 index 0000000..593061f --- /dev/null +++ b/src/dkohl/bayes/inference/EnumerateAll.java @@ -0,0 +1,118 @@ +package dkohl.bayes.inference; + +import java.util.LinkedList; +import java.util.List; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.ProbabilityAssignment; +import dkohl.bayes.probability.Variable; + +/** + * Enumeration Algorithm: Exact inference in Baysian Networks. + * + * @author Daniel Kohlsdorf + * + */ +public class EnumerateAll { + + /** + * The probability distribution of a variable. + * + * @param query + * the variable + * @param net + * the bayes net defining the independence + * @param assignments + * a set of assignments in this net. + * @return + */ + public static LinkedList enumerateAsk( + Variable query, BayesNet net, LinkedList assignments) { + LinkedList variables = net.getVariables(); + LinkedList result = new LinkedList(); + + // Evaluate probability for each possible + // value of the variable by enumeration + for (String value : query.getDomain()) { + LinkedList temp = new LinkedList(); + temp.addAll(assignments); + temp.add(new Assignment(query, value)); + double prob = enumerateAll(net, variables, temp); + result.add(new ProbabilityAssignment(query, value, prob)); + } + return result; + } + + /** + * Decides if a variable is hidden or not. A variable is hidden if it is not + * assigned. + * + * @param variable + * the variable in question. + * @param assignments + * all the assignments + * @return true if not assigned + */ + private static boolean hidden(Variable variable, + LinkedList assignments) { + for (Assignment assignment : assignments) { + if (assignment.getVariable().getName().equals(variable.getName())) { + return false; + } + } + return true; + } + + /** + * Recursively evaluate probability of an assignment. + * + * Can be seen as depth first search + backtracking (branching on hidden + * nodes) + * + * @param net + * a bayes net defining independence + * @param variables + * the variables left to evaluate + * @param assignments + * all assignments + * @return + */ + public static double enumerateAll(BayesNet net, List variables, + LinkedList assignments) { + // if no variables left to evaluate, + // leaf node reached. + if (variables.isEmpty()) { + return 1; + } + // evaluate variable, recurse on rest + // PROLOG: [Variable|Rest]. + Variable variable = variables.get(0); + List rest = variables.subList(1, variables.size()); + // if current variable is hidden + if (hidden(variable, assignments)) { + // sum out all possible values for that variable + double sumOut = 0; + for (String value : variable.getDomain()) { + // by temporarily adding each value to the asigned variable set + LinkedList temp = new LinkedList(); + temp.addAll(assignments); + temp.add(new Assignment(variable, value)); + + // then evaluate this variable + double val = net.getNodes().get(variable.getName()).eval(temp) + .getProbability(); + // and all that depend on it + val *= enumerateAll(net, rest, temp); + + sumOut += val; + } + return sumOut; + } + // if not just evaluate variable and continue. + return net.getNodes().get(variable.getName()).eval(assignments) + .getProbability() + * enumerateAll(net, rest, assignments); + } + +} diff --git a/src/dkohl/bayes/probability/Assignment.java b/src/dkohl/bayes/probability/Assignment.java new file mode 100644 index 0000000..dc3a3d5 --- /dev/null +++ b/src/dkohl/bayes/probability/Assignment.java @@ -0,0 +1,66 @@ +package dkohl.bayes.probability; + +/** + * An assigned variable + * + * @author Daniel Kohlsdorf + */ +public class Assignment { + + public static final String NOT_ASSIGNED = "NOT_ASSIGNED"; + + /** + * The variable specifying the domain. + */ + private Variable variable; + + /** + * The assigned value + */ + private String value; + + public Assignment(Variable variable, String value) { + super(); + this.variable = variable; + this.value = value; + } + + /** + * Is this assignment a valid one given my domain? + * + * @return true if the assignment is valid + */ + public boolean valid() { + for (String outcome : variable.getDomain()) { + if (outcome.equals(value)) { + return true; + } + } + return false; + } + + public Variable getVariable() { + return variable; + } + + public void setVariable(Variable variable) { + this.variable = variable; + } + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(variable); + sb.append(" = "); + sb.append(value); + return sb.toString(); + } +} diff --git a/src/dkohl/bayes/probability/Probability.java b/src/dkohl/bayes/probability/Probability.java new file mode 100644 index 0000000..b5bd7aa --- /dev/null +++ b/src/dkohl/bayes/probability/Probability.java @@ -0,0 +1,49 @@ +package dkohl.bayes.probability; + +import com.google.common.base.Preconditions; + +/** + * A probabilistic event + * + * @author Daniel Kohlsdorf + */ +public class Probability { + + /** + * The probability for this event + */ + private double probability; + + public Probability(double probability) { + setProbability(probability); + } + + public double getProbability() { + return probability; + } + + /** + * Sets this probability to a value p that holds: + * + * 0 >= p <= 1, p in R + * + * @param probability + */ + public void setProbability(double probability) { + Preconditions.checkArgument(probability <= 1, + "Probability Error: P >= 1: " + probability); + Preconditions.checkArgument(probability >= 0, + "Probability Error: P <= 0: " + probability); + this.probability = probability; + } + + /** + * Returns the rest of the probability + * + * @return 1 - p + */ + public Probability rest() { + return new Probability(1 - probability); + } + +} diff --git a/src/dkohl/bayes/probability/ProbabilityAssignment.java b/src/dkohl/bayes/probability/ProbabilityAssignment.java new file mode 100644 index 0000000..a92591e --- /dev/null +++ b/src/dkohl/bayes/probability/ProbabilityAssignment.java @@ -0,0 +1,26 @@ +package dkohl.bayes.probability; + +public class ProbabilityAssignment extends Assignment { + + private double probability; + + public ProbabilityAssignment(Variable variable, String value, + double probability) { + super(variable, value); + this.probability = probability; + } + + public double getProbability() { + return probability; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("("); + sb.append(super.toString()); + sb.append(") "); + sb.append(probability); + return sb.toString(); + } +} diff --git a/src/dkohl/bayes/probability/Variable.java b/src/dkohl/bayes/probability/Variable.java new file mode 100644 index 0000000..3fbf03c --- /dev/null +++ b/src/dkohl/bayes/probability/Variable.java @@ -0,0 +1,50 @@ +package dkohl.bayes.probability; + +/** + * A named variable with a domain of possible values it can take + * + * @author Daniel Kohlsdorf + */ +public class Variable { + + /** + * The name of the variable + */ + private String name; + + /** + * The domain of the variable + */ + private String domain[]; + + public Variable(String name, String[] domain) { + super(); + this.name = name; + this.domain = domain; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String[] getDomain() { + return domain; + } + + public void setDomain(String[] domain) { + this.domain = domain; + } + + public int domainSize() { + return domain.length; + } + + @Override + public String toString() { + return name; + } +} diff --git a/src/dkohl/bayes/probability/distribution/ContinousDistribution.java b/src/dkohl/bayes/probability/distribution/ContinousDistribution.java new file mode 100644 index 0000000..948aec8 --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/ContinousDistribution.java @@ -0,0 +1,104 @@ +package dkohl.bayes.probability.distribution; + +import java.util.HashMap; +import java.util.LinkedList; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; + +public class ContinousDistribution implements ProbabilityDistribution { + + /** + * The involved variables mapping to row + */ + private HashMap variable2row; + + private HashMap distribution; + + private String name; + + public ContinousDistribution(String names[], int self) { + distribution = new HashMap(); + this.variable2row = new HashMap(); + int count = 0; + for (int i = 0; i < names.length; i++) { + if (i != self) { + variable2row.put(names[i], count); + count++; + } else { + name = names[i]; + } + } + } + + /** + * Generates a table entry key for an assignment + * + * @param assignment + * the assignment + * + * @return the key + */ + public String generateKey(LinkedList assignment) { + String[] values = new String[variable2row.size()]; + for (Assignment col : assignment) { + if (variable2row.containsKey(col.getVariable().getName())) { + int row = variable2row.get(col.getVariable().getName()); + values[row] = col.getValue(); + } + } + String key = ""; + for (String entry : values) { + key += entry + ";"; + } + return key; + } + + public Assignment value(LinkedList assignment) { + for (Assignment a : assignment) { + if (a.getVariable().getName().equals(name)) { + return a; + } + } + return null; + } + + @Override + public Probability eval(LinkedList assignment) { + String assignment_key = generateKey(assignment); + if (distribution.get(assignment_key) == null) { + return new Probability(0); + } + return new Probability(distribution.get(assignment_key).eval( + value(assignment))); + } + + public void pushAssignment(LinkedList assignment) { + String key = generateKey(assignment); + Gaussian gaussian = new Gaussian(); + if (distribution.containsKey(key)) { + gaussian = distribution.get(key); + } + gaussian.push(value(assignment)); + distribution.put(key, gaussian); + } + + public void estimate() { + for (String key : distribution.keySet()) { + distribution.get(key).estimate(); + } + } + + public String[] getNames() { + String[] names = new String[variable2row.size()]; + for (String name : variable2row.keySet()) { + names[variable2row.get(name)] = name; + } + return names; + } + + public HashMap getAssignments() { + return distribution; + } + +} diff --git a/src/dkohl/bayes/probability/distribution/Gaussian.java b/src/dkohl/bayes/probability/distribution/Gaussian.java new file mode 100644 index 0000000..f8a4034 --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/Gaussian.java @@ -0,0 +1,49 @@ +package dkohl.bayes.probability.distribution; + +import java.util.Vector; + +import dkohl.bayes.probability.Assignment; + +public class Gaussian { + + public static final double TWO_PI = 2 * Math.PI; + + private double mean; + private double var; + private Vector samples; + + public Gaussian() { + samples = new Vector(); + } + + public void push(Assignment assignment) { + double sample = Double.valueOf(assignment.getValue()); + samples.add(sample); + } + + public void estimate() { + mean = 0; + for(Double sample : samples) { + mean += sample; + } + mean /= samples.size(); + + var = 0; + for(Double sample : samples) { + var += Math.pow(sample - mean, 2); + } + var /= samples.size(); + } + + public double eval(Assignment assignment) { + double sample = Double.valueOf(assignment.getValue()); + double fac = 1 / (Math.sqrt(TWO_PI * var)); + return fac * Math.exp(-0.5 * (Math.pow(sample - mean, 2) / var)); + } + + @Override + public String toString() { + return "< " + mean + ", " + Math.sqrt(var) + " >"; + } + +} diff --git a/src/dkohl/bayes/probability/distribution/ProbabilityDistribution.java b/src/dkohl/bayes/probability/distribution/ProbabilityDistribution.java new file mode 100644 index 0000000..a6f980e --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/ProbabilityDistribution.java @@ -0,0 +1,24 @@ +package dkohl.bayes.probability.distribution; + +import java.util.LinkedList; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; + + +/** + * A probability distribution. + * + * @author Daniel Kohlsdorf + */ +public interface ProbabilityDistribution { + + /** + * Returns probability given an assignments. + * + * @param assignment A set of assigned variables + * @return probability of that assignment + */ + public Probability eval(LinkedList assignment); + +} diff --git a/src/dkohl/bayes/probability/distribution/ProbabilityTable.java b/src/dkohl/bayes/probability/distribution/ProbabilityTable.java new file mode 100644 index 0000000..0567dd4 --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/ProbabilityTable.java @@ -0,0 +1,112 @@ +package dkohl.bayes.probability.distribution; + +import java.util.HashMap; +import java.util.LinkedList; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; + +/** + * A probability table defined as a table of all combinations of all possible + * values of the varibales involved. + * + * @author Daniel Kohlsdorf + */ +public class ProbabilityTable implements ProbabilityDistribution { + + /** + * The involved variables mapping to row + */ + private HashMap variable2row; + + /** + * An P(assignment key) assignment_key var1_value....varN_value : String + */ + private HashMap assignments; + + public ProbabilityTable(String names[]) { + assignments = new HashMap(); + this.variable2row = new HashMap(); + for (int i = 0; i < names.length; i++) { + variable2row.put(names[i], i); + } + } + + public ProbabilityTable(LinkedList names) { + assignments = new HashMap(); + this.variable2row = new HashMap(); + for (int i = 0; i < names.size(); i++) { + variable2row.put(names.get(i), i); + System.out.println(names.get(i)); + } + } + + /** + * Sets the probability for a table entry key. You can generate one from + * your assignments using generateKey, or just use this same method with + * your assignments. + * + * @param key + * the entry key + * @param probability + * the associated probability + */ + public void setProbabilityForAssignment(String key, Probability probability) { + assignments.put(key, probability); + } + + /** + * Sets the probability for the given assignment + * + * @param assignment + * the assignment + * @param probability + * the associated probability + */ + public void setProbabilityForAssignment(LinkedList assignment, + Probability probability) { + String key = generateKey(assignment); + assignments.put(key, probability); + } + + @Override + public Probability eval(LinkedList assignment) { + String key = generateKey(assignment); + return assignments.get(key); + } + + /** + * Generates a table entry key for an assignment + * + * @param assignment + * the assignment + * + * @return the key + */ + public String generateKey(LinkedList assignment) { + String[] values = new String[variable2row.size()]; + for (Assignment col : assignment) { + if (variable2row.containsKey(col.getVariable().getName())) { + int row = variable2row.get(col.getVariable().getName()); + values[row] = col.getValue(); + } + } + String key = ""; + for (String entry : values) { + key += entry + ";"; + } + return key; + } + + public HashMap getAssignments() { + return assignments; + } + + public String[] getNames() { + String[] names = new String[variable2row.size()]; + for (String name : variable2row.keySet()) { + names[variable2row.get(name)] = name; + } + return names; + } +} diff --git a/src/dkohl/bayes/probability/distribution/ProbabilityTree.java b/src/dkohl/bayes/probability/distribution/ProbabilityTree.java new file mode 100644 index 0000000..d4e9cbf --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/ProbabilityTree.java @@ -0,0 +1,76 @@ +package dkohl.bayes.probability.distribution; + +import java.util.LinkedList; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.distribution.tree.DecisionNode; +import dkohl.bayes.probability.distribution.tree.ProbabilityLeaf; + +/** + * Probability distribution in descision tree representation + * + * DANGER: No node ordering during creation yet. + * + * @author Daniel Kohlsdorf + */ +public class ProbabilityTree implements ProbabilityDistribution { + + /** + * The trees root node + */ + private DecisionNode root = null; + + /** + * Initialize the tree + * by creating one path, along an assignment of + * variables and add a probability node at the end. + * + * @param assignment + * @param probability + */ + private void initTree(LinkedList assignment, Probability probability) { + root = new DecisionNode(assignment.get(0).getVariable().getName()); + + DecisionNode parent = root; + for(int i = 1; i < assignment.size(); i++) { + // add the child to the parent with the key of the parents assignment + DecisionNode child = new DecisionNode(assignment.get(i).getVariable().getName()); + parent.put(assignment.get(i - 1).getValue(), child); + parent = child; + } + parent.put(assignment.getLast().getValue(), new ProbabilityLeaf(probability)); + } + + /** + * Follow existing paths until + * successors do not contain the current assignment, + * then start inserting. + * + * @param assignment + * @param probability + */ + public void setProbabilityForAssignment(LinkedList assignment, Probability probability) { + if(root == null) { + initTree(assignment, probability); + } else { + DecisionNode parent = root; + for(int i = 1; i < assignment.size(); i++) { + if(parent.getSuccessors().containsKey(assignment.get(i - 1).getValue())) { + parent = (DecisionNode) parent.getSuccessors().get(assignment.get(i - 1).getValue()); + } else { + DecisionNode child = new DecisionNode(assignment.get(i).getVariable().getName()); + parent.put(assignment.get(i - 1).getValue(), child); + parent = child; + } + } + parent.put(assignment.getLast().getValue(), new ProbabilityLeaf(probability)); + } + } + + @Override + public Probability eval(LinkedList assignment) { + return root.eval(assignment); + } + +} diff --git a/src/dkohl/bayes/probability/distribution/tree/DecisionNode.java b/src/dkohl/bayes/probability/distribution/tree/DecisionNode.java new file mode 100644 index 0000000..0d7d58f --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/tree/DecisionNode.java @@ -0,0 +1,61 @@ +package dkohl.bayes.probability.distribution.tree; + +import java.util.HashMap; +import java.util.LinkedList; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.distribution.ProbabilityDistribution; + +/** + * Represents a random variable. Maps each outcome to a successor. + * + * @author Daniel Kohlsdorf + */ +public class DecisionNode implements ProbabilityDistribution { + + /** + * The successors for each outcome + */ + private HashMap successors; + + /** + * Name of the variable + */ + private String variable; + + public DecisionNode(String variable) { + this.variable = variable; + successors = new HashMap(); + } + + public String getVariable() { + return variable; + } + + public void put(String value, ProbabilityDistribution distribution) { + successors.put(value, distribution); + } + + public HashMap getSuccessors() { + return successors; + } + + @Override + public Probability eval(LinkedList assignment) { + // Evaluate recursively + for (Assignment a : assignment) { + if (a.getVariable().getName().equals(variable)) { + return successors.get(a.getValue()).eval(assignment); + } + } + + try { + throw (new Exception("Domain Violation: " + variable)); + } catch (Exception e) { + e.printStackTrace(); + } + return null; + } + +} diff --git a/src/dkohl/bayes/probability/distribution/tree/ProbabilityLeaf.java b/src/dkohl/bayes/probability/distribution/tree/ProbabilityLeaf.java new file mode 100644 index 0000000..6e17bf8 --- /dev/null +++ b/src/dkohl/bayes/probability/distribution/tree/ProbabilityLeaf.java @@ -0,0 +1,29 @@ +package dkohl.bayes.probability.distribution.tree; + +import java.util.LinkedList; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.distribution.ProbabilityDistribution; + +/** + * Just a dummy node. Always returns the probability value, ignoring the + * assignment. + * + * @author Daniel Kohlsdorf + */ +public class ProbabilityLeaf implements ProbabilityDistribution { + + private Probability probability; + + public ProbabilityLeaf(Probability probability) { + super(); + this.probability = probability; + } + + @Override + public Probability eval(LinkedList assignment) { + return probability; + } + +} diff --git a/src/dkohl/bayes/statistic/DataPoint.java b/src/dkohl/bayes/statistic/DataPoint.java new file mode 100644 index 0000000..1ecf24b --- /dev/null +++ b/src/dkohl/bayes/statistic/DataPoint.java @@ -0,0 +1,27 @@ +package dkohl.bayes.statistic; + +import java.util.HashMap; + +import dkohl.bayes.probability.Assignment; + +/** + * A data point or feature vector, that keeps observations. + * + * @author Daniel Kohlsdorf + */ +public class DataPoint extends HashMap { + + private static final long serialVersionUID = 1L; + + public DataPoint(DataPoint point) { + putAll(point); + } + + public DataPoint() { + } + + public void add(Assignment assignment) { + put(assignment.getVariable().getName(), assignment); + } + +} \ No newline at end of file diff --git a/src/dkohl/bayes/statistic/DataSet.java b/src/dkohl/bayes/statistic/DataSet.java new file mode 100644 index 0000000..44aaa3f --- /dev/null +++ b/src/dkohl/bayes/statistic/DataSet.java @@ -0,0 +1,99 @@ +package dkohl.bayes.statistic; + +import java.util.LinkedList; +import java.util.Vector; + +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; + +/** + * A data set, defined as a vector of data points + * + * @author Daniel Kohlsdorf + */ +public class DataSet extends Vector { + + private static final long serialVersionUID = 1L; + + public LinkedList> getAssignmentMatchesForQuery( + LinkedList given) { + LinkedList> assignments = new LinkedList>(); + for (DataPoint point : this) { + boolean insert = true; + for (Assignment assignment : given) { + if (!match(point, assignment)) { + insert = false; + } + } + if (insert) { + assignments.add(new LinkedList(point.values())); + } + } + return assignments; + } + + /** + * Is the assignment equal to a data point / observation ? + * + * @param point + * the point / observation + * @param query + * the assignment + * @return + */ + private boolean match(String queryName, DataPoint point, + LinkedList query) { + boolean queryFound = false; + for (Assignment assignment : query) { + if (!match(point, assignment)) { + return false; + } + if (point.containsKey(queryName)) { + queryFound = true; + } + } + if (!queryFound) { + return false; + } + return true; + } + + private boolean match(DataPoint point, Assignment query) { + String name = query.getVariable().getName(); + String value = query.getValue(); + if (point.containsKey(name)) { + if (point.get(name).getValue().equals(value)) { + return true; + } + } + return false; + } + + /** + * Estimating probability for: P(Query | given_1 .... given_N) = #(Query | + * given_1 .... given_N) / #(given_1 .... given_N) + * + * @param query + * @param given + * @return + */ + public Probability prob(Assignment query, LinkedList given) { + int matches = 0; + int num_query_given = 0; + for (DataPoint point : this) { + if (match(query.getVariable().getName(), point, given)) { + matches += 1; + if (match(point, query)) { + num_query_given += 1; // point.getWeight(); + } + } + } + + if (matches == 0) { + return new Probability(0); + } + + return new Probability(num_query_given / ((double) matches)); + } + +} diff --git a/src/dkohl/onthology/Ontology.java b/src/dkohl/onthology/Ontology.java new file mode 100644 index 0000000..b0bb248 --- /dev/null +++ b/src/dkohl/onthology/Ontology.java @@ -0,0 +1,47 @@ +package dkohl.onthology; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Set; + +import com.google.common.base.Preconditions; + +public class Ontology { + + /** + * Thing <- Class + */ + private HashMap inheritance; + + private HashMap> classes2thing; + + public Ontology(HashSet classes) { + inheritance = new HashMap(); + this.classes2thing = new HashMap>(); + for (String key : classes) { + classes2thing.put(key, new LinkedList()); + } + } + + public void define(String thing, String isA) { + Preconditions.checkArgument(classes2thing.containsKey(isA), "Class: " + + isA + "notDefined"); + LinkedList things = classes2thing.get(isA); + things.add(thing); + classes2thing.put(isA, things); + inheritance.put(thing, isA); + } + + public HashMap getInheritance() { + return inheritance; + } + + public HashMap> getClasses2thing() { + return classes2thing; + } + + public Set getClasses() { + return classes2thing.keySet(); + } +} diff --git a/test/dkohl/bayes/example/AlarmExampleTest.java b/test/dkohl/bayes/example/AlarmExampleTest.java new file mode 100644 index 0000000..303c225 --- /dev/null +++ b/test/dkohl/bayes/example/AlarmExampleTest.java @@ -0,0 +1,39 @@ +package dkohl.bayes.example; + +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import java.util.LinkedList; + +import org.junit.Test; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.example.builders.AlarmNetBuilderTable; +import dkohl.bayes.example.builders.AlarmNetBuilderTree; +import dkohl.bayes.inference.EnumerateAll; +import dkohl.bayes.probability.ProbabilityAssignment; +import dkohl.bayes.probability.Variable; + +public class AlarmExampleTest { + + @Test + public void testAlarmExample() { + BayesNet sprinkler = AlarmNetBuilderTable.alarm(); + sprinkler = AlarmNetBuilderTree.sprinkler(); + // P(B | j, m) + LinkedList probs = EnumerateAll.enumerateAsk( + new Variable(AlarmNetBuilderTable.BURGLARY, + AlarmNetBuilderTable.DOMAIN), sprinkler, + AlarmNetBuilderTable.completeQueryBulgary()); + System.out.print("Burglary: <"); + for (ProbabilityAssignment p : probs) { + System.out.print("p: "+ p.toString() + ","); + } + System.out.println(">"); + + //assert that burglary is more likely to be false + assertThat(probs.size(),equalTo(2)); + assertTrue(probs.get(0).getProbability() < probs.get(1).getProbability()); + } +} \ No newline at end of file diff --git a/test/dkohl/bayes/example/FoodExampleTest.java b/test/dkohl/bayes/example/FoodExampleTest.java new file mode 100644 index 0000000..11519be --- /dev/null +++ b/test/dkohl/bayes/example/FoodExampleTest.java @@ -0,0 +1,153 @@ +package dkohl.bayes.example; + +import java.util.LinkedList; + +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.example.builders.FoodExampleBuilder; +import dkohl.bayes.inference.EnumerateAll; +import dkohl.bayes.probability.ProbabilityAssignment; +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ContinousDistribution; +import dkohl.bayes.probability.distribution.ProbabilityTable; + +public class FoodExampleTest { + + @Test + public void testFoodExample() { + BayesNet net = FoodExampleBuilder.dishNet(); + + System.out.println("VEGETRAIAN: "); + ProbabilityTable table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.SOMEONE_VEGETARIAN); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("MEET: "); + table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.CONTAINS_MEET); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("VEGETABLES: "); + table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.CONTAINS_VEGETABLE); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("BEEF: "); + table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.CONTAINS_BEEF); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("PORK: "); + table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.CONTAINS_PORK); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("POTATOS: "); + table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.CONTAINS_POTATOS); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("TOMATOS: "); + table = (ProbabilityTable) net.getNodes().get( + FoodExampleBuilder.CONTAINS_TOMATOS); + for (String name : table.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + for (String assignment : table.getAssignments().keySet()) { + System.out.println(assignment + " " + + table.getAssignments().get(assignment).getProbability()); + } + + System.out.println("TASTE: "); + ContinousDistribution dist = (ContinousDistribution) net.getNodes() + .get(FoodExampleBuilder.TASTE); + for (String name : dist.getNames()) { + System.out.print(name + " "); + } + System.out.println(); + + for (String assignment : dist.getAssignments().keySet()) { + System.out.println(assignment + " " + + dist.getAssignments().get(assignment)); + } + + LinkedList probs = EnumerateAll.enumerateAsk( + new Variable(FoodExampleBuilder.TASTE, + FoodExampleBuilder.RATING_DOMAIN), net, + FoodExampleBuilder.completeQueryTasteBeef()); + double max_val = 0; + String max_arg = null; + for (ProbabilityAssignment p : probs) { + if (p.getProbability() > max_val) { + max_val = p.getProbability(); + max_arg = p.getValue(); + } + } + System.out.println("TASE BEEF: " + max_arg + " " + max_val); + + probs = EnumerateAll.enumerateAsk(new Variable( + FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN), + net, FoodExampleBuilder.completeQueryTastePork()); + max_val = 0; + max_arg = null; + for (ProbabilityAssignment p : probs) { + if (p.getProbability() > max_val) { + max_val = p.getProbability(); + max_arg = p.getValue(); + } + } + + System.out.println("TASE PORK: " + max_arg + " " + max_val); + assertThat(max_arg, equalTo("10")); + assertTrue("Error: max_val for TASTE_PORK should be > 2.6%", max_val > 0.026); + } + +} diff --git a/test/dkohl/bayes/example/SprinklerNetExampleTest.java b/test/dkohl/bayes/example/SprinklerNetExampleTest.java new file mode 100644 index 0000000..64ce991 --- /dev/null +++ b/test/dkohl/bayes/example/SprinklerNetExampleTest.java @@ -0,0 +1,50 @@ +package dkohl.bayes.example; + +import org.junit.Test; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertThat; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.estimation.MaximumLikelihoodEstimation; +import dkohl.bayes.example.builders.EstimateSprinklerNetBuilderTable; +import dkohl.bayes.probability.distribution.ProbabilityTable; +import dkohl.bayes.statistic.DataSet; + +/** + * Parameter estimation for the sprinkler net example + * + * http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html + * + * @author Daniel Kohlsdorf + */ +public class SprinklerNetExampleTest { + + @Test + public void testSprinklerNetExample() { + BayesNet net = EstimateSprinklerNetBuilderTable.sprinkler(); + DataSet data = EstimateSprinklerNetBuilderTable.dataSet(); + MaximumLikelihoodEstimation.estimate(data, net, + EstimateSprinklerNetBuilderTable.GRASS_WET); + + /** + * Output PDF for grass is wet + */ + ProbabilityTable table = (ProbabilityTable) net.getNodes().get( + EstimateSprinklerNetBuilderTable.GRASS_WET); + for (String name : table.getNames()) { + System.out.print(name + " | "); + } + System.out.println(); + for (String key : table.getAssignments().keySet()) { + System.out.println(key + " " + + table.getAssignments().get(key).getProbability()); + if ("false;false;true;".equals(key)) { + assertThat(table.getAssignments().get(key).getProbability(), equalTo(0.0)); + } + if ("true;false;true;".equals(key)) { + assertThat(table.getAssignments().get(key).getProbability(), equalTo(0.9)); + } + } + } + +} diff --git a/test/dkohl/bayes/example/builders/AlarmNetBuilderTable.java b/test/dkohl/bayes/example/builders/AlarmNetBuilderTable.java new file mode 100644 index 0000000..1a6c089 --- /dev/null +++ b/test/dkohl/bayes/example/builders/AlarmNetBuilderTable.java @@ -0,0 +1,176 @@ +package dkohl.bayes.example.builders; + +import java.util.LinkedList; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ProbabilityTable; + +/** + * The alarm example for baysian nets. + * + * Stuart Russel, Peter Norvig: Artificial Intelligence: A Modern Approach, 3ed + * Edition, Prentice Hall, 2010 + * + * @author Daniel Kohlsdorf + */ +public class AlarmNetBuilderTable { + + // Variable names + public static final String BURGLARY = "Burglary"; + public static final String EARTHQUAKE = "Earthquake"; + public static final String ALARM = "Alarm"; + public static final String JOHN = "John"; + public static final String MARRY = "Marry"; + + // Possible outcomes + public static final String TRUE = "true"; + public static final String FALSE = "false"; + + // A variables domain + public static final String DOMAIN[] = { TRUE, FALSE }; + + // A set of variables + public static final String VARIABLES[] = { BURGLARY, EARTHQUAKE, ALARM, + JOHN, MARRY }; + + /** + * Builds the query from the book: P(B| j, m) + */ + public static LinkedList completeQueryBulgary() { + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE)); + return new LinkedList(assignment); + } + + /** + * Build burglary prior + * + * @param alarmNet + */ + private static void burglary(BayesNet alarmNet) { + Probability p_bulglary = new Probability(0.001); + String names[] = { BURGLARY }; + ProbabilityTable burglary = new ProbabilityTable(names); + burglary.setProbabilityForAssignment("true;", p_bulglary); + burglary.setProbabilityForAssignment("false;", p_bulglary.rest()); + alarmNet.setDistribution(new Variable(BURGLARY, DOMAIN), burglary); + } + + /** + * Build earthquake prior + * + * @param alarmNet + */ + private static void earthquake(BayesNet alarmNet) { + Probability p_earthquake = new Probability(0.002); + String names[] = { EARTHQUAKE }; + ProbabilityTable earthquake = new ProbabilityTable(names); + earthquake.setProbabilityForAssignment("true;", p_earthquake); + earthquake.setProbabilityForAssignment("false;", p_earthquake.rest()); + + alarmNet.setDistribution(new Variable(EARTHQUAKE, DOMAIN), earthquake); + } + + /** + * Jhon calls! + * + * @param alarmNet + */ + private static void jhon(BayesNet alarmNet) { + // P(ALARM) == true + Probability t = new Probability(.90); + + // P(ALARM) == false + Probability f = new Probability(.05); + String names[] = { ALARM, JOHN }; + ProbabilityTable jhon = new ProbabilityTable(names); + jhon.setProbabilityForAssignment("true;true;", t); + jhon.setProbabilityForAssignment("false;true;", f); + jhon.setProbabilityForAssignment("true;false", t.rest()); + jhon.setProbabilityForAssignment("false;false;", f.rest()); + + alarmNet.setDistribution(new Variable(JOHN, DOMAIN), jhon); + } + + /** + * Marry calls! + * + * @param alarmNet + */ + private static void mary(BayesNet alarmNet) { + // P(ALARM) == true + Probability t = new Probability(.70); + + // P(ALARM) == false + Probability f = new Probability(.01); + String names[] = { ALARM, MARRY }; + ProbabilityTable marry = new ProbabilityTable(names); + marry.setProbabilityForAssignment("true;true;", t); + marry.setProbabilityForAssignment("false;true;", f); + marry.setProbabilityForAssignment("true;false", t.rest()); + marry.setProbabilityForAssignment("false;false;", f.rest()); + alarmNet.setDistribution(new Variable(MARRY, DOMAIN), marry); + } + + /** + * The alarm goes off! + * + * @param alarmNet + */ + private static void alarm(BayesNet alarmNet) { + // P(ALARM | BURGLARY = true, EARTHQUAKE = true) + Probability tt = new Probability(.95); + + // P(ALARM | BURGLARY = true, EARTHQUAKE = false) + Probability tf = new Probability(.94); + + // P(ALARM | BURGLARY = false, EARTHQUAKE = true) + Probability ft = new Probability(.29); + + // P(ALARM | BURGLARY = false, EARTHQUAKE = false) + Probability ff = new Probability(.001); + String names[] = { BURGLARY, EARTHQUAKE, ALARM }; + ProbabilityTable alarm = new ProbabilityTable(names); + alarm.setProbabilityForAssignment(TRUE + ";" + TRUE + ";" + TRUE + ";", + tt); + alarm.setProbabilityForAssignment( + TRUE + ";" + FALSE + ";" + TRUE + ";", tf); + alarm.setProbabilityForAssignment( + FALSE + ";" + TRUE + ";" + TRUE + ";", ft); + alarm.setProbabilityForAssignment(FALSE + ";" + FALSE + ";" + TRUE + + ";", ff); + + alarm.setProbabilityForAssignment( + TRUE + ";" + TRUE + ";" + FALSE + ";", tt.rest()); + alarm.setProbabilityForAssignment(TRUE + ";" + FALSE + ";" + FALSE + + ";", tf.rest()); + alarm.setProbabilityForAssignment(FALSE + ";" + TRUE + ";" + FALSE + + ";", ft.rest()); + alarm.setProbabilityForAssignment(FALSE + ";" + FALSE + ";" + FALSE + + ";", ff.rest()); + alarmNet.setDistribution(new Variable(ALARM, DOMAIN), alarm); + } + + public static BayesNet alarm() { + BayesNet alarmNet = new BayesNet(VARIABLES); + + // set probability tables and priors + burglary(alarmNet); + earthquake(alarmNet); + alarm(alarmNet); + jhon(alarmNet); + mary(alarmNet); + + // construct the graph + alarmNet.connect(ALARM, BURGLARY); + alarmNet.connect(ALARM, EARTHQUAKE); + alarmNet.connect(JOHN, ALARM); + alarmNet.connect(MARRY, ALARM); + + return alarmNet; + } +} diff --git a/test/dkohl/bayes/example/builders/AlarmNetBuilderTree.java b/test/dkohl/bayes/example/builders/AlarmNetBuilderTree.java new file mode 100644 index 0000000..f455c95 --- /dev/null +++ b/test/dkohl/bayes/example/builders/AlarmNetBuilderTree.java @@ -0,0 +1,244 @@ +package dkohl.bayes.example.builders; + +import java.util.LinkedList; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ProbabilityTree; + +public class AlarmNetBuilderTree { + + // Variable names + public static final String BURGLARY = "Burglary"; + public static final String EARTHQUAKE = "Earthquake"; + public static final String ALARM = "Alarm"; + public static final String JOHN = "John"; + public static final String MARRY = "Marry"; + + // Possible outcomes + public static final String TRUE = "true"; + public static final String FALSE = "false"; + + // A variables domain + public static final String DOMAIN[] = { TRUE, FALSE }; + + // A set of variables + public static final String VARIABLES[] = { BURGLARY, EARTHQUAKE, ALARM, + JOHN, MARRY }; + + /** + * Builds the query from the book: P(B| j, m) + */ + public static LinkedList completeQueryBulgary() { + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE)); + return new LinkedList(assignment); + } + + /** + * Build burglary prior + * + * @param sprinklerNet + */ + private static void burglary(BayesNet sprinklerNet) { + Probability p_bulglary = new Probability(0.001); + + ProbabilityTree burglary = new ProbabilityTree(); + + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE)); + burglary.setProbabilityForAssignment(assignment, p_bulglary); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE)); + burglary.setProbabilityForAssignment(assignment, p_bulglary.rest()); + + sprinklerNet.setDistribution(new Variable(BURGLARY, DOMAIN), burglary); + } + + /** + * Build earthquake prior + * + * @param sprinklerNet + */ + private static void earthquake(BayesNet sprinklerNet) { + Probability p_earthquake = new Probability(0.002); + + ProbabilityTree earthquake = new ProbabilityTree(); + + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE)); + earthquake.setProbabilityForAssignment(assignment, p_earthquake); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE)); + earthquake.setProbabilityForAssignment(assignment, p_earthquake.rest()); + + sprinklerNet.setDistribution(new Variable(EARTHQUAKE, DOMAIN), + earthquake); + } + + /** + * Jhon calls! + * + * @param sprinklerNet + */ + private static void jhon(BayesNet sprinklerNet) { + // P(ALARM) == true + Probability t = new Probability(.90); + + // P(ALARM) == false + Probability f = new Probability(.05); + + ProbabilityTree jhon = new ProbabilityTree(); + + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + jhon.setProbabilityForAssignment(assignment, t); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + jhon.setProbabilityForAssignment(assignment, f); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(JOHN, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + jhon.setProbabilityForAssignment(assignment, t.rest()); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(JOHN, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + jhon.setProbabilityForAssignment(assignment, f.rest()); + + sprinklerNet.setDistribution(new Variable(JOHN, DOMAIN), jhon); + } + + /** + * Marry calls! + * + * @param sprinklerNet + */ + private static void mary(BayesNet sprinklerNet) { + // P(ALARM) == true + Probability t = new Probability(.70); + + // P(ALARM) == false + Probability f = new Probability(.01); + ProbabilityTree mary = new ProbabilityTree(); + + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + mary.setProbabilityForAssignment(assignment, t); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + mary.setProbabilityForAssignment(assignment, f); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(MARRY, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + mary.setProbabilityForAssignment(assignment, t.rest()); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(MARRY, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + mary.setProbabilityForAssignment(assignment, f.rest()); + sprinklerNet.setDistribution(new Variable(MARRY, DOMAIN), mary); + } + + /** + * The alarm goes off! + * + * @param sprinklerNet + */ + private static void alarm(BayesNet sprinklerNet) { + // P(ALARM | BURGLARY = true, EARTHQUAKE = true) + Probability tt = new Probability(.95); + + // P(ALARM | BURGLARY = true, EARTHQUAKE = false) + Probability tf = new Probability(.94); + + // P(ALARM | BURGLARY = false, EARTHQUAKE = true) + Probability ft = new Probability(.29); + + // P(ALARM | BURGLARY = false, EARTHQUAKE = false) + Probability ff = new Probability(.001); + ProbabilityTree alarm = new ProbabilityTree(); + + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + alarm.setProbabilityForAssignment(assignment, tt); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + alarm.setProbabilityForAssignment(assignment, tf); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + alarm.setProbabilityForAssignment(assignment, ft); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE)); + alarm.setProbabilityForAssignment(assignment, ff); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + alarm.setProbabilityForAssignment(assignment, tt.rest()); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + alarm.setProbabilityForAssignment(assignment, tf.rest()); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + alarm.setProbabilityForAssignment(assignment, ft.rest()); + + assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE)); + assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE)); + alarm.setProbabilityForAssignment(assignment, ff.rest()); + + sprinklerNet.setDistribution(new Variable(ALARM, DOMAIN), alarm); + } + + public static BayesNet sprinkler() { + BayesNet sprinklerNet = new BayesNet(VARIABLES); + + // set probability tables and priors + burglary(sprinklerNet); + earthquake(sprinklerNet); + alarm(sprinklerNet); + jhon(sprinklerNet); + mary(sprinklerNet); + + // construct the graph + sprinklerNet.connect(ALARM, BURGLARY); + sprinklerNet.connect(ALARM, EARTHQUAKE); + sprinklerNet.connect(JOHN, ALARM); + sprinklerNet.connect(MARRY, ALARM); + + return sprinklerNet; + } +} \ No newline at end of file diff --git a/test/dkohl/bayes/example/builders/EstimateSprinklerNetBuilderTable.java b/test/dkohl/bayes/example/builders/EstimateSprinklerNetBuilderTable.java new file mode 100644 index 0000000..501a43a --- /dev/null +++ b/test/dkohl/bayes/example/builders/EstimateSprinklerNetBuilderTable.java @@ -0,0 +1,179 @@ +package dkohl.bayes.example.builders; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ProbabilityTable; +import dkohl.bayes.statistic.DataPoint; +import dkohl.bayes.statistic.DataSet; + +/** + * The Sprinkler net example + * + * http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html + * + * @author Daniel Kohlsdorf + */ +public class EstimateSprinklerNetBuilderTable { + + // Variable names + public static final String CLOUDY = "Cloudy"; + public static final String SPRINKLER = "Sprinkler"; + public static final String GRASS_WET = "GrassWet"; + public static final String RAIN = "Rain"; + + // Possible outcomes + public static final String TRUE = "true"; + public static final String FALSE = "false"; + + // A variables domain + public static final String DOMAIN[] = { TRUE, FALSE }; + + // A set of variables + public static final String VARIABLES[] = { CLOUDY, SPRINKLER, GRASS_WET, + RAIN }; + + private static void cloudy(BayesNet sprinklerNet) { + Probability p_cloudy = new Probability(0.5); + String names[] = { CLOUDY }; + ProbabilityTable cloudy = new ProbabilityTable(names); + cloudy.setProbabilityForAssignment("true;", p_cloudy); + cloudy.setProbabilityForAssignment("false;", p_cloudy.rest()); + sprinklerNet.setDistribution(new Variable(CLOUDY, DOMAIN), cloudy); + } + + private static void rain(BayesNet sprinklerNet) { + Probability p_cloudy = new Probability(0.8); + Probability p_notcloudy = new Probability(0.2); + + String names[] = { CLOUDY, RAIN }; + ProbabilityTable rain = new ProbabilityTable(names); + rain.setProbabilityForAssignment("true;true;", p_cloudy); + rain.setProbabilityForAssignment("false;false;", p_notcloudy); + + rain.setProbabilityForAssignment("false;true;", p_notcloudy.rest()); + rain.setProbabilityForAssignment("true;false;", p_cloudy.rest()); + sprinklerNet.setDistribution(new Variable(RAIN, DOMAIN), rain); + } + + private static void sprinkler(BayesNet sprinklerNet) { + Probability p_cloudy = new Probability(0.1); + Probability p_notcloudy = new Probability(0.5); + + String names[] = { CLOUDY, SPRINKLER }; + ProbabilityTable sprinkler = new ProbabilityTable(names); + sprinkler.setProbabilityForAssignment("true;true;", p_cloudy); + sprinkler.setProbabilityForAssignment("false;false;", p_notcloudy); + + sprinkler + .setProbabilityForAssignment("false;true;", p_notcloudy.rest()); + sprinkler.setProbabilityForAssignment("true;false;", p_cloudy.rest()); + sprinklerNet + .setDistribution(new Variable(SPRINKLER, DOMAIN), sprinkler); + } + + private static void grass(BayesNet sprinklerNet) { + String names[] = { RAIN, SPRINKLER, GRASS_WET }; + ProbabilityTable sprinkler = new ProbabilityTable(names); + sprinklerNet + .setDistribution(new Variable(GRASS_WET, DOMAIN), sprinkler); + } + + public static BayesNet sprinkler() { + BayesNet sprinkler = new BayesNet(VARIABLES); + + cloudy(sprinkler); + rain(sprinkler); + sprinkler(sprinkler); + grass(sprinkler); + + sprinkler.connect(RAIN, CLOUDY); + sprinkler.connect(SPRINKLER, CLOUDY); + sprinkler.connect(GRASS_WET, RAIN); + sprinkler.connect(GRASS_WET, SPRINKLER); + + return sprinkler; + } + + private static Assignment build(String varible, String value) { + return new Assignment(new Variable(varible, DOMAIN), value); + } + + public static DataSet dataSet() { + DataSet dataSet = new DataSet(); + + /** + * If rain is false and sprinkler is false, grass is never wet. + */ + DataPoint one = new DataPoint(); + one.add(build(RAIN, FALSE)); + one.add(build(SPRINKLER, FALSE)); + one.add(build(GRASS_WET, FALSE)); + dataSet.add(one); + + /** + * 1 / 10 times the grass is not wet when the sprinkler is on + */ + DataPoint two = new DataPoint(); + two.add(build(RAIN, FALSE)); + two.add(build(SPRINKLER, TRUE)); + two.add(build(GRASS_WET, FALSE)); + dataSet.add(two); + + /** + * 9 / 10 times the sprinkler is on and the grass is wet + */ + for (int i = 0; i < 9; i++) { + DataPoint point = new DataPoint(); + point.add(build(RAIN, FALSE)); + point.add(build(SPRINKLER, TRUE)); + point.add(build(GRASS_WET, TRUE)); + dataSet.add(point); + } + + /** + * 1 / 10 times the grass is not wet when it rains + */ + DataPoint three = new DataPoint(); + three.add(build(RAIN, TRUE)); + three.add(build(SPRINKLER, FALSE)); + three.add(build(GRASS_WET, FALSE)); + dataSet.add(three); + + /** + * 9 / 10 times it rains and the grass is wet + */ + for (int i = 0; i < 9; i++) { + DataPoint point = new DataPoint(); + point.add(build(RAIN, TRUE)); + point.add(build(SPRINKLER, FALSE)); + point.add(build(GRASS_WET, TRUE)); + dataSet.add(point); + } + + /** + * 1 / 100 times the grass is not wet when it rains and the sprinkler is + * on + */ + DataPoint four = new DataPoint(); + four.add(build(RAIN, TRUE)); + four.add(build(SPRINKLER, TRUE)); + four.add(build(GRASS_WET, FALSE)); + dataSet.add(four); + + /** + * 99 / 100 times it rains and the grass is wet + */ + for (int i = 0; i < 99; i++) { + DataPoint point = new DataPoint(); + point.add(build(RAIN, TRUE)); + point.add(build(SPRINKLER, TRUE)); + point.add(build(GRASS_WET, TRUE)); + dataSet.add(point); + } + + return dataSet; + } + +} diff --git a/test/dkohl/bayes/example/builders/FoodExampleBuilder.java b/test/dkohl/bayes/example/builders/FoodExampleBuilder.java new file mode 100644 index 0000000..ddf6424 --- /dev/null +++ b/test/dkohl/bayes/example/builders/FoodExampleBuilder.java @@ -0,0 +1,266 @@ +package dkohl.bayes.example.builders; + +import java.util.HashSet; +import java.util.LinkedList; + +import dkohl.bayes.bayesnet.BayesNet; +import dkohl.bayes.estimation.MaximumLikelihoodEstimation; +import dkohl.bayes.probability.Assignment; +import dkohl.bayes.probability.Probability; +import dkohl.bayes.probability.Variable; +import dkohl.bayes.probability.distribution.ContinousDistribution; +import dkohl.bayes.probability.distribution.ProbabilityDistribution; +import dkohl.bayes.probability.distribution.ProbabilityTable; +import dkohl.bayes.statistic.DataPoint; +import dkohl.bayes.statistic.DataSet; +import dkohl.onthology.Ontology; + +public class FoodExampleBuilder { + + public static final String TASTE = "Taste"; + public static final String SOMEONE_VEGETARIAN = "Vegetarian"; + public static final String CONTAINS_MEET = "Meet"; + public static final String CONTAINS_VEGETABLE = "Vegetable"; + public static final String CONTAINS_BEEF = "Beef"; + public static final String CONTAINS_PORK = "Pork"; + public static final String CONTAINS_TOMATOS = "Tomatos"; + public static final String CONTAINS_POTATOS = "Potatos"; + + public static final String TRUE_VALUE = "true"; + public static final String FALSE_VALUE = "false"; + + public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE }; + + public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5", + "6", "7", "8", "9", "10" }; + + private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, + CONTAINS_BEEF, CONTAINS_MEET, CONTAINS_PORK, CONTAINS_POTATOS, + CONTAINS_TOMATOS, CONTAINS_VEGETABLE, TASTE }; + + private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK, + CONTAINS_POTATOS, CONTAINS_TOMATOS, }; + + public static LinkedList completeQueryTasteBeef() { + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(CONTAINS_BEEF, DOMAIN), + TRUE_VALUE)); + assignment.add(new Assignment(new Variable(CONTAINS_TOMATOS, DOMAIN), + TRUE_VALUE)); + return new LinkedList(assignment); + } + + public static LinkedList completeQueryTastePork() { + LinkedList assignment = new LinkedList(); + assignment.add(new Assignment(new Variable(CONTAINS_PORK, DOMAIN), + TRUE_VALUE)); + assignment.add(new Assignment(new Variable(CONTAINS_POTATOS, DOMAIN), + TRUE_VALUE)); + return new LinkedList(assignment); + } + + public static Ontology onto() { + HashSet classes = new HashSet(); + classes.add(CONTAINS_MEET); + classes.add(CONTAINS_VEGETABLE); + Ontology onthology = new Ontology(classes); + onthology.define(CONTAINS_PORK, CONTAINS_MEET); + onthology.define(CONTAINS_BEEF, CONTAINS_MEET); + + onthology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE); + onthology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE); + return onthology; + } + + private static Assignment build(String varible, String value) { + return new Assignment(new Variable(varible, DOMAIN), value); + } + + public static DataPoint beefPotatoDish(int weight) { + DataPoint point = new DataPoint(); + point.add(build(CONTAINS_BEEF, TRUE_VALUE)); + point.add(build(CONTAINS_POTATOS, TRUE_VALUE)); + point.add(build(TASTE, "" + weight)); + return point; + } + + public static DataPoint beefTomatoDish(int weight) { + DataPoint point = new DataPoint(); + point.add(build(CONTAINS_BEEF, TRUE_VALUE)); + point.add(build(CONTAINS_TOMATOS, TRUE_VALUE)); + point.add(build(TASTE, "" + weight)); + return point; + } + + public static DataPoint porkBeefDish(int weight) { + DataPoint point = new DataPoint(); + point.add(build(CONTAINS_BEEF, TRUE_VALUE)); + point.add(build(CONTAINS_PORK, TRUE_VALUE)); + point.add(build(TASTE, "" + weight)); + return point; + } + + public static DataPoint porkPotatoDish(int weight) { + DataPoint point = new DataPoint(); + point.add(build(CONTAINS_PORK, TRUE_VALUE)); + point.add(build(CONTAINS_POTATOS, TRUE_VALUE)); + point.add(build(TASTE, "" + weight)); + return point; + } + + public static DataPoint porkTomatoDish(int weight) { + DataPoint point = new DataPoint(); + point.add(build(CONTAINS_PORK, TRUE_VALUE)); + point.add(build(CONTAINS_TOMATOS, TRUE_VALUE)); + point.add(build(TASTE, "" + weight)); + return point; + } + + public static DataPoint potatoTomato(int weight) { + DataPoint point = new DataPoint(); + point.add(build(CONTAINS_POTATOS, TRUE_VALUE)); + point.add(build(CONTAINS_TOMATOS, TRUE_VALUE)); + point.add(build(TASTE, "" + weight)); + return point; + } + + public static DataPoint normalize(DataPoint point, Ontology onto) { + // resolve onthology + DataPoint normPoint = new DataPoint(point); + for (String key : point.keySet()) { + if (onto.getInheritance().containsKey(key)) { + normPoint + .add(build(onto.getInheritance().get(key), TRUE_VALUE)); + } + } + + // implement closed world assumption + // everything unknown is false + for (String variable : VARIABLES) { + if (!normPoint.containsKey(variable)) { + normPoint.add(build(variable, FALSE_VALUE)); + } + } + return normPoint; + } + + public static DataSet examples() { + DataSet data = new DataSet(); + Ontology onto = onto(); + // user one ratings + data.add(normalize(porkTomatoDish(10), onto)); + data.add(normalize(porkPotatoDish(9), onto)); + data.add(normalize(beefTomatoDish(3), onto)); + data.add(normalize(beefPotatoDish(0), onto)); + data.add(normalize(potatoTomato(0), onto)); + data.add(normalize(porkBeefDish(0), onto)); + + // user two ratings + data.add(normalize(porkTomatoDish(10), onto)); + data.add(normalize(porkTomatoDish(8), onto)); + data.add(normalize(porkPotatoDish(10), onto)); + data.add(normalize(beefTomatoDish(0), onto)); + data.add(normalize(beefPotatoDish(1), onto)); + data.add(normalize(potatoTomato(7), onto)); + data.add(normalize(porkBeefDish(10), onto)); + + // user three ratings + data.add(normalize(porkTomatoDish(10), onto)); + data.add(normalize(porkPotatoDish(10), onto)); + data.add(normalize(beefTomatoDish(3), onto)); + data.add(normalize(beefPotatoDish(3), onto)); + data.add(normalize(potatoTomato(3), onto)); + data.add(normalize(porkBeefDish(4), onto)); + + return data; + } + + public static ProbabilityDistribution vegi() { + String names[] = { SOMEONE_VEGETARIAN }; + ProbabilityTable table = new ProbabilityTable(names); + table.setProbabilityForAssignment("true;", new Probability(0)); + table.setProbabilityForAssignment("false;", new Probability(1)); + return table; + } + + public static ProbabilityDistribution beef() { + String names[] = { CONTAINS_MEET, CONTAINS_BEEF }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution pork() { + String names[] = { CONTAINS_MEET, CONTAINS_PORK }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution meet() { + String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEET }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution tomatos() { + String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution potatos() { + String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution vegetables() { + String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution taste() { + String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK, + CONTAINS_POTATOS, CONTAINS_TOMATOS }; + ContinousDistribution distribution = new ContinousDistribution(names, 0); + return distribution; + } + + public static BayesNet dishNet() { + BayesNet net = new BayesNet(VARIABLES); + net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi()); + net.setDistribution(new Variable(CONTAINS_MEET, DOMAIN), meet()); + net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN), + vegetables()); + net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef()); + net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork()); + net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos()); + net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos()); + net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste()); + + Ontology onthology = onto(); + for (String category : onthology.getClasses()) { + net.connect(category, SOMEONE_VEGETARIAN); + } + + for (String thing : OBSERVED) { + net.connect(thing, onthology.getInheritance().get(thing)); + net.connect(TASTE, thing); + } + + for (String category : onthology.getClasses()) { + MaximumLikelihoodEstimation.estimate(examples(), net, category); + for (String thing : onthology.getClasses2thing().get(category)) { + MaximumLikelihoodEstimation.estimate(examples(), net, thing); + } + } + + MaximumLikelihoodEstimation.estimate(examples(), net, TASTE); + ContinousDistribution distribturion = (ContinousDistribution) net + .getNodes().get(TASTE); + distribturion.estimate(); + + return net; + } + +} diff --git a/test/net/woodyfolsom/cs6601/p2/SurveyDatasetReaderTest.java b/test/net/woodyfolsom/cs6601/p2/SurveyDatasetReaderTest.java index 7278b89..18fe860 100644 --- a/test/net/woodyfolsom/cs6601/p2/SurveyDatasetReaderTest.java +++ b/test/net/woodyfolsom/cs6601/p2/SurveyDatasetReaderTest.java @@ -32,5 +32,9 @@ public class SurveyDatasetReaderTest { assertFalse(recipe.getIngredients().contains(TYPE.RED_MEAT)); assertFalse(recipe.getIngredients().contains(TYPE.POULTRY)); assertFalse(recipe.getIngredients().contains(TYPE.SHELLFISH)); + + for (int rIndex = 0; rIndex < recipeBook.getSize(); rIndex++) { + System.out.println(recipeBook.getRecipe(rIndex).getHead().getTitle()); + } } }