154 lines
4.8 KiB
Java
154 lines
4.8 KiB
Java
package dkohl.bayes.example;
|
|
|
|
import java.util.LinkedList;
|
|
|
|
import static org.hamcrest.core.IsEqual.equalTo;
|
|
import static org.junit.Assert.assertThat;
|
|
import static org.junit.Assert.assertTrue;
|
|
|
|
import org.junit.Test;
|
|
|
|
import dkohl.bayes.bayesnet.BayesNet;
|
|
import dkohl.bayes.example.builders.FoodExampleBuilder;
|
|
import dkohl.bayes.inference.EnumerateAll;
|
|
import dkohl.bayes.probability.ProbabilityAssignment;
|
|
import dkohl.bayes.probability.Variable;
|
|
import dkohl.bayes.probability.distribution.ContinousDistribution;
|
|
import dkohl.bayes.probability.distribution.ProbabilityTable;
|
|
|
|
public class FoodExampleTest {
|
|
|
|
@Test
|
|
public void testFoodExample() {
|
|
BayesNet net = FoodExampleBuilder.dishNet();
|
|
|
|
System.out.println("VEGETRAIAN: ");
|
|
ProbabilityTable table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.SOMEONE_VEGETARIAN);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("MEAT: ");
|
|
table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.CONTAINS_MEAT);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("VEGETABLES: ");
|
|
table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.CONTAINS_VEGETABLE);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("BEEF: ");
|
|
table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.CONTAINS_BEEF);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("PORK: ");
|
|
table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.CONTAINS_PORK);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("POTATOS: ");
|
|
table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.CONTAINS_POTATOS);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("TOMATOS: ");
|
|
table = (ProbabilityTable) net.getNodes().get(
|
|
FoodExampleBuilder.CONTAINS_TOMATOS);
|
|
for (String name : table.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
for (String assignment : table.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ table.getAssignments().get(assignment).getProbability());
|
|
}
|
|
|
|
System.out.println("TASTE: ");
|
|
ContinousDistribution dist = (ContinousDistribution) net.getNodes()
|
|
.get(FoodExampleBuilder.TASTE);
|
|
for (String name : dist.getNames()) {
|
|
System.out.print(name + " ");
|
|
}
|
|
System.out.println();
|
|
|
|
for (String assignment : dist.getAssignments().keySet()) {
|
|
System.out.println(assignment + " "
|
|
+ dist.getAssignments().get(assignment));
|
|
}
|
|
|
|
LinkedList<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(
|
|
new Variable(FoodExampleBuilder.TASTE,
|
|
FoodExampleBuilder.RATING_DOMAIN), net,
|
|
FoodExampleBuilder.completeQueryTasteBeef());
|
|
double max_val = 0;
|
|
String max_arg = null;
|
|
for (ProbabilityAssignment p : probs) {
|
|
if (p.getProbability() > max_val) {
|
|
max_val = p.getProbability();
|
|
max_arg = p.getValue();
|
|
}
|
|
}
|
|
System.out.println("TASE BEEF: " + max_arg + " " + max_val);
|
|
|
|
probs = EnumerateAll.enumerateAsk(new Variable(
|
|
FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN),
|
|
net, FoodExampleBuilder.completeQueryTastePork());
|
|
max_val = 0;
|
|
max_arg = null;
|
|
for (ProbabilityAssignment p : probs) {
|
|
if (p.getProbability() > max_val) {
|
|
max_val = p.getProbability();
|
|
max_arg = p.getValue();
|
|
}
|
|
}
|
|
|
|
System.out.println("TASE PORK: " + max_arg + " " + max_val);
|
|
assertThat(max_arg, equalTo("10"));
|
|
assertTrue("Error: max_val for TASTE_PORK should be > 2.6%", max_val > 0.026);
|
|
}
|
|
|
|
}
|