This commit is contained in:
dkohl
2012-03-12 21:01:30 -04:00
parent 9235ff8d0f
commit 12b4bab59d
3 changed files with 324 additions and 298 deletions

View File

@@ -22,248 +22,256 @@ import dkohl.bayes.statistic.DataSet;
import dkohl.onthology.Ontology; import dkohl.onthology.Ontology;
public class FoodNetBuilder { public class FoodNetBuilder {
public static final String TASTE = "Taste"; public static final String TASTE = "Taste";
public static final String SOMEONE_VEGETARIAN = "vegetarian";
//public static final String SOMEONE_ALLERGIC_NUTS = "allergic-nuts";
public static final String CONTAINS_MEAT = "Meat";
public static final String CONTAINS_VEGETABLE = "Vegetable";
public static final String CONTAINS_BEEF = TYPE.BEEF.toString();
public static final String CONTAINS_PORK = TYPE.PORK.toString();
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
//public static final String CONTAINS_NUTS = "Nuts";
//public static final String CONTAINS_GENERIC_NUTS = TYPE.GENERIC_NUTS.toString();
public static final String TRUE_VALUE = "true";
public static final String FALSE_VALUE = "false";
public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE }; public static final String SOMEONE_VEGETARIAN = "vegetarian";
// public static final String SOMEONE_ALLERGIC_NUTS = "allergic-nuts";
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5", public static final String CONTAINS_MEAT = "Meat";
"6", "7", "8", "9", "10" }; public static final String CONTAINS_VEGETABLE = "Vegetable";
public static final String CONTAINS_BEEF = TYPE.BEEF.toString();
public static final String CONTAINS_PORK = TYPE.PORK.toString();
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
// public static final String CONTAINS_NUTS = "Nuts";
// public static final String CONTAINS_GENERIC_NUTS =
// TYPE.GENERIC_NUTS.toString();
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, /*SOMEONE_ALLERGIC_NUTS,*/ public static final String TRUE_VALUE = "true";
CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS, public static final String FALSE_VALUE = "false";
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, /*CONTAINS_NUTS, CONTAINS_GENERIC_NUTS,*/ TASTE };
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK, public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE };
CONTAINS_POTATOS, CONTAINS_TOMATOS/*, CONTAINS_GENERIC_NUTS*/};
public static Ontology createOntology() { public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
HashSet<String> classes = new HashSet<String>(); "6", "7", "8", "9", "10" };
classes.add(CONTAINS_MEAT);
//classes.add(CONTAINS_NUTS);
classes.add(CONTAINS_VEGETABLE);
Ontology ontology = new Ontology(classes);
ontology.define(CONTAINS_PORK, CONTAINS_MEAT);
ontology.define(CONTAINS_BEEF, CONTAINS_MEAT);
//ontology.define(CONTAINS_GENERIC_NUTS, CONTAINS_NUTS); private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, /*
* SOMEONE_ALLERGIC_NUTS
ontology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE); * ,
ontology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE); */
return ontology; CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS,
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, /*
* CONTAINS_NUTS,
* CONTAINS_GENERIC_NUTS,
*/TASTE };
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS /* , CONTAINS_GENERIC_NUTS */};
public static Ontology createOntology() {
HashSet<String> classes = new HashSet<String>();
classes.add(CONTAINS_MEAT);
// classes.add(CONTAINS_NUTS);
classes.add(CONTAINS_VEGETABLE);
Ontology ontology = new Ontology(classes);
ontology.define(CONTAINS_PORK, CONTAINS_MEAT);
ontology.define(CONTAINS_BEEF, CONTAINS_MEAT);
// ontology.define(CONTAINS_GENERIC_NUTS, CONTAINS_NUTS);
ontology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
ontology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
return ontology;
}
private static Assignment build(String varible, String value) {
return new Assignment(new Variable(varible, DOMAIN), value);
}
public static DataPoint normalize(DataPoint point, Ontology onto) {
// resolve onthology
DataPoint normPoint = new DataPoint(point);
for (String key : point.keySet()) {
if (onto.getInheritance().containsKey(key)) {
normPoint
.add(build(onto.getInheritance().get(key), TRUE_VALUE));
}
} }
private static Assignment build(String varible, String value) { // implement closed world assumption
return new Assignment(new Variable(varible, DOMAIN), value); // everything unknown is false
for (String variable : VARIABLES) {
if (!normPoint.containsKey(variable)) {
normPoint.add(build(variable, FALSE_VALUE));
}
} }
return normPoint;
}
public static DataPoint normalize(DataPoint point, Ontology onto) { public static DataSet getSurveyDataSet(Survey survey,
// resolve onthology RecipeBook recipeBook, int startIndex, int endIndex) {
DataPoint normPoint = new DataPoint(point); DataSet data = FoodExampleBuilder.examples();
for (String key : point.keySet()) { Ontology onto = createOntology();
if (onto.getInheritance().containsKey(key)) {
normPoint
.add(build(onto.getInheritance().get(key), TRUE_VALUE));
}
}
// implement closed world assumption for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
// everything unknown is false Diner diner = survey.getDiner(dinerIndex);
for (String variable : VARIABLES) { for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) {
if (!normPoint.containsKey(variable)) { data.add(normalize(
normPoint.add(build(variable, FALSE_VALUE)); createDataPoint(recipeBook, survey.getDish(dishIndex),
} diner.getRating(dishIndex)), onto));
} }
return normPoint;
} }
return data;
}
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) { public static DataPoint createDataPoint(RecipeBook recipeBook,
DataSet data = FoodExampleBuilder.examples(); String recipeName, int weight) {
Ontology onto = createOntology(); DataPoint point = new DataPoint();
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
Diner diner = survey.getDiner(dinerIndex);
for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) {
data.add(normalize(createDataPoint(recipeBook, survey.getDish(dishIndex), diner.getRating(dishIndex)),onto));
}
}
return data;
}
public static DataPoint createDataPoint(RecipeBook recipeBook, String recipeName, int weight) {
DataPoint point = new DataPoint();
Recipe recipe = recipeBook.getRecipe(recipeName);
Ingredients ingredients = recipe.getIngredients();
if (ingredients.contains(TYPE.BEEF)) {
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
}
if (ingredients.contains(TYPE.PORK)) {
point.add(build(CONTAINS_PORK, TRUE_VALUE));
}
if (ingredients.contains(TYPE.POTATO)) {
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
}
if (ingredients.contains(TYPE.TOMATO)) {
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
}
/*
if (ingredients.contains(TYPE.GENERIC_NUTS)) {
point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE));
}*/
point.add(build(TASTE, "" + weight));
return point;
}
public static ProbabilityDistribution vegi(Survey survey) { Recipe recipe = recipeBook.getRecipe(recipeName);
String names[] = { SOMEONE_VEGETARIAN }; Ingredients ingredients = recipe.getIngredients();
ProbabilityTable table = new ProbabilityTable(names);
//if (survey.isDiner("vegetarian")) {
// table.setProbabilityForAssignment("true;", new Probability(1));
// table.setProbabilityForAssignment("false;", new Probability(0));
//} else {
table.setProbabilityForAssignment("true;", new Probability(0));
table.setProbabilityForAssignment("false;", new Probability(1));
//}
return table;
}
if (ingredients.contains(TYPE.BEEF)) {
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
}
if (ingredients.contains(TYPE.PORK)) {
point.add(build(CONTAINS_PORK, TRUE_VALUE));
}
if (ingredients.contains(TYPE.POTATO)) {
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
}
if (ingredients.contains(TYPE.TOMATO)) {
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
}
/* /*
public static ProbabilityDistribution allergicNuts(Survey survey) { * if (ingredients.contains(TYPE.GENERIC_NUTS)) {
String names[] = { SOMEONE_ALLERGIC_NUTS }; * point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE)); }
ProbabilityTable table = new ProbabilityTable(names); */
if (survey.isDiner("allergic-nuts")) {
table.setProbabilityForAssignment("true;", new Probability(1)); point.add(build(TASTE, "" + weight));
table.setProbabilityForAssignment("false;", new Probability(0));
} else { return point;
table.setProbabilityForAssignment("true;", new Probability(0)); }
table.setProbabilityForAssignment("false;", new Probability(1));
} public static ProbabilityDistribution vegi(Survey survey) {
return table; String names[] = {CONTAINS_MEAT, SOMEONE_VEGETARIAN};
}*/ ProbabilityTable table = new ProbabilityTable(names);
table.setProbabilityForAssignment("true;true;", new Probability(0));
public static ProbabilityDistribution beef() { table.setProbabilityForAssignment("false;true;", new Probability(1));
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF }; table.setProbabilityForAssignment("true;false;", new Probability(1));
ProbabilityTable table = new ProbabilityTable(names); table.setProbabilityForAssignment("false;false;", new Probability(1));
return table; return table;
}
/*
* public static ProbabilityDistribution allergicNuts(Survey survey) {
* String names[] = { SOMEONE_ALLERGIC_NUTS }; ProbabilityTable table = new
* ProbabilityTable(names); if (survey.isDiner("allergic-nuts")) {
* table.setProbabilityForAssignment("true;", new Probability(1));
* table.setProbabilityForAssignment("false;", new Probability(0)); } else {
* table.setProbabilityForAssignment("true;", new Probability(0));
* table.setProbabilityForAssignment("false;", new Probability(1)); } return
* table; }
*/
public static ProbabilityDistribution beef() {
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution pork() {
String names[] = { CONTAINS_MEAT, CONTAINS_PORK };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution meat() {
String names[] = { CONTAINS_MEAT };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
/*
* public static ProbabilityDistribution genericNuts() { String names[] = {
* CONTAINS_NUTS, CONTAINS_GENERIC_NUTS }; ProbabilityTable table = new
* ProbabilityTable(names); return table; }
*
* public static ProbabilityDistribution nuts() { String names[] = {
* SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS }; ProbabilityTable table = new
* ProbabilityTable(names); return table; }
*/
public static ProbabilityDistribution tomatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution potatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution vegetables() {
String names[] = { CONTAINS_VEGETABLE };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution taste() {
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS /* , CONTAINS_GENERIC_NUTS */};
ContinousDistribution distribution = new ContinousDistribution(names, 0);
return distribution;
}
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook,
int startIndex, int endIndex) {
BayesNet net = new BayesNet(VARIABLES);
// net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN),
// vegi(survey));
// net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN),
// allergicNuts(survey));
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat());
// net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts());
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN),
vegi(survey));
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
vegetables());
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
// net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN),
// genericNuts());
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
Ontology ontology = createOntology();
// for (String category : ontology.getClasses()) {
// net.connect(category, SOMEONE_VEGETARIAN);
// /*net.connect(category, SOMEONE_ALLERGIC_NUTS);*/
// }
net.connect(SOMEONE_VEGETARIAN, CONTAINS_MEAT);
for (String thing : OBSERVED) {
net.connect(thing, ontology.getInheritance().get(thing));
net.connect(TASTE, thing);
} }
public static ProbabilityDistribution pork() { DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex,
String names[] = { CONTAINS_MEAT, CONTAINS_PORK }; endIndex);
ProbabilityTable table = new ProbabilityTable(names);
return table; for (String category : ontology.getClasses()) {
} MaximumLikelihoodEstimation.estimate(dataSet, net, category);
for (String thing : ontology.getClasses2thing().get(category)) {
public static ProbabilityDistribution meat() { MaximumLikelihoodEstimation.estimate(dataSet, net, thing);
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT }; }
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
/*
public static ProbabilityDistribution genericNuts() {
String names[] = { CONTAINS_NUTS, CONTAINS_GENERIC_NUTS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution nuts() {
String names[] = { SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}*/
public static ProbabilityDistribution tomatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
} }
public static ProbabilityDistribution potatos() { MaximumLikelihoodEstimation.estimate(dataSet, net, TASTE);
String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS }; ContinousDistribution distribturion = (ContinousDistribution) net
ProbabilityTable table = new ProbabilityTable(names); .getNodes().get(TASTE);
return table;
}
public static ProbabilityDistribution vegetables() { distribturion.estimate();
String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution taste() { return net;
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK, }
CONTAINS_POTATOS, CONTAINS_TOMATOS/*, CONTAINS_GENERIC_NUTS*/ };
ContinousDistribution distribution = new ContinousDistribution(names, 0);
return distribution;
}
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) {
BayesNet net = new BayesNet(VARIABLES);
// net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey));
//net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN), allergicNuts(survey));
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat());
//net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts());
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey));
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
vegetables());
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
//net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN), genericNuts());
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
Ontology ontology = createOntology();
// for (String category : ontology.getClasses()) {
// net.connect(category, SOMEONE_VEGETARIAN);
// /*net.connect(category, SOMEONE_ALLERGIC_NUTS);*/
// }
net.connect(SOMEONE_VEGETARIAN, CONTAINS_MEAT);
for (String thing : OBSERVED) {
net.connect(thing, ontology.getInheritance().get(thing));
net.connect(TASTE, thing);
}
DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex, endIndex);
for (String category : ontology.getClasses()) {
MaximumLikelihoodEstimation.estimate(dataSet, net, category);
for (String thing : ontology.getClasses2thing().get(category)) {
MaximumLikelihoodEstimation.estimate(dataSet, net, thing);
}
}
MaximumLikelihoodEstimation.estimate(dataSet, net, TASTE);
ContinousDistribution distribturion = (ContinousDistribution) net
.getNodes().get(TASTE);
distribturion.estimate();
return net;
}
} }

View File

@@ -98,7 +98,10 @@ public class EnumerateAll {
LinkedList<Assignment> temp = new LinkedList<Assignment>(); LinkedList<Assignment> temp = new LinkedList<Assignment>();
temp.addAll(assignments); temp.addAll(assignments);
temp.add(new Assignment(variable, value)); temp.add(new Assignment(variable, value));
if(net.getNodes().get(variable.getName()).eval(temp) == null) {
System.out.println(variable.getName());
}
// then evaluate this variable // then evaluate this variable
double val = net.getNodes().get(variable.getName()).eval(temp) double val = net.getNodes().get(variable.getName()).eval(temp)
.getProbability(); .getProbability();

View File

@@ -13,87 +13,102 @@ import dkohl.bayes.probability.Probability;
*/ */
public class DataSet extends Vector<DataPoint> { public class DataSet extends Vector<DataPoint> {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
public LinkedList<LinkedList<Assignment>> getAssignmentMatchesForQuery( public LinkedList<LinkedList<Assignment>> getAssignmentMatchesForQuery(
LinkedList<Assignment> given) { LinkedList<Assignment> given) {
LinkedList<LinkedList<Assignment>> assignments = new LinkedList<LinkedList<Assignment>>(); LinkedList<LinkedList<Assignment>> assignments = new LinkedList<LinkedList<Assignment>>();
for (DataPoint point : this) { for (DataPoint point : this) {
boolean insert = true; boolean insert = true;
for (Assignment assignment : given) { for (Assignment assignment : given) {
if (!match(point, assignment)) { if (!match(point, assignment)) {
insert = false; insert = false;
}
}
if (insert) {
assignments.add(new LinkedList<Assignment>(point.values()));
}
} }
return assignments; }
if (insert) {
assignments.add(new LinkedList<Assignment>(point.values()));
}
} }
return assignments;
}
/** /**
* Is the assignment equal to a data point / observation ? * Is the assignment equal to a data point / observation ?
* *
* @param point * @param point
* the point / observation * the point / observation
* @param query * @param query
* the assignment * the assignment
* @return * @return
*/ */
private boolean match(String queryName, DataPoint point, private boolean match(String queryName, DataPoint point,
LinkedList<Assignment> query) { LinkedList<Assignment> query) {
boolean queryFound = false; boolean queryFound = false;
for (Assignment assignment : query) { for (Assignment assignment : query) {
if (!match(point, assignment)) { if (!match(point, assignment)) {
return false;
}
if (point.containsKey(queryName)) {
queryFound = true;
}
}
if (!queryFound) {
return false;
}
return true;
}
private boolean match(DataPoint point, Assignment query) {
String name = query.getVariable().getName();
String value = query.getValue();
if (point.containsKey(name)) {
if (point.get(name).getValue().equals(value)) {
return true;
}
}
return false; return false;
}
if (point.containsKey(queryName)) {
queryFound = true;
}
} }
if (!queryFound) {
/** return false;
* Estimating probability for: P(Query | given_1 .... given_N) = #(Query |
* given_1 .... given_N) / #(given_1 .... given_N)
*
* @param query
* @param given
* @return
*/
public Probability prob(Assignment query, LinkedList<Assignment> given) {
int matches = 0;
int num_query_given = 0;
for (DataPoint point : this) {
if (match(query.getVariable().getName(), point, given)) {
matches += 1;
if (match(point, query)) {
num_query_given += 1; // point.getWeight();
}
}
}
if (matches == 0) {
return new Probability(0);
}
return new Probability(num_query_given / ((double) matches));
} }
return true;
}
private boolean match(DataPoint point, Assignment query) {
String name = query.getVariable().getName();
String value = query.getValue();
if (point.containsKey(name)) {
if (point.get(name).getValue().equals(value)) {
return true;
}
}
return false;
}
/**
* Estimating probability for: P(Query | given_1 .... given_N) = #(Query |
* given_1 .... given_N) / #(given_1 .... given_N)
*
* @param query
* @param given
* @return
*/
public Probability prob(Assignment query, LinkedList<Assignment> given) {
int matches = 0;
int num_query_given = 0;
if (given.size() == 0) {
for (DataPoint point : this) {
if (point.containsKey(query.getVariable().getName())) {
matches += 1;
if (point.get(query.getVariable().getName()).getValue()
.equals(query.getValue())) {
num_query_given += 1;
}
}
}
return new Probability(num_query_given / ((double) matches));
} else {
for (DataPoint point : this) {
if (match(query.getVariable().getName(), point, given)) {
matches += 1;
if (match(point, query)) {
num_query_given += 1; // point.getWeight();
}
}
}
if (matches == 0) {
return new Probability(0);
}
return new Probability(num_query_given / ((double) matches));
}
}
} }