Added ability to generate 1, 2, 3-gram models on a company/date-range basis using <UNK> to represent the initial appearance of a previously unknown word.
This commit is contained in:
361
src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java
Normal file
361
src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java
Normal file
@@ -0,0 +1,361 @@
|
||||
package net.woodyfolsom.cs6601.p3.ngram;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
|
||||
public class NGramModel {
|
||||
static final int MAX_N_GRAM_LENGTH = 3;
|
||||
static final int RANDOM_LENGTH = 100;
|
||||
|
||||
private static final String START = "<start>";
|
||||
private static final String END = "<end>";
|
||||
private static final String UNK = "<unk>";
|
||||
|
||||
private Map<Integer, NGramDistribution> nGrams;
|
||||
private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1];
|
||||
|
||||
private Pattern wordPattern = Pattern.compile("\\w+");
|
||||
private Set<String> nonWords = new HashSet<String>();
|
||||
private Set<String> words = new HashSet<String>();
|
||||
private Set<String> wordsSeenOnce = new HashSet<String>();
|
||||
|
||||
double getProb(NGram nGram, NGram nMinusOneGram) {
|
||||
NGramDistribution nGramDist = nGrams.get(nGram.size());
|
||||
double prob;
|
||||
if (nGramDist.containsKey(nGram)) {
|
||||
// impossible for the bigram not to exist if the trigram does
|
||||
NGramDistribution nMinusOneGramDist = nGrams.get(nGram.size() - 1);
|
||||
int nMinusOneGramCount = nMinusOneGramDist.get(nMinusOneGram);
|
||||
int nGramCount = nGramDist.get(nGram);
|
||||
prob = (double) (nGramCount) / nMinusOneGramCount;
|
||||
} else {
|
||||
// Laplace smoothing
|
||||
// prob = 1.0 / ((totalNGramCounts[nGram.size()] + 2.0));
|
||||
NGram backoff = nGram.subNGram(1);
|
||||
NGram backoffMinusOne = backoff.subNGram(0, backoff.size() - 1);
|
||||
return getProb(backoff, backoffMinusOne);
|
||||
}
|
||||
|
||||
return prob;
|
||||
}
|
||||
|
||||
static String getRandomToken(Map<Integer, NGramDistribution> nGrams, int n,
|
||||
NGram nMinusOneGram) {
|
||||
List<Entry<NGram, Integer>> matchingNgrams = new ArrayList<Entry<NGram, Integer>>();
|
||||
NGramDistribution ngDist = nGrams.get(n);
|
||||
|
||||
for (Entry<NGram, Integer> entry : ngDist.entrySet()) {
|
||||
if (entry.getKey().startsWith(nMinusOneGram)) {
|
||||
matchingNgrams.add(entry);
|
||||
}
|
||||
}
|
||||
|
||||
int totalCount = 0;
|
||||
for (int i = 0; i < matchingNgrams.size(); i++) {
|
||||
totalCount += matchingNgrams.get(i).getValue();
|
||||
}
|
||||
double random = Math.random();
|
||||
int randomNGramIndex = (int) (random * totalCount);
|
||||
|
||||
totalCount = 0;
|
||||
for (int i = 0; i < matchingNgrams.size(); i++) {
|
||||
int currentCount = matchingNgrams.get(i).getValue();
|
||||
if (randomNGramIndex < totalCount + currentCount) {
|
||||
NGram ngram = matchingNgrams.get(i).getKey();
|
||||
return ngram.get(ngram.size() - 1);
|
||||
}
|
||||
totalCount += currentCount;
|
||||
}
|
||||
return getRandomToken(nGrams, n - 1, nMinusOneGram.subNGram(1));
|
||||
}
|
||||
|
||||
public NGramModel() {
|
||||
// initialize Maps of 0-grams, unigrams, bigrams, trigrams...
|
||||
nGrams = new HashMap<Integer, NGramDistribution>();
|
||||
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||
nGrams.put(i, new NGramDistribution());
|
||||
}
|
||||
}
|
||||
|
||||
private void addNGram(int nGramLength, NGram nGram) {
|
||||
if (nGram.size() < nGramLength) {
|
||||
System.out.println("Cannot create " + nGramLength + "-gram from: "
|
||||
+ nGram);
|
||||
}
|
||||
|
||||
Map<NGram, Integer> nGramCounts = nGrams.get(nGramLength);
|
||||
NGram nGramCopy = nGram.copy(nGramLength);
|
||||
|
||||
if (nGramCounts.containsKey(nGramCopy)) {
|
||||
int nGramCount = nGramCounts.get(nGramCopy);
|
||||
nGramCounts.put(nGramCopy, nGramCount + 1);
|
||||
} else {
|
||||
nGramCounts.put(nGramCopy, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an arbitrary String, replace punctutation with spaces, remove
|
||||
* non-alphanumeric characters, prepend with <START> token, append <END>
|
||||
* token.
|
||||
*
|
||||
* @param rawHeadline
|
||||
* @return
|
||||
*/
|
||||
private String sanitize(String rawHeadline) {
|
||||
String nonpunctuatedHeadline = rawHeadline.replaceAll(
|
||||
"[\'\";:,\\]\\[]", " ");
|
||||
String alphaNumericHeadline = nonpunctuatedHeadline.replaceAll(
|
||||
"[^A-Za-z0-9 ]", "");
|
||||
return START + " " + alphaNumericHeadline + " " + END;
|
||||
}
|
||||
|
||||
private void calcPerplexity(List<Headline> validationSet, int nGramLimit,
|
||||
boolean useUnk) throws FileNotFoundException, IOException {
|
||||
List<String> fileByLines = new ArrayList<String>();
|
||||
StringBuilder currentLine = new StringBuilder();
|
||||
|
||||
for (Headline headline : validationSet) {
|
||||
String sanitizedLine = sanitize(headline.getText());
|
||||
// split on whitespace
|
||||
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
|
||||
for (String token : tokens) {
|
||||
if (!isWord(token)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String word;
|
||||
if (!words.contains(token) && useUnk) {
|
||||
word = UNK;
|
||||
// words.add(token);
|
||||
} else {
|
||||
word = token;
|
||||
}
|
||||
|
||||
if (END.equals(word)) {
|
||||
currentLine.append(word);
|
||||
fileByLines.add(currentLine.toString());
|
||||
currentLine = new StringBuilder();
|
||||
} else {
|
||||
currentLine.append(word);
|
||||
currentLine.append(" ");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int wordNum = 0;
|
||||
int sentNum = 0;
|
||||
|
||||
double logProbT = 0.0;
|
||||
for (String str : fileByLines) {
|
||||
double logProbS = 0.0;
|
||||
NGram currentNgram = new NGram(nGramLimit);
|
||||
String[] tokens = str.split("\\s+");
|
||||
for (String token : tokens) {
|
||||
String word;
|
||||
if (!words.contains(token) && useUnk) {
|
||||
word = UNK;
|
||||
} else {
|
||||
word = token;
|
||||
}
|
||||
currentNgram.add(word);
|
||||
NGram nMinusOneGram = currentNgram.subNGram(0,
|
||||
currentNgram.size() - 1);
|
||||
|
||||
double prob = getProb(currentNgram, nMinusOneGram);
|
||||
|
||||
logProbS += Math.log(prob);
|
||||
}
|
||||
logProbT += logProbS;
|
||||
wordNum += tokens.length;
|
||||
sentNum++;
|
||||
}
|
||||
|
||||
System.out.println("Evaluated " + sentNum + " sentences and " + wordNum
|
||||
+ " words using " + nGramLimit + "-gram model.");
|
||||
int N = wordNum + sentNum;
|
||||
System.out.println("Total n-grams: " + N);
|
||||
double perplexity = Math.pow(Math.E, (-1.0 / N) * logProbT);
|
||||
System.out.println("Perplexity: over " + N
|
||||
+ " recognized n-grams in verification corpus: " + perplexity);
|
||||
}
|
||||
|
||||
private void generateModel(List<Headline> traininSet, boolean genRandom,
|
||||
boolean useUnk) throws FileNotFoundException, IOException {
|
||||
StringBuilder currentLine = new StringBuilder();
|
||||
List<String> fileByLines = new ArrayList<String>();
|
||||
|
||||
for (Headline headline : traininSet) {
|
||||
String headlineText = headline.getText();
|
||||
if (headlineText.length() == 0) {
|
||||
continue;
|
||||
}
|
||||
String sanitizedLine = sanitize(headline.getText());
|
||||
// split on whitespace
|
||||
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
|
||||
for (String token : tokens) {
|
||||
if (!isWord(token)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String word;
|
||||
if (!wordsSeenOnce.contains(token) && useUnk) {
|
||||
word = UNK;
|
||||
wordsSeenOnce.add(token);
|
||||
} else {
|
||||
words.add(token);
|
||||
word = token;
|
||||
}
|
||||
|
||||
if (END.equals(word)) {
|
||||
currentLine.append(word);
|
||||
fileByLines.add(currentLine.toString());
|
||||
currentLine = new StringBuilder();
|
||||
} else {
|
||||
currentLine.append(word);
|
||||
currentLine.append(" ");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (String str : fileByLines) {
|
||||
System.out.println(str);
|
||||
NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH);
|
||||
for (String token : str.split("\\s+")) {
|
||||
currentNgram.add(token);
|
||||
for (int i = 0; i <= currentNgram.size(); i++) {
|
||||
addNGram(currentNgram.size() - i,
|
||||
currentNgram.subNGram(i, currentNgram.size()));
|
||||
totalNGramCounts[currentNgram.size() - i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Most common words: ");
|
||||
|
||||
List<Entry<NGram, Integer>> unigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(1).entrySet());
|
||||
Collections.sort(unigrams, new NGramComparator());
|
||||
|
||||
List<Entry<NGram, Integer>> bigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(2).entrySet());
|
||||
Collections.sort(bigrams, new NGramComparator());
|
||||
|
||||
List<Entry<NGram, Integer>> trigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(3).entrySet());
|
||||
Collections.sort(trigrams, new NGramComparator());
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ unigrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1])));
|
||||
}
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ bigrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2])));
|
||||
}
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ trigrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3])));
|
||||
}
|
||||
if (genRandom) {
|
||||
for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) {
|
||||
System.out.println("Random sentence of length " + RANDOM_LENGTH
|
||||
+ " using " + nGramLength + "-gram language model:");
|
||||
StringBuilder randomText = new StringBuilder();
|
||||
|
||||
NGram randomNminusOneGram = new NGram(nGramLength);
|
||||
for (int i = 0; i < RANDOM_LENGTH; i++) {
|
||||
String randomToken = getRandomToken(nGrams,
|
||||
randomNminusOneGram.size() + 1, randomNminusOneGram);
|
||||
NGram randomNGram = randomNminusOneGram.copy();
|
||||
randomNGram.add(randomToken);
|
||||
randomText.append(randomNGram.get(randomNGram.size() - 1));
|
||||
randomText.append(" ");
|
||||
|
||||
if (randomNGram.size() < nGramLength) {
|
||||
randomNminusOneGram = randomNGram;
|
||||
} else {
|
||||
randomNminusOneGram = randomNGram.subNGram(1);
|
||||
}
|
||||
}
|
||||
System.out.println(randomText);
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isWord(String aString) {
|
||||
if (aString == null || aString.length() == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (nonWords.contains(aString)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (START.equals(aString) || END.equals(aString) || UNK.equals(aString)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Matcher wordMatcher = wordPattern.matcher(aString);
|
||||
if (wordMatcher.find()) {
|
||||
return true;
|
||||
}
|
||||
nonWords.add(aString);
|
||||
return false;
|
||||
}
|
||||
|
||||
public void reportModel(List<Headline> trainingSet,
|
||||
List<Headline> validationSet) {
|
||||
try {
|
||||
NGramModel ngm = new NGramModel();
|
||||
boolean doCalcPerplexity = true;
|
||||
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity);
|
||||
if (doCalcPerplexity) {
|
||||
for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||
ngm.calcPerplexity(validationSet, i, true);
|
||||
}
|
||||
}
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
fnfe.printStackTrace(System.err);
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace(System.err);
|
||||
}
|
||||
}
|
||||
|
||||
private class NGramComparator implements Comparator<Entry<NGram, Integer>> {
|
||||
|
||||
@Override
|
||||
public int compare(Entry<NGram, Integer> o1, Entry<NGram, Integer> o2) {
|
||||
return o2.getValue().compareTo(o1.getValue());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user