Work in progress.
Able to read short survey, recipe book but crashing in Bayes Net code.
This commit is contained in:
218
src/dkohl/bayes/builders/FoodNetBuilder.java
Normal file
218
src/dkohl/bayes/builders/FoodNetBuilder.java
Normal file
@@ -0,0 +1,218 @@
|
||||
package dkohl.bayes.builders;
|
||||
|
||||
import java.util.HashSet;
|
||||
|
||||
import net.woodyfolsom.cs6601.p2.Diner;
|
||||
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
|
||||
import net.woodyfolsom.cs6601.p2.Ingredients;
|
||||
import net.woodyfolsom.cs6601.p2.Recipe;
|
||||
import net.woodyfolsom.cs6601.p2.RecipeBook;
|
||||
import net.woodyfolsom.cs6601.p2.Survey;
|
||||
import dkohl.bayes.bayesnet.BayesNet;
|
||||
import dkohl.bayes.estimation.MaximumLikelihoodEstimation;
|
||||
import dkohl.bayes.example.builders.FoodExampleBuilder;
|
||||
import dkohl.bayes.probability.Assignment;
|
||||
import dkohl.bayes.probability.Probability;
|
||||
import dkohl.bayes.probability.Variable;
|
||||
import dkohl.bayes.probability.distribution.ContinousDistribution;
|
||||
import dkohl.bayes.probability.distribution.ProbabilityDistribution;
|
||||
import dkohl.bayes.probability.distribution.ProbabilityTable;
|
||||
import dkohl.bayes.statistic.DataPoint;
|
||||
import dkohl.bayes.statistic.DataSet;
|
||||
import dkohl.onthology.Ontology;
|
||||
|
||||
public class FoodNetBuilder {
|
||||
public static final String TASTE = "Taste";
|
||||
public static final String SOMEONE_VEGETARIAN = "Vegetarian";
|
||||
public static final String CONTAINS_MEAT = "Meat";
|
||||
public static final String CONTAINS_VEGETABLE = "Vegetable";
|
||||
public static final String CONTAINS_BEEF = TYPE.BEEF.toString();
|
||||
public static final String CONTAINS_PORK = TYPE.PORK.toString();
|
||||
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
|
||||
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
|
||||
|
||||
public static final String TRUE_VALUE = "true";
|
||||
public static final String FALSE_VALUE = "false";
|
||||
|
||||
public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE };
|
||||
|
||||
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
|
||||
"6", "7", "8", "9", "10" };
|
||||
|
||||
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN,
|
||||
CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS,
|
||||
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, TASTE };
|
||||
|
||||
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
|
||||
CONTAINS_POTATOS, CONTAINS_TOMATOS, };
|
||||
|
||||
public static Ontology createOntology() {
|
||||
HashSet<String> classes = new HashSet<String>();
|
||||
|
||||
classes.add(CONTAINS_MEAT);
|
||||
classes.add(CONTAINS_VEGETABLE);
|
||||
|
||||
Ontology onthology = new Ontology(classes);
|
||||
|
||||
onthology.define(CONTAINS_PORK, CONTAINS_MEAT);
|
||||
onthology.define(CONTAINS_BEEF, CONTAINS_MEAT);
|
||||
|
||||
onthology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
|
||||
onthology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
|
||||
return onthology;
|
||||
}
|
||||
|
||||
private static Assignment build(String varible, String value) {
|
||||
return new Assignment(new Variable(varible, DOMAIN), value);
|
||||
}
|
||||
|
||||
public static DataPoint normalize(DataPoint point, Ontology onto) {
|
||||
// resolve onthology
|
||||
DataPoint normPoint = new DataPoint(point);
|
||||
for (String key : point.keySet()) {
|
||||
if (onto.getInheritance().containsKey(key)) {
|
||||
normPoint
|
||||
.add(build(onto.getInheritance().get(key), TRUE_VALUE));
|
||||
}
|
||||
}
|
||||
|
||||
// implement closed world assumption
|
||||
// everything unknown is false
|
||||
for (String variable : VARIABLES) {
|
||||
if (!normPoint.containsKey(variable)) {
|
||||
normPoint.add(build(variable, FALSE_VALUE));
|
||||
}
|
||||
}
|
||||
return normPoint;
|
||||
}
|
||||
|
||||
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook) {
|
||||
DataSet data = FoodExampleBuilder.examples();
|
||||
Ontology onto = createOntology();
|
||||
|
||||
int nDishes = survey.getDishCount();
|
||||
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
|
||||
Diner diner = survey.getDiner(dinerIndex);
|
||||
for (int dishIndex = 0; dishIndex < nDishes; dishIndex++) {
|
||||
data.add(normalize(createDataPoint(recipeBook, survey.getDish(dishIndex), diner.getRating(dishIndex)),onto));
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
public static DataPoint createDataPoint(RecipeBook recipeBook, String recipeName, int weight) {
|
||||
DataPoint point = new DataPoint();
|
||||
|
||||
Recipe recipe = recipeBook.getRecipe(recipeName);
|
||||
Ingredients ingredients = recipe.getIngredients();
|
||||
|
||||
if (ingredients.contains(TYPE.BEEF)) {
|
||||
point.add(build(CONTAINS_BEEF, TRUE_VALUE));
|
||||
}
|
||||
if (ingredients.contains(TYPE.PORK)) {
|
||||
point.add(build(CONTAINS_PORK, TRUE_VALUE));
|
||||
}
|
||||
if (ingredients.contains(TYPE.POTATO)) {
|
||||
point.add(build(CONTAINS_POTATOS, TRUE_VALUE));
|
||||
}
|
||||
if (ingredients.contains(TYPE.TOMATO)) {
|
||||
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
|
||||
}
|
||||
|
||||
point.add(build(TASTE, "" + weight));
|
||||
|
||||
return point;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution vegi() {
|
||||
String names[] = { SOMEONE_VEGETARIAN };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
table.setProbabilityForAssignment("true;", new Probability(0));
|
||||
table.setProbabilityForAssignment("false;", new Probability(1));
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution beef() {
|
||||
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution pork() {
|
||||
String names[] = { CONTAINS_MEAT, CONTAINS_PORK };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution meet() {
|
||||
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution tomatos() {
|
||||
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution potatos() {
|
||||
String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution vegetables() {
|
||||
String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN };
|
||||
ProbabilityTable table = new ProbabilityTable(names);
|
||||
return table;
|
||||
}
|
||||
|
||||
public static ProbabilityDistribution taste() {
|
||||
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK,
|
||||
CONTAINS_POTATOS, CONTAINS_TOMATOS };
|
||||
ContinousDistribution distribution = new ContinousDistribution(names, 0);
|
||||
return distribution;
|
||||
}
|
||||
|
||||
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook) {
|
||||
BayesNet net = new BayesNet(VARIABLES);
|
||||
|
||||
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi());
|
||||
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meet());
|
||||
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
|
||||
vegetables());
|
||||
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
|
||||
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
|
||||
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
|
||||
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
|
||||
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
|
||||
|
||||
Ontology ontology = createOntology();
|
||||
for (String category : ontology.getClasses()) {
|
||||
net.connect(category, SOMEONE_VEGETARIAN);
|
||||
}
|
||||
|
||||
for (String thing : OBSERVED) {
|
||||
net.connect(thing, ontology.getInheritance().get(thing));
|
||||
net.connect(TASTE, thing);
|
||||
}
|
||||
|
||||
DataSet dataSet = getSurveyDataSet(survey, recipeBook);
|
||||
|
||||
for (String category : ontology.getClasses()) {
|
||||
MaximumLikelihoodEstimation.estimate(dataSet, net, category);
|
||||
for (String thing : ontology.getClasses2thing().get(category)) {
|
||||
MaximumLikelihoodEstimation.estimate(dataSet, net, thing);
|
||||
}
|
||||
}
|
||||
|
||||
MaximumLikelihoodEstimation.estimate(dataSet, net, TASTE);
|
||||
ContinousDistribution distribturion = (ContinousDistribution) net
|
||||
.getNodes().get(TASTE);
|
||||
|
||||
distribturion.estimate();
|
||||
|
||||
return net;
|
||||
}
|
||||
}
|
||||
@@ -27,8 +27,8 @@ public class EnumerateAll {
|
||||
* a set of assignments in this net.
|
||||
* @return
|
||||
*/
|
||||
public static LinkedList<ProbabilityAssignment> enumerateAsk(
|
||||
Variable query, BayesNet net, LinkedList<Assignment> assignments) {
|
||||
public static List<ProbabilityAssignment> enumerateAsk(
|
||||
Variable query, BayesNet net, List<Assignment> assignments) {
|
||||
LinkedList<Variable> variables = net.getVariables();
|
||||
LinkedList<ProbabilityAssignment> result = new LinkedList<ProbabilityAssignment>();
|
||||
|
||||
|
||||
@@ -1,61 +1,112 @@
|
||||
package net.woodyfolsom.cs6601.p2;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.stat.regression.SimpleRegression;
|
||||
import net.woodyfolsom.cs6601.p2.Ingredient.TYPE;
|
||||
import dkohl.bayes.bayesnet.BayesNet;
|
||||
import dkohl.bayes.builders.FoodNetBuilder;
|
||||
import dkohl.bayes.example.builders.FoodExampleBuilder;
|
||||
import dkohl.bayes.inference.EnumerateAll;
|
||||
import dkohl.bayes.probability.Assignment;
|
||||
import dkohl.bayes.probability.ProbabilityAssignment;
|
||||
import dkohl.bayes.probability.Variable;
|
||||
|
||||
public class BayesChef {
|
||||
|
||||
public static void main(String... args) {
|
||||
System.out.println("Reading recipe book.");
|
||||
RecipeBook recipeBook = RecipeBookReader.readRecipeBook(new File("data/survey_recipes.xml"));
|
||||
|
||||
System.out.println("Read data for " + recipeBook.getSize() + " recipes.");
|
||||
RecipeBook recipeBook = RecipeBookReader.readRecipeBook(new File("data/short_recipebook.xml"));
|
||||
|
||||
System.out.println("Reading survey data.");
|
||||
Survey survey = SurveyReader.readSurvey(new File("data/survey.xml"));
|
||||
Survey survey = SurveyReader.readSurvey(new File("data/short_survey.xml"));
|
||||
|
||||
System.out.println("Read data for " + survey.getDinerCount() + " diner(s).");
|
||||
new BayesChef().reportBestMeal(recipeBook,survey);
|
||||
}
|
||||
|
||||
private void reportBestMeal(RecipeBook recipeBook, Survey survey) {
|
||||
int numRecipes = recipeBook.getSize();
|
||||
System.out.println("Read data for " + numRecipes + " recipes.");
|
||||
|
||||
int numDiners = survey.getDinerCount();
|
||||
System.out.println("Read data for " + numDiners + " diner(s).");
|
||||
|
||||
/*
|
||||
System.out.println("Setting evidence for first 11 recipes.");
|
||||
|
||||
System.out.println("Evaluating preference for remaining recipes.");
|
||||
|
||||
//read survey prefs from survey.xml
|
||||
double[][] surveyPrefs = new double[5][];
|
||||
for (int i = 0; i < 5; i++) {
|
||||
surveyPrefs[i] = new double[11];
|
||||
for (int j = 0; j < 11; j++) {
|
||||
double[][] surveyPrefs = new double[numDiners][];
|
||||
for (int i = 0; i < numDiners; i++) {
|
||||
surveyPrefs[i] = new double[numRecipes/2];
|
||||
for (int j = 0; j < numRecipes/2; j++) {
|
||||
surveyPrefs[i][j] = survey.getDiner(i).getRating(j);
|
||||
}
|
||||
}
|
||||
|
||||
//generating stub evaluated preferences to test RMSE calculation
|
||||
double[][] calculatedPrefs = new double[5][];
|
||||
calculatedPrefs[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
|
||||
calculatedPrefs[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
|
||||
calculatedPrefs[2] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
|
||||
calculatedPrefs[3] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
|
||||
calculatedPrefs[4] = new double[] { 10.0, 10.0, 0.0, 10.0, 10.0, 4.0, 10.0, 10.0, 10.0, 10.0, 10.0 };
|
||||
double[][] calculatedPrefs = new double[numDiners][];
|
||||
calculatedPrefs[0] = new double[numRecipes/2];
|
||||
|
||||
System.out.println("RMSE for recipes #12-22 (calculated vs. surveyed preference):");
|
||||
System.out.println("RMSE for recipes #" + numRecipes/2 + "-" + numRecipes +" (calculated vs. surveyed preference):");
|
||||
|
||||
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
|
||||
SimpleRegression simpleRegression = new SimpleRegression();
|
||||
|
||||
for (int i = 0; i < 11; i++) {
|
||||
for (int i = 0; i < numRecipes/2; i++) {
|
||||
simpleRegression.addData(i,calculatedPrefs[dinerIndex][i]);
|
||||
}
|
||||
double calculatedMSE = simpleRegression.getMeanSquareError();
|
||||
simpleRegression.clear();
|
||||
|
||||
for (int i = 0; i < 11; i++) {
|
||||
for (int i = 0; i < numRecipes/2; i++) {
|
||||
simpleRegression.addData(i,surveyPrefs[dinerIndex][i]);
|
||||
}
|
||||
|
||||
double surveyMSE = simpleRegression.getMeanSquareError();
|
||||
|
||||
System.out.println("Diner # " + (dinerIndex + 1) + ": " + calculatedMSE + " vs. "+ surveyMSE);
|
||||
}*/
|
||||
|
||||
BayesNet net = FoodNetBuilder.createDishNet(survey, recipeBook);
|
||||
//BayesNet net = FoodExampleBuilder.dishNet();
|
||||
|
||||
printPreference(net, TYPE.PORK);
|
||||
printPreference(net, TYPE.BEEF);
|
||||
printPreference(net, TYPE.POTATO);
|
||||
printPreference(net, TYPE.TOMATO);
|
||||
}
|
||||
|
||||
void printPreference(BayesNet net, TYPE type) {
|
||||
List<ProbabilityAssignment> probs = EnumerateAll.enumerateAsk(new Variable(
|
||||
FoodNetBuilder.TASTE, FoodNetBuilder.RATING_DOMAIN),
|
||||
net, createQuery(type));
|
||||
double max_val = 0.0;
|
||||
String max_arg = "N/A";
|
||||
for (ProbabilityAssignment p : probs) {
|
||||
if (p.getProbability() > max_val) {
|
||||
max_val = p.getProbability();
|
||||
max_arg = p.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("TASTE " + type + ": " + max_arg + " " + max_val);
|
||||
}
|
||||
|
||||
List<Assignment> createQuery(TYPE type) {
|
||||
List<Assignment> assignment = new LinkedList<Assignment>();
|
||||
switch (type) {
|
||||
case PORK:
|
||||
case BEEF:
|
||||
case POTATO:
|
||||
case TOMATO:
|
||||
assignment.add(new Assignment(new Variable(type.toString(), FoodNetBuilder.DOMAIN),
|
||||
FoodNetBuilder.TRUE_VALUE));
|
||||
break;
|
||||
default:
|
||||
}
|
||||
return assignment;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.thoughtworks.xstream.annotations.XStreamAlias;
|
||||
|
||||
@XStreamAlias("ing")
|
||||
public class Ingredient {
|
||||
public enum TYPE { ALCOHOL, DAIRY, EGGS, FISH, GLUTEN, GRAIN, NUTS, PORK, POULTRY, RED_MEAT, SHELLFISH, SPICE, SUGAR}
|
||||
public enum TYPE { ALCOHOL, BEEF, DAIRY, EGGS, FISH, GLUTEN, GRAIN, NUTS, PORK, POULTRY, POTATO, SHELLFISH, SPICE, SUGAR, TOMATO}
|
||||
|
||||
private String item;
|
||||
|
||||
@@ -14,20 +14,31 @@ public class Ingredient {
|
||||
|
||||
public boolean isType(TYPE type) {
|
||||
switch (type) {
|
||||
case BEEF :
|
||||
return item.contains("beef");
|
||||
case DAIRY :
|
||||
return item.contains("margarine");
|
||||
case EGGS :
|
||||
return item.equals("egg") || item.equals("eggs");
|
||||
case GLUTEN :
|
||||
return item.contains("flour");
|
||||
case RED_MEAT :
|
||||
return item.contains("beef");
|
||||
case PORK :
|
||||
return item.contains("pork");
|
||||
case POTATO :
|
||||
return item.contains("potato");
|
||||
case SPICE :
|
||||
return item.endsWith("cinnamon") || item.endsWith("nutmeg") || item.endsWith("cloves");
|
||||
case SUGAR :
|
||||
return item.endsWith("sugar");
|
||||
case TOMATO :
|
||||
return item.contains("tomato");
|
||||
default : //unknown ingredient, e.g. coffee, bananas, honey
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public Object readResolve() {
|
||||
item = item.toLowerCase();
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,16 @@ public class RecipeBook {
|
||||
return recipes.get(index);
|
||||
}
|
||||
|
||||
//TODO build an index of recipes by name?
|
||||
public Recipe getRecipe(String recipeName) {
|
||||
for (Recipe recipe : recipes) {
|
||||
if (recipeName.equalsIgnoreCase(recipe.getHead().getTitle())) {
|
||||
return recipe;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public int getSize() {
|
||||
return recipes.size();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user