package dkohl.bayes.example; import java.util.LinkedList; import static org.hamcrest.core.IsEqual.equalTo; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import org.junit.Test; import dkohl.bayes.bayesnet.BayesNet; import dkohl.bayes.example.builders.FoodExampleBuilder; import dkohl.bayes.inference.EnumerateAll; import dkohl.bayes.probability.ProbabilityAssignment; import dkohl.bayes.probability.Variable; import dkohl.bayes.probability.distribution.ContinousDistribution; import dkohl.bayes.probability.distribution.ProbabilityTable; public class FoodExampleTest { @Test public void testFoodExample() { BayesNet net = FoodExampleBuilder.dishNet(); System.out.println("VEGETRAIAN: "); ProbabilityTable table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.SOMEONE_VEGETARIAN); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("MEAT: "); table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.CONTAINS_MEAT); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("VEGETABLES: "); table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.CONTAINS_VEGETABLE); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("BEEF: "); table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.CONTAINS_BEEF); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("PORK: "); table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.CONTAINS_PORK); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("POTATOS: "); table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.CONTAINS_POTATOS); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("TOMATOS: "); table = (ProbabilityTable) net.getNodes().get( FoodExampleBuilder.CONTAINS_TOMATOS); for (String name : table.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : table.getAssignments().keySet()) { System.out.println(assignment + " " + table.getAssignments().get(assignment).getProbability()); } System.out.println("TASTE: "); ContinousDistribution dist = (ContinousDistribution) net.getNodes() .get(FoodExampleBuilder.TASTE); for (String name : dist.getNames()) { System.out.print(name + " "); } System.out.println(); for (String assignment : dist.getAssignments().keySet()) { System.out.println(assignment + " " + dist.getAssignments().get(assignment)); } LinkedList probs = EnumerateAll.enumerateAsk( new Variable(FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN), net, FoodExampleBuilder.completeQueryTasteBeef()); double max_val = 0; String max_arg = null; for (ProbabilityAssignment p : probs) { if (p.getProbability() > max_val) { max_val = p.getProbability(); max_arg = p.getValue(); } } System.out.println("TASE BEEF: " + max_arg + " " + max_val); probs = EnumerateAll.enumerateAsk(new Variable( FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN), net, FoodExampleBuilder.completeQueryTastePork()); max_val = 0; max_arg = null; for (ProbabilityAssignment p : probs) { if (p.getProbability() > max_val) { max_val = p.getProbability(); max_arg = p.getValue(); } } System.out.println("TASE PORK: " + max_arg + " " + max_val); assertThat(max_arg, equalTo("10")); assertTrue("Error: max_val for TASTE_PORK should be > 2.6%", max_val > 0.026); } }