From 12b4bab59d7e665bb0ca7a84d9fb7c9edb546ed8 Mon Sep 17 00:00:00 2001 From: dkohl Date: Mon, 12 Mar 2012 21:01:30 -0400 Subject: [PATCH] Fix --- src/dkohl/bayes/builders/FoodNetBuilder.java | 452 ++++++++++--------- src/dkohl/bayes/inference/EnumerateAll.java | 5 +- src/dkohl/bayes/statistic/DataSet.java | 165 ++++--- 3 files changed, 324 insertions(+), 298 deletions(-) diff --git a/src/dkohl/bayes/builders/FoodNetBuilder.java b/src/dkohl/bayes/builders/FoodNetBuilder.java index 8d428f4..6974a4f 100644 --- a/src/dkohl/bayes/builders/FoodNetBuilder.java +++ b/src/dkohl/bayes/builders/FoodNetBuilder.java @@ -22,248 +22,256 @@ import dkohl.bayes.statistic.DataSet; import dkohl.onthology.Ontology; public class FoodNetBuilder { - public static final String TASTE = "Taste"; - - public static final String SOMEONE_VEGETARIAN = "vegetarian"; - //public static final String SOMEONE_ALLERGIC_NUTS = "allergic-nuts"; - - public static final String CONTAINS_MEAT = "Meat"; - public static final String CONTAINS_VEGETABLE = "Vegetable"; - public static final String CONTAINS_BEEF = TYPE.BEEF.toString(); - public static final String CONTAINS_PORK = TYPE.PORK.toString(); - public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString(); - public static final String CONTAINS_POTATOS = TYPE.POTATO.toString(); - //public static final String CONTAINS_NUTS = "Nuts"; - //public static final String CONTAINS_GENERIC_NUTS = TYPE.GENERIC_NUTS.toString(); - - public static final String TRUE_VALUE = "true"; - public static final String FALSE_VALUE = "false"; + public static final String TASTE = "Taste"; - public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE }; + public static final String SOMEONE_VEGETARIAN = "vegetarian"; + // public static final String SOMEONE_ALLERGIC_NUTS = "allergic-nuts"; - public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5", - "6", "7", "8", "9", "10" }; + public static final String CONTAINS_MEAT = "Meat"; + public static final String CONTAINS_VEGETABLE = "Vegetable"; + public static final String CONTAINS_BEEF = TYPE.BEEF.toString(); + public static final String CONTAINS_PORK = TYPE.PORK.toString(); + public static final String CONTAINS_TOMATOS = TYPE.TOMATO.toString(); + public static final String CONTAINS_POTATOS = TYPE.POTATO.toString(); + // public static final String CONTAINS_NUTS = "Nuts"; + // public static final String CONTAINS_GENERIC_NUTS = + // TYPE.GENERIC_NUTS.toString(); - private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, /*SOMEONE_ALLERGIC_NUTS,*/ - CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS, - CONTAINS_TOMATOS, CONTAINS_VEGETABLE, /*CONTAINS_NUTS, CONTAINS_GENERIC_NUTS,*/ TASTE }; + public static final String TRUE_VALUE = "true"; + public static final String FALSE_VALUE = "false"; - private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK, - CONTAINS_POTATOS, CONTAINS_TOMATOS/*, CONTAINS_GENERIC_NUTS*/}; + public static final String DOMAIN[] = { TRUE_VALUE, FALSE_VALUE }; - public static Ontology createOntology() { - HashSet classes = new HashSet(); - - classes.add(CONTAINS_MEAT); - //classes.add(CONTAINS_NUTS); - classes.add(CONTAINS_VEGETABLE); - - Ontology ontology = new Ontology(classes); - - ontology.define(CONTAINS_PORK, CONTAINS_MEAT); - ontology.define(CONTAINS_BEEF, CONTAINS_MEAT); + public static final String RATING_DOMAIN[] = { "1", "2", "3", "4", "5", + "6", "7", "8", "9", "10" }; - //ontology.define(CONTAINS_GENERIC_NUTS, CONTAINS_NUTS); - - ontology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE); - ontology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE); - return ontology; + private static final String[] VARIABLES = { SOMEONE_VEGETARIAN, /* + * SOMEONE_ALLERGIC_NUTS + * , + */ + CONTAINS_BEEF, CONTAINS_MEAT, CONTAINS_PORK, CONTAINS_POTATOS, + CONTAINS_TOMATOS, CONTAINS_VEGETABLE, /* + * CONTAINS_NUTS, + * CONTAINS_GENERIC_NUTS, + */TASTE }; + + private static final String[] OBSERVED = { CONTAINS_BEEF, CONTAINS_PORK, + CONTAINS_POTATOS, CONTAINS_TOMATOS /* , CONTAINS_GENERIC_NUTS */}; + + public static Ontology createOntology() { + HashSet classes = new HashSet(); + + classes.add(CONTAINS_MEAT); + // classes.add(CONTAINS_NUTS); + classes.add(CONTAINS_VEGETABLE); + + Ontology ontology = new Ontology(classes); + + ontology.define(CONTAINS_PORK, CONTAINS_MEAT); + ontology.define(CONTAINS_BEEF, CONTAINS_MEAT); + + // ontology.define(CONTAINS_GENERIC_NUTS, CONTAINS_NUTS); + + ontology.define(CONTAINS_TOMATOS, CONTAINS_VEGETABLE); + ontology.define(CONTAINS_POTATOS, CONTAINS_VEGETABLE); + return ontology; + } + + private static Assignment build(String varible, String value) { + return new Assignment(new Variable(varible, DOMAIN), value); + } + + public static DataPoint normalize(DataPoint point, Ontology onto) { + // resolve onthology + DataPoint normPoint = new DataPoint(point); + for (String key : point.keySet()) { + if (onto.getInheritance().containsKey(key)) { + normPoint + .add(build(onto.getInheritance().get(key), TRUE_VALUE)); + } } - private static Assignment build(String varible, String value) { - return new Assignment(new Variable(varible, DOMAIN), value); + // implement closed world assumption + // everything unknown is false + for (String variable : VARIABLES) { + if (!normPoint.containsKey(variable)) { + normPoint.add(build(variable, FALSE_VALUE)); + } } + return normPoint; + } - public static DataPoint normalize(DataPoint point, Ontology onto) { - // resolve onthology - DataPoint normPoint = new DataPoint(point); - for (String key : point.keySet()) { - if (onto.getInheritance().containsKey(key)) { - normPoint - .add(build(onto.getInheritance().get(key), TRUE_VALUE)); - } - } + public static DataSet getSurveyDataSet(Survey survey, + RecipeBook recipeBook, int startIndex, int endIndex) { + DataSet data = FoodExampleBuilder.examples(); + Ontology onto = createOntology(); - // implement closed world assumption - // everything unknown is false - for (String variable : VARIABLES) { - if (!normPoint.containsKey(variable)) { - normPoint.add(build(variable, FALSE_VALUE)); - } - } - return normPoint; + for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) { + Diner diner = survey.getDiner(dinerIndex); + for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) { + data.add(normalize( + createDataPoint(recipeBook, survey.getDish(dishIndex), + diner.getRating(dishIndex)), onto)); + } } + return data; + } - public static DataSet getSurveyDataSet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) { - DataSet data = FoodExampleBuilder.examples(); - Ontology onto = createOntology(); - - for (int dinerIndex = 0; dinerIndex < survey.getDinerCount(); dinerIndex++) { - Diner diner = survey.getDiner(dinerIndex); - for (int dishIndex = startIndex; dishIndex < endIndex; dishIndex++) { - data.add(normalize(createDataPoint(recipeBook, survey.getDish(dishIndex), diner.getRating(dishIndex)),onto)); - } - } - return data; - } - - public static DataPoint createDataPoint(RecipeBook recipeBook, String recipeName, int weight) { - DataPoint point = new DataPoint(); - - Recipe recipe = recipeBook.getRecipe(recipeName); - Ingredients ingredients = recipe.getIngredients(); - - if (ingredients.contains(TYPE.BEEF)) { - point.add(build(CONTAINS_BEEF, TRUE_VALUE)); - } - if (ingredients.contains(TYPE.PORK)) { - point.add(build(CONTAINS_PORK, TRUE_VALUE)); - } - if (ingredients.contains(TYPE.POTATO)) { - point.add(build(CONTAINS_POTATOS, TRUE_VALUE)); - } - if (ingredients.contains(TYPE.TOMATO)) { - point.add(build(CONTAINS_TOMATOS, TRUE_VALUE)); - } - /* - if (ingredients.contains(TYPE.GENERIC_NUTS)) { - point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE)); - }*/ - - point.add(build(TASTE, "" + weight)); - - return point; - } + public static DataPoint createDataPoint(RecipeBook recipeBook, + String recipeName, int weight) { + DataPoint point = new DataPoint(); - public static ProbabilityDistribution vegi(Survey survey) { - String names[] = { SOMEONE_VEGETARIAN }; - ProbabilityTable table = new ProbabilityTable(names); - //if (survey.isDiner("vegetarian")) { - // table.setProbabilityForAssignment("true;", new Probability(1)); - // table.setProbabilityForAssignment("false;", new Probability(0)); - //} else { - table.setProbabilityForAssignment("true;", new Probability(0)); - table.setProbabilityForAssignment("false;", new Probability(1)); - //} - return table; - } + Recipe recipe = recipeBook.getRecipe(recipeName); + Ingredients ingredients = recipe.getIngredients(); + if (ingredients.contains(TYPE.BEEF)) { + point.add(build(CONTAINS_BEEF, TRUE_VALUE)); + } + if (ingredients.contains(TYPE.PORK)) { + point.add(build(CONTAINS_PORK, TRUE_VALUE)); + } + if (ingredients.contains(TYPE.POTATO)) { + point.add(build(CONTAINS_POTATOS, TRUE_VALUE)); + } + if (ingredients.contains(TYPE.TOMATO)) { + point.add(build(CONTAINS_TOMATOS, TRUE_VALUE)); + } /* - public static ProbabilityDistribution allergicNuts(Survey survey) { - String names[] = { SOMEONE_ALLERGIC_NUTS }; - ProbabilityTable table = new ProbabilityTable(names); - if (survey.isDiner("allergic-nuts")) { - table.setProbabilityForAssignment("true;", new Probability(1)); - table.setProbabilityForAssignment("false;", new Probability(0)); - } else { - table.setProbabilityForAssignment("true;", new Probability(0)); - table.setProbabilityForAssignment("false;", new Probability(1)); - } - return table; - }*/ - - public static ProbabilityDistribution beef() { - String names[] = { CONTAINS_MEAT, CONTAINS_BEEF }; - ProbabilityTable table = new ProbabilityTable(names); - return table; + * if (ingredients.contains(TYPE.GENERIC_NUTS)) { + * point.add(build(CONTAINS_GENERIC_NUTS, TRUE_VALUE)); } + */ + + point.add(build(TASTE, "" + weight)); + + return point; + } + + public static ProbabilityDistribution vegi(Survey survey) { + String names[] = {CONTAINS_MEAT, SOMEONE_VEGETARIAN}; + ProbabilityTable table = new ProbabilityTable(names); + table.setProbabilityForAssignment("true;true;", new Probability(0)); + table.setProbabilityForAssignment("false;true;", new Probability(1)); + table.setProbabilityForAssignment("true;false;", new Probability(1)); + table.setProbabilityForAssignment("false;false;", new Probability(1)); + return table; + } + + /* + * public static ProbabilityDistribution allergicNuts(Survey survey) { + * String names[] = { SOMEONE_ALLERGIC_NUTS }; ProbabilityTable table = new + * ProbabilityTable(names); if (survey.isDiner("allergic-nuts")) { + * table.setProbabilityForAssignment("true;", new Probability(1)); + * table.setProbabilityForAssignment("false;", new Probability(0)); } else { + * table.setProbabilityForAssignment("true;", new Probability(0)); + * table.setProbabilityForAssignment("false;", new Probability(1)); } return + * table; } + */ + + public static ProbabilityDistribution beef() { + String names[] = { CONTAINS_MEAT, CONTAINS_BEEF }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution pork() { + String names[] = { CONTAINS_MEAT, CONTAINS_PORK }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution meat() { + String names[] = { CONTAINS_MEAT }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + /* + * public static ProbabilityDistribution genericNuts() { String names[] = { + * CONTAINS_NUTS, CONTAINS_GENERIC_NUTS }; ProbabilityTable table = new + * ProbabilityTable(names); return table; } + * + * public static ProbabilityDistribution nuts() { String names[] = { + * SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS }; ProbabilityTable table = new + * ProbabilityTable(names); return table; } + */ + + public static ProbabilityDistribution tomatos() { + String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution potatos() { + String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution vegetables() { + String names[] = { CONTAINS_VEGETABLE }; + ProbabilityTable table = new ProbabilityTable(names); + return table; + } + + public static ProbabilityDistribution taste() { + String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK, + CONTAINS_POTATOS, CONTAINS_TOMATOS /* , CONTAINS_GENERIC_NUTS */}; + ContinousDistribution distribution = new ContinousDistribution(names, 0); + return distribution; + } + + public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook, + int startIndex, int endIndex) { + BayesNet net = new BayesNet(VARIABLES); + + // net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), + // vegi(survey)); + // net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN), + // allergicNuts(survey)); + + net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat()); + // net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts()); + net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), + vegi(survey)); + net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN), + vegetables()); + net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef()); + net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork()); + net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos()); + net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos()); + // net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN), + // genericNuts()); + + net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste()); + + Ontology ontology = createOntology(); + // for (String category : ontology.getClasses()) { + // net.connect(category, SOMEONE_VEGETARIAN); + // /*net.connect(category, SOMEONE_ALLERGIC_NUTS);*/ + // } + net.connect(SOMEONE_VEGETARIAN, CONTAINS_MEAT); + + for (String thing : OBSERVED) { + net.connect(thing, ontology.getInheritance().get(thing)); + net.connect(TASTE, thing); } - public static ProbabilityDistribution pork() { - String names[] = { CONTAINS_MEAT, CONTAINS_PORK }; - ProbabilityTable table = new ProbabilityTable(names); - return table; - } - - public static ProbabilityDistribution meat() { - String names[] = { SOMEONE_VEGETARIAN, CONTAINS_MEAT }; - ProbabilityTable table = new ProbabilityTable(names); - return table; - } - - /* - public static ProbabilityDistribution genericNuts() { - String names[] = { CONTAINS_NUTS, CONTAINS_GENERIC_NUTS }; - ProbabilityTable table = new ProbabilityTable(names); - return table; - } - - public static ProbabilityDistribution nuts() { - String names[] = { SOMEONE_ALLERGIC_NUTS, CONTAINS_NUTS }; - ProbabilityTable table = new ProbabilityTable(names); - return table; - }*/ - - public static ProbabilityDistribution tomatos() { - String names[] = { CONTAINS_VEGETABLE, CONTAINS_TOMATOS }; - ProbabilityTable table = new ProbabilityTable(names); - return table; + DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex, + endIndex); + + for (String category : ontology.getClasses()) { + MaximumLikelihoodEstimation.estimate(dataSet, net, category); + for (String thing : ontology.getClasses2thing().get(category)) { + MaximumLikelihoodEstimation.estimate(dataSet, net, thing); + } } - public static ProbabilityDistribution potatos() { - String names[] = { CONTAINS_VEGETABLE, CONTAINS_POTATOS }; - ProbabilityTable table = new ProbabilityTable(names); - return table; - } + MaximumLikelihoodEstimation.estimate(dataSet, net, TASTE); + ContinousDistribution distribturion = (ContinousDistribution) net + .getNodes().get(TASTE); - public static ProbabilityDistribution vegetables() { - String names[] = { CONTAINS_VEGETABLE, SOMEONE_VEGETARIAN }; - ProbabilityTable table = new ProbabilityTable(names); - return table; - } + distribturion.estimate(); - public static ProbabilityDistribution taste() { - String names[] = { TASTE, CONTAINS_BEEF, CONTAINS_PORK, - CONTAINS_POTATOS, CONTAINS_TOMATOS/*, CONTAINS_GENERIC_NUTS*/ }; - ContinousDistribution distribution = new ContinousDistribution(names, 0); - return distribution; - } - - public static BayesNet createDishNet(Survey survey, RecipeBook recipeBook, int startIndex, int endIndex) { - BayesNet net = new BayesNet(VARIABLES); - -// net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey)); - //net.setDistribution(new Variable(SOMEONE_ALLERGIC_NUTS, DOMAIN), allergicNuts(survey)); - - net.setDistribution(new Variable(CONTAINS_MEAT, DOMAIN), meat()); - //net.setDistribution(new Variable(CONTAINS_NUTS, DOMAIN), nuts()); - net.setDistribution(new Variable(SOMEONE_VEGETARIAN, DOMAIN), vegi(survey)); - net.setDistribution(new Variable(CONTAINS_VEGETABLE, DOMAIN), - vegetables()); - net.setDistribution(new Variable(CONTAINS_BEEF, DOMAIN), beef()); - net.setDistribution(new Variable(CONTAINS_PORK, DOMAIN), pork()); - net.setDistribution(new Variable(CONTAINS_POTATOS, DOMAIN), potatos()); - net.setDistribution(new Variable(CONTAINS_TOMATOS, DOMAIN), tomatos()); - //net.setDistribution(new Variable(CONTAINS_GENERIC_NUTS, DOMAIN), genericNuts()); - - net.setDistribution(new Variable(TASTE, RATING_DOMAIN), taste()); - - Ontology ontology = createOntology(); -// for (String category : ontology.getClasses()) { -// net.connect(category, SOMEONE_VEGETARIAN); -// /*net.connect(category, SOMEONE_ALLERGIC_NUTS);*/ -// } - net.connect(SOMEONE_VEGETARIAN, CONTAINS_MEAT); - - for (String thing : OBSERVED) { - net.connect(thing, ontology.getInheritance().get(thing)); - net.connect(TASTE, thing); - } - - DataSet dataSet = getSurveyDataSet(survey, recipeBook, startIndex, endIndex); - - for (String category : ontology.getClasses()) { - MaximumLikelihoodEstimation.estimate(dataSet, net, category); - for (String thing : ontology.getClasses2thing().get(category)) { - MaximumLikelihoodEstimation.estimate(dataSet, net, thing); - } - } - - MaximumLikelihoodEstimation.estimate(dataSet, net, TASTE); - ContinousDistribution distribturion = (ContinousDistribution) net - .getNodes().get(TASTE); - - distribturion.estimate(); - - return net; - } + return net; + } } \ No newline at end of file diff --git a/src/dkohl/bayes/inference/EnumerateAll.java b/src/dkohl/bayes/inference/EnumerateAll.java index 47bab52..fefdeae 100644 --- a/src/dkohl/bayes/inference/EnumerateAll.java +++ b/src/dkohl/bayes/inference/EnumerateAll.java @@ -98,7 +98,10 @@ public class EnumerateAll { LinkedList temp = new LinkedList(); temp.addAll(assignments); temp.add(new Assignment(variable, value)); - + + if(net.getNodes().get(variable.getName()).eval(temp) == null) { + System.out.println(variable.getName()); + } // then evaluate this variable double val = net.getNodes().get(variable.getName()).eval(temp) .getProbability(); diff --git a/src/dkohl/bayes/statistic/DataSet.java b/src/dkohl/bayes/statistic/DataSet.java index 44aaa3f..2028c89 100644 --- a/src/dkohl/bayes/statistic/DataSet.java +++ b/src/dkohl/bayes/statistic/DataSet.java @@ -13,87 +13,102 @@ import dkohl.bayes.probability.Probability; */ public class DataSet extends Vector { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - public LinkedList> getAssignmentMatchesForQuery( - LinkedList given) { - LinkedList> assignments = new LinkedList>(); - for (DataPoint point : this) { - boolean insert = true; - for (Assignment assignment : given) { - if (!match(point, assignment)) { - insert = false; - } - } - if (insert) { - assignments.add(new LinkedList(point.values())); - } + public LinkedList> getAssignmentMatchesForQuery( + LinkedList given) { + LinkedList> assignments = new LinkedList>(); + for (DataPoint point : this) { + boolean insert = true; + for (Assignment assignment : given) { + if (!match(point, assignment)) { + insert = false; } - return assignments; + } + if (insert) { + assignments.add(new LinkedList(point.values())); + } } + return assignments; + } - /** - * Is the assignment equal to a data point / observation ? - * - * @param point - * the point / observation - * @param query - * the assignment - * @return - */ - private boolean match(String queryName, DataPoint point, - LinkedList query) { - boolean queryFound = false; - for (Assignment assignment : query) { - if (!match(point, assignment)) { - return false; - } - if (point.containsKey(queryName)) { - queryFound = true; - } - } - if (!queryFound) { - return false; - } - return true; - } - - private boolean match(DataPoint point, Assignment query) { - String name = query.getVariable().getName(); - String value = query.getValue(); - if (point.containsKey(name)) { - if (point.get(name).getValue().equals(value)) { - return true; - } - } + /** + * Is the assignment equal to a data point / observation ? + * + * @param point + * the point / observation + * @param query + * the assignment + * @return + */ + private boolean match(String queryName, DataPoint point, + LinkedList query) { + boolean queryFound = false; + for (Assignment assignment : query) { + if (!match(point, assignment)) { return false; + } + if (point.containsKey(queryName)) { + queryFound = true; + } } - - /** - * Estimating probability for: P(Query | given_1 .... given_N) = #(Query | - * given_1 .... given_N) / #(given_1 .... given_N) - * - * @param query - * @param given - * @return - */ - public Probability prob(Assignment query, LinkedList given) { - int matches = 0; - int num_query_given = 0; - for (DataPoint point : this) { - if (match(query.getVariable().getName(), point, given)) { - matches += 1; - if (match(point, query)) { - num_query_given += 1; // point.getWeight(); - } - } - } - - if (matches == 0) { - return new Probability(0); - } - - return new Probability(num_query_given / ((double) matches)); + if (!queryFound) { + return false; } + return true; + } + + private boolean match(DataPoint point, Assignment query) { + String name = query.getVariable().getName(); + String value = query.getValue(); + if (point.containsKey(name)) { + if (point.get(name).getValue().equals(value)) { + return true; + } + } + return false; + } + + /** + * Estimating probability for: P(Query | given_1 .... given_N) = #(Query | + * given_1 .... given_N) / #(given_1 .... given_N) + * + * @param query + * @param given + * @return + */ + public Probability prob(Assignment query, LinkedList given) { + int matches = 0; + int num_query_given = 0; + + if (given.size() == 0) { + for (DataPoint point : this) { + if (point.containsKey(query.getVariable().getName())) { + matches += 1; + if (point.get(query.getVariable().getName()).getValue() + .equals(query.getValue())) { + num_query_given += 1; + } + } + } + return new Probability(num_query_given / ((double) matches)); + } else { + + for (DataPoint point : this) { + if (match(query.getVariable().getName(), point, given)) { + matches += 1; + if (match(point, query)) { + num_query_given += 1; // point.getWeight(); + } + } + } + + if (matches == 0) { + return new Probability(0); + } + + return new Probability(num_query_given / ((double) matches)); + } + } }