Added ability to generate 1, 2, 3-gram models on a company/date-range basis using <UNK> to represent the initial appearance of a previously unknown word.
This commit is contained in:
28
AppContext.xml
Normal file
28
AppContext.xml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<beans xmlns="http://www.springframework.org/schema/beans"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xmlns:aop="http://www.springframework.org/schema/aop"
|
||||||
|
xmlns:context="http://www.springframework.org/schema/context"
|
||||||
|
xsi:schemaLocation="http://www.springframework.org/schema/beans
|
||||||
|
http://www.springframework.org/schema/beans/spring-beans-2.5.xsd
|
||||||
|
http://www.springframework.org/schema/aop
|
||||||
|
http://www.springframework.org/schema/aop/spring-aop-2.5.xsd
|
||||||
|
http://www.springframework.org/schema/context
|
||||||
|
http://www.springframework.org/schema/context/spring-context-2.5.xsd"
|
||||||
|
default-autowire="byName">
|
||||||
|
|
||||||
|
<bean id="dmdataSource"
|
||||||
|
class="org.springframework.jdbc.datasource.DriverManagerDataSource">
|
||||||
|
<property name="driverClassName" value="com.mysql.jdbc.Driver" />
|
||||||
|
<property name="url" value="jdbc:mysql://woodyfolsom.net:3306/cs6601p3" />
|
||||||
|
<property name="username" value="cs6601" />
|
||||||
|
<property name="password" value="n0nst@p" />
|
||||||
|
</bean>
|
||||||
|
|
||||||
|
<bean id="mySQLHeadlineSvc" class="net.woodyfolsom.cs6601.p3.svc.MySQLHeadlineServiceImpl" />
|
||||||
|
<bean id="yahooHeadlineSvc" class="net.woodyfolsom.cs6601.p3.svc.YahooHeadlineServiceImpl" />
|
||||||
|
|
||||||
|
<context:annotation-config />
|
||||||
|
<context:component-scan base-package="net.woodyfolsom.cs6601.p3"/>
|
||||||
|
|
||||||
|
</beans>
|
||||||
115
src/net/woodyfolsom/cs6601/p3/ModelGenerator.java
Normal file
115
src/net/woodyfolsom/cs6601/p3/ModelGenerator.java
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package net.woodyfolsom.cs6601.p3;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.text.DateFormat;
|
||||||
|
import java.text.ParseException;
|
||||||
|
import java.text.SimpleDateFormat;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Date;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.context.ApplicationContext;
|
||||||
|
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import net.woodyfolsom.cs6601.p3.domain.Company;
|
||||||
|
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||||
|
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
|
||||||
|
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
public class ModelGenerator {
|
||||||
|
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
||||||
|
|
||||||
|
private static final int INVALID_DATE = 1;
|
||||||
|
private static final int IO_EXCEPTION = 2;
|
||||||
|
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
HeadlineService mySQLHeadlineServiceImpl;
|
||||||
|
|
||||||
|
private NGramModel ngramModel = new NGramModel();
|
||||||
|
|
||||||
|
public static void main(String... args) {
|
||||||
|
|
||||||
|
ApplicationContext context = new FileSystemXmlApplicationContext(
|
||||||
|
new String[] { "AppContext.xml" });
|
||||||
|
ModelGenerator modelGenerator = context.getBean(ModelGenerator.class);
|
||||||
|
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
|
||||||
|
|
||||||
|
Date startDate = null;
|
||||||
|
Date endDate = null;
|
||||||
|
Date valStart = null;
|
||||||
|
Date valEnd = null;
|
||||||
|
try {
|
||||||
|
startDate = dateFmt.parse("2012-01-01");
|
||||||
|
endDate = dateFmt.parse("2012-03-31");
|
||||||
|
valStart = dateFmt.parse("2012-04-01");
|
||||||
|
valEnd = dateFmt.parse("2012-04-14");
|
||||||
|
} catch (ParseException pe) {
|
||||||
|
System.exit(INVALID_DATE);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
List<Company> fortune50 = modelGenerator
|
||||||
|
.getFortune50(stockSymbolsCSV);
|
||||||
|
for (Company company : fortune50) {
|
||||||
|
System.out.println("Getting headlines for Fortune 50 company #"
|
||||||
|
+ company.getId() + " (" + company.getName() + ")...");
|
||||||
|
List<Headline> trainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate);
|
||||||
|
System.out.println("Pulled " + trainingSet.size() + " headlines for "
|
||||||
|
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
||||||
|
List<Headline> validationSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), valStart, valEnd);
|
||||||
|
|
||||||
|
if (trainingSet.size() == 0) {
|
||||||
|
System.out.println("Training dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (validationSet.size() == 0) {
|
||||||
|
System.out.println("Validation dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
modelGenerator.ngramModel.reportModel(trainingSet, validationSet);
|
||||||
|
System.out.println("Finished " + company.getId() + " / 50");
|
||||||
|
}
|
||||||
|
} catch (FileNotFoundException fnfe) {
|
||||||
|
System.out.println("Stock symbol CSV file does not exist: "
|
||||||
|
+ stockSymbolsCSV);
|
||||||
|
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
System.out.println("Stock symbol CSV file does not exist: "
|
||||||
|
+ stockSymbolsCSV);
|
||||||
|
System.exit(IO_EXCEPTION);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Company> getFortune50(File csvFile)
|
||||||
|
throws FileNotFoundException, IOException {
|
||||||
|
List<Company> fortune50 = new ArrayList<Company>();
|
||||||
|
FileInputStream fis = new FileInputStream(csvFile);
|
||||||
|
InputStreamReader reader = new InputStreamReader(fis);
|
||||||
|
BufferedReader buf = new BufferedReader(reader);
|
||||||
|
String csvline = null;
|
||||||
|
while ((csvline = buf.readLine()) != null) {
|
||||||
|
if (csvline.length() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String[] fields = csvline.split(",");
|
||||||
|
if (fields.length != 3) {
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Badly formatted csv file name (3 values expected): "
|
||||||
|
+ csvline);
|
||||||
|
}
|
||||||
|
int id = Integer.valueOf(fields[0]);
|
||||||
|
fortune50.add(new Company(id, fields[1], fields[2]));
|
||||||
|
}
|
||||||
|
return fortune50;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,4 +13,5 @@ public interface HeadlineDao {
|
|||||||
|
|
||||||
Headline select(int id);
|
Headline select(int id);
|
||||||
List<Headline> select(String stock, Date date);
|
List<Headline> select(String stock, Date date);
|
||||||
|
List<Headline> select(String stock, Date startDate, Date endDate);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ public class HeadlineDaoImpl implements HeadlineDao {
|
|||||||
|
|
||||||
private static final String SELECT_BY_ID_QRY = "SELECT * from headlines WHERE id = ?";
|
private static final String SELECT_BY_ID_QRY = "SELECT * from headlines WHERE id = ?";
|
||||||
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ?";
|
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ?";
|
||||||
|
private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ?";
|
||||||
|
|
||||||
private JdbcTemplate jdbcTemplate;
|
private JdbcTemplate jdbcTemplate;
|
||||||
|
|
||||||
public int deleteById(int headlineId) {
|
public int deleteById(int headlineId) {
|
||||||
@@ -66,6 +67,11 @@ public class HeadlineDaoImpl implements HeadlineDao {
|
|||||||
new RequestMapper(), stock, date);
|
new RequestMapper(), stock, date);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<Headline> select(String stock, Date startDate, Date endDate) {
|
||||||
|
return jdbcTemplate.query(SELECT_BY_DATE_RANGE_QRY,
|
||||||
|
new RequestMapper(), stock, startDate, endDate);
|
||||||
|
}
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
public void createTemplate(DataSource dataSource) {
|
public void createTemplate(DataSource dataSource) {
|
||||||
this.jdbcTemplate = new JdbcTemplate(dataSource);
|
this.jdbcTemplate = new JdbcTemplate(dataSource);
|
||||||
|
|||||||
95
src/net/woodyfolsom/cs6601/p3/ngram/NGram.java
Normal file
95
src/net/woodyfolsom/cs6601/p3/ngram/NGram.java
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package net.woodyfolsom.cs6601.p3.ngram;
|
||||||
|
|
||||||
|
import java.util.LinkedList;
|
||||||
|
|
||||||
|
public class NGram extends LinkedList<String> implements Comparable<NGram> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
private int maxLength;
|
||||||
|
|
||||||
|
public NGram(int maxLength) {
|
||||||
|
super();
|
||||||
|
this.maxLength = maxLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean add(String word) {
|
||||||
|
if (super.size() == maxLength) {
|
||||||
|
super.remove(0);
|
||||||
|
}
|
||||||
|
return super.add(word);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int compareTo(NGram other) {
|
||||||
|
return this.toString().compareTo(other.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
public NGram copy() {
|
||||||
|
NGram nGram = new NGram(this.maxLength);
|
||||||
|
for (String token : this) {
|
||||||
|
nGram.add(token);
|
||||||
|
}
|
||||||
|
return nGram;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NGram copy(int length) {
|
||||||
|
if (length > super.size()) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGram nGramCopy = new NGram(length);
|
||||||
|
|
||||||
|
for (int i = super.size() - length; i < size(); i++) {
|
||||||
|
nGramCopy.add(super.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
return nGramCopy;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean startsWith(NGram other) {
|
||||||
|
if (other.size() != this.size() - 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < other.size(); i++) {
|
||||||
|
if (other.get(i).equals(this.get(i)) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NGram subNGram(int index) {
|
||||||
|
if (index > super.size()) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
NGram subNGram = new NGram(maxLength);
|
||||||
|
for (int i = index; i < super.size(); i++) {
|
||||||
|
subNGram.add(super.get(i));
|
||||||
|
}
|
||||||
|
return subNGram;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NGram subNGram(int start, int end) {
|
||||||
|
if (start > end) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
NGram subNGram = new NGram(end - start);
|
||||||
|
for (int i = start; i < end; i++) {
|
||||||
|
subNGram.add(super.get(i));
|
||||||
|
}
|
||||||
|
return subNGram;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
if (super.size() == 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (int i = 0; i < super.size() - 1; i++) {
|
||||||
|
sb.append(super.get(i));
|
||||||
|
sb.append(" ");
|
||||||
|
}
|
||||||
|
sb.append(super.get(super.size() - 1));
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package net.woodyfolsom.cs6601.p3.ngram;
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
public class NGramDistribution extends HashMap<NGram, Integer> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
}
|
||||||
361
src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java
Normal file
361
src/net/woodyfolsom/cs6601/p3/ngram/NGramModel.java
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
package net.woodyfolsom.cs6601.p3.ngram;
|
||||||
|
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||||
|
|
||||||
|
public class NGramModel {
|
||||||
|
static final int MAX_N_GRAM_LENGTH = 3;
|
||||||
|
static final int RANDOM_LENGTH = 100;
|
||||||
|
|
||||||
|
private static final String START = "<start>";
|
||||||
|
private static final String END = "<end>";
|
||||||
|
private static final String UNK = "<unk>";
|
||||||
|
|
||||||
|
private Map<Integer, NGramDistribution> nGrams;
|
||||||
|
private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1];
|
||||||
|
|
||||||
|
private Pattern wordPattern = Pattern.compile("\\w+");
|
||||||
|
private Set<String> nonWords = new HashSet<String>();
|
||||||
|
private Set<String> words = new HashSet<String>();
|
||||||
|
private Set<String> wordsSeenOnce = new HashSet<String>();
|
||||||
|
|
||||||
|
double getProb(NGram nGram, NGram nMinusOneGram) {
|
||||||
|
NGramDistribution nGramDist = nGrams.get(nGram.size());
|
||||||
|
double prob;
|
||||||
|
if (nGramDist.containsKey(nGram)) {
|
||||||
|
// impossible for the bigram not to exist if the trigram does
|
||||||
|
NGramDistribution nMinusOneGramDist = nGrams.get(nGram.size() - 1);
|
||||||
|
int nMinusOneGramCount = nMinusOneGramDist.get(nMinusOneGram);
|
||||||
|
int nGramCount = nGramDist.get(nGram);
|
||||||
|
prob = (double) (nGramCount) / nMinusOneGramCount;
|
||||||
|
} else {
|
||||||
|
// Laplace smoothing
|
||||||
|
// prob = 1.0 / ((totalNGramCounts[nGram.size()] + 2.0));
|
||||||
|
NGram backoff = nGram.subNGram(1);
|
||||||
|
NGram backoffMinusOne = backoff.subNGram(0, backoff.size() - 1);
|
||||||
|
return getProb(backoff, backoffMinusOne);
|
||||||
|
}
|
||||||
|
|
||||||
|
return prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
static String getRandomToken(Map<Integer, NGramDistribution> nGrams, int n,
|
||||||
|
NGram nMinusOneGram) {
|
||||||
|
List<Entry<NGram, Integer>> matchingNgrams = new ArrayList<Entry<NGram, Integer>>();
|
||||||
|
NGramDistribution ngDist = nGrams.get(n);
|
||||||
|
|
||||||
|
for (Entry<NGram, Integer> entry : ngDist.entrySet()) {
|
||||||
|
if (entry.getKey().startsWith(nMinusOneGram)) {
|
||||||
|
matchingNgrams.add(entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int totalCount = 0;
|
||||||
|
for (int i = 0; i < matchingNgrams.size(); i++) {
|
||||||
|
totalCount += matchingNgrams.get(i).getValue();
|
||||||
|
}
|
||||||
|
double random = Math.random();
|
||||||
|
int randomNGramIndex = (int) (random * totalCount);
|
||||||
|
|
||||||
|
totalCount = 0;
|
||||||
|
for (int i = 0; i < matchingNgrams.size(); i++) {
|
||||||
|
int currentCount = matchingNgrams.get(i).getValue();
|
||||||
|
if (randomNGramIndex < totalCount + currentCount) {
|
||||||
|
NGram ngram = matchingNgrams.get(i).getKey();
|
||||||
|
return ngram.get(ngram.size() - 1);
|
||||||
|
}
|
||||||
|
totalCount += currentCount;
|
||||||
|
}
|
||||||
|
return getRandomToken(nGrams, n - 1, nMinusOneGram.subNGram(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public NGramModel() {
|
||||||
|
// initialize Maps of 0-grams, unigrams, bigrams, trigrams...
|
||||||
|
nGrams = new HashMap<Integer, NGramDistribution>();
|
||||||
|
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||||
|
nGrams.put(i, new NGramDistribution());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addNGram(int nGramLength, NGram nGram) {
|
||||||
|
if (nGram.size() < nGramLength) {
|
||||||
|
System.out.println("Cannot create " + nGramLength + "-gram from: "
|
||||||
|
+ nGram);
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<NGram, Integer> nGramCounts = nGrams.get(nGramLength);
|
||||||
|
NGram nGramCopy = nGram.copy(nGramLength);
|
||||||
|
|
||||||
|
if (nGramCounts.containsKey(nGramCopy)) {
|
||||||
|
int nGramCount = nGramCounts.get(nGramCopy);
|
||||||
|
nGramCounts.put(nGramCopy, nGramCount + 1);
|
||||||
|
} else {
|
||||||
|
nGramCounts.put(nGramCopy, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given an arbitrary String, replace punctutation with spaces, remove
|
||||||
|
* non-alphanumeric characters, prepend with <START> token, append <END>
|
||||||
|
* token.
|
||||||
|
*
|
||||||
|
* @param rawHeadline
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private String sanitize(String rawHeadline) {
|
||||||
|
String nonpunctuatedHeadline = rawHeadline.replaceAll(
|
||||||
|
"[\'\";:,\\]\\[]", " ");
|
||||||
|
String alphaNumericHeadline = nonpunctuatedHeadline.replaceAll(
|
||||||
|
"[^A-Za-z0-9 ]", "");
|
||||||
|
return START + " " + alphaNumericHeadline + " " + END;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void calcPerplexity(List<Headline> validationSet, int nGramLimit,
|
||||||
|
boolean useUnk) throws FileNotFoundException, IOException {
|
||||||
|
List<String> fileByLines = new ArrayList<String>();
|
||||||
|
StringBuilder currentLine = new StringBuilder();
|
||||||
|
|
||||||
|
for (Headline headline : validationSet) {
|
||||||
|
String sanitizedLine = sanitize(headline.getText());
|
||||||
|
// split on whitespace
|
||||||
|
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
|
||||||
|
for (String token : tokens) {
|
||||||
|
if (!isWord(token)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
String word;
|
||||||
|
if (!words.contains(token) && useUnk) {
|
||||||
|
word = UNK;
|
||||||
|
// words.add(token);
|
||||||
|
} else {
|
||||||
|
word = token;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (END.equals(word)) {
|
||||||
|
currentLine.append(word);
|
||||||
|
fileByLines.add(currentLine.toString());
|
||||||
|
currentLine = new StringBuilder();
|
||||||
|
} else {
|
||||||
|
currentLine.append(word);
|
||||||
|
currentLine.append(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int wordNum = 0;
|
||||||
|
int sentNum = 0;
|
||||||
|
|
||||||
|
double logProbT = 0.0;
|
||||||
|
for (String str : fileByLines) {
|
||||||
|
double logProbS = 0.0;
|
||||||
|
NGram currentNgram = new NGram(nGramLimit);
|
||||||
|
String[] tokens = str.split("\\s+");
|
||||||
|
for (String token : tokens) {
|
||||||
|
String word;
|
||||||
|
if (!words.contains(token) && useUnk) {
|
||||||
|
word = UNK;
|
||||||
|
} else {
|
||||||
|
word = token;
|
||||||
|
}
|
||||||
|
currentNgram.add(word);
|
||||||
|
NGram nMinusOneGram = currentNgram.subNGram(0,
|
||||||
|
currentNgram.size() - 1);
|
||||||
|
|
||||||
|
double prob = getProb(currentNgram, nMinusOneGram);
|
||||||
|
|
||||||
|
logProbS += Math.log(prob);
|
||||||
|
}
|
||||||
|
logProbT += logProbS;
|
||||||
|
wordNum += tokens.length;
|
||||||
|
sentNum++;
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Evaluated " + sentNum + " sentences and " + wordNum
|
||||||
|
+ " words using " + nGramLimit + "-gram model.");
|
||||||
|
int N = wordNum + sentNum;
|
||||||
|
System.out.println("Total n-grams: " + N);
|
||||||
|
double perplexity = Math.pow(Math.E, (-1.0 / N) * logProbT);
|
||||||
|
System.out.println("Perplexity: over " + N
|
||||||
|
+ " recognized n-grams in verification corpus: " + perplexity);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void generateModel(List<Headline> traininSet, boolean genRandom,
|
||||||
|
boolean useUnk) throws FileNotFoundException, IOException {
|
||||||
|
StringBuilder currentLine = new StringBuilder();
|
||||||
|
List<String> fileByLines = new ArrayList<String>();
|
||||||
|
|
||||||
|
for (Headline headline : traininSet) {
|
||||||
|
String headlineText = headline.getText();
|
||||||
|
if (headlineText.length() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String sanitizedLine = sanitize(headline.getText());
|
||||||
|
// split on whitespace
|
||||||
|
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
|
||||||
|
for (String token : tokens) {
|
||||||
|
if (!isWord(token)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
String word;
|
||||||
|
if (!wordsSeenOnce.contains(token) && useUnk) {
|
||||||
|
word = UNK;
|
||||||
|
wordsSeenOnce.add(token);
|
||||||
|
} else {
|
||||||
|
words.add(token);
|
||||||
|
word = token;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (END.equals(word)) {
|
||||||
|
currentLine.append(word);
|
||||||
|
fileByLines.add(currentLine.toString());
|
||||||
|
currentLine = new StringBuilder();
|
||||||
|
} else {
|
||||||
|
currentLine.append(word);
|
||||||
|
currentLine.append(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (String str : fileByLines) {
|
||||||
|
System.out.println(str);
|
||||||
|
NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH);
|
||||||
|
for (String token : str.split("\\s+")) {
|
||||||
|
currentNgram.add(token);
|
||||||
|
for (int i = 0; i <= currentNgram.size(); i++) {
|
||||||
|
addNGram(currentNgram.size() - i,
|
||||||
|
currentNgram.subNGram(i, currentNgram.size()));
|
||||||
|
totalNGramCounts[currentNgram.size() - i]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Most common words: ");
|
||||||
|
|
||||||
|
List<Entry<NGram, Integer>> unigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||||
|
nGrams.get(1).entrySet());
|
||||||
|
Collections.sort(unigrams, new NGramComparator());
|
||||||
|
|
||||||
|
List<Entry<NGram, Integer>> bigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||||
|
nGrams.get(2).entrySet());
|
||||||
|
Collections.sort(bigrams, new NGramComparator());
|
||||||
|
|
||||||
|
List<Entry<NGram, Integer>> trigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||||
|
nGrams.get(3).entrySet());
|
||||||
|
Collections.sort(trigrams, new NGramComparator());
|
||||||
|
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
System.out
|
||||||
|
.println(i
|
||||||
|
+ ". "
|
||||||
|
+ unigrams.get(i - 1).getKey()
|
||||||
|
+ " : "
|
||||||
|
+ (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1])));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
System.out
|
||||||
|
.println(i
|
||||||
|
+ ". "
|
||||||
|
+ bigrams.get(i - 1).getKey()
|
||||||
|
+ " : "
|
||||||
|
+ (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2])));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
System.out
|
||||||
|
.println(i
|
||||||
|
+ ". "
|
||||||
|
+ trigrams.get(i - 1).getKey()
|
||||||
|
+ " : "
|
||||||
|
+ (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3])));
|
||||||
|
}
|
||||||
|
if (genRandom) {
|
||||||
|
for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) {
|
||||||
|
System.out.println("Random sentence of length " + RANDOM_LENGTH
|
||||||
|
+ " using " + nGramLength + "-gram language model:");
|
||||||
|
StringBuilder randomText = new StringBuilder();
|
||||||
|
|
||||||
|
NGram randomNminusOneGram = new NGram(nGramLength);
|
||||||
|
for (int i = 0; i < RANDOM_LENGTH; i++) {
|
||||||
|
String randomToken = getRandomToken(nGrams,
|
||||||
|
randomNminusOneGram.size() + 1, randomNminusOneGram);
|
||||||
|
NGram randomNGram = randomNminusOneGram.copy();
|
||||||
|
randomNGram.add(randomToken);
|
||||||
|
randomText.append(randomNGram.get(randomNGram.size() - 1));
|
||||||
|
randomText.append(" ");
|
||||||
|
|
||||||
|
if (randomNGram.size() < nGramLength) {
|
||||||
|
randomNminusOneGram = randomNGram;
|
||||||
|
} else {
|
||||||
|
randomNminusOneGram = randomNGram.subNGram(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
System.out.println(randomText);
|
||||||
|
System.out.println();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isWord(String aString) {
|
||||||
|
if (aString == null || aString.length() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nonWords.contains(aString)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (START.equals(aString) || END.equals(aString) || UNK.equals(aString)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Matcher wordMatcher = wordPattern.matcher(aString);
|
||||||
|
if (wordMatcher.find()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
nonWords.add(aString);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void reportModel(List<Headline> trainingSet,
|
||||||
|
List<Headline> validationSet) {
|
||||||
|
try {
|
||||||
|
NGramModel ngm = new NGramModel();
|
||||||
|
boolean doCalcPerplexity = true;
|
||||||
|
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity);
|
||||||
|
if (doCalcPerplexity) {
|
||||||
|
for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||||
|
ngm.calcPerplexity(validationSet, i, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (FileNotFoundException fnfe) {
|
||||||
|
fnfe.printStackTrace(System.err);
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
ioe.printStackTrace(System.err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class NGramComparator implements Comparator<Entry<NGram, Integer>> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int compare(Entry<NGram, Integer> o1, Entry<NGram, Integer> o2) {
|
||||||
|
return o2.getValue().compareTo(o1.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,4 +9,5 @@ public interface HeadlineService {
|
|||||||
int insertHeadline(Headline headline);
|
int insertHeadline(Headline headline);
|
||||||
int[] insertHeadlines(List<Headline> headline);
|
int[] insertHeadlines(List<Headline> headline);
|
||||||
List<Headline> getHeadlines(String stock, Date date);
|
List<Headline> getHeadlines(String stock, Date date);
|
||||||
|
List<Headline> getHeadlines(String stock, Date startDate, Date endDate);
|
||||||
}
|
}
|
||||||
@@ -32,4 +32,9 @@ public class MySQLHeadlineServiceImpl implements HeadlineService {
|
|||||||
public List<Headline> getHeadlines(String stock, Date date) {
|
public List<Headline> getHeadlines(String stock, Date date) {
|
||||||
return headlineDao.select(stock, date);
|
return headlineDao.select(stock, date);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Headline> getHeadlines(String stock, Date startDate, Date endDate) {
|
||||||
|
return headlineDao.select(stock, startDate, endDate);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -86,4 +86,10 @@ public class YahooHeadlineServiceImpl implements HeadlineService {
|
|||||||
String formattedDate = DATE_FORMATTER.format(date);
|
String formattedDate = DATE_FORMATTER.format(date);
|
||||||
return QUERY_URL.replaceAll(STOCK_SYMBOL_FIELD, stock).replaceAll(STORY_DATE_FIELD, formattedDate);
|
return QUERY_URL.replaceAll(STOCK_SYMBOL_FIELD, stock).replaceAll(STORY_DATE_FIELD, formattedDate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Headline> getHeadlines(String stock, Date startDate,
|
||||||
|
Date endDate) {
|
||||||
|
throw new UnsupportedOperationException("This implementation does not support getting headlines for a date range.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user