Files
cs6601p2/test/dkohl/bayes/example/FoodExampleTest.java

154 lines
4.8 KiB
Java

package dkohl.bayes.example;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.List;
import org.junit.Test;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.builders.FoodExampleBuilder;
import dkohl.bayes.inference.EnumerateAll;
import dkohl.bayes.probability.ProbabilityAssignment;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ContinousDistribution;
import dkohl.bayes.probability.distribution.ProbabilityTable;
public class FoodExampleTest {
@Test
public void testFoodExample() {
BayesNet net = FoodExampleBuilder.dishNet();
System.out.println("VEGETRAIAN: ");
ProbabilityTable table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.SOMEONE_VEGETARIAN);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("MEAT: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_MEAT);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("VEGETABLES: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_VEGETABLE);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("BEEF: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_BEEF);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("PORK: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_PORK);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("POTATOS: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_POTATOS);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("TOMATOS: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_TOMATOS);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("TASTE: ");
ContinousDistribution dist = (ContinousDistribution) net.getNodes()
.get(FoodExampleBuilder.TASTE);
for (String name : dist.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : dist.getAssignments().keySet()) {
System.out.println(assignment + " "
+ dist.getAssignments().get(assignment));
}
List<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(
new Variable(FoodExampleBuilder.TASTE,
FoodExampleBuilder.RATING_DOMAIN), net,
FoodExampleBuilder.completeQueryTasteBeef());
double max_val = 0;
String max_arg = null;
for (ProbabilityAssignment p : probs) {
if (p.getProbability() > max_val) {
max_val = p.getProbability();
max_arg = p.getValue();
}
}
System.out.println("TASE BEEF: " + max_arg + " " + max_val);
probs = EnumerateAll.enumerateAsk(new Variable(
FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN),
net, FoodExampleBuilder.completeQueryTastePork());
max_val = 0;
max_arg = null;
for (ProbabilityAssignment p : probs) {
if (p.getProbability() > max_val) {
max_val = p.getProbability();
max_arg = p.getValue();
}
}
System.out.println("TASTE PORK: " + max_arg + " " + max_val);
assertThat(max_arg, equalTo("10"));
assertTrue("Error: max_val for TASTE_PORK should be > 2.6%", max_val > 0.026);
}
}