BayesChef now reports MSE of evidence and hidden recipe flavor combos.

Also, improved formatting of output to show exactly which flavors were included in the calculations.
This commit is contained in:
Woody Folsom
2012-03-12 09:18:28 -04:00
parent 06066f470a
commit b1e5f2c74e
4 changed files with 82 additions and 46 deletions

View File

@@ -3,11 +3,14 @@ package net.woodyfolsom.cs6601.p2;
import java.io.File; import java.io.File;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE; import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import dkohl.bayes.bayesnet.BayesNet; import dkohl.bayes.bayesnet.BayesNet;
import dkohl.bayes.builders.FoodNetBuilder; import dkohl.bayes.builders.FoodNetBuilder;
import dkohl.bayes.example.builders.FoodExampleBuilder;
import dkohl.bayes.inference.EnumerateAll; import dkohl.bayes.inference.EnumerateAll;
import dkohl.bayes.probability.Assignment; import dkohl.bayes.probability.Assignment;
import dkohl.bayes.probability.ProbabilityAssignment; import dkohl.bayes.probability.ProbabilityAssignment;
@@ -32,54 +35,53 @@ public class BayesChef {
int numDiners = survey.getDinerCount(); int numDiners = survey.getDinerCount();
System.out.println("Read data for " + numDiners + " diner(s)."); System.out.println("Read data for " + numDiners + " diner(s).");
/* System.out.println("Creating Bayes net for survey dataset...");
System.out.println("Setting evidence for first 11 recipes."); BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook);
System.out.println("Evaluating preference for remaining recipes."); System.out.println("Querying Bayes net for individual flavor preferences: ");
double[][] flavorPrefs = new double[4][];
//read survey prefs from survey.xml flavorPrefs[0] = new double[] {0.0,printPreference(net, TYPE.PORK)};
double[][] surveyPrefs = new double[numDiners][]; flavorPrefs[1] = new double[] {1.0,printPreference(net, TYPE.BEEF)};
for (int i = 0; i < numDiners; i++) { flavorPrefs[2] = new double[] {2.0,printPreference(net, TYPE.POTATO)};
surveyPrefs[i] = new double[numRecipes/2]; flavorPrefs[3] = new double[] {3.0,printPreference(net, TYPE.TOMATO)};
for (int j = 0; j < numRecipes/2; j++) {
surveyPrefs[i][j] = survey.getDiner(i).getRating(j);
}
}
//generating stub evaluated preferences to test RMSE calculation
double[][] calculatedPrefs = new double[numDiners][];
calculatedPrefs[0] = new double[numRecipes/2];
System.out.println("RMSE for recipes #" + numRecipes/2 + "-" + numRecipes +" (calculated vs. surveyed preference):");
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
SimpleRegression simpleRegression = new SimpleRegression(); SimpleRegression simpleRegression = new SimpleRegression();
simpleRegression.addData(flavorPrefs);
for (int i = 0; i < numRecipes/2; i++) { System.out.println("Individual flavor pref MSE: " + simpleRegression.getMeanSquareError());
simpleRegression.addData(i,calculatedPrefs[dinerIndex][i]);
}
double calculatedMSE = simpleRegression.getMeanSquareError();
simpleRegression.clear(); simpleRegression.clear();
for (int i = 0; i < numRecipes/2; i++) { System.out.println("Querying Bayes net for recipe flavor preferences: ");
simpleRegression.addData(i,surveyPrefs[dinerIndex][i]); flavorPrefs = new double[recipeBook.getSize()][];
for (int i = 0; i < recipeBook.getSize(); i++) {
simpleRegression.addData(i,printPreference(net, recipeBook.getRecipe(i)));
}
System.out.println("Recipe flavor pref MSE: " + simpleRegression.getMeanSquareError());
} }
double surveyMSE = simpleRegression.getMeanSquareError(); int printPreference(BayesNet net, Recipe recipe) {
Set<TYPE> ingredientTypes = recipe.getIngredients().getTypes();
System.out.println("Diner # " + (dinerIndex + 1) + ": " + calculatedMSE + " vs. "+ surveyMSE); List<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(new Variable(
}*/ FoodNetBuilder.TASTE, FoodNetBuilder.RATING_DOMAIN),
net, createQuery(ingredientTypes.toArray(new TYPE[ingredientTypes.size()])));
BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook); double max_val = 0.0;
//BayesNet net = FoodExampleBuilder.dishNet(); String max_arg = "N/A";
for (ProbabilityAssignment p : probs) {
printPreference(net, TYPE.PORK); if (p.getProbability() > max_val) {
printPreference(net, TYPE.BEEF); max_val = p.getProbability();
printPreference(net, TYPE.POTATO); max_arg = p.getValue();
printPreference(net, TYPE.TOMATO); }
} }
void printPreference(BayesNet net, TYPE type) { System.out.println("TASTE for " + recipe.getName() + " " + ingredientTypes + " : " + max_arg + " " + max_val);
return Integer.valueOf(max_arg);
}
int printPreference(BayesNet net, TYPE type) {
List<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(new Variable( List<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(new Variable(
FoodNetBuilder.TASTE, FoodNetBuilder.RATING_DOMAIN), FoodNetBuilder.TASTE, FoodNetBuilder.RATING_DOMAIN),
net, createQuery(type)); net, createQuery(type));
@@ -93,10 +95,12 @@ public class BayesChef {
} }
System.out.println("TASTE " + type + ": " + max_arg + " " + max_val); System.out.println("TASTE " + type + ": " + max_arg + " " + max_val);
return Integer.valueOf(max_arg);
} }
List<Assignment> createQuery(TYPE type) { List<Assignment> createQuery(TYPE... types) {
List<Assignment> assignment = new LinkedList<Assignment>(); List<Assignment> assignment = new LinkedList<Assignment>();
for (TYPE type : types) {
switch (type) { switch (type) {
case PORK: case PORK:
case BEEF: case BEEF:
@@ -107,6 +111,7 @@ public class BayesChef {
break; break;
default: default:
} }
}
return assignment; return assignment;
} }
} }

View File

@@ -1,5 +1,10 @@
package net.woodyfolsom.cs6601.p2; package net.woodyfolsom.cs6601.p2;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import com.thoughtworks.xstream.annotations.XStreamAlias; import com.thoughtworks.xstream.annotations.XStreamAlias;
@XStreamAlias("ing") @XStreamAlias("ing")
@@ -12,6 +17,16 @@ public class Ingredient {
return item; return item;
} }
public Set<TYPE> getTypes() {
Set<TYPE> types = new HashSet<TYPE>();
for (TYPE type : TYPE.values()) {
if (isType(type)) {
types.add(type);
}
}
return types;
}
public boolean isType(TYPE type) { public boolean isType(TYPE type) {
switch (type) { switch (type) {
case BEEF : case BEEF :

View File

@@ -1,7 +1,11 @@
package net.woodyfolsom.cs6601.p2; package net.woodyfolsom.cs6601.p2;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
import com.thoughtworks.xstream.annotations.XStreamAlias; import com.thoughtworks.xstream.annotations.XStreamAlias;
import com.thoughtworks.xstream.annotations.XStreamImplicit; import com.thoughtworks.xstream.annotations.XStreamImplicit;
@@ -31,6 +35,14 @@ public class Ingredients {
return ingredients; return ingredients;
} }
public Set<TYPE> getTypes() {
Set<TYPE> types = new HashSet<TYPE>();
for (Ingredient ingredient : ingredients) {
types.addAll(ingredient.getTypes());
}
return types;
}
public void setIngredients(List<Ingredient> ingredients) { public void setIngredients(List<Ingredient> ingredients) {
this.ingredients = ingredients; this.ingredients = ingredients;
} }

View File

@@ -14,4 +14,8 @@ public class Recipe {
public Ingredients getIngredients() { public Ingredients getIngredients() {
return ingredients; return ingredients;
} }
public String getName() {
return head.getTitle();
}
} }