Work in progress.
This commit is contained in:
@@ -23,14 +23,19 @@ import dkohl.onthology.Ontology;
|
||||
|
||||
public class FoodNetBuilder {
|
||||
public static final String TASTE = "Taste";
|
||||
public static final String SOMEONE_VEGETARIAN = "Vegetarian";
|
||||
|
||||
public static final String SOMEONE_VEGETARIAN = "vegetarian";
|
||||
public static final String SOMEONE_ALLERGIC_NUTS = "allergic-nuts";
|
||||
|
||||
public static final String CONTAINS_MEAT = "Meat";
|
||||
public static final String CONTAINS_VEGETABLE = "Vegetable";
|
||||
public static final String CONTAINS_BEEF = TYPE.BEEF.toString();
|
||||
public static final String CONTAINS_PORK = TYPE.PORK.toString();
|
||||
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
|
||||
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
|
||||
|
||||
public static final String CONTAINS_NUTS = "Nuts";
|
||||
public static final String CONTAINS_GENERIC_NUTS = TYPE.GENERIC_NUTS.toString();
|
||||
|
||||
public static final String TRUE_VALUE = "true";
|
||||
public static final String FALSE_VALUE = "false";
|
||||
|
||||
@@ -39,27 +44,30 @@ public class FoodNetBuilder {
|
||||
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
|
||||
"6", "7", "8", "9", "10" };
|
||||
|
||||
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN,
|
||||
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, SOMEONE_ALLERGIC_NUTS,
|
||||
CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS,
|
||||
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, TASTE };
|
||||
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, CONTAINS_NUTS, CONTAINS_GENERIC_NUTS, TASTE };
|
||||
|
||||
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
|
||||
CONTAINS_POTATOS, CONTAINS_TOMATOS, };
|
||||
CONTAINS_POTATOS, CONTAINS_TOMATOS, CONTAINS_GENERIC_NUTS};
|
||||
|
||||
public static Ontology createOntology() {
|
||||
HashSet<String> classes = new HashSet<String>();
|
||||
|
||||
classes.add(CONTAINS_MEAT);
|
||||
classes.add(CONTAINS_NUTS);
|
||||
classes.add(CONTAINS_VEGETABLE);
|
||||
|
||||
Ontology onthology = new Ontology(classes);
|
||||
Ontology ontology = new Ontology(classes);
|
||||
|
||||
onthology.define(CONTAINS_PORK, CONTAINS_MEAT);
|
||||
onthology.define(CONTAINS_BEEF, CONTAINS_MEAT);
|
||||
ontology.define(CONTAINS_PORK, CONTAINS_MEAT);
|
||||
ontology.define(CONTAINS_BEEF, CONTAINS_MEAT);
|
||||
|
||||
onthology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
|
||||
onthology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
|
||||
return onthology;
|
||||
ontology.define(CONTAINS_GENERIC_NUTS, CONTAINS_NUTS);
|
||||
|
||||
ontology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
|
||||
ontology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
|
||||
return ontology;
|
||||
}
|
||||
|
||||
private static Assignment build(String varible, String value) {
|
||||
@@ -86,14 +94,13 @@ public class FoodNetBuilder {
|
||||
return normPoint;
|
||||
}
|
||||
|
||||
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook) {
|
||||
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) {
|
||||
DataSet data = FoodExampleBuilder.examples();
|
||||
Ontology onto = createOntology();
|
||||
|
||||
int nDishes = survey.getDishCount();
|
||||
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
|
||||
Diner diner = survey.getDiner(dinerIndex);
|
||||
for (int dishIndex = 0; dishIndex < nDishes; dishIndex++) {
|
||||
for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) {
|
||||
data.add(normalize(createDataPoint(recipeBook, survey.getDish(dishIndex), diner.getRating(dishIndex)),onto));
|
||||
}
|
||||
}
|
||||
@@ -118,20 +125,41 @@ public class FoodNetBuilder {
|
||||
if (ingredients.contains(TYPE.TOMATO)) {
|
||||
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
|
||||
}
|
||||
if (ingredients.contains(TYPE.GENERIC_NUTS)) {
|
||||
point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE));
|
||||
}
|
||||
|
||||
point.add(build(TASTE, "" + weight));
|
||||
|
||||
return point;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution vegi() {
|
||||
public static ProbabilityDistribution vegi(Survey survey) {
|
||||
String names[] = { SOMEONE_VEGETARIAN };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
table.setProbabilityForAssignment("true;", new Probability(0));
|
||||
table.setProbabilityForAssignment("false;", new Probability(1));
|
||||
if (survey.isDiner("vegetarian")) {
|
||||
table.setProbabilityForAssignment("true;", new Probability(1));
|
||||
table.setProbabilityForAssignment("false;", new Probability(0));
|
||||
} else {
|
||||
table.setProbabilityForAssignment("true;", new Probability(0));
|
||||
table.setProbabilityForAssignment("false;", new Probability(1));
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution allergicNuts(Survey survey) {
|
||||
String names[] = { SOMEONE_ALLERGIC_NUTS };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
if (survey.isDiner("allergic-nuts")) {
|
||||
table.setProbabilityForAssignment("true;", new Probability(1));
|
||||
table.setProbabilityForAssignment("false;", new Probability(0));
|
||||
} else {
|
||||
table.setProbabilityForAssignment("true;", new Probability(0));
|
||||
table.setProbabilityForAssignment("false;", new Probability(1));
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution beef() {
|
||||
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
@@ -143,13 +171,25 @@ public class FoodNetBuilder {
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution meet() {
|
||||
|
||||
public static ProbabilityDistribution meat() {
|
||||
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
|
||||
public static ProbabilityDistribution genericNuts() {
|
||||
String names[] = { CONTAINS_NUTS, CONTAINS_GENERIC_NUTS };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution nuts() {
|
||||
String names[] = { SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution tomatos() {
|
||||
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
@@ -170,27 +210,34 @@ public class FoodNetBuilder {
|
||||
|
||||
public static ProbabilityDistribution taste() {
|
||||
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK,
|
||||
CONTAINS_POTATOS, CONTAINS_TOMATOS };
|
||||
CONTAINS_POTATOS, CONTAINS_TOMATOS, CONTAINS_GENERIC_NUTS };
|
||||
ContinousDistribution distribution = new ContinousDistribution(names, 0);
|
||||
return distribution;
|
||||
}
|
||||
|
||||
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook) {
|
||||
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) {
|
||||
BayesNet net = new BayesNet(VARIABLES);
|
||||
|
||||
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi());
|
||||
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meet());
|
||||
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey));
|
||||
net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN), allergicNuts(survey));
|
||||
|
||||
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat());
|
||||
net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts());
|
||||
|
||||
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
|
||||
vegetables());
|
||||
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
|
||||
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
|
||||
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
|
||||
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
|
||||
net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN), genericNuts());
|
||||
|
||||
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
|
||||
|
||||
Ontology ontology = createOntology();
|
||||
for (String category : ontology.getClasses()) {
|
||||
net.connect(category, SOMEONE_VEGETARIAN);
|
||||
net.connect(category, SOMEONE_ALLERGIC_NUTS);
|
||||
}
|
||||
|
||||
for (String thing : OBSERVED) {
|
||||
@@ -198,7 +245,7 @@ public class FoodNetBuilder {
|
||||
net.connect(TASTE, thing);
|
||||
}
|
||||
|
||||
DataSet dataSet = getSurveyDataSet(survey, recipeBook);
|
||||
DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex, endIndex);
|
||||
|
||||
for (String category : ontology.getClasses()) {
|
||||
MaximumLikelihoodEstimation.estimate(dataSet, net, category);
|
||||
|
||||
@@ -4,32 +4,27 @@ import java.util.LinkedList;
|
||||
|
||||
public class RootMeanSquareError {
|
||||
|
||||
private LinkedList<Double> expected;
|
||||
private LinkedList<Double> groundTruth;
|
||||
|
||||
public RootMeanSquareError() {
|
||||
expected = new LinkedList<Double>();
|
||||
groundTruth = new LinkedList<Double>();
|
||||
}
|
||||
|
||||
public void push(double predicted, double actual) {
|
||||
expected.add(predicted);
|
||||
groundTruth.add(actual);
|
||||
}
|
||||
|
||||
public double error() {
|
||||
double mean = 0.0;
|
||||
for(Double val : groundTruth) {
|
||||
mean += val;
|
||||
private LinkedList<Double> expected;
|
||||
private LinkedList<Double> groundTruth;
|
||||
|
||||
public RootMeanSquareError() {
|
||||
expected = new LinkedList<Double>();
|
||||
groundTruth = new LinkedList<Double>();
|
||||
}
|
||||
mean /= groundTruth.size();
|
||||
|
||||
double err = 0.0;
|
||||
for(Double val : expected) {
|
||||
err += Math.pow(mean - val, 2);
|
||||
|
||||
public void push(double predicted, double actual) {
|
||||
expected.add(predicted);
|
||||
groundTruth.add(actual);
|
||||
}
|
||||
return Math.sqrt(err);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public double error() {
|
||||
double mean = 0.0;
|
||||
for (int i = 0; i < expected.size(); i++){
|
||||
mean += Math.pow(expected.get(i) - groundTruth.get(i), 2);
|
||||
}
|
||||
mean /= groundTruth.size();
|
||||
|
||||
return Math.sqrt(mean);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,24 +6,22 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
|
||||
|
||||
import org.apache.commons.math3.stat.regression.SimpleRegression;
|
||||
|
||||
import dkohl.bayes.bayesnet.BayesNet;
|
||||
import dkohl.bayes.builders.FoodNetBuilder;
|
||||
import dkohl.bayes.inference.EnumerateAll;
|
||||
import dkohl.bayes.probability.Assignment;
|
||||
import dkohl.bayes.probability.ProbabilityAssignment;
|
||||
import dkohl.bayes.probability.Variable;
|
||||
import dkohl.util.RootMeanSquareError;
|
||||
|
||||
public class BayesChef {
|
||||
|
||||
public static void main(String... args) {
|
||||
System.out.println("Reading recipe book.");
|
||||
RecipeBook recipeBook = RecipeBookReader.readRecipeBook(new File("data/short_recipebook.xml"));
|
||||
RecipeBook recipeBook = RecipeBookReader.readRecipeBook(new File("data/long_recipebook.xml"));
|
||||
|
||||
System.out.println("Reading survey data.");
|
||||
Survey survey = SurveyReader.readSurvey(new File("data/short_survey.xml"));
|
||||
Survey survey = SurveyReader.readSurvey(new File("data/long_survey.xml"));
|
||||
|
||||
new BayesChef().reportBestMeal(recipeBook,survey);
|
||||
}
|
||||
@@ -36,29 +34,39 @@ public class BayesChef {
|
||||
System.out.println("Read data for " + numDiners + " diner(s).");
|
||||
|
||||
System.out.println("Creating Bayes net for survey dataset...");
|
||||
BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook);
|
||||
|
||||
System.out.println("Querying Bayes net for individual flavor preferences: ");
|
||||
double[][] flavorPrefs = new double[4][];
|
||||
int numSurveyDishes = survey.getDishCount();
|
||||
int bayesNetIndexStart = 0;
|
||||
int bayesNetIndexEnd = numSurveyDishes / 2;
|
||||
|
||||
flavorPrefs[0] = new double[] {0.0,printPreference(net, TYPE.PORK)};
|
||||
flavorPrefs[1] = new double[] {1.0,printPreference(net, TYPE.BEEF)};
|
||||
flavorPrefs[2] = new double[] {2.0,printPreference(net, TYPE.POTATO)};
|
||||
flavorPrefs[3] = new double[] {3.0,printPreference(net, TYPE.TOMATO)};
|
||||
int comparisonIndexStart = bayesNetIndexEnd;
|
||||
int comparisonIndexEnd = numSurveyDishes;
|
||||
|
||||
SimpleRegression simpleRegression = new SimpleRegression();
|
||||
simpleRegression.addData(flavorPrefs);
|
||||
//TODO change this to include min/max recipe indices for building the net
|
||||
BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook, bayesNetIndexStart, bayesNetIndexEnd);
|
||||
|
||||
System.out.println("Individual flavor pref MSE: " + simpleRegression.getMeanSquareError());
|
||||
simpleRegression.clear();
|
||||
System.out.println("Querying Bayes net for predicted flavor preferences: ");
|
||||
//TODO change this to query actual flavor combos from latter half of survey
|
||||
double[] predictedRating = new double[numSurveyDishes - bayesNetIndexEnd];
|
||||
|
||||
System.out.println("Querying Bayes net for recipe flavor preferences: ");
|
||||
flavorPrefs = new double[recipeBook.getSize()][];
|
||||
|
||||
for (int i = 0; i < recipeBook.getSize(); i++) {
|
||||
simpleRegression.addData(i,printPreference(net, recipeBook.getRecipe(i)));
|
||||
for (int i = comparisonIndexStart; i < comparisonIndexEnd; i++) {
|
||||
predictedRating[i - comparisonIndexStart] = printPreference(net, recipeBook.getRecipe(survey.getDish(i)));
|
||||
}
|
||||
System.out.println("Recipe flavor pref MSE: " + simpleRegression.getMeanSquareError());
|
||||
|
||||
System.out.println("Querying survey dataset for actual recipe flavor preferences: ");
|
||||
double[] actualRating = new double[predictedRating.length];
|
||||
|
||||
for (int i = comparisonIndexStart; i < comparisonIndexEnd; i++) {
|
||||
actualRating[i - comparisonIndexStart] = survey.getAverageRating(i);
|
||||
System.out.println("Actual average rating for " + survey.getDish(i) + ": " + actualRating[i - comparisonIndexStart]);
|
||||
}
|
||||
|
||||
RootMeanSquareError rmse = new RootMeanSquareError();
|
||||
for (int i = 0; i < actualRating.length; i++) {
|
||||
rmse.push(predictedRating[i], actualRating[i]);
|
||||
}
|
||||
|
||||
System.out.println("Root Mean Squared Error (predicted vs. actual): " + rmse.error());
|
||||
}
|
||||
|
||||
int printPreference(BayesNet net, Recipe recipe) {
|
||||
@@ -69,7 +77,7 @@ public class BayesChef {
|
||||
net, createQuery(ingredientTypes.toArray(new TYPE[ingredientTypes.size()])));
|
||||
|
||||
double max_val = 0.0;
|
||||
String max_arg = "N/A";
|
||||
String max_arg = "0";
|
||||
for (ProbabilityAssignment p : probs) {
|
||||
if (p.getProbability() > max_val) {
|
||||
max_val = p.getProbability();
|
||||
@@ -106,6 +114,7 @@ public class BayesChef {
|
||||
case BEEF:
|
||||
case POTATO:
|
||||
case TOMATO:
|
||||
case GENERIC_NUTS:
|
||||
assignment.add(new Assignment(new Variable(type.toString(), FoodNetBuilder.DOMAIN),
|
||||
FoodNetBuilder.TRUE_VALUE));
|
||||
break;
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.thoughtworks.xstream.annotations.XStreamAlias;
|
||||
public class Diner {
|
||||
private int id;
|
||||
private Map<Integer,Integer> ratings = new HashMap<Integer,Integer>();
|
||||
private Map<Integer,Boolean> allergies = new HashMap<Integer,Boolean>();
|
||||
private Map<Integer,Boolean> categories = new HashMap<Integer,Boolean>();
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
@@ -19,7 +19,7 @@ public class Diner {
|
||||
return ratings.get(dishId);
|
||||
}
|
||||
|
||||
public boolean isAllergic(int categoryId) {
|
||||
return allergies.get(categoryId);
|
||||
public boolean isCategory(int categoryId) {
|
||||
return categories.get(categoryId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package net.woodyfolsom.cs6601.p2;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import com.thoughtworks.xstream.annotations.XStreamAlias;
|
||||
|
||||
@XStreamAlias("ing")
|
||||
public class Ingredient {
|
||||
public enum TYPE { ALCOHOL, BEEF, DAIRY, EGGS, FISH, GLUTEN, GRAIN, NUTS, PORK, POULTRY, POTATO, SHELLFISH, SPICE, SUGAR, TOMATO}
|
||||
public enum TYPE { ALCOHOL, BEEF, DAIRY, EGGS, FISH, GENERIC_NUTS, GLUTEN, GRAIN, PORK, POULTRY, POTATO, SHELLFISH, SPICE, SUGAR, TOMATO}
|
||||
|
||||
private String item;
|
||||
|
||||
@@ -30,11 +28,15 @@ public class Ingredient {
|
||||
public boolean isType(TYPE type) {
|
||||
switch (type) {
|
||||
case BEEF :
|
||||
return item.contains("beef");
|
||||
//For our purposes, veal is just expensive beef
|
||||
return item.contains("beef") || item.contains("veal");
|
||||
case DAIRY :
|
||||
return item.contains("margarine");
|
||||
return item.contains("margarine") || item.contains("milk");
|
||||
case EGGS :
|
||||
return item.equals("egg") || item.equals("eggs");
|
||||
case GENERIC_NUTS :
|
||||
//cashews, peanuts or generic nuts
|
||||
return item.contains("cashew") || item.contains("peanut") || item.contains("nuts");
|
||||
case GLUTEN :
|
||||
return item.contains("flour");
|
||||
case PORK :
|
||||
|
||||
@@ -4,6 +4,7 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
import com.thoughtworks.xstream.annotations.XStreamAlias;
|
||||
import com.thoughtworks.xstream.annotations.XStreamImplicit;
|
||||
@@ -16,6 +17,32 @@ public class Survey {
|
||||
@XStreamImplicit(itemFieldName="diner")
|
||||
private List<Diner> diners = new ArrayList<Diner>();
|
||||
|
||||
public double getAverageRating(int recipeIndex) {
|
||||
double total = 0.0;
|
||||
for (Diner diner : diners) {
|
||||
total += diner.getRating(recipeIndex);
|
||||
}
|
||||
return total/diners.size();
|
||||
}
|
||||
|
||||
public boolean isDiner(String category) {
|
||||
for (int i = 0; i < diners.size(); i++) {
|
||||
if (isCategory(i,category)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public boolean isCategory(int dinerIndex, String category) {
|
||||
for (Entry<Integer,String> entry : categories.entrySet()) {
|
||||
if (entry.getValue().equals(category)) {
|
||||
return diners.get(dinerIndex).isCategory(entry.getKey());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public Diner getDiner(int dinerIndex) {
|
||||
return diners.get(dinerIndex);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user