From eec32b19c18a26ae7f29a5022007791738429c19 Mon Sep 17 00:00:00 2001 From: Woody Folsom Date: Mon, 16 Apr 2012 14:03:16 -0400 Subject: [PATCH] Added ability to generate 1, 2, 3-gram models on a company/date-range basis using to represent the initial appearance of a previously unknown word. --- AppContext.xml | 28 ++ .../woodyfolsom/cs6601/p3/ModelGenerator.java | 115 ++++++ .../cs6601/p3/dao/HeadlineDao.java | 1 + .../cs6601/p3/dao/HeadlineDaoImpl.java | 8 +- .../woodyfolsom/cs6601/p3/ngram/NGram.java | 95 +++++ .../cs6601/p3/ngram/NGramDistribution.java | 6 + .../cs6601/p3/ngram/NGramModel.java | 361 ++++++++++++++++++ .../cs6601/p3/svc/HeadlineService.java | 1 + .../p3/svc/MySQLHeadlineServiceImpl.java | 5 + .../p3/svc/YahooHeadlineServiceImpl.java | 6 + 10 files changed, 625 insertions(+), 1 deletion(-) create mode 100644 AppContext.xml create mode 100644 src/net/woodyfolsom/cs6601/p3/ModelGenerator.java create mode 100644 src/net/woodyfolsom/cs6601/p3/ngram/NGram.java create mode 100644 src/net/woodyfolsom/cs6601/p3/ngram/NGramDistribution.java create mode 100644 src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java diff --git a/AppContext.xml b/AppContext.xml new file mode 100644 index 0000000..30a1ee5 --- /dev/null +++ b/AppContext.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/net/woodyfolsom/cs6601/p3/ModelGenerator.java b/src/net/woodyfolsom/cs6601/p3/ModelGenerator.java new file mode 100644 index 0000000..352fb4d --- /dev/null +++ b/src/net/woodyfolsom/cs6601/p3/ModelGenerator.java @@ -0,0 +1,115 @@ +package net.woodyfolsom.cs6601.p3; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStreamReader; +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.support.FileSystemXmlApplicationContext; +import org.springframework.stereotype.Component; + +import net.woodyfolsom.cs6601.p3.domain.Company; +import net.woodyfolsom.cs6601.p3.domain.Headline; +import net.woodyfolsom.cs6601.p3.ngram.NGramModel; +import net.woodyfolsom.cs6601.p3.svc.HeadlineService; + +@Component +public class ModelGenerator { + private static final File stockSymbolsCSV = new File("stock_symbols.csv"); + + private static final int INVALID_DATE = 1; + private static final int IO_EXCEPTION = 2; + private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3; + + @Autowired + HeadlineService mySQLHeadlineServiceImpl; + + private NGramModel ngramModel = new NGramModel(); + + public static void main(String... args) { + + ApplicationContext context = new FileSystemXmlApplicationContext( + new String[] { "AppContext.xml" }); + ModelGenerator modelGenerator = context.getBean(ModelGenerator.class); + DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd"); + + Date startDate = null; + Date endDate = null; + Date valStart = null; + Date valEnd = null; + try { + startDate = dateFmt.parse("2012-01-01"); + endDate = dateFmt.parse("2012-03-31"); + valStart = dateFmt.parse("2012-04-01"); + valEnd = dateFmt.parse("2012-04-14"); + } catch (ParseException pe) { + System.exit(INVALID_DATE); + } + + try { + List fortune50 = modelGenerator + .getFortune50(stockSymbolsCSV); + for (Company company : fortune50) { + System.out.println("Getting headlines for Fortune 50 company #" + + company.getId() + " (" + company.getName() + ")..."); + List trainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate); + System.out.println("Pulled " + trainingSet.size() + " headlines for " + + company.getStockSymbol() + " from " + startDate + " to " + endDate); + List validationSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), valStart, valEnd); + + if (trainingSet.size() == 0) { + System.out.println("Training dataset contains 0 headlines for " + company.getName() + ", skipping model generation."); + continue; + } + if (validationSet.size() == 0) { + System.out.println("Validation dataset contains 0 headlines for " + company.getName() + ", skipping model generation."); + continue; + } + + modelGenerator.ngramModel.reportModel(trainingSet, validationSet); + System.out.println("Finished " + company.getId() + " / 50"); + } + } catch (FileNotFoundException fnfe) { + System.out.println("Stock symbol CSV file does not exist: " + + stockSymbolsCSV); + System.exit(STOCK_SYMBOL_CSV_NOT_FOUND); + } catch (IOException ioe) { + System.out.println("Stock symbol CSV file does not exist: " + + stockSymbolsCSV); + System.exit(IO_EXCEPTION); + } + } + + private List getFortune50(File csvFile) + throws FileNotFoundException, IOException { + List fortune50 = new ArrayList(); + FileInputStream fis = new FileInputStream(csvFile); + InputStreamReader reader = new InputStreamReader(fis); + BufferedReader buf = new BufferedReader(reader); + String csvline = null; + while ((csvline = buf.readLine()) != null) { + if (csvline.length() == 0) { + continue; + } + String[] fields = csvline.split(","); + if (fields.length != 3) { + throw new RuntimeException( + "Badly formatted csv file name (3 values expected): " + + csvline); + } + int id = Integer.valueOf(fields[0]); + fortune50.add(new Company(id, fields[1], fields[2])); + } + return fortune50; + } +} diff --git a/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDao.java b/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDao.java index 798ac52..6df3bb4 100644 --- a/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDao.java +++ b/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDao.java @@ -13,4 +13,5 @@ public interface HeadlineDao { Headline select(int id); List select(String stock, Date date); + List select(String stock, Date startDate, Date endDate); } diff --git a/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDaoImpl.java b/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDaoImpl.java index 8ad6431..0d66934 100644 --- a/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDaoImpl.java +++ b/src/net/woodyfolsom/cs6601/p3/dao/HeadlineDaoImpl.java @@ -24,7 +24,8 @@ public class HeadlineDaoImpl implements HeadlineDao { private static final String SELECT_BY_ID_QRY = "SELECT * from headlines WHERE id = ?"; private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ?"; - + private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ?"; + private JdbcTemplate jdbcTemplate; public int deleteById(int headlineId) { @@ -66,6 +67,11 @@ public class HeadlineDaoImpl implements HeadlineDao { new RequestMapper(), stock, date); } + public List select(String stock, Date startDate, Date endDate) { + return jdbcTemplate.query(SELECT_BY_DATE_RANGE_QRY, + new RequestMapper(), stock, startDate, endDate); + } + @Autowired public void createTemplate(DataSource dataSource) { this.jdbcTemplate = new JdbcTemplate(dataSource); diff --git a/src/net/woodyfolsom/cs6601/p3/ngram/NGram.java b/src/net/woodyfolsom/cs6601/p3/ngram/NGram.java new file mode 100644 index 0000000..68a292a --- /dev/null +++ b/src/net/woodyfolsom/cs6601/p3/ngram/NGram.java @@ -0,0 +1,95 @@ +package net.woodyfolsom.cs6601.p3.ngram; + +import java.util.LinkedList; + +public class NGram extends LinkedList implements Comparable { + private static final long serialVersionUID = 1L; + private int maxLength; + + public NGram(int maxLength) { + super(); + this.maxLength = maxLength; + } + + @Override + public boolean add(String word) { + if (super.size() == maxLength) { + super.remove(0); + } + return super.add(word); + } + + public int compareTo(NGram other) { + return this.toString().compareTo(other.toString()); + } + + public NGram copy() { + NGram nGram = new NGram(this.maxLength); + for (String token : this) { + nGram.add(token); + } + return nGram; + } + + public NGram copy(int length) { + if (length > super.size()) { + throw new IllegalArgumentException(); + } + + NGram nGramCopy = new NGram(length); + + for (int i = super.size() - length; i < size(); i++) { + nGramCopy.add(super.get(i)); + } + + return nGramCopy; + } + + public boolean startsWith(NGram other) { + if (other.size() != this.size() - 1) { + return false; + } + for (int i = 0; i < other.size(); i++) { + if (other.get(i).equals(this.get(i)) == false) { + return false; + } + } + return true; + } + + public NGram subNGram(int index) { + if (index > super.size()) { + throw new IllegalArgumentException(); + } + NGram subNGram = new NGram(maxLength); + for (int i = index; i < super.size(); i++) { + subNGram.add(super.get(i)); + } + return subNGram; + } + + public NGram subNGram(int start, int end) { + if (start > end) { + throw new IllegalArgumentException(); + } + NGram subNGram = new NGram(end - start); + for (int i = start; i < end; i++) { + subNGram.add(super.get(i)); + } + return subNGram; + } + + @Override + public String toString() { + if (super.size() == 0) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < super.size() - 1; i++) { + sb.append(super.get(i)); + sb.append(" "); + } + sb.append(super.get(super.size() - 1)); + return sb.toString(); + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/cs6601/p3/ngram/NGramDistribution.java b/src/net/woodyfolsom/cs6601/p3/ngram/NGramDistribution.java new file mode 100644 index 0000000..17fe6d1 --- /dev/null +++ b/src/net/woodyfolsom/cs6601/p3/ngram/NGramDistribution.java @@ -0,0 +1,6 @@ +package net.woodyfolsom.cs6601.p3.ngram; +import java.util.HashMap; + +public class NGramDistribution extends HashMap { + private static final long serialVersionUID = 1L; +} \ No newline at end of file diff --git a/src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java b/src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java new file mode 100644 index 0000000..43e6063 --- /dev/null +++ b/src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java @@ -0,0 +1,361 @@ +package net.woodyfolsom.cs6601.p3.ngram; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import net.woodyfolsom.cs6601.p3.domain.Headline; + +public class NGramModel { + static final int MAX_N_GRAM_LENGTH = 3; + static final int RANDOM_LENGTH = 100; + + private static final String START = ""; + private static final String END = ""; + private static final String UNK = ""; + + private Map nGrams; + private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1]; + + private Pattern wordPattern = Pattern.compile("\\w+"); + private Set nonWords = new HashSet(); + private Set words = new HashSet(); + private Set wordsSeenOnce = new HashSet(); + + double getProb(NGram nGram, NGram nMinusOneGram) { + NGramDistribution nGramDist = nGrams.get(nGram.size()); + double prob; + if (nGramDist.containsKey(nGram)) { + // impossible for the bigram not to exist if the trigram does + NGramDistribution nMinusOneGramDist = nGrams.get(nGram.size() - 1); + int nMinusOneGramCount = nMinusOneGramDist.get(nMinusOneGram); + int nGramCount = nGramDist.get(nGram); + prob = (double) (nGramCount) / nMinusOneGramCount; + } else { + // Laplace smoothing + // prob = 1.0 / ((totalNGramCounts[nGram.size()] + 2.0)); + NGram backoff = nGram.subNGram(1); + NGram backoffMinusOne = backoff.subNGram(0, backoff.size() - 1); + return getProb(backoff, backoffMinusOne); + } + + return prob; + } + + static String getRandomToken(Map nGrams, int n, + NGram nMinusOneGram) { + List> matchingNgrams = new ArrayList>(); + NGramDistribution ngDist = nGrams.get(n); + + for (Entry entry : ngDist.entrySet()) { + if (entry.getKey().startsWith(nMinusOneGram)) { + matchingNgrams.add(entry); + } + } + + int totalCount = 0; + for (int i = 0; i < matchingNgrams.size(); i++) { + totalCount += matchingNgrams.get(i).getValue(); + } + double random = Math.random(); + int randomNGramIndex = (int) (random * totalCount); + + totalCount = 0; + for (int i = 0; i < matchingNgrams.size(); i++) { + int currentCount = matchingNgrams.get(i).getValue(); + if (randomNGramIndex < totalCount + currentCount) { + NGram ngram = matchingNgrams.get(i).getKey(); + return ngram.get(ngram.size() - 1); + } + totalCount += currentCount; + } + return getRandomToken(nGrams, n - 1, nMinusOneGram.subNGram(1)); + } + + public NGramModel() { + // initialize Maps of 0-grams, unigrams, bigrams, trigrams... + nGrams = new HashMap(); + for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) { + nGrams.put(i, new NGramDistribution()); + } + } + + private void addNGram(int nGramLength, NGram nGram) { + if (nGram.size() < nGramLength) { + System.out.println("Cannot create " + nGramLength + "-gram from: " + + nGram); + } + + Map nGramCounts = nGrams.get(nGramLength); + NGram nGramCopy = nGram.copy(nGramLength); + + if (nGramCounts.containsKey(nGramCopy)) { + int nGramCount = nGramCounts.get(nGramCopy); + nGramCounts.put(nGramCopy, nGramCount + 1); + } else { + nGramCounts.put(nGramCopy, 1); + } + } + + /** + * Given an arbitrary String, replace punctutation with spaces, remove + * non-alphanumeric characters, prepend with token, append + * token. + * + * @param rawHeadline + * @return + */ + private String sanitize(String rawHeadline) { + String nonpunctuatedHeadline = rawHeadline.replaceAll( + "[\'\";:,\\]\\[]", " "); + String alphaNumericHeadline = nonpunctuatedHeadline.replaceAll( + "[^A-Za-z0-9 ]", ""); + return START + " " + alphaNumericHeadline + " " + END; + } + + private void calcPerplexity(List validationSet, int nGramLimit, + boolean useUnk) throws FileNotFoundException, IOException { + List fileByLines = new ArrayList(); + StringBuilder currentLine = new StringBuilder(); + + for (Headline headline : validationSet) { + String sanitizedLine = sanitize(headline.getText()); + // split on whitespace + String[] tokens = sanitizedLine.toLowerCase().split("\\s+"); + for (String token : tokens) { + if (!isWord(token)) { + continue; + } + + String word; + if (!words.contains(token) && useUnk) { + word = UNK; + // words.add(token); + } else { + word = token; + } + + if (END.equals(word)) { + currentLine.append(word); + fileByLines.add(currentLine.toString()); + currentLine = new StringBuilder(); + } else { + currentLine.append(word); + currentLine.append(" "); + } + } + } + + int wordNum = 0; + int sentNum = 0; + + double logProbT = 0.0; + for (String str : fileByLines) { + double logProbS = 0.0; + NGram currentNgram = new NGram(nGramLimit); + String[] tokens = str.split("\\s+"); + for (String token : tokens) { + String word; + if (!words.contains(token) && useUnk) { + word = UNK; + } else { + word = token; + } + currentNgram.add(word); + NGram nMinusOneGram = currentNgram.subNGram(0, + currentNgram.size() - 1); + + double prob = getProb(currentNgram, nMinusOneGram); + + logProbS += Math.log(prob); + } + logProbT += logProbS; + wordNum += tokens.length; + sentNum++; + } + + System.out.println("Evaluated " + sentNum + " sentences and " + wordNum + + " words using " + nGramLimit + "-gram model."); + int N = wordNum + sentNum; + System.out.println("Total n-grams: " + N); + double perplexity = Math.pow(Math.E, (-1.0 / N) * logProbT); + System.out.println("Perplexity: over " + N + + " recognized n-grams in verification corpus: " + perplexity); + } + + private void generateModel(List traininSet, boolean genRandom, + boolean useUnk) throws FileNotFoundException, IOException { + StringBuilder currentLine = new StringBuilder(); + List fileByLines = new ArrayList(); + + for (Headline headline : traininSet) { + String headlineText = headline.getText(); + if (headlineText.length() == 0) { + continue; + } + String sanitizedLine = sanitize(headline.getText()); + // split on whitespace + String[] tokens = sanitizedLine.toLowerCase().split("\\s+"); + for (String token : tokens) { + if (!isWord(token)) { + continue; + } + + String word; + if (!wordsSeenOnce.contains(token) && useUnk) { + word = UNK; + wordsSeenOnce.add(token); + } else { + words.add(token); + word = token; + } + + if (END.equals(word)) { + currentLine.append(word); + fileByLines.add(currentLine.toString()); + currentLine = new StringBuilder(); + } else { + currentLine.append(word); + currentLine.append(" "); + } + } + } + + for (String str : fileByLines) { + System.out.println(str); + NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH); + for (String token : str.split("\\s+")) { + currentNgram.add(token); + for (int i = 0; i <= currentNgram.size(); i++) { + addNGram(currentNgram.size() - i, + currentNgram.subNGram(i, currentNgram.size())); + totalNGramCounts[currentNgram.size() - i]++; + } + } + } + + System.out.println("Most common words: "); + + List> unigrams = new ArrayList>( + nGrams.get(1).entrySet()); + Collections.sort(unigrams, new NGramComparator()); + + List> bigrams = new ArrayList>( + nGrams.get(2).entrySet()); + Collections.sort(bigrams, new NGramComparator()); + + List> trigrams = new ArrayList>( + nGrams.get(3).entrySet()); + Collections.sort(trigrams, new NGramComparator()); + + for (int i = 1; i <= 10; i++) { + System.out + .println(i + + ". " + + unigrams.get(i - 1).getKey() + + " : " + + (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1]))); + } + + for (int i = 1; i <= 10; i++) { + System.out + .println(i + + ". " + + bigrams.get(i - 1).getKey() + + " : " + + (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2]))); + } + + for (int i = 1; i <= 10; i++) { + System.out + .println(i + + ". " + + trigrams.get(i - 1).getKey() + + " : " + + (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3]))); + } + if (genRandom) { + for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) { + System.out.println("Random sentence of length " + RANDOM_LENGTH + + " using " + nGramLength + "-gram language model:"); + StringBuilder randomText = new StringBuilder(); + + NGram randomNminusOneGram = new NGram(nGramLength); + for (int i = 0; i < RANDOM_LENGTH; i++) { + String randomToken = getRandomToken(nGrams, + randomNminusOneGram.size() + 1, randomNminusOneGram); + NGram randomNGram = randomNminusOneGram.copy(); + randomNGram.add(randomToken); + randomText.append(randomNGram.get(randomNGram.size() - 1)); + randomText.append(" "); + + if (randomNGram.size() < nGramLength) { + randomNminusOneGram = randomNGram; + } else { + randomNminusOneGram = randomNGram.subNGram(1); + } + } + System.out.println(randomText); + System.out.println(); + } + } + } + + private boolean isWord(String aString) { + if (aString == null || aString.length() == 0) { + return false; + } + + if (nonWords.contains(aString)) { + return false; + } + + if (START.equals(aString) || END.equals(aString) || UNK.equals(aString)) { + return true; + } + + Matcher wordMatcher = wordPattern.matcher(aString); + if (wordMatcher.find()) { + return true; + } + nonWords.add(aString); + return false; + } + + public void reportModel(List trainingSet, + List validationSet) { + try { + NGramModel ngm = new NGramModel(); + boolean doCalcPerplexity = true; + ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity); + if (doCalcPerplexity) { + for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) { + ngm.calcPerplexity(validationSet, i, true); + } + } + } catch (FileNotFoundException fnfe) { + fnfe.printStackTrace(System.err); + } catch (IOException ioe) { + ioe.printStackTrace(System.err); + } + } + + private class NGramComparator implements Comparator> { + + @Override + public int compare(Entry o1, Entry o2) { + return o2.getValue().compareTo(o1.getValue()); + } + + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/cs6601/p3/svc/HeadlineService.java b/src/net/woodyfolsom/cs6601/p3/svc/HeadlineService.java index d9a0783..db84742 100644 --- a/src/net/woodyfolsom/cs6601/p3/svc/HeadlineService.java +++ b/src/net/woodyfolsom/cs6601/p3/svc/HeadlineService.java @@ -9,4 +9,5 @@ public interface HeadlineService { int insertHeadline(Headline headline); int[] insertHeadlines(List headline); List getHeadlines(String stock, Date date); + List getHeadlines(String stock, Date startDate, Date endDate); } \ No newline at end of file diff --git a/src/net/woodyfolsom/cs6601/p3/svc/MySQLHeadlineServiceImpl.java b/src/net/woodyfolsom/cs6601/p3/svc/MySQLHeadlineServiceImpl.java index 15efddd..adb8f01 100644 --- a/src/net/woodyfolsom/cs6601/p3/svc/MySQLHeadlineServiceImpl.java +++ b/src/net/woodyfolsom/cs6601/p3/svc/MySQLHeadlineServiceImpl.java @@ -32,4 +32,9 @@ public class MySQLHeadlineServiceImpl implements HeadlineService { public List getHeadlines(String stock, Date date) { return headlineDao.select(stock, date); } + + @Override + public List getHeadlines(String stock, Date startDate, Date endDate) { + return headlineDao.select(stock, startDate, endDate); + } } \ No newline at end of file diff --git a/src/net/woodyfolsom/cs6601/p3/svc/YahooHeadlineServiceImpl.java b/src/net/woodyfolsom/cs6601/p3/svc/YahooHeadlineServiceImpl.java index 9844eca..64cddf5 100644 --- a/src/net/woodyfolsom/cs6601/p3/svc/YahooHeadlineServiceImpl.java +++ b/src/net/woodyfolsom/cs6601/p3/svc/YahooHeadlineServiceImpl.java @@ -86,4 +86,10 @@ public class YahooHeadlineServiceImpl implements HeadlineService { String formattedDate = DATE_FORMATTER.format(date); return QUERY_URL.replaceAll(STOCK_SYMBOL_FIELD, stock).replaceAll(STORY_DATE_FIELD, formattedDate); } + + @Override + public List getHeadlines(String stock, Date startDate, + Date endDate) { + throw new UnsupportedOperationException("This implementation does not support getting headlines for a date range."); + } } \ No newline at end of file