This commit is contained in:
dkohl
2012-03-12 21:01:30 -04:00
parent 9235ff8d0f
commit 12b4bab59d
3 changed files with 324 additions and 298 deletions

View File

@@ -34,7 +34,8 @@ public class FoodNetBuilder {
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString(); public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString(); public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
// public static final String CONTAINS_NUTS = "Nuts"; // public static final String CONTAINS_NUTS = "Nuts";
//public static final String CONTAINS_GENERIC_NUTS = TYPE.GENERIC_NUTS.toString(); // public static final String CONTAINS_GENERIC_NUTS =
// TYPE.GENERIC_NUTS.toString();
public static final String TRUE_VALUE = "true"; public static final String TRUE_VALUE = "true";
public static final String FALSE_VALUE = "false"; public static final String FALSE_VALUE = "false";
@@ -44,9 +45,15 @@ public class FoodNetBuilder {
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5", public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
"6", "7", "8", "9", "10" }; "6", "7", "8", "9", "10" };
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, /*SOMEONE_ALLERGIC_NUTS,*/ private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, /*
* SOMEONE_ALLERGIC_NUTS
* ,
*/
CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS, CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS,
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, /*CONTAINS_NUTS, CONTAINS_GENERIC_NUTS,*/ TASTE }; CONTAINS_TOMATOS, CONTAINS_VEGETABLE, /*
* CONTAINS_NUTS,
* CONTAINS_GENERIC_NUTS,
*/TASTE };
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK, private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS /* , CONTAINS_GENERIC_NUTS */}; CONTAINS_POTATOS, CONTAINS_TOMATOS /* , CONTAINS_GENERIC_NUTS */};
@@ -94,20 +101,24 @@ public class FoodNetBuilder {
return normPoint; return normPoint;
} }
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) { public static DataSet getSurveyDataSet(Survey survey,
RecipeBook recipeBook, int startIndex, int endIndex) {
DataSet data = FoodExampleBuilder.examples(); DataSet data = FoodExampleBuilder.examples();
Ontology onto = createOntology(); Ontology onto = createOntology();
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) { for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
Diner diner = survey.getDiner(dinerIndex); Diner diner = survey.getDiner(dinerIndex);
for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) { for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) {
data.add(normalize(createDataPoint(recipeBook, survey.getDish(dishIndex), diner.getRating(dishIndex)),onto)); data.add(normalize(
createDataPoint(recipeBook, survey.getDish(dishIndex),
diner.getRating(dishIndex)), onto));
} }
} }
return data; return data;
} }
public static DataPoint createDataPoint(RecipeBook recipeBook, String recipeName, int weight) { public static DataPoint createDataPoint(RecipeBook recipeBook,
String recipeName, int weight) {
DataPoint point = new DataPoint(); DataPoint point = new DataPoint();
Recipe recipe = recipeBook.getRecipe(recipeName); Recipe recipe = recipeBook.getRecipe(recipeName);
@@ -126,9 +137,9 @@ public class FoodNetBuilder {
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE)); point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
} }
/* /*
if (ingredients.contains(TYPE.GENERIC_NUTS)) { * if (ingredients.contains(TYPE.GENERIC_NUTS)) {
point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE)); * point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE)); }
}*/ */
point.add(build(TASTE, "" + weight)); point.add(build(TASTE, "" + weight));
@@ -136,31 +147,25 @@ public class FoodNetBuilder {
} }
public static ProbabilityDistribution vegi(Survey survey) { public static ProbabilityDistribution vegi(Survey survey) {
String names[] = { SOMEONE_VEGETARIAN }; String names[] = {CONTAINS_MEAT, SOMEONE_VEGETARIAN};
ProbabilityTable table = new ProbabilityTable(names); ProbabilityTable table = new ProbabilityTable(names);
//if (survey.isDiner("vegetarian")) { table.setProbabilityForAssignment("true;true;", new Probability(0));
// table.setProbabilityForAssignment("true;", new Probability(1)); table.setProbabilityForAssignment("false;true;", new Probability(1));
// table.setProbabilityForAssignment("false;", new Probability(0)); table.setProbabilityForAssignment("true;false;", new Probability(1));
//} else { table.setProbabilityForAssignment("false;false;", new Probability(1));
table.setProbabilityForAssignment("true;", new Probability(0));
table.setProbabilityForAssignment("false;", new Probability(1));
//}
return table; return table;
} }
/* /*
public static ProbabilityDistribution allergicNuts(Survey survey) { * public static ProbabilityDistribution allergicNuts(Survey survey) {
String names[] = { SOMEONE_ALLERGIC_NUTS }; * String names[] = { SOMEONE_ALLERGIC_NUTS }; ProbabilityTable table = new
ProbabilityTable table = new ProbabilityTable(names); * ProbabilityTable(names); if (survey.isDiner("allergic-nuts")) {
if (survey.isDiner("allergic-nuts")) { * table.setProbabilityForAssignment("true;", new Probability(1));
table.setProbabilityForAssignment("true;", new Probability(1)); * table.setProbabilityForAssignment("false;", new Probability(0)); } else {
table.setProbabilityForAssignment("false;", new Probability(0)); * table.setProbabilityForAssignment("true;", new Probability(0));
} else { * table.setProbabilityForAssignment("false;", new Probability(1)); } return
table.setProbabilityForAssignment("true;", new Probability(0)); * table; }
table.setProbabilityForAssignment("false;", new Probability(1)); */
}
return table;
}*/
public static ProbabilityDistribution beef() { public static ProbabilityDistribution beef() {
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF }; String names[] = { CONTAINS_MEAT, CONTAINS_BEEF };
@@ -175,23 +180,20 @@ public class FoodNetBuilder {
} }
public static ProbabilityDistribution meat() { public static ProbabilityDistribution meat() {
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT }; String names[] = { CONTAINS_MEAT };
ProbabilityTable table = new ProbabilityTable(names); ProbabilityTable table = new ProbabilityTable(names);
return table; return table;
} }
/* /*
public static ProbabilityDistribution genericNuts() { * public static ProbabilityDistribution genericNuts() { String names[] = {
String names[] = { CONTAINS_NUTS, CONTAINS_GENERIC_NUTS }; * CONTAINS_NUTS, CONTAINS_GENERIC_NUTS }; ProbabilityTable table = new
ProbabilityTable table = new ProbabilityTable(names); * ProbabilityTable(names); return table; }
return table; *
} * public static ProbabilityDistribution nuts() { String names[] = {
* SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS }; ProbabilityTable table = new
public static ProbabilityDistribution nuts() { * ProbabilityTable(names); return table; }
String names[] = { SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS }; */
ProbabilityTable table = new ProbabilityTable(names);
return table;
}*/
public static ProbabilityDistribution tomatos() { public static ProbabilityDistribution tomatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS }; String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
@@ -206,7 +208,7 @@ public class FoodNetBuilder {
} }
public static ProbabilityDistribution vegetables() { public static ProbabilityDistribution vegetables() {
String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN }; String names[] = { CONTAINS_VEGETABLE };
ProbabilityTable table = new ProbabilityTable(names); ProbabilityTable table = new ProbabilityTable(names);
return table; return table;
} }
@@ -218,22 +220,27 @@ public class FoodNetBuilder {
return distribution; return distribution;
} }
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) { public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook,
int startIndex, int endIndex) {
BayesNet net = new BayesNet(VARIABLES); BayesNet net = new BayesNet(VARIABLES);
// net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey)); // net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN),
//net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN), allergicNuts(survey)); // vegi(survey));
// net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN),
// allergicNuts(survey));
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat()); net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat());
// net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts()); // net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts());
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey)); net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN),
vegi(survey));
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN), net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
vegetables()); vegetables());
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef()); net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork()); net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos()); net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos()); net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
//net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN), genericNuts()); // net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN),
// genericNuts());
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste()); net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
@@ -249,7 +256,8 @@ public class FoodNetBuilder {
net.connect(TASTE, thing); net.connect(TASTE, thing);
} }
DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex, endIndex); DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex,
endIndex);
for (String category : ontology.getClasses()) { for (String category : ontology.getClasses()) {
MaximumLikelihoodEstimation.estimate(dataSet, net, category); MaximumLikelihoodEstimation.estimate(dataSet, net, category);

View File

@@ -99,6 +99,9 @@ public class EnumerateAll {
temp.addAll(assignments); temp.addAll(assignments);
temp.add(new Assignment(variable, value)); temp.add(new Assignment(variable, value));
if(net.getNodes().get(variable.getName()).eval(temp) == null) {
System.out.println(variable.getName());
}
// then evaluate this variable // then evaluate this variable
double val = net.getNodes().get(variable.getName()).eval(temp) double val = net.getNodes().get(variable.getName()).eval(temp)
.getProbability(); .getProbability();

View File

@@ -80,6 +80,20 @@ public class DataSet extends Vector<DataPoint> {
public Probability prob(Assignment query, LinkedList<Assignment> given) { public Probability prob(Assignment query, LinkedList<Assignment> given) {
int matches = 0; int matches = 0;
int num_query_given = 0; int num_query_given = 0;
if (given.size() == 0) {
for (DataPoint point : this) {
if (point.containsKey(query.getVariable().getName())) {
matches += 1;
if (point.get(query.getVariable().getName()).getValue()
.equals(query.getValue())) {
num_query_given += 1;
}
}
}
return new Probability(num_query_given / ((double) matches));
} else {
for (DataPoint point : this) { for (DataPoint point : this) {
if (match(query.getVariable().getName(), point, given)) { if (match(query.getVariable().getName(), point, given)) {
matches += 1; matches += 1;
@@ -95,5 +109,6 @@ public class DataSet extends Vector<DataPoint> {
return new Probability(num_query_given / ((double) matches)); return new Probability(num_query_given / ((double) matches));
} }
}
} }