Added Daniel's Bayes Net code. Converted example code to unit tests. Minor code clean-up.

This commit is contained in:
Woody Folsom
2012-03-11 10:33:45 -04:00
parent a021dc2fc0
commit 571d0a1922
27 changed files with 2310 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
package dkohl.bayes.example;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.LinkedList;
import org.junit.Test;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.example.builders.AlarmNetBuilderTable;
import dkohl.bayes.example.builders.AlarmNetBuilderTree;
import dkohl.bayes.inference.EnumerateAll;
import dkohl.bayes.probability.ProbabilityAssignment;
import dkohl.bayes.probability.Variable;
public class AlarmExampleTest {
@Test
public void testAlarmExample() {
BayesNet sprinkler = AlarmNetBuilderTable.alarm();
sprinkler = AlarmNetBuilderTree.sprinkler();
// P(B | j, m)
LinkedList<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(
new Variable(AlarmNetBuilderTable.BURGLARY,
AlarmNetBuilderTable.DOMAIN), sprinkler,
AlarmNetBuilderTable.completeQueryBulgary());
System.out.print("Burglary: <");
for (ProbabilityAssignment p : probs) {
System.out.print("p: "+ p.toString() + ",");
}
System.out.println(">");
//assert that burglary is more likely to be false
assertThat(probs.size(),equalTo(2));
assertTrue(probs.get(0).getProbability() < probs.get(1).getProbability());
}
}

View File

@@ -0,0 +1,153 @@
package dkohl.bayes.example;
import java.util.LinkedList;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.example.builders.FoodExampleBuilder;
import dkohl.bayes.inference.EnumerateAll;
import dkohl.bayes.probability.ProbabilityAssignment;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ContinousDistribution;
import dkohl.bayes.probability.distribution.ProbabilityTable;
public class FoodExampleTest {
@Test
public void testFoodExample() {
BayesNet net = FoodExampleBuilder.dishNet();
System.out.println("VEGETRAIAN: ");
ProbabilityTable table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.SOMEONE_VEGETARIAN);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("MEET: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_MEET);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("VEGETABLES: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_VEGETABLE);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("BEEF: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_BEEF);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("PORK: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_PORK);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("POTATOS: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_POTATOS);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("TOMATOS: ");
table = (ProbabilityTable) net.getNodes().get(
FoodExampleBuilder.CONTAINS_TOMATOS);
for (String name : table.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : table.getAssignments().keySet()) {
System.out.println(assignment + " "
+ table.getAssignments().get(assignment).getProbability());
}
System.out.println("TASTE: ");
ContinousDistribution dist = (ContinousDistribution) net.getNodes()
.get(FoodExampleBuilder.TASTE);
for (String name : dist.getNames()) {
System.out.print(name + " ");
}
System.out.println();
for (String assignment : dist.getAssignments().keySet()) {
System.out.println(assignment + " "
+ dist.getAssignments().get(assignment));
}
LinkedList<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(
new Variable(FoodExampleBuilder.TASTE,
FoodExampleBuilder.RATING_DOMAIN), net,
FoodExampleBuilder.completeQueryTasteBeef());
double max_val = 0;
String max_arg = null;
for (ProbabilityAssignment p : probs) {
if (p.getProbability() > max_val) {
max_val = p.getProbability();
max_arg = p.getValue();
}
}
System.out.println("TASE BEEF: " + max_arg + " " + max_val);
probs = EnumerateAll.enumerateAsk(new Variable(
FoodExampleBuilder.TASTE, FoodExampleBuilder.RATING_DOMAIN),
net, FoodExampleBuilder.completeQueryTastePork());
max_val = 0;
max_arg = null;
for (ProbabilityAssignment p : probs) {
if (p.getProbability() > max_val) {
max_val = p.getProbability();
max_arg = p.getValue();
}
}
System.out.println("TASE PORK: " + max_arg + " " + max_val);
assertThat(max_arg, equalTo("10"));
assertTrue("Error: max_val for TASTE_PORK should be > 2.6%", max_val > 0.026);
}
}

View File

@@ -0,0 +1,50 @@
package dkohl.bayes.example;
import org.junit.Test;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertThat;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.estimation.MaximumLikelihoodEstimation;
import dkohl.bayes.example.builders.EstimateSprinklerNetBuilderTable;
import dkohl.bayes.probability.distribution.ProbabilityTable;
import dkohl.bayes.statistic.DataSet;
/**
* Parameter estimation for the sprinkler net example
*
* http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html
*
* @author Daniel Kohlsdorf
*/
public class SprinklerNetExampleTest {
@Test
public void testSprinklerNetExample() {
BayesNet net = EstimateSprinklerNetBuilderTable.sprinkler();
DataSet data = EstimateSprinklerNetBuilderTable.dataSet();
MaximumLikelihoodEstimation.estimate(data, net,
EstimateSprinklerNetBuilderTable.GRASS_WET);
/**
* Output PDF for grass is wet
*/
ProbabilityTable table = (ProbabilityTable) net.getNodes().get(
EstimateSprinklerNetBuilderTable.GRASS_WET);
for (String name : table.getNames()) {
System.out.print(name + " | ");
}
System.out.println();
for (String key : table.getAssignments().keySet()) {
System.out.println(key + " "
+ table.getAssignments().get(key).getProbability());
if ("false;false;true;".equals(key)) {
assertThat(table.getAssignments().get(key).getProbability(), equalTo(0.0));
}
if ("true;false;true;".equals(key)) {
assertThat(table.getAssignments().get(key).getProbability(), equalTo(0.9));
}
}
}
}

View File

@@ -0,0 +1,176 @@
package dkohl.bayes.example.builders;
import java.util.LinkedList;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.Probability;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ProbabilityTable;
/**
* The alarm example for baysian nets.
*
* Stuart Russel, Peter Norvig: Artificial Intelligence: A Modern Approach, 3ed
* Edition, Prentice Hall, 2010
*
* @author Daniel Kohlsdorf
*/
public class AlarmNetBuilderTable {
// Variable names
public static final String BURGLARY = "Burglary";
public static final String EARTHQUAKE = "Earthquake";
public static final String ALARM = "Alarm";
public static final String JOHN = "John";
public static final String MARRY = "Marry";
// Possible outcomes
public static final String TRUE = "true";
public static final String FALSE = "false";
// A variables domain
public static final String DOMAIN[] = { TRUE, FALSE };
// A set of variables
public static final String VARIABLES[] = { BURGLARY, EARTHQUAKE, ALARM,
JOHN, MARRY };
/**
* Builds the query from the book: P(B| j, m)
*/
public static LinkedList<Assignment> completeQueryBulgary() {
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE));
return new LinkedList<Assignment>(assignment);
}
/**
* Build burglary prior
*
* @param alarmNet
*/
private static void burglary(BayesNet alarmNet) {
Probability p_bulglary = new Probability(0.001);
String names[] = { BURGLARY };
ProbabilityTable burglary = new ProbabilityTable(names);
burglary.setProbabilityForAssignment("true;", p_bulglary);
burglary.setProbabilityForAssignment("false;", p_bulglary.rest());
alarmNet.setDistribution(new Variable(BURGLARY, DOMAIN), burglary);
}
/**
* Build earthquake prior
*
* @param alarmNet
*/
private static void earthquake(BayesNet alarmNet) {
Probability p_earthquake = new Probability(0.002);
String names[] = { EARTHQUAKE };
ProbabilityTable earthquake = new ProbabilityTable(names);
earthquake.setProbabilityForAssignment("true;", p_earthquake);
earthquake.setProbabilityForAssignment("false;", p_earthquake.rest());
alarmNet.setDistribution(new Variable(EARTHQUAKE, DOMAIN), earthquake);
}
/**
* Jhon calls!
*
* @param alarmNet
*/
private static void jhon(BayesNet alarmNet) {
// P(ALARM) == true
Probability t = new Probability(.90);
// P(ALARM) == false
Probability f = new Probability(.05);
String names[] = { ALARM, JOHN };
ProbabilityTable jhon = new ProbabilityTable(names);
jhon.setProbabilityForAssignment("true;true;", t);
jhon.setProbabilityForAssignment("false;true;", f);
jhon.setProbabilityForAssignment("true;false", t.rest());
jhon.setProbabilityForAssignment("false;false;", f.rest());
alarmNet.setDistribution(new Variable(JOHN, DOMAIN), jhon);
}
/**
* Marry calls!
*
* @param alarmNet
*/
private static void mary(BayesNet alarmNet) {
// P(ALARM) == true
Probability t = new Probability(.70);
// P(ALARM) == false
Probability f = new Probability(.01);
String names[] = { ALARM, MARRY };
ProbabilityTable marry = new ProbabilityTable(names);
marry.setProbabilityForAssignment("true;true;", t);
marry.setProbabilityForAssignment("false;true;", f);
marry.setProbabilityForAssignment("true;false", t.rest());
marry.setProbabilityForAssignment("false;false;", f.rest());
alarmNet.setDistribution(new Variable(MARRY, DOMAIN), marry);
}
/**
* The alarm goes off!
*
* @param alarmNet
*/
private static void alarm(BayesNet alarmNet) {
// P(ALARM | BURGLARY = true, EARTHQUAKE = true)
Probability tt = new Probability(.95);
// P(ALARM | BURGLARY = true, EARTHQUAKE = false)
Probability tf = new Probability(.94);
// P(ALARM | BURGLARY = false, EARTHQUAKE = true)
Probability ft = new Probability(.29);
// P(ALARM | BURGLARY = false, EARTHQUAKE = false)
Probability ff = new Probability(.001);
String names[] = { BURGLARY, EARTHQUAKE, ALARM };
ProbabilityTable alarm = new ProbabilityTable(names);
alarm.setProbabilityForAssignment(TRUE + ";" + TRUE + ";" + TRUE + ";",
tt);
alarm.setProbabilityForAssignment(
TRUE + ";" + FALSE + ";" + TRUE + ";", tf);
alarm.setProbabilityForAssignment(
FALSE + ";" + TRUE + ";" + TRUE + ";", ft);
alarm.setProbabilityForAssignment(FALSE + ";" + FALSE + ";" + TRUE
+ ";", ff);
alarm.setProbabilityForAssignment(
TRUE + ";" + TRUE + ";" + FALSE + ";", tt.rest());
alarm.setProbabilityForAssignment(TRUE + ";" + FALSE + ";" + FALSE
+ ";", tf.rest());
alarm.setProbabilityForAssignment(FALSE + ";" + TRUE + ";" + FALSE
+ ";", ft.rest());
alarm.setProbabilityForAssignment(FALSE + ";" + FALSE + ";" + FALSE
+ ";", ff.rest());
alarmNet.setDistribution(new Variable(ALARM, DOMAIN), alarm);
}
public static BayesNet alarm() {
BayesNet alarmNet = new BayesNet(VARIABLES);
// set probability tables and priors
burglary(alarmNet);
earthquake(alarmNet);
alarm(alarmNet);
jhon(alarmNet);
mary(alarmNet);
// construct the graph
alarmNet.connect(ALARM, BURGLARY);
alarmNet.connect(ALARM, EARTHQUAKE);
alarmNet.connect(JOHN, ALARM);
alarmNet.connect(MARRY, ALARM);
return alarmNet;
}
}

View File

@@ -0,0 +1,244 @@
package dkohl.bayes.example.builders;
import java.util.LinkedList;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.Probability;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ProbabilityTree;
public class AlarmNetBuilderTree {
// Variable names
public static final String BURGLARY = "Burglary";
public static final String EARTHQUAKE = "Earthquake";
public static final String ALARM = "Alarm";
public static final String JOHN = "John";
public static final String MARRY = "Marry";
// Possible outcomes
public static final String TRUE = "true";
public static final String FALSE = "false";
// A variables domain
public static final String DOMAIN[] = { TRUE, FALSE };
// A set of variables
public static final String VARIABLES[] = { BURGLARY, EARTHQUAKE, ALARM,
JOHN, MARRY };
/**
* Builds the query from the book: P(B| j, m)
*/
public static LinkedList<Assignment> completeQueryBulgary() {
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE));
return new LinkedList<Assignment>(assignment);
}
/**
* Build burglary prior
*
* @param sprinklerNet
*/
private static void burglary(BayesNet sprinklerNet) {
Probability p_bulglary = new Probability(0.001);
ProbabilityTree burglary = new ProbabilityTree();
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE));
burglary.setProbabilityForAssignment(assignment, p_bulglary);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE));
burglary.setProbabilityForAssignment(assignment, p_bulglary.rest());
sprinklerNet.setDistribution(new Variable(BURGLARY, DOMAIN), burglary);
}
/**
* Build earthquake prior
*
* @param sprinklerNet
*/
private static void earthquake(BayesNet sprinklerNet) {
Probability p_earthquake = new Probability(0.002);
ProbabilityTree earthquake = new ProbabilityTree();
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE));
earthquake.setProbabilityForAssignment(assignment, p_earthquake);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE));
earthquake.setProbabilityForAssignment(assignment, p_earthquake.rest());
sprinklerNet.setDistribution(new Variable(EARTHQUAKE, DOMAIN),
earthquake);
}
/**
* Jhon calls!
*
* @param sprinklerNet
*/
private static void jhon(BayesNet sprinklerNet) {
// P(ALARM) == true
Probability t = new Probability(.90);
// P(ALARM) == false
Probability f = new Probability(.05);
ProbabilityTree jhon = new ProbabilityTree();
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
jhon.setProbabilityForAssignment(assignment, t);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(JOHN, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
jhon.setProbabilityForAssignment(assignment, f);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(JOHN, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
jhon.setProbabilityForAssignment(assignment, t.rest());
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(JOHN, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
jhon.setProbabilityForAssignment(assignment, f.rest());
sprinklerNet.setDistribution(new Variable(JOHN, DOMAIN), jhon);
}
/**
* Marry calls!
*
* @param sprinklerNet
*/
private static void mary(BayesNet sprinklerNet) {
// P(ALARM) == true
Probability t = new Probability(.70);
// P(ALARM) == false
Probability f = new Probability(.01);
ProbabilityTree mary = new ProbabilityTree();
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
mary.setProbabilityForAssignment(assignment, t);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(MARRY, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
mary.setProbabilityForAssignment(assignment, f);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(MARRY, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
mary.setProbabilityForAssignment(assignment, t.rest());
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(MARRY, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
mary.setProbabilityForAssignment(assignment, f.rest());
sprinklerNet.setDistribution(new Variable(MARRY, DOMAIN), mary);
}
/**
* The alarm goes off!
*
* @param sprinklerNet
*/
private static void alarm(BayesNet sprinklerNet) {
// P(ALARM | BURGLARY = true, EARTHQUAKE = true)
Probability tt = new Probability(.95);
// P(ALARM | BURGLARY = true, EARTHQUAKE = false)
Probability tf = new Probability(.94);
// P(ALARM | BURGLARY = false, EARTHQUAKE = true)
Probability ft = new Probability(.29);
// P(ALARM | BURGLARY = false, EARTHQUAKE = false)
Probability ff = new Probability(.001);
ProbabilityTree alarm = new ProbabilityTree();
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
alarm.setProbabilityForAssignment(assignment, tt);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
alarm.setProbabilityForAssignment(assignment, tf);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
alarm.setProbabilityForAssignment(assignment, ft);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), TRUE));
alarm.setProbabilityForAssignment(assignment, ff);
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
alarm.setProbabilityForAssignment(assignment, tt.rest());
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
alarm.setProbabilityForAssignment(assignment, tf.rest());
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), TRUE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
alarm.setProbabilityForAssignment(assignment, ft.rest());
assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(BURGLARY, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(EARTHQUAKE, DOMAIN), FALSE));
assignment.add(new Assignment(new Variable(ALARM, DOMAIN), FALSE));
alarm.setProbabilityForAssignment(assignment, ff.rest());
sprinklerNet.setDistribution(new Variable(ALARM, DOMAIN), alarm);
}
public static BayesNet sprinkler() {
BayesNet sprinklerNet = new BayesNet(VARIABLES);
// set probability tables and priors
burglary(sprinklerNet);
earthquake(sprinklerNet);
alarm(sprinklerNet);
jhon(sprinklerNet);
mary(sprinklerNet);
// construct the graph
sprinklerNet.connect(ALARM, BURGLARY);
sprinklerNet.connect(ALARM, EARTHQUAKE);
sprinklerNet.connect(JOHN, ALARM);
sprinklerNet.connect(MARRY, ALARM);
return sprinklerNet;
}
}

View File

@@ -0,0 +1,179 @@
package dkohl.bayes.example.builders;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.Probability;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ProbabilityTable;
import dkohl.bayes.statistic.DataPoint;
import dkohl.bayes.statistic.DataSet;
/**
* The Sprinkler net example
*
* http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html
*
* @author Daniel Kohlsdorf
*/
public class EstimateSprinklerNetBuilderTable {
// Variable names
public static final String CLOUDY = "Cloudy";
public static final String SPRINKLER = "Sprinkler";
public static final String GRASS_WET = "GrassWet";
public static final String RAIN = "Rain";
// Possible outcomes
public static final String TRUE = "true";
public static final String FALSE = "false";
// A variables domain
public static final String DOMAIN[] = { TRUE, FALSE };
// A set of variables
public static final String VARIABLES[] = { CLOUDY, SPRINKLER, GRASS_WET,
RAIN };
private static void cloudy(BayesNet sprinklerNet) {
Probability p_cloudy = new Probability(0.5);
String names[] = { CLOUDY };
ProbabilityTable cloudy = new ProbabilityTable(names);
cloudy.setProbabilityForAssignment("true;", p_cloudy);
cloudy.setProbabilityForAssignment("false;", p_cloudy.rest());
sprinklerNet.setDistribution(new Variable(CLOUDY, DOMAIN), cloudy);
}
private static void rain(BayesNet sprinklerNet) {
Probability p_cloudy = new Probability(0.8);
Probability p_notcloudy = new Probability(0.2);
String names[] = { CLOUDY, RAIN };
ProbabilityTable rain = new ProbabilityTable(names);
rain.setProbabilityForAssignment("true;true;", p_cloudy);
rain.setProbabilityForAssignment("false;false;", p_notcloudy);
rain.setProbabilityForAssignment("false;true;", p_notcloudy.rest());
rain.setProbabilityForAssignment("true;false;", p_cloudy.rest());
sprinklerNet.setDistribution(new Variable(RAIN, DOMAIN), rain);
}
private static void sprinkler(BayesNet sprinklerNet) {
Probability p_cloudy = new Probability(0.1);
Probability p_notcloudy = new Probability(0.5);
String names[] = { CLOUDY, SPRINKLER };
ProbabilityTable sprinkler = new ProbabilityTable(names);
sprinkler.setProbabilityForAssignment("true;true;", p_cloudy);
sprinkler.setProbabilityForAssignment("false;false;", p_notcloudy);
sprinkler
.setProbabilityForAssignment("false;true;", p_notcloudy.rest());
sprinkler.setProbabilityForAssignment("true;false;", p_cloudy.rest());
sprinklerNet
.setDistribution(new Variable(SPRINKLER, DOMAIN), sprinkler);
}
private static void grass(BayesNet sprinklerNet) {
String names[] = { RAIN, SPRINKLER, GRASS_WET };
ProbabilityTable sprinkler = new ProbabilityTable(names);
sprinklerNet
.setDistribution(new Variable(GRASS_WET, DOMAIN), sprinkler);
}
public static BayesNet sprinkler() {
BayesNet sprinkler = new BayesNet(VARIABLES);
cloudy(sprinkler);
rain(sprinkler);
sprinkler(sprinkler);
grass(sprinkler);
sprinkler.connect(RAIN, CLOUDY);
sprinkler.connect(SPRINKLER, CLOUDY);
sprinkler.connect(GRASS_WET, RAIN);
sprinkler.connect(GRASS_WET, SPRINKLER);
return sprinkler;
}
private static Assignment build(String varible, String value) {
return new Assignment(new Variable(varible, DOMAIN), value);
}
public static DataSet dataSet() {
DataSet dataSet = new DataSet();
/**
* If rain is false and sprinkler is false, grass is never wet.
*/
DataPoint one = new DataPoint();
one.add(build(RAIN, FALSE));
one.add(build(SPRINKLER, FALSE));
one.add(build(GRASS_WET, FALSE));
dataSet.add(one);
/**
* 1 / 10 times the grass is not wet when the sprinkler is on
*/
DataPoint two = new DataPoint();
two.add(build(RAIN, FALSE));
two.add(build(SPRINKLER, TRUE));
two.add(build(GRASS_WET, FALSE));
dataSet.add(two);
/**
* 9 / 10 times the sprinkler is on and the grass is wet
*/
for (int i = 0; i < 9; i++) {
DataPoint point = new DataPoint();
point.add(build(RAIN, FALSE));
point.add(build(SPRINKLER, TRUE));
point.add(build(GRASS_WET, TRUE));
dataSet.add(point);
}
/**
* 1 / 10 times the grass is not wet when it rains
*/
DataPoint three = new DataPoint();
three.add(build(RAIN, TRUE));
three.add(build(SPRINKLER, FALSE));
three.add(build(GRASS_WET, FALSE));
dataSet.add(three);
/**
* 9 / 10 times it rains and the grass is wet
*/
for (int i = 0; i < 9; i++) {
DataPoint point = new DataPoint();
point.add(build(RAIN, TRUE));
point.add(build(SPRINKLER, FALSE));
point.add(build(GRASS_WET, TRUE));
dataSet.add(point);
}
/**
* 1 / 100 times the grass is not wet when it rains and the sprinkler is
* on
*/
DataPoint four = new DataPoint();
four.add(build(RAIN, TRUE));
four.add(build(SPRINKLER, TRUE));
four.add(build(GRASS_WET, FALSE));
dataSet.add(four);
/**
* 99 / 100 times it rains and the grass is wet
*/
for (int i = 0; i < 99; i++) {
DataPoint point = new DataPoint();
point.add(build(RAIN, TRUE));
point.add(build(SPRINKLER, TRUE));
point.add(build(GRASS_WET, TRUE));
dataSet.add(point);
}
return dataSet;
}
}

View File

@@ -0,0 +1,266 @@
package dkohl.bayes.example.builders;
import java.util.HashSet;
import java.util.LinkedList;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.estimation.MaximumLikelihoodEstimation;
import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.Probability;
import dkohl.bayes.probability.Variable;
import dkohl.bayes.probability.distribution.ContinousDistribution;
import dkohl.bayes.probability.distribution.ProbabilityDistribution;
import dkohl.bayes.probability.distribution.ProbabilityTable;
import dkohl.bayes.statistic.DataPoint;
import dkohl.bayes.statistic.DataSet;
import dkohl.onthology.Ontology;
public class FoodExampleBuilder {
public static final String TASTE = "Taste";
public static final String SOMEONE_VEGETARIAN = "Vegetarian";
public static final String CONTAINS_MEET = "Meet";
public static final String CONTAINS_VEGETABLE = "Vegetable";
public static final String CONTAINS_BEEF = "Beef";
public static final String CONTAINS_PORK = "Pork";
public static final String CONTAINS_TOMATOS = "Tomatos";
public static final String CONTAINS_POTATOS = "Potatos";
public static final String TRUE_VALUE = "true";
public static final String FALSE_VALUE = "false";
public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE };
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
"6", "7", "8", "9", "10" };
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN,
CONTAINS_BEEF, CONTAINS_MEET, CONTAINS_PORK, CONTAINS_POTATOS,
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, TASTE };
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS, };
public static LinkedList<Assignment> completeQueryTasteBeef() {
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(CONTAINS_BEEF, DOMAIN),
TRUE_VALUE));
assignment.add(new Assignment(new Variable(CONTAINS_TOMATOS, DOMAIN),
TRUE_VALUE));
return new LinkedList<Assignment>(assignment);
}
public static LinkedList<Assignment> completeQueryTastePork() {
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
assignment.add(new Assignment(new Variable(CONTAINS_PORK, DOMAIN),
TRUE_VALUE));
assignment.add(new Assignment(new Variable(CONTAINS_POTATOS, DOMAIN),
TRUE_VALUE));
return new LinkedList<Assignment>(assignment);
}
public static Ontology onto() {
HashSet<String> classes = new HashSet<String>();
classes.add(CONTAINS_MEET);
classes.add(CONTAINS_VEGETABLE);
Ontology onthology = new Ontology(classes);
onthology.define(CONTAINS_PORK, CONTAINS_MEET);
onthology.define(CONTAINS_BEEF, CONTAINS_MEET);
onthology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
onthology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
return onthology;
}
private static Assignment build(String varible, String value) {
return new Assignment(new Variable(varible, DOMAIN), value);
}
public static DataPoint beefPotatoDish(int weight) {
DataPoint point = new DataPoint();
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
point.add(build(TASTE, "" + weight));
return point;
}
public static DataPoint beefTomatoDish(int weight) {
DataPoint point = new DataPoint();
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
point.add(build(TASTE, "" + weight));
return point;
}
public static DataPoint porkBeefDish(int weight) {
DataPoint point = new DataPoint();
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
point.add(build(CONTAINS_PORK, TRUE_VALUE));
point.add(build(TASTE, "" + weight));
return point;
}
public static DataPoint porkPotatoDish(int weight) {
DataPoint point = new DataPoint();
point.add(build(CONTAINS_PORK, TRUE_VALUE));
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
point.add(build(TASTE, "" + weight));
return point;
}
public static DataPoint porkTomatoDish(int weight) {
DataPoint point = new DataPoint();
point.add(build(CONTAINS_PORK, TRUE_VALUE));
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
point.add(build(TASTE, "" + weight));
return point;
}
public static DataPoint potatoTomato(int weight) {
DataPoint point = new DataPoint();
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
point.add(build(TASTE, "" + weight));
return point;
}
public static DataPoint normalize(DataPoint point, Ontology onto) {
// resolve onthology
DataPoint normPoint = new DataPoint(point);
for (String key : point.keySet()) {
if (onto.getInheritance().containsKey(key)) {
normPoint
.add(build(onto.getInheritance().get(key), TRUE_VALUE));
}
}
// implement closed world assumption
// everything unknown is false
for (String variable : VARIABLES) {
if (!normPoint.containsKey(variable)) {
normPoint.add(build(variable, FALSE_VALUE));
}
}
return normPoint;
}
public static DataSet examples() {
DataSet data = new DataSet();
Ontology onto = onto();
// user one ratings
data.add(normalize(porkTomatoDish(10), onto));
data.add(normalize(porkPotatoDish(9), onto));
data.add(normalize(beefTomatoDish(3), onto));
data.add(normalize(beefPotatoDish(0), onto));
data.add(normalize(potatoTomato(0), onto));
data.add(normalize(porkBeefDish(0), onto));
// user two ratings
data.add(normalize(porkTomatoDish(10), onto));
data.add(normalize(porkTomatoDish(8), onto));
data.add(normalize(porkPotatoDish(10), onto));
data.add(normalize(beefTomatoDish(0), onto));
data.add(normalize(beefPotatoDish(1), onto));
data.add(normalize(potatoTomato(7), onto));
data.add(normalize(porkBeefDish(10), onto));
// user three ratings
data.add(normalize(porkTomatoDish(10), onto));
data.add(normalize(porkPotatoDish(10), onto));
data.add(normalize(beefTomatoDish(3), onto));
data.add(normalize(beefPotatoDish(3), onto));
data.add(normalize(potatoTomato(3), onto));
data.add(normalize(porkBeefDish(4), onto));
return data;
}
public static ProbabilityDistribution vegi() {
String names[] = { SOMEONE_VEGETARIAN };
ProbabilityTable table = new ProbabilityTable(names);
table.setProbabilityForAssignment("true;", new Probability(0));
table.setProbabilityForAssignment("false;", new Probability(1));
return table;
}
public static ProbabilityDistribution beef() {
String names[] = { CONTAINS_MEET, CONTAINS_BEEF };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution pork() {
String names[] = { CONTAINS_MEET, CONTAINS_PORK };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution meet() {
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEET };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution tomatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution potatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution vegetables() {
String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution taste() {
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS };
ContinousDistribution distribution = new ContinousDistribution(names, 0);
return distribution;
}
public static BayesNet dishNet() {
BayesNet net = new BayesNet(VARIABLES);
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi());
net.setDistribution(new Variable(CONTAINS_MEET, DOMAIN), meet());
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
vegetables());
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
Ontology onthology = onto();
for (String category : onthology.getClasses()) {
net.connect(category, SOMEONE_VEGETARIAN);
}
for (String thing : OBSERVED) {
net.connect(thing, onthology.getInheritance().get(thing));
net.connect(TASTE, thing);
}
for (String category : onthology.getClasses()) {
MaximumLikelihoodEstimation.estimate(examples(), net, category);
for (String thing : onthology.getClasses2thing().get(category)) {
MaximumLikelihoodEstimation.estimate(examples(), net, thing);
}
}
MaximumLikelihoodEstimation.estimate(examples(), net, TASTE);
ContinousDistribution distribturion = (ContinousDistribution) net
.getNodes().get(TASTE);
distribturion.estimate();
return net;
}
}

View File

@@ -32,5 +32,9 @@ public class SurveyDatasetReaderTest {
assertFalse(recipe.getIngredients().contains(TYPE.RED_MEAT));
assertFalse(recipe.getIngredients().contains(TYPE.POULTRY));
assertFalse(recipe.getIngredients().contains(TYPE.SHELLFISH));
for (int rIndex = 0; rIndex < recipeBook.getSize(); rIndex++) {
System.out.println(recipeBook.getRecipe(rIndex).getHead().getTitle());
}
}
}