Work in progress.

This commit is contained in:
Woody Folsom
2012-03-12 14:33:41 -04:00
parent c1f68f6f5c
commit 16a97ba39a
15 changed files with 1424 additions and 753 deletions

View File

@@ -6,24 +6,22 @@ import java.util.List;
import java.util.Set;
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.builders.FoodNetBuilder;
import dkohl.bayes.inference.EnumerateAll;
import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.ProbabilityAssignment;
import dkohl.bayes.probability.Variable;
import dkohl.util.RootMeanSquareError;
public class BayesChef {
public static void main(String... args) {
System.out.println("Reading recipe book.");
RecipeBook recipeBook = RecipeBookReader.readRecipeBook(new File("data/short_recipebook.xml"));
RecipeBook recipeBook = RecipeBookReader.readRecipeBook(new File("data/long_recipebook.xml"));
System.out.println("Reading survey data.");
Survey survey = SurveyReader.readSurvey(new File("data/short_survey.xml"));
Survey survey = SurveyReader.readSurvey(new File("data/long_survey.xml"));
new BayesChef().reportBestMeal(recipeBook,survey);
}
@@ -36,29 +34,39 @@ public class BayesChef {
System.out.println("Read data for " + numDiners + " diner(s).");
System.out.println("Creating Bayes net for survey dataset...");
BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook);
System.out.println("Querying Bayes net for individual flavor preferences: ");
double[][] flavorPrefs = new double[4][];
int numSurveyDishes = survey.getDishCount();
int bayesNetIndexStart = 0;
int bayesNetIndexEnd = numSurveyDishes / 2;
flavorPrefs[0] = new double[] {0.0,printPreference(net, TYPE.PORK)};
flavorPrefs[1] = new double[] {1.0,printPreference(net, TYPE.BEEF)};
flavorPrefs[2] = new double[] {2.0,printPreference(net, TYPE.POTATO)};
flavorPrefs[3] = new double[] {3.0,printPreference(net, TYPE.TOMATO)};
int comparisonIndexStart = bayesNetIndexEnd;
int comparisonIndexEnd = numSurveyDishes;
SimpleRegression simpleRegression = new SimpleRegression();
simpleRegression.addData(flavorPrefs);
//TODO change this to include min/max recipe indices for building the net
BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook, bayesNetIndexStart, bayesNetIndexEnd);
System.out.println("Individual flavor pref MSE: " + simpleRegression.getMeanSquareError());
simpleRegression.clear();
System.out.println("Querying Bayes net for predicted flavor preferences: ");
//TODO change this to query actual flavor combos from latter half of survey
double[] predictedRating = new double[numSurveyDishes - bayesNetIndexEnd];
System.out.println("Querying Bayes net for recipe flavor preferences: ");
flavorPrefs = new double[recipeBook.getSize()][];
for (int i = 0; i < recipeBook.getSize(); i++) {
simpleRegression.addData(i,printPreference(net, recipeBook.getRecipe(i)));
for (int i = comparisonIndexStart; i < comparisonIndexEnd; i++) {
predictedRating[i - comparisonIndexStart] = printPreference(net, recipeBook.getRecipe(survey.getDish(i)));
}
System.out.println("Recipe flavor pref MSE: " + simpleRegression.getMeanSquareError());
System.out.println("Querying survey dataset for actual recipe flavor preferences: ");
double[] actualRating = new double[predictedRating.length];
for (int i = comparisonIndexStart; i < comparisonIndexEnd; i++) {
actualRating[i - comparisonIndexStart] = survey.getAverageRating(i);
System.out.println("Actual average rating for " + survey.getDish(i) + ": " + actualRating[i - comparisonIndexStart]);
}
RootMeanSquareError rmse = new RootMeanSquareError();
for (int i = 0; i < actualRating.length; i++) {
rmse.push(predictedRating[i], actualRating[i]);
}
System.out.println("Root Mean Squared Error (predicted vs. actual): " + rmse.error());
}
int printPreference(BayesNet net, Recipe recipe) {
@@ -69,7 +77,7 @@ public class BayesChef {
net, createQuery(ingredientTypes.toArray(new TYPE[ingredientTypes.size()])));
double max_val = 0.0;
String max_arg = "N/A";
String max_arg = "0";
for (ProbabilityAssignment p : probs) {
if (p.getProbability() > max_val) {
max_val = p.getProbability();
@@ -106,6 +114,7 @@ public class BayesChef {
case BEEF:
case POTATO:
case TOMATO:
case GENERIC_NUTS:
assignment.add(new Assignment(new Variable(type.toString(), FoodNetBuilder.DOMAIN),
FoodNetBuilder.TRUE_VALUE));
break;

View File

@@ -9,7 +9,7 @@ import com.thoughtworks.xstream.annotations.XStreamAlias;
public class Diner {
private int id;
private Map<Integer,Integer> ratings = new HashMap<Integer,Integer>();
private Map<Integer,Boolean> allergies = new HashMap<Integer,Boolean>();
private Map<Integer,Boolean> categories = new HashMap<Integer,Boolean>();
public int getId() {
return id;
@@ -19,7 +19,7 @@ public class Diner {
return ratings.get(dishId);
}
public boolean isAllergic(int categoryId) {
return allergies.get(categoryId);
public boolean isCategory(int categoryId) {
return categories.get(categoryId);
}
}

View File

@@ -1,15 +1,13 @@
package net.woodyfolsom.cs6601.p2;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import com.thoughtworks.xstream.annotations.XStreamAlias;
@XStreamAlias("ing")
public class Ingredient {
public enum TYPE { ALCOHOL, BEEF, DAIRY, EGGS, FISH, GLUTEN, GRAIN, NUTS, PORK, POULTRY, POTATO, SHELLFISH, SPICE, SUGAR, TOMATO}
public enum TYPE { ALCOHOL, BEEF, DAIRY, EGGS, FISH, GENERIC_NUTS, GLUTEN, GRAIN, PORK, POULTRY, POTATO, SHELLFISH, SPICE, SUGAR, TOMATO}
private String item;
@@ -30,11 +28,15 @@ public class Ingredient {
public boolean isType(TYPE type) {
switch (type) {
case BEEF :
return item.contains("beef");
//For our purposes, veal is just expensive beef
return item.contains("beef") || item.contains("veal");
case DAIRY :
return item.contains("margarine");
return item.contains("margarine") || item.contains("milk");
case EGGS :
return item.equals("egg") || item.equals("eggs");
case GENERIC_NUTS :
//cashews, peanuts or generic nuts
return item.contains("cashew") || item.contains("peanut") || item.contains("nuts");
case GLUTEN :
return item.contains("flour");
case PORK :

View File

@@ -4,6 +4,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import com.thoughtworks.xstream.annotations.XStreamAlias;
import com.thoughtworks.xstream.annotations.XStreamImplicit;
@@ -16,6 +17,32 @@ public class Survey {
@XStreamImplicit(itemFieldName="diner")
private List<Diner> diners = new ArrayList<Diner>();
public double getAverageRating(int recipeIndex) {
double total = 0.0;
for (Diner diner : diners) {
total += diner.getRating(recipeIndex);
}
return total/diners.size();
}
public boolean isDiner(String category) {
for (int i = 0; i < diners.size(); i++) {
if (isCategory(i,category)) {
return true;
}
}
return false;
}
public boolean isCategory(int dinerIndex, String category) {
for (Entry<Integer,String> entry : categories.entrySet()) {
if (entry.getValue().equals(category)) {
return diners.get(dinerIndex).isCategory(entry.getKey());
}
}
return false;
}
public Diner getDiner(int dinerIndex) {
return diners.get(dinerIndex);
}