diff --git a/src/net/woodyfolsom/cs6601/p2/BayesChef.java b/src/net/woodyfolsom/cs6601/p2/BayesChef.java index ec5344e..f39cd59 100644 --- a/src/net/woodyfolsom/cs6601/p2/BayesChef.java +++ b/src/net/woodyfolsom/cs6601/p2/BayesChef.java @@ -3,11 +3,14 @@ package net.woodyfolsom.cs6601.p2; import java.io.File; import java.util.LinkedList; import java.util.List; +import java.util.Set; import net.woodyfolsom.cs6601.p2.Ingredient.TYPE; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + import dkohl.bayes.bayesnet.BayesNet; import dkohl.bayes.builders.FoodNetBuilder; -import dkohl.bayes.example.builders.FoodExampleBuilder; import dkohl.bayes.inference.EnumerateAll; import dkohl.bayes.probability.Assignment; import dkohl.bayes.probability.ProbabilityAssignment; @@ -32,54 +35,53 @@ public class BayesChef { int numDiners = survey.getDinerCount(); System.out.println("Read data for " + numDiners + " diner(s)."); - /* - System.out.println("Setting evidence for first 11 recipes."); - - System.out.println("Evaluating preference for remaining recipes."); - - //read survey prefs from survey.xml - double[][] surveyPrefs = new double[numDiners][]; - for (int i = 0; i < numDiners; i++) { - surveyPrefs[i] = new double[numRecipes/2]; - for (int j = 0; j < numRecipes/2; j++) { - surveyPrefs[i][j] = survey.getDiner(i).getRating(j); - } - } - - //generating stub evaluated preferences to test RMSE calculation - double[][] calculatedPrefs = new double[numDiners][]; - calculatedPrefs[0] = new double[numRecipes/2]; - - System.out.println("RMSE for recipes #" + numRecipes/2 + "-" + numRecipes +" (calculated vs. surveyed preference):"); - - for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) { - SimpleRegression simpleRegression = new SimpleRegression(); - - for (int i = 0; i < numRecipes/2; i++) { - simpleRegression.addData(i,calculatedPrefs[dinerIndex][i]); - } - double calculatedMSE = simpleRegression.getMeanSquareError(); - simpleRegression.clear(); - - for (int i = 0; i < numRecipes/2; i++) { - simpleRegression.addData(i,surveyPrefs[dinerIndex][i]); - } - - double surveyMSE = simpleRegression.getMeanSquareError(); - - System.out.println("Diner # " + (dinerIndex + 1) + ": " + calculatedMSE + " vs. "+ surveyMSE); - }*/ - + System.out.println("Creating Bayes net for survey dataset..."); BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook); - //BayesNet net = FoodExampleBuilder.dishNet(); - printPreference(net, TYPE.PORK); - printPreference(net, TYPE.BEEF); - printPreference(net, TYPE.POTATO); - printPreference(net, TYPE.TOMATO); + System.out.println("Querying Bayes net for individual flavor preferences: "); + double[][] flavorPrefs = new double[4][]; + + flavorPrefs[0] = new double[] {0.0,printPreference(net, TYPE.PORK)}; + flavorPrefs[1] = new double[] {1.0,printPreference(net, TYPE.BEEF)}; + flavorPrefs[2] = new double[] {2.0,printPreference(net, TYPE.POTATO)}; + flavorPrefs[3] = new double[] {3.0,printPreference(net, TYPE.TOMATO)}; + + SimpleRegression simpleRegression = new SimpleRegression(); + simpleRegression.addData(flavorPrefs); + + System.out.println("Individual flavor pref MSE: " + simpleRegression.getMeanSquareError()); + simpleRegression.clear(); + + System.out.println("Querying Bayes net for recipe flavor preferences: "); + flavorPrefs = new double[recipeBook.getSize()][]; + + for (int i = 0; i < recipeBook.getSize(); i++) { + simpleRegression.addData(i,printPreference(net, recipeBook.getRecipe(i))); + } + System.out.println("Recipe flavor pref MSE: " + simpleRegression.getMeanSquareError()); } - void printPreference(BayesNet net, TYPE type) { + int printPreference(BayesNet net, Recipe recipe) { + Set ingredientTypes = recipe.getIngredients().getTypes(); + + List probs = EnumerateAll.enumerateAsk(new Variable( + FoodNetBuilder.TASTE, FoodNetBuilder.RATING_DOMAIN), + net, createQuery(ingredientTypes.toArray(new TYPE[ingredientTypes.size()]))); + + double max_val = 0.0; + String max_arg = "N/A"; + for (ProbabilityAssignment p : probs) { + if (p.getProbability() > max_val) { + max_val = p.getProbability(); + max_arg = p.getValue(); + } + } + + System.out.println("TASTE for " + recipe.getName() + " " + ingredientTypes + " : " + max_arg + " " + max_val); + return Integer.valueOf(max_arg); + } + + int printPreference(BayesNet net, TYPE type) { List probs = EnumerateAll.enumerateAsk(new Variable( FoodNetBuilder.TASTE, FoodNetBuilder.RATING_DOMAIN), net, createQuery(type)); @@ -93,10 +95,12 @@ public class BayesChef { } System.out.println("TASTE " + type + ": " + max_arg + " " + max_val); + return Integer.valueOf(max_arg); } - List createQuery(TYPE type) { + List createQuery(TYPE... types) { List assignment = new LinkedList(); + for (TYPE type : types) { switch (type) { case PORK: case BEEF: @@ -107,6 +111,7 @@ public class BayesChef { break; default: } + } return assignment; } } diff --git a/src/net/woodyfolsom/cs6601/p2/Ingredient.java b/src/net/woodyfolsom/cs6601/p2/Ingredient.java index 26bbfe4..6ba48a4 100644 --- a/src/net/woodyfolsom/cs6601/p2/Ingredient.java +++ b/src/net/woodyfolsom/cs6601/p2/Ingredient.java @@ -1,5 +1,10 @@ package net.woodyfolsom.cs6601.p2; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + import com.thoughtworks.xstream.annotations.XStreamAlias; @XStreamAlias("ing") @@ -12,6 +17,16 @@ public class Ingredient { return item; } + public Set getTypes() { + Set types = new HashSet(); + for (TYPE type : TYPE.values()) { + if (isType(type)) { + types.add(type); + } + } + return types; + } + public boolean isType(TYPE type) { switch (type) { case BEEF : diff --git a/src/net/woodyfolsom/cs6601/p2/Ingredients.java b/src/net/woodyfolsom/cs6601/p2/Ingredients.java index 077b569..8789018 100644 --- a/src/net/woodyfolsom/cs6601/p2/Ingredients.java +++ b/src/net/woodyfolsom/cs6601/p2/Ingredients.java @@ -1,7 +1,11 @@ package net.woodyfolsom.cs6601.p2; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; + +import net.woodyfolsom.cs6601.p2.Ingredient.TYPE; import com.thoughtworks.xstream.annotations.XStreamAlias; import com.thoughtworks.xstream.annotations.XStreamImplicit; @@ -31,6 +35,14 @@ public class Ingredients { return ingredients; } + public Set getTypes() { + Set types = new HashSet(); + for (Ingredient ingredient : ingredients) { + types.addAll(ingredient.getTypes()); + } + return types; + } + public void setIngredients(List ingredients) { this.ingredients = ingredients; } diff --git a/src/net/woodyfolsom/cs6601/p2/Recipe.java b/src/net/woodyfolsom/cs6601/p2/Recipe.java index fb1a655..57d329a 100644 --- a/src/net/woodyfolsom/cs6601/p2/Recipe.java +++ b/src/net/woodyfolsom/cs6601/p2/Recipe.java @@ -14,4 +14,8 @@ public class Recipe { public Ingredients getIngredients() { return ingredients; } + + public String getName() { + return head.getTitle(); + } } \ No newline at end of file