Added ability to generate 1, 2, 3-gram models on a company/date-range basis using <UNK> to represent the initial appearance of a previously unknown word.

This commit is contained in:
Woody Folsom
2012-04-16 14:03:16 -04:00
parent 027adff2dd
commit eec32b19c1
10 changed files with 625 additions and 1 deletions

View File

@@ -0,0 +1,115 @@
package net.woodyfolsom.cs6601.p3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.FileSystemXmlApplicationContext;
import org.springframework.stereotype.Component;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
@Component
public class ModelGenerator {
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
private static final int INVALID_DATE = 1;
private static final int IO_EXCEPTION = 2;
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
@Autowired
HeadlineService mySQLHeadlineServiceImpl;
private NGramModel ngramModel = new NGramModel();
public static void main(String... args) {
ApplicationContext context = new FileSystemXmlApplicationContext(
new String[] { "AppContext.xml" });
ModelGenerator modelGenerator = context.getBean(ModelGenerator.class);
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
Date startDate = null;
Date endDate = null;
Date valStart = null;
Date valEnd = null;
try {
startDate = dateFmt.parse("2012-01-01");
endDate = dateFmt.parse("2012-03-31");
valStart = dateFmt.parse("2012-04-01");
valEnd = dateFmt.parse("2012-04-14");
} catch (ParseException pe) {
System.exit(INVALID_DATE);
}
try {
List<Company> fortune50 = modelGenerator
.getFortune50(stockSymbolsCSV);
for (Company company : fortune50) {
System.out.println("Getting headlines for Fortune 50 company #"
+ company.getId() + " (" + company.getName() + ")...");
List<Headline> trainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate);
System.out.println("Pulled " + trainingSet.size() + " headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
List<Headline> validationSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), valStart, valEnd);
if (trainingSet.size() == 0) {
System.out.println("Training dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
continue;
}
if (validationSet.size() == 0) {
System.out.println("Validation dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
continue;
}
modelGenerator.ngramModel.reportModel(trainingSet, validationSet);
System.out.println("Finished " + company.getId() + " / 50");
}
} catch (FileNotFoundException fnfe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
} catch (IOException ioe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(IO_EXCEPTION);
}
}
private List<Company> getFortune50(File csvFile)
throws FileNotFoundException, IOException {
List<Company> fortune50 = new ArrayList<Company>();
FileInputStream fis = new FileInputStream(csvFile);
InputStreamReader reader = new InputStreamReader(fis);
BufferedReader buf = new BufferedReader(reader);
String csvline = null;
while ((csvline = buf.readLine()) != null) {
if (csvline.length() == 0) {
continue;
}
String[] fields = csvline.split(",");
if (fields.length != 3) {
throw new RuntimeException(
"Badly formatted csv file name (3 values expected): "
+ csvline);
}
int id = Integer.valueOf(fields[0]);
fortune50.add(new Company(id, fields[1], fields[2]));
}
return fortune50;
}
}

View File

@@ -13,4 +13,5 @@ public interface HeadlineDao {
Headline select(int id);
List<Headline> select(String stock, Date date);
List<Headline> select(String stock, Date startDate, Date endDate);
}

View File

@@ -24,7 +24,8 @@ public class HeadlineDaoImpl implements HeadlineDao {
private static final String SELECT_BY_ID_QRY = "SELECT * from headlines WHERE id = ?";
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ?";
private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ?";
private JdbcTemplate jdbcTemplate;
public int deleteById(int headlineId) {
@@ -66,6 +67,11 @@ public class HeadlineDaoImpl implements HeadlineDao {
new RequestMapper(), stock, date);
}
public List<Headline> select(String stock, Date startDate, Date endDate) {
return jdbcTemplate.query(SELECT_BY_DATE_RANGE_QRY,
new RequestMapper(), stock, startDate, endDate);
}
@Autowired
public void createTemplate(DataSource dataSource) {
this.jdbcTemplate = new JdbcTemplate(dataSource);

View File

@@ -0,0 +1,95 @@
package net.woodyfolsom.cs6601.p3.ngram;
import java.util.LinkedList;
public class NGram extends LinkedList<String> implements Comparable<NGram> {
private static final long serialVersionUID = 1L;
private int maxLength;
public NGram(int maxLength) {
super();
this.maxLength = maxLength;
}
@Override
public boolean add(String word) {
if (super.size() == maxLength) {
super.remove(0);
}
return super.add(word);
}
public int compareTo(NGram other) {
return this.toString().compareTo(other.toString());
}
public NGram copy() {
NGram nGram = new NGram(this.maxLength);
for (String token : this) {
nGram.add(token);
}
return nGram;
}
public NGram copy(int length) {
if (length > super.size()) {
throw new IllegalArgumentException();
}
NGram nGramCopy = new NGram(length);
for (int i = super.size() - length; i < size(); i++) {
nGramCopy.add(super.get(i));
}
return nGramCopy;
}
public boolean startsWith(NGram other) {
if (other.size() != this.size() - 1) {
return false;
}
for (int i = 0; i < other.size(); i++) {
if (other.get(i).equals(this.get(i)) == false) {
return false;
}
}
return true;
}
public NGram subNGram(int index) {
if (index > super.size()) {
throw new IllegalArgumentException();
}
NGram subNGram = new NGram(maxLength);
for (int i = index; i < super.size(); i++) {
subNGram.add(super.get(i));
}
return subNGram;
}
public NGram subNGram(int start, int end) {
if (start > end) {
throw new IllegalArgumentException();
}
NGram subNGram = new NGram(end - start);
for (int i = start; i < end; i++) {
subNGram.add(super.get(i));
}
return subNGram;
}
@Override
public String toString() {
if (super.size() == 0) {
return "";
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < super.size() - 1; i++) {
sb.append(super.get(i));
sb.append(" ");
}
sb.append(super.get(super.size() - 1));
return sb.toString();
}
}

View File

@@ -0,0 +1,6 @@
package net.woodyfolsom.cs6601.p3.ngram;
import java.util.HashMap;
public class NGramDistribution extends HashMap<NGram, Integer> {
private static final long serialVersionUID = 1L;
}

View File

@@ -0,0 +1,361 @@
package net.woodyfolsom.cs6601.p3.ngram;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import net.woodyfolsom.cs6601.p3.domain.Headline;
public class NGramModel {
static final int MAX_N_GRAM_LENGTH = 3;
static final int RANDOM_LENGTH = 100;
private static final String START = "<start>";
private static final String END = "<end>";
private static final String UNK = "<unk>";
private Map<Integer, NGramDistribution> nGrams;
private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1];
private Pattern wordPattern = Pattern.compile("\\w+");
private Set<String> nonWords = new HashSet<String>();
private Set<String> words = new HashSet<String>();
private Set<String> wordsSeenOnce = new HashSet<String>();
double getProb(NGram nGram, NGram nMinusOneGram) {
NGramDistribution nGramDist = nGrams.get(nGram.size());
double prob;
if (nGramDist.containsKey(nGram)) {
// impossible for the bigram not to exist if the trigram does
NGramDistribution nMinusOneGramDist = nGrams.get(nGram.size() - 1);
int nMinusOneGramCount = nMinusOneGramDist.get(nMinusOneGram);
int nGramCount = nGramDist.get(nGram);
prob = (double) (nGramCount) / nMinusOneGramCount;
} else {
// Laplace smoothing
// prob = 1.0 / ((totalNGramCounts[nGram.size()] + 2.0));
NGram backoff = nGram.subNGram(1);
NGram backoffMinusOne = backoff.subNGram(0, backoff.size() - 1);
return getProb(backoff, backoffMinusOne);
}
return prob;
}
static String getRandomToken(Map<Integer, NGramDistribution> nGrams, int n,
NGram nMinusOneGram) {
List<Entry<NGram, Integer>> matchingNgrams = new ArrayList<Entry<NGram, Integer>>();
NGramDistribution ngDist = nGrams.get(n);
for (Entry<NGram, Integer> entry : ngDist.entrySet()) {
if (entry.getKey().startsWith(nMinusOneGram)) {
matchingNgrams.add(entry);
}
}
int totalCount = 0;
for (int i = 0; i < matchingNgrams.size(); i++) {
totalCount += matchingNgrams.get(i).getValue();
}
double random = Math.random();
int randomNGramIndex = (int) (random * totalCount);
totalCount = 0;
for (int i = 0; i < matchingNgrams.size(); i++) {
int currentCount = matchingNgrams.get(i).getValue();
if (randomNGramIndex < totalCount + currentCount) {
NGram ngram = matchingNgrams.get(i).getKey();
return ngram.get(ngram.size() - 1);
}
totalCount += currentCount;
}
return getRandomToken(nGrams, n - 1, nMinusOneGram.subNGram(1));
}
public NGramModel() {
// initialize Maps of 0-grams, unigrams, bigrams, trigrams...
nGrams = new HashMap<Integer, NGramDistribution>();
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
nGrams.put(i, new NGramDistribution());
}
}
private void addNGram(int nGramLength, NGram nGram) {
if (nGram.size() < nGramLength) {
System.out.println("Cannot create " + nGramLength + "-gram from: "
+ nGram);
}
Map<NGram, Integer> nGramCounts = nGrams.get(nGramLength);
NGram nGramCopy = nGram.copy(nGramLength);
if (nGramCounts.containsKey(nGramCopy)) {
int nGramCount = nGramCounts.get(nGramCopy);
nGramCounts.put(nGramCopy, nGramCount + 1);
} else {
nGramCounts.put(nGramCopy, 1);
}
}
/**
* Given an arbitrary String, replace punctutation with spaces, remove
* non-alphanumeric characters, prepend with <START> token, append <END>
* token.
*
* @param rawHeadline
* @return
*/
private String sanitize(String rawHeadline) {
String nonpunctuatedHeadline = rawHeadline.replaceAll(
"[\'\";:,\\]\\[]", " ");
String alphaNumericHeadline = nonpunctuatedHeadline.replaceAll(
"[^A-Za-z0-9 ]", "");
return START + " " + alphaNumericHeadline + " " + END;
}
private void calcPerplexity(List<Headline> validationSet, int nGramLimit,
boolean useUnk) throws FileNotFoundException, IOException {
List<String> fileByLines = new ArrayList<String>();
StringBuilder currentLine = new StringBuilder();
for (Headline headline : validationSet) {
String sanitizedLine = sanitize(headline.getText());
// split on whitespace
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
for (String token : tokens) {
if (!isWord(token)) {
continue;
}
String word;
if (!words.contains(token) && useUnk) {
word = UNK;
// words.add(token);
} else {
word = token;
}
if (END.equals(word)) {
currentLine.append(word);
fileByLines.add(currentLine.toString());
currentLine = new StringBuilder();
} else {
currentLine.append(word);
currentLine.append(" ");
}
}
}
int wordNum = 0;
int sentNum = 0;
double logProbT = 0.0;
for (String str : fileByLines) {
double logProbS = 0.0;
NGram currentNgram = new NGram(nGramLimit);
String[] tokens = str.split("\\s+");
for (String token : tokens) {
String word;
if (!words.contains(token) && useUnk) {
word = UNK;
} else {
word = token;
}
currentNgram.add(word);
NGram nMinusOneGram = currentNgram.subNGram(0,
currentNgram.size() - 1);
double prob = getProb(currentNgram, nMinusOneGram);
logProbS += Math.log(prob);
}
logProbT += logProbS;
wordNum += tokens.length;
sentNum++;
}
System.out.println("Evaluated " + sentNum + " sentences and " + wordNum
+ " words using " + nGramLimit + "-gram model.");
int N = wordNum + sentNum;
System.out.println("Total n-grams: " + N);
double perplexity = Math.pow(Math.E, (-1.0 / N) * logProbT);
System.out.println("Perplexity: over " + N
+ " recognized n-grams in verification corpus: " + perplexity);
}
private void generateModel(List<Headline> traininSet, boolean genRandom,
boolean useUnk) throws FileNotFoundException, IOException {
StringBuilder currentLine = new StringBuilder();
List<String> fileByLines = new ArrayList<String>();
for (Headline headline : traininSet) {
String headlineText = headline.getText();
if (headlineText.length() == 0) {
continue;
}
String sanitizedLine = sanitize(headline.getText());
// split on whitespace
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
for (String token : tokens) {
if (!isWord(token)) {
continue;
}
String word;
if (!wordsSeenOnce.contains(token) && useUnk) {
word = UNK;
wordsSeenOnce.add(token);
} else {
words.add(token);
word = token;
}
if (END.equals(word)) {
currentLine.append(word);
fileByLines.add(currentLine.toString());
currentLine = new StringBuilder();
} else {
currentLine.append(word);
currentLine.append(" ");
}
}
}
for (String str : fileByLines) {
System.out.println(str);
NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH);
for (String token : str.split("\\s+")) {
currentNgram.add(token);
for (int i = 0; i <= currentNgram.size(); i++) {
addNGram(currentNgram.size() - i,
currentNgram.subNGram(i, currentNgram.size()));
totalNGramCounts[currentNgram.size() - i]++;
}
}
}
System.out.println("Most common words: ");
List<Entry<NGram, Integer>> unigrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(1).entrySet());
Collections.sort(unigrams, new NGramComparator());
List<Entry<NGram, Integer>> bigrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(2).entrySet());
Collections.sort(bigrams, new NGramComparator());
List<Entry<NGram, Integer>> trigrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(3).entrySet());
Collections.sort(trigrams, new NGramComparator());
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ unigrams.get(i - 1).getKey()
+ " : "
+ (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1])));
}
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ bigrams.get(i - 1).getKey()
+ " : "
+ (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2])));
}
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ trigrams.get(i - 1).getKey()
+ " : "
+ (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3])));
}
if (genRandom) {
for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) {
System.out.println("Random sentence of length " + RANDOM_LENGTH
+ " using " + nGramLength + "-gram language model:");
StringBuilder randomText = new StringBuilder();
NGram randomNminusOneGram = new NGram(nGramLength);
for (int i = 0; i < RANDOM_LENGTH; i++) {
String randomToken = getRandomToken(nGrams,
randomNminusOneGram.size() + 1, randomNminusOneGram);
NGram randomNGram = randomNminusOneGram.copy();
randomNGram.add(randomToken);
randomText.append(randomNGram.get(randomNGram.size() - 1));
randomText.append(" ");
if (randomNGram.size() < nGramLength) {
randomNminusOneGram = randomNGram;
} else {
randomNminusOneGram = randomNGram.subNGram(1);
}
}
System.out.println(randomText);
System.out.println();
}
}
}
private boolean isWord(String aString) {
if (aString == null || aString.length() == 0) {
return false;
}
if (nonWords.contains(aString)) {
return false;
}
if (START.equals(aString) || END.equals(aString) || UNK.equals(aString)) {
return true;
}
Matcher wordMatcher = wordPattern.matcher(aString);
if (wordMatcher.find()) {
return true;
}
nonWords.add(aString);
return false;
}
public void reportModel(List<Headline> trainingSet,
List<Headline> validationSet) {
try {
NGramModel ngm = new NGramModel();
boolean doCalcPerplexity = true;
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity);
if (doCalcPerplexity) {
for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) {
ngm.calcPerplexity(validationSet, i, true);
}
}
} catch (FileNotFoundException fnfe) {
fnfe.printStackTrace(System.err);
} catch (IOException ioe) {
ioe.printStackTrace(System.err);
}
}
private class NGramComparator implements Comparator<Entry<NGram, Integer>> {
@Override
public int compare(Entry<NGram, Integer> o1, Entry<NGram, Integer> o2) {
return o2.getValue().compareTo(o1.getValue());
}
}
}

View File

@@ -9,4 +9,5 @@ public interface HeadlineService {
int insertHeadline(Headline headline);
int[] insertHeadlines(List<Headline> headline);
List<Headline> getHeadlines(String stock, Date date);
List<Headline> getHeadlines(String stock, Date startDate, Date endDate);
}

View File

@@ -32,4 +32,9 @@ public class MySQLHeadlineServiceImpl implements HeadlineService {
public List<Headline> getHeadlines(String stock, Date date) {
return headlineDao.select(stock, date);
}
@Override
public List<Headline> getHeadlines(String stock, Date startDate, Date endDate) {
return headlineDao.select(stock, startDate, endDate);
}
}

View File

@@ -86,4 +86,10 @@ public class YahooHeadlineServiceImpl implements HeadlineService {
String formattedDate = DATE_FORMATTER.format(date);
return QUERY_URL.replaceAll(STOCK_SYMBOL_FIELD, stock).replaceAll(STORY_DATE_FIELD, formattedDate);
}
@Override
public List<Headline> getHeadlines(String stock, Date startDate,
Date endDate) {
throw new UnsupportedOperationException("This implementation does not support getting headlines for a date range.");
}
}