119 lines
3.4 KiB
Java
119 lines
3.4 KiB
Java
package dkohl.bayes.inference;
|
|
|
|
import java.util.LinkedList;
|
|
import java.util.List;
|
|
|
|
import dkohl.bayes.bayesnet.BayesNet;
|
|
import dkohl.bayes.probability.Assignment;
|
|
import dkohl.bayes.probability.ProbabilityAssignment;
|
|
import dkohl.bayes.probability.Variable;
|
|
|
|
/**
|
|
* Enumeration Algorithm: Exact inference in Baysian Networks.
|
|
*
|
|
* @author Daniel Kohlsdorf
|
|
*
|
|
*/
|
|
public class EnumerateAll {
|
|
|
|
/**
|
|
* The probability distribution of a variable.
|
|
*
|
|
* @param query
|
|
* the variable
|
|
* @param net
|
|
* the bayes net defining the independence
|
|
* @param assignments
|
|
* a set of assignments in this net.
|
|
* @return
|
|
*/
|
|
public static List<ProbabilityAssignment> enumerateAsk(
|
|
Variable query, BayesNet net, List<Assignment> assignments) {
|
|
LinkedList<Variable> variables = net.getVariables();
|
|
LinkedList<ProbabilityAssignment> result = new LinkedList<ProbabilityAssignment>();
|
|
|
|
// Evaluate probability for each possible
|
|
// value of the variable by enumeration
|
|
for (String value : query.getDomain()) {
|
|
LinkedList<Assignment> temp = new LinkedList<Assignment>();
|
|
temp.addAll(assignments);
|
|
temp.add(new Assignment(query, value));
|
|
double prob = enumerateAll(net, variables, temp);
|
|
result.add(new ProbabilityAssignment(query, value, prob));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* Decides if a variable is hidden or not. A variable is hidden if it is not
|
|
* assigned.
|
|
*
|
|
* @param variable
|
|
* the variable in question.
|
|
* @param assignments
|
|
* all the assignments
|
|
* @return true if not assigned
|
|
*/
|
|
private static boolean hidden(Variable variable,
|
|
LinkedList<Assignment> assignments) {
|
|
for (Assignment assignment : assignments) {
|
|
if (assignment.getVariable().getName().equals(variable.getName())) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* Recursively evaluate probability of an assignment.
|
|
*
|
|
* Can be seen as depth first search + backtracking (branching on hidden
|
|
* nodes)
|
|
*
|
|
* @param net
|
|
* a bayes net defining independence
|
|
* @param variables
|
|
* the variables left to evaluate
|
|
* @param assignments
|
|
* all assignments
|
|
* @return
|
|
*/
|
|
public static double enumerateAll(BayesNet net, List<Variable> variables,
|
|
LinkedList<Assignment> assignments) {
|
|
// if no variables left to evaluate,
|
|
// leaf node reached.
|
|
if (variables.isEmpty()) {
|
|
return 1;
|
|
}
|
|
// evaluate variable, recurse on rest
|
|
// PROLOG: [Variable|Rest].
|
|
Variable variable = variables.get(0);
|
|
List<Variable> rest = variables.subList(1, variables.size());
|
|
// if current variable is hidden
|
|
if (hidden(variable, assignments)) {
|
|
// sum out all possible values for that variable
|
|
double sumOut = 0;
|
|
for (String value : variable.getDomain()) {
|
|
// by temporarily adding each value to the asigned variable set
|
|
LinkedList<Assignment> temp = new LinkedList<Assignment>();
|
|
temp.addAll(assignments);
|
|
temp.add(new Assignment(variable, value));
|
|
|
|
// then evaluate this variable
|
|
double val = net.getNodes().get(variable.getName()).eval(temp)
|
|
.getProbability();
|
|
// and all that depend on it
|
|
val *= enumerateAll(net, rest, temp);
|
|
|
|
sumOut += val;
|
|
}
|
|
return sumOut;
|
|
}
|
|
// if not just evaluate variable and continue.
|
|
return net.getNodes().get(variable.getName()).eval(assignments)
|
|
.getProbability()
|
|
* enumerateAll(net, rest, assignments);
|
|
}
|
|
|
|
}
|