Added Daniel's Bayes Net code. Converted example code to unit tests. Minor code clean-up.
This commit is contained in:
119
src/dkohl/bayes/estimation/MaximumLikelihoodEstimation.java
Normal file
119
src/dkohl/bayes/estimation/MaximumLikelihoodEstimation.java
Normal file
@@ -0,0 +1,119 @@
|
||||
package dkohl.bayes.estimation;
|
||||
|
||||
import java.util.LinkedList;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import dkohl.bayes.bayesnet.BayesNet;
|
||||
import dkohl.bayes.probability.Assignment;
|
||||
import dkohl.bayes.probability.Probability;
|
||||
import dkohl.bayes.probability.Variable;
|
||||
import dkohl.bayes.probability.distribution.ContinousDistribution;
|
||||
import dkohl.bayes.probability.distribution.ProbabilityDistribution;
|
||||
import dkohl.bayes.probability.distribution.ProbabilityTable;
|
||||
import dkohl.bayes.probability.distribution.ProbabilityTree;
|
||||
import dkohl.bayes.statistic.DataSet;
|
||||
|
||||
/**
|
||||
* Works for fully observable bayes nets
|
||||
*
|
||||
* @author Daniel Kohlsdorf
|
||||
*/
|
||||
public class MaximumLikelihoodEstimation {
|
||||
|
||||
/**
|
||||
* Enumerates all possible assignments for a set of variables using depth
|
||||
* first + back tracking. Estimates the probability for each assignment
|
||||
* given data and inserts the values in a probability function.
|
||||
*
|
||||
* @param target
|
||||
* The target variable
|
||||
* @param assignments
|
||||
* A list built by traversing the tree. Adding an assignment on
|
||||
* each layer
|
||||
* @param variables
|
||||
* All variables to assign
|
||||
* @param current
|
||||
* The current variable
|
||||
* @param data
|
||||
* the data for estimation
|
||||
* @param table
|
||||
* the probability table
|
||||
*/
|
||||
private static void enumerate(Assignment target,
|
||||
LinkedList<Assignment> assignments, LinkedList<Variable> variables,
|
||||
int current, DataSet data, ProbabilityDistribution dist) {
|
||||
|
||||
if (assignments.size() == variables.size()) {
|
||||
if (dist instanceof ProbabilityTable) {
|
||||
double likelihood = data.prob(target, assignments)
|
||||
.getProbability();
|
||||
assignments.add(target);
|
||||
ProbabilityTable table = (ProbabilityTable) dist;
|
||||
table.setProbabilityForAssignment(assignments, new Probability(
|
||||
likelihood));
|
||||
}
|
||||
if (dist instanceof ProbabilityTree) {
|
||||
double likelihood = data.prob(target, assignments)
|
||||
.getProbability();
|
||||
assignments.add(target);
|
||||
|
||||
ProbabilityTree tree = (ProbabilityTree) dist;
|
||||
tree.setProbabilityForAssignment(assignments, new Probability(
|
||||
likelihood));
|
||||
}
|
||||
if (dist instanceof ContinousDistribution) {
|
||||
ContinousDistribution pdf = (ContinousDistribution) dist;
|
||||
for (LinkedList<Assignment> assignment : data
|
||||
.getAssignmentMatchesForQuery(assignments)) {
|
||||
pdf.pushAssignment(assignment);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
Variable variable = variables.get(assignments.size());
|
||||
for (String value : variable.getDomain()) {
|
||||
LinkedList<Assignment> new_assignments = new LinkedList<Assignment>(
|
||||
assignments);
|
||||
new_assignments.add(new Assignment(variable, value));
|
||||
enumerate(target, new_assignments, variables, current + 1, data,
|
||||
dist);
|
||||
}
|
||||
}
|
||||
|
||||
public static void estimate(DataSet data, BayesNet net, String targetName) {
|
||||
Variable target = null;
|
||||
// Search parent variables
|
||||
LinkedList<String> parentNames = net.getParents(targetName);
|
||||
LinkedList<Variable> allVariables = net.getVariables();
|
||||
LinkedList<Variable> parentVariables = new LinkedList<Variable>();
|
||||
for (Variable variable : allVariables) {
|
||||
if (parentNames.contains(variable.getName())) {
|
||||
parentVariables.add(variable);
|
||||
}
|
||||
if (variable.getName().equals(targetName)) {
|
||||
target = variable;
|
||||
}
|
||||
}
|
||||
Preconditions.checkState(target != null, "MLE: variable " + targetName
|
||||
+ "Not in net");
|
||||
|
||||
/**
|
||||
* For all assignments of this variable list all assignments of it's
|
||||
* parents and estimate the probability for each assignment using data.
|
||||
*/
|
||||
ProbabilityDistribution dist = net.getNodes().get(targetName);
|
||||
if (dist instanceof ContinousDistribution) {
|
||||
enumerate(null, new LinkedList<Assignment>(), parentVariables, 0,
|
||||
data, dist);
|
||||
} else {
|
||||
for (String value : target.getDomain()) {
|
||||
enumerate(new Assignment(target, value),
|
||||
new LinkedList<Assignment>(), parentVariables, 0, data,
|
||||
dist);
|
||||
}
|
||||
}
|
||||
net.updateDistribution(target, dist);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user