Work in progress.

This commit is contained in:
Woody Folsom
2012-03-12 14:33:41 -04:00
parent c1f68f6f5c
commit 16a97ba39a
15 changed files with 1424 additions and 753 deletions

View File

@@ -23,14 +23,19 @@ import dkohl.onthology.Ontology;
public class FoodNetBuilder {
public static final String TASTE = "Taste";
public static final String SOMEONE_VEGETARIAN = "Vegetarian";
public static final String SOMEONE_VEGETARIAN = "vegetarian";
public static final String SOMEONE_ALLERGIC_NUTS = "allergic-nuts";
public static final String CONTAINS_MEAT = "Meat";
public static final String CONTAINS_VEGETABLE = "Vegetable";
public static final String CONTAINS_BEEF = TYPE.BEEF.toString();
public static final String CONTAINS_PORK = TYPE.PORK.toString();
public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString();
public static final String CONTAINS_POTATOS = TYPE.POTATO.toString();
public static final String CONTAINS_NUTS = "Nuts";
public static final String CONTAINS_GENERIC_NUTS = TYPE.GENERIC_NUTS.toString();
public static final String TRUE_VALUE = "true";
public static final String FALSE_VALUE = "false";
@@ -39,27 +44,30 @@ public class FoodNetBuilder {
public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5",
"6", "7", "8", "9", "10" };
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN,
private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, SOMEONE_ALLERGIC_NUTS,
CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS,
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, TASTE };
CONTAINS_TOMATOS, CONTAINS_VEGETABLE, CONTAINS_NUTS, CONTAINS_GENERIC_NUTS, TASTE };
private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS, };
CONTAINS_POTATOS, CONTAINS_TOMATOS, CONTAINS_GENERIC_NUTS};
public static Ontology createOntology() {
HashSet<String> classes = new HashSet<String>();
classes.add(CONTAINS_MEAT);
classes.add(CONTAINS_NUTS);
classes.add(CONTAINS_VEGETABLE);
Ontology onthology = new Ontology(classes);
Ontology ontology = new Ontology(classes);
onthology.define(CONTAINS_PORK, CONTAINS_MEAT);
onthology.define(CONTAINS_BEEF, CONTAINS_MEAT);
ontology.define(CONTAINS_PORK, CONTAINS_MEAT);
ontology.define(CONTAINS_BEEF, CONTAINS_MEAT);
onthology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
onthology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
return onthology;
ontology.define(CONTAINS_GENERIC_NUTS, CONTAINS_NUTS);
ontology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE);
ontology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE);
return ontology;
}
private static Assignment build(String varible, String value) {
@@ -86,14 +94,13 @@ public class FoodNetBuilder {
return normPoint;
}
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook) {
public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) {
DataSet data = FoodExampleBuilder.examples();
Ontology onto = createOntology();
int nDishes = survey.getDishCount();
for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) {
Diner diner = survey.getDiner(dinerIndex);
for (int dishIndex = 0; dishIndex < nDishes; dishIndex++) {
for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) {
data.add(normalize(createDataPoint(recipeBook, survey.getDish(dishIndex), diner.getRating(dishIndex)),onto));
}
}
@@ -118,20 +125,41 @@ public class FoodNetBuilder {
if (ingredients.contains(TYPE.TOMATO)) {
point.add(build(CONTAINS_TOMATOS, TRUE_VALUE));
}
if (ingredients.contains(TYPE.GENERIC_NUTS)) {
point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE));
}
point.add(build(TASTE, "" + weight));
return point;
}
public static ProbabilityDistribution vegi() {
public static ProbabilityDistribution vegi(Survey survey) {
String names[] = { SOMEONE_VEGETARIAN };
ProbabilityTable table = new ProbabilityTable(names);
table.setProbabilityForAssignment("true;", new Probability(0));
table.setProbabilityForAssignment("false;", new Probability(1));
if (survey.isDiner("vegetarian")) {
table.setProbabilityForAssignment("true;", new Probability(1));
table.setProbabilityForAssignment("false;", new Probability(0));
} else {
table.setProbabilityForAssignment("true;", new Probability(0));
table.setProbabilityForAssignment("false;", new Probability(1));
}
return table;
}
public static ProbabilityDistribution allergicNuts(Survey survey) {
String names[] = { SOMEONE_ALLERGIC_NUTS };
ProbabilityTable table = new ProbabilityTable(names);
if (survey.isDiner("allergic-nuts")) {
table.setProbabilityForAssignment("true;", new Probability(1));
table.setProbabilityForAssignment("false;", new Probability(0));
} else {
table.setProbabilityForAssignment("true;", new Probability(0));
table.setProbabilityForAssignment("false;", new Probability(1));
}
return table;
}
public static ProbabilityDistribution beef() {
String names[] = { CONTAINS_MEAT, CONTAINS_BEEF };
ProbabilityTable table = new ProbabilityTable(names);
@@ -143,13 +171,25 @@ public class FoodNetBuilder {
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution meet() {
public static ProbabilityDistribution meat() {
String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution genericNuts() {
String names[] = { CONTAINS_NUTS, CONTAINS_GENERIC_NUTS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution nuts() {
String names[] = { SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS };
ProbabilityTable table = new ProbabilityTable(names);
return table;
}
public static ProbabilityDistribution tomatos() {
String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS };
ProbabilityTable table = new ProbabilityTable(names);
@@ -170,27 +210,34 @@ public class FoodNetBuilder {
public static ProbabilityDistribution taste() {
String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK,
CONTAINS_POTATOS, CONTAINS_TOMATOS };
CONTAINS_POTATOS, CONTAINS_TOMATOS, CONTAINS_GENERIC_NUTS };
ContinousDistribution distribution = new ContinousDistribution(names, 0);
return distribution;
}
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook) {
public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) {
BayesNet net = new BayesNet(VARIABLES);
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi());
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meet());
net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey));
net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN), allergicNuts(survey));
net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat());
net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts());
net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN),
vegetables());
net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef());
net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork());
net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos());
net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos());
net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN), genericNuts());
net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste());
Ontology ontology = createOntology();
for (String category : ontology.getClasses()) {
net.connect(category, SOMEONE_VEGETARIAN);
net.connect(category, SOMEONE_ALLERGIC_NUTS);
}
for (String thing : OBSERVED) {
@@ -198,7 +245,7 @@ public class FoodNetBuilder {
net.connect(TASTE, thing);
}
DataSet dataSet = getSurveyDataSet(survey, recipeBook);
DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex, endIndex);
for (String category : ontology.getClasses()) {
MaximumLikelihoodEstimation.estimate(dataSet, net, category);

View File

@@ -4,32 +4,27 @@ import java.util.LinkedList;
public class RootMeanSquareError {
private LinkedList<Double> expected;
private LinkedList<Double> groundTruth;
public RootMeanSquareError() {
expected = new LinkedList<Double>();
groundTruth = new LinkedList<Double>();
}
public void push(double predicted, double actual) {
expected.add(predicted);
groundTruth.add(actual);
}
public double error() {
double mean = 0.0;
for(Double val : groundTruth) {
mean += val;
private LinkedList<Double> expected;
private LinkedList<Double> groundTruth;
public RootMeanSquareError() {
expected = new LinkedList<Double>();
groundTruth = new LinkedList<Double>();
}
mean /= groundTruth.size();
double err = 0.0;
for(Double val : expected) {
err += Math.pow(mean - val, 2);
public void push(double predicted, double actual) {
expected.add(predicted);
groundTruth.add(actual);
}
return Math.sqrt(err);
}
public double error() {
double mean = 0.0;
for (int i = 0; i < expected.size(); i++){
mean += Math.pow(expected.get(i) - groundTruth.get(i), 2);
}
mean /= groundTruth.size();
return Math.sqrt(mean);
}
}