Added Daniel's Bayes Net code. Converted example code to unit tests. Minor code clean-up.
This commit is contained in:
153
test/dkohl/bayes/example/FoodExampleTest.java
Normal file
153
test/dkohl/bayes/example/FoodExampleTest.java
Normal file
@@ -0,0 +1,153 @@
|
||||
package dkohl.bayes.example;
|
||||
|
||||
import java.util.LinkedList;
|
||||
|
||||
import static org.hamcrest.core.IsEqual.equalTo;
|
||||
import static org.junit.Assert.assertThat;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import dkohl.bayes.bayesnet.BayesNet;
|
||||
import dkohl.bayes.example.builders.FoodExampleBuilder;
|
||||
import dkohl.bayes.inference.EnumerateAll;
|
||||
import dkohl.bayes.probability.ProbabilityAssignment;
|
||||
import dkohl.bayes.probability.Variable;
|
||||
import dkohl.bayes.probability.distribution.ContinousDistribution;
|
||||
import dkohl.bayes.probability.distribution.ProbabilityTable;
|
||||
|
||||
public class FoodExampleTest {
|
||||
|
||||
@Test
|
||||
public void testFoodExample() {
|
||||
BayesNet net = FoodExampleBuilder.dishNet();
|
||||
|
||||
System.out.println("VEGETRAIAN: ");
|
||||
ProbabilityTable table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.SOMEONE_VEGETARIAN);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("MEET: ");
|
||||
table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.CONTAINS_MEET);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("VEGETABLES: ");
|
||||
table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.CONTAINS_VEGETABLE);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("BEEF: ");
|
||||
table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.CONTAINS_BEEF);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("PORK: ");
|
||||
table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.CONTAINS_PORK);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("POTATOS: ");
|
||||
table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.CONTAINS_POTATOS);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("TOMATOS: ");
|
||||
table = (ProbabilityTable) net.getNodes().get(
|
||||
FoodExampleBuilder.CONTAINS_TOMATOS);
|
||||
for (String name : table.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
for (String assignment : table.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ table.getAssignments().get(assignment).getProbability());
|
||||
}
|
||||
|
||||
System.out.println("TASTE: ");
|
||||
ContinousDistribution dist = (ContinousDistribution) net.getNodes()
|
||||
.get(FoodExampleBuilder.TASTE);
|
||||
for (String name : dist.getNames()) {
|
||||
System.out.print(name + " ");
|
||||
}
|
||||
System.out.println();
|
||||
|
||||
for (String assignment : dist.getAssignments().keySet()) {
|
||||
System.out.println(assignment + " "
|
||||
+ dist.getAssignments().get(assignment));
|
||||
}
|
||||
|
||||
LinkedList<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(
|
||||
new Variable(FoodExampleBuilder.TASTE,
|
||||
FoodExampleBuilder.RATING_DOMAIN), net,
|
||||
FoodExampleBuilder.completeQueryTasteBeef());
|
||||
double max_val = 0;
|
||||
String max_arg = null;
|
||||
for (ProbabilityAssignment p : probs) {
|
||||
if (p.getProbability() > max_val) {
|
||||
max_val = p.getProbability();
|
||||
max_arg = p.getValue();
|
||||
}
|
||||
}
|
||||
System.out.println("TASE BEEF: " + max_arg + " " + max_val);
|
||||
|
||||
probs = EnumerateAll.enumerateAsk(new Variable(
|
||||
FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN),
|
||||
net, FoodExampleBuilder.completeQueryTastePork());
|
||||
max_val = 0;
|
||||
max_arg = null;
|
||||
for (ProbabilityAssignment p : probs) {
|
||||
if (p.getProbability() > max_val) {
|
||||
max_val = p.getProbability();
|
||||
max_arg = p.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("TASE PORK: " + max_arg + " " + max_val);
|
||||
assertThat(max_arg, equalTo("10"));
|
||||
assertTrue("Error: max_val for TASTE_PORK should be > 2.6%", max_val > 0.026);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user