package dkohl.bayes.estimation; import java.util.LinkedList; import com.google.common.base.Preconditions; import dkohl.bayes.bayesnet.BayesNet; import dkohl.bayes.probability.Assignment; import dkohl.bayes.probability.Probability; import dkohl.bayes.probability.Variable; import dkohl.bayes.probability.distribution.ContinousDistribution; import dkohl.bayes.probability.distribution.ProbabilityDistribution; import dkohl.bayes.probability.distribution.ProbabilityTable; import dkohl.bayes.probability.distribution.ProbabilityTree; import dkohl.bayes.statistic.DataSet; /** * Works for fully observable bayes nets * * @author Daniel Kohlsdorf */ public class MaximumLikelihoodEstimation { /** * Enumerates all possible assignments for a set of variables using depth * first + back tracking. Estimates the probability for each assignment * given data and inserts the values in a probability function. * * @param target * The target variable * @param assignments * A list built by traversing the tree. Adding an assignment on * each layer * @param variables * All variables to assign * @param current * The current variable * @param data * the data for estimation * @param table * the probability table */ private static void enumerate(Assignment target, LinkedList assignments, LinkedList variables, int current, DataSet data, ProbabilityDistribution dist) { if (assignments.size() == variables.size()) { if (dist instanceof ProbabilityTable) { double likelihood = data.prob(target, assignments) .getProbability(); assignments.add(target); ProbabilityTable table = (ProbabilityTable) dist; table.setProbabilityForAssignment(assignments, new Probability( likelihood)); } if (dist instanceof ProbabilityTree) { double likelihood = data.prob(target, assignments) .getProbability(); assignments.add(target); ProbabilityTree tree = (ProbabilityTree) dist; tree.setProbabilityForAssignment(assignments, new Probability( likelihood)); } if (dist instanceof ContinousDistribution) { ContinousDistribution pdf = (ContinousDistribution) dist; for (LinkedList assignment : data .getAssignmentMatchesForQuery(assignments)) { pdf.pushAssignment(assignment); } } return; } Variable variable = variables.get(assignments.size()); for (String value : variable.getDomain()) { LinkedList new_assignments = new LinkedList( assignments); new_assignments.add(new Assignment(variable, value)); enumerate(target, new_assignments, variables, current + 1, data, dist); } } public static void estimate(DataSet data, BayesNet net, String targetName) { Variable target = null; // Search parent variables LinkedList parentNames = net.getParents(targetName); LinkedList allVariables = net.getVariables(); LinkedList parentVariables = new LinkedList(); for (Variable variable : allVariables) { if (parentNames.contains(variable.getName())) { parentVariables.add(variable); } if (variable.getName().equals(targetName)) { target = variable; } } Preconditions.checkState(target != null, "MLE: variable " + targetName + "Not in net"); /** * For all assignments of this variable list all assignments of it's * parents and estimate the probability for each assignment using data. */ ProbabilityDistribution dist = net.getNodes().get(targetName); if (dist instanceof ContinousDistribution) { enumerate(null, new LinkedList(), parentVariables, 0, data, dist); } else { for (String value : target.getDomain()) { enumerate(new Assignment(target, value), new LinkedList(), parentVariables, 0, data, dist); } } net.updateDistribution(target, dist); } }