120 lines
3.9 KiB
Java
120 lines
3.9 KiB
Java
package dkohl.bayes.estimation;
|
|
|
|
import java.util.LinkedList;
|
|
|
|
import com.google.common.base.Preconditions;
|
|
|
|
import dkohl.bayes.bayesnet.BayesNet;
|
|
import dkohl.bayes.probability.Assignment;
|
|
import dkohl.bayes.probability.Probability;
|
|
import dkohl.bayes.probability.Variable;
|
|
import dkohl.bayes.probability.distribution.ContinousDistribution;
|
|
import dkohl.bayes.probability.distribution.ProbabilityDistribution;
|
|
import dkohl.bayes.probability.distribution.ProbabilityTable;
|
|
import dkohl.bayes.probability.distribution.ProbabilityTree;
|
|
import dkohl.bayes.statistic.DataSet;
|
|
|
|
/**
|
|
* Works for fully observable bayes nets
|
|
*
|
|
* @author Daniel Kohlsdorf
|
|
*/
|
|
public class MaximumLikelihoodEstimation {
|
|
|
|
/**
|
|
* Enumerates all possible assignments for a set of variables using depth
|
|
* first + back tracking. Estimates the probability for each assignment
|
|
* given data and inserts the values in a probability function.
|
|
*
|
|
* @param target
|
|
* The target variable
|
|
* @param assignments
|
|
* A list built by traversing the tree. Adding an assignment on
|
|
* each layer
|
|
* @param variables
|
|
* All variables to assign
|
|
* @param current
|
|
* The current variable
|
|
* @param data
|
|
* the data for estimation
|
|
* @param table
|
|
* the probability table
|
|
*/
|
|
private static void enumerate(Assignment target,
|
|
LinkedList<Assignment> assignments, LinkedList<Variable> variables,
|
|
int current, DataSet data, ProbabilityDistribution dist) {
|
|
|
|
if (assignments.size() == variables.size()) {
|
|
if (dist instanceof ProbabilityTable) {
|
|
double likelihood = data.prob(target, assignments)
|
|
.getProbability();
|
|
assignments.add(target);
|
|
ProbabilityTable table = (ProbabilityTable) dist;
|
|
table.setProbabilityForAssignment(assignments, new Probability(
|
|
likelihood));
|
|
}
|
|
if (dist instanceof ProbabilityTree) {
|
|
double likelihood = data.prob(target, assignments)
|
|
.getProbability();
|
|
assignments.add(target);
|
|
|
|
ProbabilityTree tree = (ProbabilityTree) dist;
|
|
tree.setProbabilityForAssignment(assignments, new Probability(
|
|
likelihood));
|
|
}
|
|
if (dist instanceof ContinousDistribution) {
|
|
ContinousDistribution pdf = (ContinousDistribution) dist;
|
|
for (LinkedList<Assignment> assignment : data
|
|
.getAssignmentMatchesForQuery(assignments)) {
|
|
pdf.pushAssignment(assignment);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
Variable variable = variables.get(assignments.size());
|
|
for (String value : variable.getDomain()) {
|
|
LinkedList<Assignment> new_assignments = new LinkedList<Assignment>(
|
|
assignments);
|
|
new_assignments.add(new Assignment(variable, value));
|
|
enumerate(target, new_assignments, variables, current + 1, data,
|
|
dist);
|
|
}
|
|
}
|
|
|
|
public static void estimate(DataSet data, BayesNet net, String targetName) {
|
|
Variable target = null;
|
|
// Search parent variables
|
|
LinkedList<String> parentNames = net.getParents(targetName);
|
|
LinkedList<Variable> allVariables = net.getVariables();
|
|
LinkedList<Variable> parentVariables = new LinkedList<Variable>();
|
|
for (Variable variable : allVariables) {
|
|
if (parentNames.contains(variable.getName())) {
|
|
parentVariables.add(variable);
|
|
}
|
|
if (variable.getName().equals(targetName)) {
|
|
target = variable;
|
|
}
|
|
}
|
|
Preconditions.checkState(target != null, "MLE: variable " + targetName
|
|
+ "Not in net");
|
|
|
|
/**
|
|
* For all assignments of this variable list all assignments of it's
|
|
* parents and estimate the probability for each assignment using data.
|
|
*/
|
|
ProbabilityDistribution dist = net.getNodes().get(targetName);
|
|
if (dist instanceof ContinousDistribution) {
|
|
enumerate(null, new LinkedList<Assignment>(), parentVariables, 0,
|
|
data, dist);
|
|
} else {
|
|
for (String value : target.getDomain()) {
|
|
enumerate(new Assignment(target, value),
|
|
new LinkedList<Assignment>(), parentVariables, 0, data,
|
|
dist);
|
|
}
|
|
}
|
|
net.updateDistribution(target, dist);
|
|
}
|
|
|
|
}
|