Files
cs6601p2/src/dkohl/bayes/estimation/MaximumLikelihoodEstimation.java

120 lines
3.9 KiB
Java

package dkohl.bayes.estimation;
import java.util.LinkedList;
import com.google.common.base.Preconditions;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.Probability;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ContinousDistribution;
import dkohl.bayes.probability.distribution.ProbabilityDistribution;
import dkohl.bayes.probability.distribution.ProbabilityTable;
import dkohl.bayes.probability.distribution.ProbabilityTree;
import dkohl.bayes.statistic.DataSet;
/**
* Works for fully observable bayes nets
*
* @author Daniel Kohlsdorf
*/
public class MaximumLikelihoodEstimation {
/**
* Enumerates all possible assignments for a set of variables using depth
* first + back tracking. Estimates the probability for each assignment
* given data and inserts the values in a probability function.
*
* @param target
* The target variable
* @param assignments
* A list built by traversing the tree. Adding an assignment on
* each layer
* @param variables
* All variables to assign
* @param current
* The current variable
* @param data
* the data for estimation
* @param table
* the probability table
*/
private static void enumerate(Assignment target,
LinkedList<Assignment> assignments, LinkedList<Variable> variables,
int current, DataSet data, ProbabilityDistribution dist) {
if (assignments.size() == variables.size()) {
if (dist instanceof ProbabilityTable) {
double likelihood = data.prob(target, assignments)
.getProbability();
assignments.add(target);
ProbabilityTable table = (ProbabilityTable) dist;
table.setProbabilityForAssignment(assignments, new Probability(
likelihood));
}
if (dist instanceof ProbabilityTree) {
double likelihood = data.prob(target, assignments)
.getProbability();
assignments.add(target);
ProbabilityTree tree = (ProbabilityTree) dist;
tree.setProbabilityForAssignment(assignments, new Probability(
likelihood));
}
if (dist instanceof ContinousDistribution) {
ContinousDistribution pdf = (ContinousDistribution) dist;
for (LinkedList<Assignment> assignment : data
.getAssignmentMatchesForQuery(assignments)) {
pdf.pushAssignment(assignment);
}
}
return;
}
Variable variable = variables.get(assignments.size());
for (String value : variable.getDomain()) {
LinkedList<Assignment> new_assignments = new LinkedList<Assignment>(
assignments);
new_assignments.add(new Assignment(variable, value));
enumerate(target, new_assignments, variables, current + 1, data,
dist);
}
}
public static void estimate(DataSet data, BayesNet net, String targetName) {
Variable target = null;
// Search parent variables
LinkedList<String> parentNames = net.getParents(targetName);
LinkedList<Variable> allVariables = net.getVariables();
LinkedList<Variable> parentVariables = new LinkedList<Variable>();
for (Variable variable : allVariables) {
if (parentNames.contains(variable.getName())) {
parentVariables.add(variable);
}
if (variable.getName().equals(targetName)) {
target = variable;
}
}
Preconditions.checkState(target != null, "MLE: variable " + targetName
+ "Not in net");
/**
* For all assignments of this variable list all assignments of it's
* parents and estimate the probability for each assignment using data.
*/
ProbabilityDistribution dist = net.getNodes().get(targetName);
if (dist instanceof ContinousDistribution) {
enumerate(null, new LinkedList<Assignment>(), parentVariables, 0,
data, dist);
} else {
for (String value : target.getDomain()) {
enumerate(new Assignment(target, value),
new LinkedList<Assignment>(), parentVariables, 0, data,
dist);
}
}
net.updateDistribution(target, dist);
}
}