package net.woodyfolsom.cs6601.p3.ngram; import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import net.woodyfolsom.cs6601.p3.domain.Headline; public class NGramModel { static final int MAX_N_GRAM_LENGTH = 3; static final int RANDOM_LENGTH = 100; private static final String START = ""; private static final String END = ""; private static final String UNK = ""; private Map nGrams; private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1]; private Pattern wordPattern = Pattern.compile("\\w+"); private Set nonWords = new HashSet(); private Set words = new HashSet(); private Set wordsSeenOnce = new HashSet(); double getProb(NGram nGram, NGram nMinusOneGram) { NGramDistribution nGramDist = nGrams.get(nGram.size()); double prob; if (nGramDist.containsKey(nGram)) { // impossible for the bigram not to exist if the trigram does NGramDistribution nMinusOneGramDist = nGrams.get(nGram.size() - 1); int nMinusOneGramCount = nMinusOneGramDist.get(nMinusOneGram); int nGramCount = nGramDist.get(nGram); prob = (double) (nGramCount) / nMinusOneGramCount; } else { // Laplace smoothing // prob = 1.0 / ((totalNGramCounts[nGram.size()] + 2.0)); NGram backoff = nGram.subNGram(1); NGram backoffMinusOne = backoff.subNGram(0, backoff.size() - 1); return getProb(backoff, backoffMinusOne); } return prob; } static String getRandomToken(Map nGrams, int n, NGram nMinusOneGram) { List> matchingNgrams = new ArrayList>(); NGramDistribution ngDist = nGrams.get(n); for (Entry entry : ngDist.entrySet()) { if (entry.getKey().startsWith(nMinusOneGram)) { matchingNgrams.add(entry); } } int totalCount = 0; for (int i = 0; i < matchingNgrams.size(); i++) { totalCount += matchingNgrams.get(i).getValue(); } double random = Math.random(); int randomNGramIndex = (int) (random * totalCount); totalCount = 0; for (int i = 0; i < matchingNgrams.size(); i++) { int currentCount = matchingNgrams.get(i).getValue(); if (randomNGramIndex < totalCount + currentCount) { NGram ngram = matchingNgrams.get(i).getKey(); return ngram.get(ngram.size() - 1); } totalCount += currentCount; } return getRandomToken(nGrams, n - 1, nMinusOneGram.subNGram(1)); } public NGramModel() { // initialize Maps of 0-grams, unigrams, bigrams, trigrams... nGrams = new HashMap(); for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) { nGrams.put(i, new NGramDistribution()); } } private void addNGram(int nGramLength, NGram nGram) { if (nGram.size() < nGramLength) { System.out.println("Cannot create " + nGramLength + "-gram from: " + nGram); } Map nGramCounts = nGrams.get(nGramLength); NGram nGramCopy = nGram.copy(nGramLength); if (nGramCounts.containsKey(nGramCopy)) { int nGramCount = nGramCounts.get(nGramCopy); nGramCounts.put(nGramCopy, nGramCount + 1); } else { nGramCounts.put(nGramCopy, 1); } } /** * Given an arbitrary String, replace punctutation with spaces, remove * non-alphanumeric characters, prepend with token, append * token. * * @param rawHeadline * @return */ private String sanitize(String rawHeadline) { String nonpunctuatedHeadline = rawHeadline.replaceAll( "[\'\";:,\\]\\[]", " "); String alphaNumericHeadline = nonpunctuatedHeadline.replaceAll( "[^A-Za-z0-9 ]", ""); return START + " " + alphaNumericHeadline + " " + END; } private void calcPerplexity(List validationSet, int nGramLimit, boolean useUnk) throws FileNotFoundException, IOException { List fileByLines = new ArrayList(); StringBuilder currentLine = new StringBuilder(); for (Headline headline : validationSet) { String sanitizedLine = sanitize(headline.getText()); // split on whitespace String[] tokens = sanitizedLine.toLowerCase().split("\\s+"); for (String token : tokens) { if (!isWord(token)) { continue; } String word; if (!words.contains(token) && useUnk) { word = UNK; // words.add(token); } else { word = token; } if (END.equals(word)) { currentLine.append(word); fileByLines.add(currentLine.toString()); currentLine = new StringBuilder(); } else { currentLine.append(word); currentLine.append(" "); } } } int wordNum = 0; int sentNum = 0; double logProbT = 0.0; for (String str : fileByLines) { double logProbS = 0.0; NGram currentNgram = new NGram(nGramLimit); String[] tokens = str.split("\\s+"); for (String token : tokens) { String word; if (!words.contains(token) && useUnk) { word = UNK; } else { word = token; } currentNgram.add(word); NGram nMinusOneGram = currentNgram.subNGram(0, currentNgram.size() - 1); double prob = getProb(currentNgram, nMinusOneGram); logProbS += Math.log(prob); } logProbT += logProbS; wordNum += tokens.length; sentNum++; } System.out.println("Evaluated " + sentNum + " sentences and " + wordNum + " words using " + nGramLimit + "-gram model."); int N = wordNum + sentNum; System.out.println("Total n-grams: " + N); double perplexity = Math.pow(Math.E, (-1.0 / N) * logProbT); System.out.println("Perplexity: over " + N + " recognized n-grams in verification corpus: " + perplexity); } private void generateModel(List traininSet, boolean genRandom, boolean useUnk) throws FileNotFoundException, IOException { StringBuilder currentLine = new StringBuilder(); List fileByLines = new ArrayList(); for (Headline headline : traininSet) { String headlineText = headline.getText(); if (headlineText.length() == 0) { continue; } String sanitizedLine = sanitize(headline.getText()); // split on whitespace String[] tokens = sanitizedLine.toLowerCase().split("\\s+"); for (String token : tokens) { if (!isWord(token)) { continue; } String word; if (!wordsSeenOnce.contains(token) && useUnk) { word = UNK; wordsSeenOnce.add(token); } else { words.add(token); word = token; } if (END.equals(word)) { currentLine.append(word); fileByLines.add(currentLine.toString()); currentLine = new StringBuilder(); } else { currentLine.append(word); currentLine.append(" "); } } } for (String str : fileByLines) { System.out.println(str); NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH); for (String token : str.split("\\s+")) { currentNgram.add(token); for (int i = 0; i <= currentNgram.size(); i++) { addNGram(currentNgram.size() - i, currentNgram.subNGram(i, currentNgram.size())); totalNGramCounts[currentNgram.size() - i]++; } } } System.out.println("Most common words: "); List> unigrams = new ArrayList>( nGrams.get(1).entrySet()); Collections.sort(unigrams, new NGramComparator()); List> bigrams = new ArrayList>( nGrams.get(2).entrySet()); Collections.sort(bigrams, new NGramComparator()); List> trigrams = new ArrayList>( nGrams.get(3).entrySet()); Collections.sort(trigrams, new NGramComparator()); for (int i = 1; i <= 10; i++) { System.out .println(i + ". " + unigrams.get(i - 1).getKey() + " : " + (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1]))); } for (int i = 1; i <= 10; i++) { System.out .println(i + ". " + bigrams.get(i - 1).getKey() + " : " + (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2]))); } for (int i = 1; i <= 10; i++) { System.out .println(i + ". " + trigrams.get(i - 1).getKey() + " : " + (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3]))); } if (genRandom) { for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) { System.out.println("Random sentence of length " + RANDOM_LENGTH + " using " + nGramLength + "-gram language model:"); StringBuilder randomText = new StringBuilder(); NGram randomNminusOneGram = new NGram(nGramLength); for (int i = 0; i < RANDOM_LENGTH; i++) { String randomToken = getRandomToken(nGrams, randomNminusOneGram.size() + 1, randomNminusOneGram); NGram randomNGram = randomNminusOneGram.copy(); randomNGram.add(randomToken); randomText.append(randomNGram.get(randomNGram.size() - 1)); randomText.append(" "); if (randomNGram.size() < nGramLength) { randomNminusOneGram = randomNGram; } else { randomNminusOneGram = randomNGram.subNGram(1); } } System.out.println(randomText); System.out.println(); } } } private boolean isWord(String aString) { if (aString == null || aString.length() == 0) { return false; } if (nonWords.contains(aString)) { return false; } if (START.equals(aString) || END.equals(aString) || UNK.equals(aString)) { return true; } Matcher wordMatcher = wordPattern.matcher(aString); if (wordMatcher.find()) { return true; } nonWords.add(aString); return false; } public void reportModel(List trainingSet, List validationSet) { try { NGramModel ngm = new NGramModel(); boolean doCalcPerplexity = true; ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity); if (doCalcPerplexity) { for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) { ngm.calcPerplexity(validationSet, i, true); } } } catch (FileNotFoundException fnfe) { fnfe.printStackTrace(System.err); } catch (IOException ioe) { ioe.printStackTrace(System.err); } } private class NGramComparator implements Comparator> { @Override public int compare(Entry o1, Entry o2) { return o2.getValue().compareTo(o1.getValue()); } } }