Added Daniel's Bayes Net code. Converted example code to unit tests. Minor code clean-up.
This commit is contained in:
118
src/dkohl/bayes/inference/EnumerateAll.java
Normal file
118
src/dkohl/bayes/inference/EnumerateAll.java
Normal file
@@ -0,0 +1,118 @@
|
||||
package dkohl.bayes.inference;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import dkohl.bayes.bayesnet.BayesNet;
|
||||
import dkohl.bayes.probability.Assignment;
|
||||
import dkohl.bayes.probability.ProbabilityAssignment;
|
||||
import dkohl.bayes.probability.Variable;
|
||||
|
||||
/**
|
||||
* Enumeration Algorithm: Exact inference in Baysian Networks.
|
||||
*
|
||||
* @author Daniel Kohlsdorf
|
||||
*
|
||||
*/
|
||||
public class EnumerateAll {
|
||||
|
||||
/**
|
||||
* The probability distribution of a variable.
|
||||
*
|
||||
* @param query
|
||||
* the variable
|
||||
* @param net
|
||||
* the bayes net defining the independence
|
||||
* @param assignments
|
||||
* a set of assignments in this net.
|
||||
* @return
|
||||
*/
|
||||
public static LinkedList<ProbabilityAssignment> enumerateAsk(
|
||||
Variable query, BayesNet net, LinkedList<Assignment> assignments) {
|
||||
LinkedList<Variable> variables = net.getVariables();
|
||||
LinkedList<ProbabilityAssignment> result = new LinkedList<ProbabilityAssignment>();
|
||||
|
||||
// Evaluate probability for each possible
|
||||
// value of the variable by enumeration
|
||||
for (String value : query.getDomain()) {
|
||||
LinkedList<Assignment> temp = new LinkedList<Assignment>();
|
||||
temp.addAll(assignments);
|
||||
temp.add(new Assignment(query, value));
|
||||
double prob = enumerateAll(net, variables, temp);
|
||||
result.add(new ProbabilityAssignment(query, value, prob));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decides if a variable is hidden or not. A variable is hidden if it is not
|
||||
* assigned.
|
||||
*
|
||||
* @param variable
|
||||
* the variable in question.
|
||||
* @param assignments
|
||||
* all the assignments
|
||||
* @return true if not assigned
|
||||
*/
|
||||
private static boolean hidden(Variable variable,
|
||||
LinkedList<Assignment> assignments) {
|
||||
for (Assignment assignment : assignments) {
|
||||
if (assignment.getVariable().getName().equals(variable.getName())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively evaluate probability of an assignment.
|
||||
*
|
||||
* Can be seen as depth first search + backtracking (branching on hidden
|
||||
* nodes)
|
||||
*
|
||||
* @param net
|
||||
* a bayes net defining independence
|
||||
* @param variables
|
||||
* the variables left to evaluate
|
||||
* @param assignments
|
||||
* all assignments
|
||||
* @return
|
||||
*/
|
||||
public static double enumerateAll(BayesNet net, List<Variable> variables,
|
||||
LinkedList<Assignment> assignments) {
|
||||
// if no variables left to evaluate,
|
||||
// leaf node reached.
|
||||
if (variables.isEmpty()) {
|
||||
return 1;
|
||||
}
|
||||
// evaluate variable, recurse on rest
|
||||
// PROLOG: [Variable|Rest].
|
||||
Variable variable = variables.get(0);
|
||||
List<Variable> rest = variables.subList(1, variables.size());
|
||||
// if current variable is hidden
|
||||
if (hidden(variable, assignments)) {
|
||||
// sum out all possible values for that variable
|
||||
double sumOut = 0;
|
||||
for (String value : variable.getDomain()) {
|
||||
// by temporarily adding each value to the asigned variable set
|
||||
LinkedList<Assignment> temp = new LinkedList<Assignment>();
|
||||
temp.addAll(assignments);
|
||||
temp.add(new Assignment(variable, value));
|
||||
|
||||
// then evaluate this variable
|
||||
double val = net.getNodes().get(variable.getName()).eval(temp)
|
||||
.getProbability();
|
||||
// and all that depend on it
|
||||
val *= enumerateAll(net, rest, temp);
|
||||
|
||||
sumOut += val;
|
||||
}
|
||||
return sumOut;
|
||||
}
|
||||
// if not just evaluate variable and continue.
|
||||
return net.getNodes().get(variable.getName()).eval(assignments)
|
||||
.getProbability()
|
||||
* enumerateAll(net, rest, assignments);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user