269 lines
9.1 KiB
Java
269 lines
9.1 KiB
Java
package dkohl.bayes.example.builders;
|
|
|
|
import java.util.HashSet;
|
|
import java.util.LinkedList;
|
|
|
|
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
|
|
|
|
import dkohl.bayes.bayesnet.BayesNet;
|
|
import dkohl.bayes.estimation.MaximumLikelihoodEstimation;
|
|
import dkohl.bayes.probability.Assignment;
|
|
import dkohl.bayes.probability.Probability;
|
|
import dkohl.bayes.probability.Variable;
|
|
import dkohl.bayes.probability.distribution.ContinousDistribution;
|
|
import dkohl.bayes.probability.distribution.ProbabilityDistribution;
|
|
import dkohl.bayes.probability.distribution.ProbabilityTable;
|
|
import dkohl.bayes.statistic.DataPoint;
|
|
import dkohl.bayes.statistic.DataSet;
|
|
import dkohl.onthology.Ontology;
|
|
|
|
public class FoodExampleBuilder {
|
|
|
|
public static final String TASTE = "Taste";
|
|
public static final String SOMEONE_VEGETARIAN = "Vegetarian";
|
|
public static final String CONTAINS_MEAT = "Meat";
|
|
public static final String CONTAINS_VEGETABLE = "Vegetable";
|
|
public static final String CONTAINS_BEEF = TYPE.BEEF.toString();
|
|
public static final String CONTAINS_PORK = TYPE.PORK.toString();
|
|
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
|
|
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
|
|
|
|
public static final String TRUE_VALUE = Boolean.TRUE.toString();
|
|
public static final String FALSE_VALUE = Boolean.FALSE.toString();
|
|
|
|
public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE };
|
|
|
|
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
|
|
"6", "7", "8", "9", "10" };
|
|
|
|
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN,
|
|
CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS,
|
|
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, TASTE };
|
|
|
|
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
|
|
CONTAINS_POTATOS, CONTAINS_TOMATOS, };
|
|
|
|
public static LinkedList<Assignment> completeQueryTasteBeef() {
|
|
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
|
|
assignment.add(new Assignment(new Variable(CONTAINS_BEEF, DOMAIN),
|
|
TRUE_VALUE));
|
|
assignment.add(new Assignment(new Variable(CONTAINS_TOMATOS, DOMAIN),
|
|
TRUE_VALUE));
|
|
return new LinkedList<Assignment>(assignment);
|
|
}
|
|
|
|
public static LinkedList<Assignment> completeQueryTastePork() {
|
|
LinkedList<Assignment> assignment = new LinkedList<Assignment>();
|
|
assignment.add(new Assignment(new Variable(CONTAINS_PORK, DOMAIN),
|
|
TRUE_VALUE));
|
|
assignment.add(new Assignment(new Variable(CONTAINS_POTATOS, DOMAIN),
|
|
TRUE_VALUE));
|
|
return new LinkedList<Assignment>(assignment);
|
|
}
|
|
|
|
public static Ontology createOntology() {
|
|
HashSet<String> classes = new HashSet<String>();
|
|
classes.add(CONTAINS_MEAT);
|
|
classes.add(CONTAINS_VEGETABLE);
|
|
Ontology onthology = new Ontology(classes);
|
|
onthology.define(CONTAINS_PORK, CONTAINS_MEAT);
|
|
onthology.define(CONTAINS_BEEF, CONTAINS_MEAT);
|
|
|
|
onthology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
|
|
onthology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
|
|
return onthology;
|
|
}
|
|
|
|
private static Assignment build(String varible, String value) {
|
|
return new Assignment(new Variable(varible, DOMAIN), value);
|
|
}
|
|
|
|
public static DataPoint beefPotatoDish(int weight) {
|
|
DataPoint point = new DataPoint();
|
|
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
|
|
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
|
|
point.add(build(TASTE, "" + weight));
|
|
return point;
|
|
}
|
|
|
|
public static DataPoint beefTomatoDish(int weight) {
|
|
DataPoint point = new DataPoint();
|
|
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
|
|
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
|
|
point.add(build(TASTE, "" + weight));
|
|
return point;
|
|
}
|
|
|
|
public static DataPoint porkBeefDish(int weight) {
|
|
DataPoint point = new DataPoint();
|
|
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
|
|
point.add(build(CONTAINS_PORK, TRUE_VALUE));
|
|
point.add(build(TASTE, "" + weight));
|
|
return point;
|
|
}
|
|
|
|
public static DataPoint porkPotatoDish(int weight) {
|
|
DataPoint point = new DataPoint();
|
|
point.add(build(CONTAINS_PORK, TRUE_VALUE));
|
|
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
|
|
point.add(build(TASTE, "" + weight));
|
|
return point;
|
|
}
|
|
|
|
public static DataPoint porkTomatoDish(int weight) {
|
|
DataPoint point = new DataPoint();
|
|
point.add(build(CONTAINS_PORK, TRUE_VALUE));
|
|
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
|
|
point.add(build(TASTE, "" + weight));
|
|
return point;
|
|
}
|
|
|
|
public static DataPoint potatoTomato(int weight) {
|
|
DataPoint point = new DataPoint();
|
|
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
|
|
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
|
|
point.add(build(TASTE, "" + weight));
|
|
return point;
|
|
}
|
|
|
|
public static DataPoint normalize(DataPoint point, Ontology onto) {
|
|
// resolve onthology
|
|
DataPoint normPoint = new DataPoint(point);
|
|
for (String key : point.keySet()) {
|
|
if (onto.getInheritance().containsKey(key)) {
|
|
normPoint
|
|
.add(build(onto.getInheritance().get(key), TRUE_VALUE));
|
|
}
|
|
}
|
|
|
|
// implement closed world assumption
|
|
// everything unknown is false
|
|
for (String variable : VARIABLES) {
|
|
if (!normPoint.containsKey(variable)) {
|
|
normPoint.add(build(variable, FALSE_VALUE));
|
|
}
|
|
}
|
|
return normPoint;
|
|
}
|
|
|
|
public static DataSet examples() {
|
|
DataSet data = new DataSet();
|
|
Ontology onto = createOntology();
|
|
// user one ratings
|
|
data.add(normalize(porkTomatoDish(10), onto));
|
|
data.add(normalize(porkPotatoDish(9), onto));
|
|
data.add(normalize(beefTomatoDish(3), onto));
|
|
data.add(normalize(beefPotatoDish(0), onto));
|
|
data.add(normalize(potatoTomato(0), onto));
|
|
data.add(normalize(porkBeefDish(0), onto));
|
|
|
|
// user two ratings
|
|
data.add(normalize(porkTomatoDish(10), onto));
|
|
data.add(normalize(porkTomatoDish(8), onto));
|
|
data.add(normalize(porkPotatoDish(10), onto));
|
|
data.add(normalize(beefTomatoDish(0), onto));
|
|
data.add(normalize(beefPotatoDish(1), onto));
|
|
data.add(normalize(potatoTomato(7), onto));
|
|
data.add(normalize(porkBeefDish(10), onto));
|
|
|
|
// user three ratings
|
|
data.add(normalize(porkTomatoDish(10), onto));
|
|
data.add(normalize(porkPotatoDish(10), onto));
|
|
data.add(normalize(beefTomatoDish(3), onto));
|
|
data.add(normalize(beefPotatoDish(3), onto));
|
|
data.add(normalize(potatoTomato(3), onto));
|
|
data.add(normalize(porkBeefDish(4), onto));
|
|
|
|
return data;
|
|
}
|
|
|
|
public static ProbabilityDistribution vegi() {
|
|
String names[] = { SOMEONE_VEGETARIAN };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
table.setProbabilityForAssignment("true;", new Probability(0));
|
|
table.setProbabilityForAssignment("false;", new Probability(1));
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution beef() {
|
|
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution pork() {
|
|
String names[] = { CONTAINS_MEAT, CONTAINS_PORK };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution meet() {
|
|
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution tomatos() {
|
|
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution potatos() {
|
|
String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution vegetables() {
|
|
String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN };
|
|
ProbabilityTable table = new ProbabilityTable(names);
|
|
return table;
|
|
}
|
|
|
|
public static ProbabilityDistribution taste() {
|
|
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK,
|
|
CONTAINS_POTATOS, CONTAINS_TOMATOS };
|
|
ContinousDistribution distribution = new ContinousDistribution(names, 0);
|
|
return distribution;
|
|
}
|
|
|
|
public static BayesNet dishNet() {
|
|
BayesNet net = new BayesNet(VARIABLES);
|
|
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi());
|
|
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meet());
|
|
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
|
|
vegetables());
|
|
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
|
|
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
|
|
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
|
|
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
|
|
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
|
|
|
|
Ontology ontology = createOntology();
|
|
for (String category : ontology.getClasses()) {
|
|
net.connect(category, SOMEONE_VEGETARIAN);
|
|
}
|
|
|
|
for (String thing : OBSERVED) {
|
|
net.connect(thing, ontology.getInheritance().get(thing));
|
|
net.connect(TASTE, thing);
|
|
}
|
|
|
|
for (String category : ontology.getClasses()) {
|
|
MaximumLikelihoodEstimation.estimate(examples(), net, category);
|
|
for (String thing : ontology.getClasses2thing().get(category)) {
|
|
MaximumLikelihoodEstimation.estimate(examples(), net, thing);
|
|
}
|
|
}
|
|
|
|
MaximumLikelihoodEstimation.estimate(examples(), net, TASTE);
|
|
ContinousDistribution distribturion = (ContinousDistribution) net
|
|
.getNodes().get(TASTE);
|
|
distribturion.estimate();
|
|
|
|
return net;
|
|
}
|
|
|
|
}
|