Added script to pull historical stock data and resulting data files (1 per company). Added code to generate average price change per 1, 2 and 3-gram. Added code to output average price change per headline for VALIDATION dataset.
This commit is contained in:
@@ -35,10 +35,10 @@ public class HeadlinePuller {
|
||||
private static final int NO_ARGS = 5;
|
||||
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 6;
|
||||
|
||||
@Autowired
|
||||
HeadlineService mySQLHeadlineServiceImpl;
|
||||
@Autowired
|
||||
HeadlineService yahooHeadlineServiceImpl;
|
||||
//@Autowired
|
||||
//HeadlineService mySQLHeadlineServiceImpl;
|
||||
//@Autowired
|
||||
//HeadlineService yahooHeadlineServiceImpl;
|
||||
|
||||
private static void printUsage() {
|
||||
System.out
|
||||
@@ -115,10 +115,10 @@ public class HeadlinePuller {
|
||||
for (calendar.setTime(startDate); (today = calendar.getTime())
|
||||
.compareTo(endDate) <= 0; calendar
|
||||
.add(Calendar.DATE, 1)) {
|
||||
List<Headline> headlines = headlinePuller.pullHeadlines(
|
||||
company.getStockSymbol(), today);
|
||||
int[] updates = headlinePuller.mySQLHeadlineServiceImpl.insertHeadlines(headlines);
|
||||
System.out.println(updates.length + " rows updated");
|
||||
//List<Headline> headlines = headlinePuller.pullHeadlines(
|
||||
// company.getStockSymbol(), today);
|
||||
//int[] updates = headlinePuller.mySQLHeadlineServiceImpl.insertHeadlines(headlines);
|
||||
//System.out.println(updates.length + " rows updated");
|
||||
}
|
||||
}
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
@@ -132,12 +132,12 @@ public class HeadlinePuller {
|
||||
}
|
||||
}
|
||||
|
||||
private List<Headline> pullHeadlines(String stockSymbol, Date date) {
|
||||
List<Headline> headlines = yahooHeadlineServiceImpl.getHeadlines(
|
||||
stockSymbol, date);
|
||||
System.out.println("Pulled " + headlines.size() + " headlines for " + stockSymbol + " on " + date);
|
||||
return headlines;
|
||||
}
|
||||
//private List<Headline> pullHeadlines(String stockSymbol, Date date) {
|
||||
//List<Headline> headlines = yahooHeadlineServiceImpl.getHeadlines(
|
||||
// stockSymbol, date);
|
||||
//System.out.println("Pulled " + headlines.size() + " headlines for " + stockSymbol + " on " + date);
|
||||
//return headlines;
|
||||
//}
|
||||
|
||||
private List<Company> getFortune50(File csvFile)
|
||||
throws FileNotFoundException, IOException {
|
||||
|
||||
@@ -45,39 +45,30 @@ public class ModelGenerator {
|
||||
|
||||
Date startDate = null;
|
||||
Date endDate = null;
|
||||
Date valStart = null;
|
||||
Date valEnd = null;
|
||||
try {
|
||||
startDate = dateFmt.parse("2012-01-01");
|
||||
endDate = dateFmt.parse("2012-03-31");
|
||||
valStart = dateFmt.parse("2012-04-01");
|
||||
valEnd = dateFmt.parse("2012-04-14");
|
||||
endDate = dateFmt.parse("2012-04-14");
|
||||
} catch (ParseException pe) {
|
||||
System.exit(INVALID_DATE);
|
||||
}
|
||||
|
||||
List<Headline> trainingSet = new ArrayList<Headline>();
|
||||
//actually, this is the TEST dataset
|
||||
List<Headline> testSet = new ArrayList<Headline>();
|
||||
|
||||
try {
|
||||
List<Company> fortune50 = modelGenerator
|
||||
.getFortune50(stockSymbolsCSV);
|
||||
for (Company company : fortune50) {
|
||||
System.out.println("Getting headlines for Fortune 50 company #"
|
||||
+ company.getId() + " (" + company.getName() + ")...");
|
||||
List<Headline> trainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate);
|
||||
System.out.println("Pulled " + trainingSet.size() + " headlines for "
|
||||
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
|
||||
System.out.println("Pulled " + coTrainingSet.size() + " headlines for "
|
||||
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
||||
List<Headline> validationSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), valStart, valEnd);
|
||||
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
|
||||
|
||||
if (trainingSet.size() == 0) {
|
||||
System.out.println("Training dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
|
||||
continue;
|
||||
}
|
||||
if (validationSet.size() == 0) {
|
||||
System.out.println("Validation dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
|
||||
continue;
|
||||
}
|
||||
|
||||
modelGenerator.ngramModel.reportModel(trainingSet, validationSet);
|
||||
System.out.println("Finished " + company.getId() + " / 50");
|
||||
trainingSet.addAll(coTrainingSet);
|
||||
testSet.addAll(coTestSet);
|
||||
}
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
System.out.println("Stock symbol CSV file does not exist: "
|
||||
@@ -88,6 +79,8 @@ public class ModelGenerator {
|
||||
+ stockSymbolsCSV);
|
||||
System.exit(IO_EXCEPTION);
|
||||
}
|
||||
|
||||
//modelGenerator.ngramModel.reportModel(trainingSet, testSet);
|
||||
}
|
||||
|
||||
private List<Company> getFortune50(File csvFile)
|
||||
|
||||
180
src/net/woodyfolsom/cs6601/p3/PricePoller.java
Normal file
180
src/net/woodyfolsom/cs6601/p3/PricePoller.java
Normal file
@@ -0,0 +1,180 @@
|
||||
package net.woodyfolsom.cs6601.p3;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.text.DateFormat;
|
||||
import java.text.ParseException;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import net.woodyfolsom.cs6601.p3.domain.Company;
|
||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
|
||||
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
|
||||
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
||||
|
||||
@Component
|
||||
public class PricePoller {
|
||||
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
||||
|
||||
private static final int INVALID_DATE = 1;
|
||||
private static final int IO_EXCEPTION = 2;
|
||||
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
|
||||
|
||||
@Autowired
|
||||
HeadlineService mySQLHeadlineServiceImpl;
|
||||
|
||||
private NGramModel ngramModel = new NGramModel();
|
||||
|
||||
public static void main(String... args) {
|
||||
|
||||
ApplicationContext context = new FileSystemXmlApplicationContext(
|
||||
new String[] { "AppContext.xml" });
|
||||
PricePoller modelGenerator = context.getBean(PricePoller.class);
|
||||
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
|
||||
|
||||
Date startDate = null;
|
||||
Date endDate = null;
|
||||
try {
|
||||
startDate = dateFmt.parse("2012-01-01");
|
||||
endDate = dateFmt.parse("2012-04-14");
|
||||
} catch (ParseException pe) {
|
||||
System.exit(INVALID_DATE);
|
||||
}
|
||||
|
||||
List<Headline> trainingSet = new ArrayList<Headline>();
|
||||
//actually, this is the TEST dataset
|
||||
List<Headline> testSet = new ArrayList<Headline>();
|
||||
Map<String,Map<Date,StockPrice>> stockTrends = new HashMap<String,Map<Date,StockPrice>>();
|
||||
try {
|
||||
List<Company> fortune50 = modelGenerator
|
||||
.getFortune50(stockSymbolsCSV);
|
||||
for (Company company : fortune50) {
|
||||
stockTrends.put(company.getStockSymbol(), new HashMap<Date,StockPrice>());
|
||||
System.out.println("Polling price data for " + company.getName());
|
||||
File stockPriceFile = new File("data" + File.separator + company.getStockSymbol() + ".txt");
|
||||
BufferedReader buf;
|
||||
try {
|
||||
buf = new BufferedReader(new InputStreamReader(new FileInputStream(stockPriceFile)));
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
System.out.println("Unable to find historical stock data file for: " + company.getStockSymbol());
|
||||
continue;
|
||||
}
|
||||
String line;
|
||||
int linesRead = 0;
|
||||
try {
|
||||
while ((line = buf.readLine()) != null) {
|
||||
linesRead++;
|
||||
if (linesRead == 1) {
|
||||
continue; // header line
|
||||
}
|
||||
String[] fields = line.trim().split(",");
|
||||
Date date;
|
||||
try {
|
||||
date = dateFmt.parse(fields[0]);
|
||||
} catch (ParseException pe) {
|
||||
System.out.println("Error parsing date: " + fields[0]);
|
||||
continue;
|
||||
}
|
||||
if (date.compareTo(endDate) > 0) {
|
||||
continue;
|
||||
}
|
||||
if (date.compareTo(startDate) < 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
double open;
|
||||
double high;
|
||||
double low;
|
||||
double close;
|
||||
long volume;
|
||||
double adjClose;
|
||||
|
||||
try {
|
||||
open = Double.parseDouble(fields[1]);
|
||||
high = Double.parseDouble(fields[2]);
|
||||
low = Double.parseDouble(fields[3]);
|
||||
close = Double.parseDouble(fields[4]);
|
||||
volume = Long.parseLong(fields[5]);
|
||||
adjClose = Double.parseDouble(fields[6]);
|
||||
} catch (NumberFormatException nfe) {
|
||||
System.out.println(nfe.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
StockPrice stockPrice = new StockPrice(date,open,high,low,close,volume,adjClose);
|
||||
|
||||
stockTrends.get(company.getStockSymbol()).put(date,stockPrice);
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
System.err.println(ioe.getMessage());
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
buf.close();
|
||||
} catch (IOException ioe) {
|
||||
System.err.println(ioe.getMessage());
|
||||
}
|
||||
}
|
||||
for (Company company : fortune50) {
|
||||
System.out.println("Getting headlines for Fortune 50 company #"
|
||||
+ company.getId() + " (" + company.getName() + ")...");
|
||||
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
|
||||
System.out.println("Pulled " + coTrainingSet.size() + " TRAINING headlines for "
|
||||
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
||||
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
|
||||
System.out.println("Pulled " + coTestSet.size() + " TEST headlines for "
|
||||
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
||||
|
||||
trainingSet.addAll(coTrainingSet);
|
||||
testSet.addAll(coTestSet);
|
||||
}
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
System.out.println("Stock symbol CSV file does not exist: "
|
||||
+ stockSymbolsCSV);
|
||||
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
|
||||
} catch (IOException ioe) {
|
||||
System.out.println("Stock symbol CSV file does not exist: "
|
||||
+ stockSymbolsCSV);
|
||||
System.exit(IO_EXCEPTION);
|
||||
}
|
||||
|
||||
modelGenerator.ngramModel.reportModel(trainingSet, testSet, stockTrends);
|
||||
}
|
||||
|
||||
private List<Company> getFortune50(File csvFile)
|
||||
throws FileNotFoundException, IOException {
|
||||
List<Company> fortune50 = new ArrayList<Company>();
|
||||
FileInputStream fis = new FileInputStream(csvFile);
|
||||
InputStreamReader reader = new InputStreamReader(fis);
|
||||
BufferedReader buf = new BufferedReader(reader);
|
||||
String csvline = null;
|
||||
while ((csvline = buf.readLine()) != null) {
|
||||
if (csvline.length() == 0) {
|
||||
continue;
|
||||
}
|
||||
String[] fields = csvline.split(",");
|
||||
if (fields.length != 3) {
|
||||
throw new RuntimeException(
|
||||
"Badly formatted csv file name (3 values expected): "
|
||||
+ csvline);
|
||||
}
|
||||
int id = Integer.valueOf(fields[0]);
|
||||
fortune50.add(new Company(id, fields[1], fields[2]));
|
||||
}
|
||||
return fortune50;
|
||||
}
|
||||
}
|
||||
25
src/net/woodyfolsom/cs6601/p3/StockUtil.java
Normal file
25
src/net/woodyfolsom/cs6601/p3/StockUtil.java
Normal file
@@ -0,0 +1,25 @@
|
||||
package net.woodyfolsom.cs6601.p3;
|
||||
|
||||
import java.util.Calendar;
|
||||
import java.util.Date;
|
||||
|
||||
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
|
||||
|
||||
public class StockUtil {
|
||||
public static Date getNextTradingDay(Date date) {
|
||||
Calendar cal = Calendar.getInstance();
|
||||
cal.setTime(date);
|
||||
do {
|
||||
cal.add(Calendar.DATE, 1);
|
||||
} while (cal.get(Calendar.DAY_OF_WEEK) == 1 || cal.get(Calendar.DAY_OF_WEEK) == 7);
|
||||
return cal.getTime();
|
||||
}
|
||||
|
||||
public static double getPercentChange(StockPrice stockPrice) {
|
||||
double close = stockPrice.getClose();
|
||||
double open = stockPrice.getOpen();
|
||||
//If close is 2x open, pct change is 1.0;
|
||||
//If close is 0.9 * open, pct change is -0.10;
|
||||
return 100.0 * ((close / open) - 1.00);
|
||||
}
|
||||
}
|
||||
256
src/net/woodyfolsom/cs6601/p3/ValidationSetCreator.java
Normal file
256
src/net/woodyfolsom/cs6601/p3/ValidationSetCreator.java
Normal file
@@ -0,0 +1,256 @@
|
||||
package net.woodyfolsom.cs6601.p3;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.text.DateFormat;
|
||||
import java.text.DecimalFormat;
|
||||
import java.text.ParseException;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import net.woodyfolsom.cs6601.p3.domain.Company;
|
||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
|
||||
import net.woodyfolsom.cs6601.p3.ngram.NGram;
|
||||
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
|
||||
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
||||
|
||||
@Component
|
||||
public class ValidationSetCreator {
|
||||
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
||||
|
||||
private static final int INVALID_DATE = 1;
|
||||
private static final int IO_EXCEPTION = 2;
|
||||
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
|
||||
|
||||
@Autowired
|
||||
HeadlineService mySQLHeadlineServiceImpl;
|
||||
|
||||
public static void main(String... args) {
|
||||
|
||||
ApplicationContext context = new FileSystemXmlApplicationContext(
|
||||
new String[] { "AppContext.xml" });
|
||||
ValidationSetCreator modelGenerator = context
|
||||
.getBean(ValidationSetCreator.class);
|
||||
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
|
||||
|
||||
Date startDate = null;
|
||||
Date endDate = null;
|
||||
try {
|
||||
startDate = dateFmt.parse("2012-01-01");
|
||||
endDate = dateFmt.parse("2012-04-14");
|
||||
} catch (ParseException pe) {
|
||||
System.exit(INVALID_DATE);
|
||||
}
|
||||
|
||||
Map<String, Map<Date, StockPrice>> stockTrends = new HashMap<String, Map<Date, StockPrice>>();
|
||||
try {
|
||||
List<Company> fortune50 = modelGenerator
|
||||
.getFortune50(stockSymbolsCSV);
|
||||
for (Company company : fortune50) {
|
||||
stockTrends.put(company.getStockSymbol(),
|
||||
new HashMap<Date, StockPrice>());
|
||||
System.out.println("Polling price data for "
|
||||
+ company.getName());
|
||||
File stockPriceFile = new File("data" + File.separator
|
||||
+ company.getStockSymbol() + ".txt");
|
||||
BufferedReader buf;
|
||||
try {
|
||||
buf = new BufferedReader(new InputStreamReader(
|
||||
new FileInputStream(stockPriceFile)));
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
System.out
|
||||
.println("Unable to find historical stock data file for: "
|
||||
+ company.getStockSymbol());
|
||||
continue;
|
||||
}
|
||||
String line;
|
||||
int linesRead = 0;
|
||||
try {
|
||||
while ((line = buf.readLine()) != null) {
|
||||
linesRead++;
|
||||
if (linesRead == 1) {
|
||||
continue; // header line
|
||||
}
|
||||
String[] fields = line.trim().split(",");
|
||||
Date date;
|
||||
try {
|
||||
date = dateFmt.parse(fields[0]);
|
||||
} catch (ParseException pe) {
|
||||
System.out.println("Error parsing date: "
|
||||
+ fields[0]);
|
||||
continue;
|
||||
}
|
||||
if (date.compareTo(endDate) > 0) {
|
||||
continue;
|
||||
}
|
||||
if (date.compareTo(startDate) < 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
double open;
|
||||
double high;
|
||||
double low;
|
||||
double close;
|
||||
long volume;
|
||||
double adjClose;
|
||||
|
||||
try {
|
||||
open = Double.parseDouble(fields[1]);
|
||||
high = Double.parseDouble(fields[2]);
|
||||
low = Double.parseDouble(fields[3]);
|
||||
close = Double.parseDouble(fields[4]);
|
||||
volume = Long.parseLong(fields[5]);
|
||||
adjClose = Double.parseDouble(fields[6]);
|
||||
} catch (NumberFormatException nfe) {
|
||||
System.out.println(nfe.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
StockPrice stockPrice = new StockPrice(date, open,
|
||||
high, low, close, volume, adjClose);
|
||||
|
||||
stockTrends.get(company.getStockSymbol()).put(date,
|
||||
stockPrice);
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
System.err.println(ioe.getMessage());
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
buf.close();
|
||||
} catch (IOException ioe) {
|
||||
System.err.println(ioe.getMessage());
|
||||
}
|
||||
}
|
||||
List<Headline> valSet = new ArrayList<Headline>();
|
||||
for (Company company : fortune50) {
|
||||
System.out.println("Getting headlines for Fortune 50 company #"
|
||||
+ company.getId() + " (" + company.getName() + ")...");
|
||||
List<Headline> coValSet = modelGenerator.mySQLHeadlineServiceImpl
|
||||
.getHeadlines(company.getStockSymbol(), startDate,
|
||||
endDate, 3);
|
||||
System.out.println("Pulled " + coValSet.size()
|
||||
+ " VALIDATION headlines for "
|
||||
+ company.getStockSymbol() + " from " + startDate
|
||||
+ " to " + endDate);
|
||||
|
||||
valSet.addAll(coValSet);
|
||||
}
|
||||
|
||||
File file = new File("validation.txt");
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
|
||||
new FileOutputStream(file)));
|
||||
Map<String,Integer> headlineCount = new HashMap<String,Integer>();
|
||||
Map<String,Double> totalPctChange = new HashMap<String,Double>();
|
||||
|
||||
for (Headline headline : valSet) {
|
||||
String text = headline.getText();
|
||||
Integer count = headlineCount.get(text);
|
||||
Double pctChange = totalPctChange.get(text);
|
||||
Date date = headline.getDate();
|
||||
String stock = headline.getStock();
|
||||
StockPrice stockPrice = stockTrends.get(stock).get(
|
||||
StockUtil.getNextTradingDay(date));
|
||||
double pctPriceChange;
|
||||
if (stockPrice == null) {
|
||||
pctPriceChange = 0.0;
|
||||
} else {
|
||||
pctPriceChange = StockUtil.getPercentChange(stockPrice);
|
||||
}
|
||||
if (count == null) {
|
||||
headlineCount.put(text, 1);
|
||||
totalPctChange.put(text, pctPriceChange);
|
||||
} else {
|
||||
headlineCount.put(text, count+1);
|
||||
totalPctChange.put(text, pctChange + pctPriceChange);
|
||||
}
|
||||
}
|
||||
|
||||
Set<String> processedSet = new HashSet<String>();
|
||||
DecimalFormat decFmt = new DecimalFormat("###0.0000");
|
||||
for (Headline headline : valSet) {
|
||||
String text = headline.getText();
|
||||
if (processedSet.contains(text)) {
|
||||
continue;
|
||||
}
|
||||
processedSet.add(text);
|
||||
int id = headline.getId();
|
||||
String stock = headline.getStock();
|
||||
Date date = headline.getDate();
|
||||
String dateFormatted = dateFmt.format(date);
|
||||
|
||||
double totalPriceChange = totalPctChange.get(text);
|
||||
int totalCount = headlineCount.get(text);
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(id);
|
||||
sb.append(", ");
|
||||
sb.append(stock);
|
||||
sb.append(", ");
|
||||
sb.append(dateFormatted);
|
||||
sb.append(", ");
|
||||
sb.append(decFmt.format(totalPriceChange/totalCount));
|
||||
sb.append(", ");
|
||||
|
||||
text = text.replaceAll(
|
||||
"[\'\";:,\\]\\[]", " ");
|
||||
text = text.replaceAll(
|
||||
"[^A-Za-z0-9 ]", "");
|
||||
sb.append(text);
|
||||
|
||||
sb.append("\n");
|
||||
writer.write(sb.toString());
|
||||
}
|
||||
writer.close();
|
||||
} catch (FileNotFoundException fnfe) {
|
||||
System.out.println("Stock symbol CSV file does not exist: "
|
||||
+ stockSymbolsCSV);
|
||||
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
|
||||
} catch (IOException ioe) {
|
||||
System.out.println("Stock symbol CSV file does not exist: "
|
||||
+ stockSymbolsCSV);
|
||||
System.exit(IO_EXCEPTION);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Company> getFortune50(File csvFile)
|
||||
throws FileNotFoundException, IOException {
|
||||
List<Company> fortune50 = new ArrayList<Company>();
|
||||
FileInputStream fis = new FileInputStream(csvFile);
|
||||
InputStreamReader reader = new InputStreamReader(fis);
|
||||
BufferedReader buf = new BufferedReader(reader);
|
||||
String csvline = null;
|
||||
while ((csvline = buf.readLine()) != null) {
|
||||
if (csvline.length() == 0) {
|
||||
continue;
|
||||
}
|
||||
String[] fields = csvline.split(",");
|
||||
if (fields.length != 3) {
|
||||
throw new RuntimeException(
|
||||
"Badly formatted csv file name (3 values expected): "
|
||||
+ csvline);
|
||||
}
|
||||
int id = Integer.valueOf(fields[0]);
|
||||
fortune50.add(new Company(id, fields[1], fields[2]));
|
||||
}
|
||||
return fortune50;
|
||||
}
|
||||
}
|
||||
@@ -6,12 +6,14 @@ import java.util.List;
|
||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
|
||||
public interface HeadlineDao {
|
||||
|
||||
boolean assignRandomDatasets(int training, int test, int validation);
|
||||
int getCount();
|
||||
int getCount(int dataset);
|
||||
int deleteById(int id);
|
||||
int insert(Headline headline);
|
||||
int[] insertBatch(List<Headline> headlines);
|
||||
|
||||
Headline select(int id);
|
||||
List<Headline> select(String stock, Date date);
|
||||
List<Headline> select(String stock, Date startDate, Date endDate);
|
||||
List<Headline> select(String stock, Date startDate, Date endDate, int dataset);
|
||||
}
|
||||
|
||||
@@ -18,16 +18,36 @@ import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
|
||||
@Repository
|
||||
public class HeadlineDaoImpl implements HeadlineDao {
|
||||
private static final String COUNT_ALL_QRY = "SELECT COUNT(1) FROM headlines";
|
||||
private static final String COUNT_DATASET_QRY = "SELECT COUNT(1) FROM headlines where dataset = ?";
|
||||
|
||||
private static final String DELETE_BY_ID_STMT = "DELETE from headlines WHERE id = ?";
|
||||
|
||||
private static final String INSERT_STMT = "INSERT INTO headlines (text, date, stock, dataset) values (?, ?, ?, ?)";
|
||||
|
||||
private static final String SELECT_BY_ID_QRY = "SELECT * from headlines WHERE id = ?";
|
||||
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ?";
|
||||
private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ?";
|
||||
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ? AND dataset = 1";
|
||||
private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ? AND dataset = ?";
|
||||
|
||||
private static final String ASSIGN_RANDOM_PCT_QRY = "update headlines set dataset = (select FLOOR(RAND() * (200 - 101) + 101))";
|
||||
private static final String REMAP_TRAINING_QRY = "update headlines set dataset = 1 where dataset >= 101 and dataset <= (100 + ?)";
|
||||
private static final String REMAP_TEST_QRY = "update headlines set dataset = 2 where dataset >= (100 + ?) and dataset <= (100 + ?)";
|
||||
private static final String REMAP_VAL_QRY = "update headlines set dataset = 3 where dataset >= (100 + ?) and dataset <= 200";
|
||||
|
||||
private JdbcTemplate jdbcTemplate;
|
||||
|
||||
@Override
|
||||
public boolean assignRandomDatasets(int training, int test, int validation) {
|
||||
if (training + test + validation != 100) {
|
||||
return false;
|
||||
}
|
||||
jdbcTemplate.update(ASSIGN_RANDOM_PCT_QRY);
|
||||
jdbcTemplate.update(REMAP_TRAINING_QRY,training);
|
||||
jdbcTemplate.update(REMAP_TEST_QRY,training,training+test);
|
||||
jdbcTemplate.update(REMAP_VAL_QRY,training+test);
|
||||
return true;
|
||||
}
|
||||
|
||||
public int deleteById(int headlineId) {
|
||||
return jdbcTemplate.update(DELETE_BY_ID_STMT,
|
||||
new RequestMapper(), headlineId);
|
||||
@@ -64,12 +84,12 @@ public class HeadlineDaoImpl implements HeadlineDao {
|
||||
|
||||
public List<Headline> select(String stock, Date date) {
|
||||
return jdbcTemplate.query(SELECT_BY_STOCK_QRY,
|
||||
new RequestMapper(), stock, date);
|
||||
new RequestMapper(), stock, date, 1);
|
||||
}
|
||||
|
||||
public List<Headline> select(String stock, Date startDate, Date endDate) {
|
||||
public List<Headline> select(String stock, Date startDate, Date endDate, int dataset) {
|
||||
return jdbcTemplate.query(SELECT_BY_DATE_RANGE_QRY,
|
||||
new RequestMapper(), stock, startDate, endDate);
|
||||
new RequestMapper(), stock, startDate, endDate, dataset);
|
||||
}
|
||||
|
||||
@Autowired
|
||||
@@ -82,6 +102,7 @@ public class HeadlineDaoImpl implements HeadlineDao {
|
||||
@Override
|
||||
public Headline mapRow(ResultSet rs, int arg1) throws SQLException {
|
||||
Headline headline = new Headline();
|
||||
headline.setId(rs.getInt("id"));
|
||||
headline.setText(rs.getString("text"));
|
||||
headline.setStock(rs.getString("stock"));
|
||||
headline.setDate(rs.getDate("date"));
|
||||
@@ -90,4 +111,14 @@ public class HeadlineDaoImpl implements HeadlineDao {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCount() {
|
||||
return jdbcTemplate.queryForInt(COUNT_ALL_QRY);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCount(int dataset) {
|
||||
return jdbcTemplate.queryForInt(COUNT_DATASET_QRY,dataset);
|
||||
}
|
||||
}
|
||||
57
src/net/woodyfolsom/cs6601/p3/domain/StockPrice.java
Normal file
57
src/net/woodyfolsom/cs6601/p3/domain/StockPrice.java
Normal file
@@ -0,0 +1,57 @@
|
||||
package net.woodyfolsom.cs6601.p3.domain;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
public class StockPrice {
|
||||
//Date,Open,High,Low,Close,Volume,Adj_Close
|
||||
//2012-04-17,0.28,0.33,0.28,0.32,3408100,0.32
|
||||
final Date date;
|
||||
|
||||
final double open;
|
||||
final double high;
|
||||
final double low;
|
||||
final double close;
|
||||
final double adjClose;
|
||||
|
||||
final long volume;
|
||||
|
||||
public StockPrice(Date date, double open, double high, double low, double close,
|
||||
long volume, double adjClose) {
|
||||
super();
|
||||
this.date = date;
|
||||
this.open = open;
|
||||
this.high = high;
|
||||
this.low = low;
|
||||
this.close = close;
|
||||
this.volume = volume;
|
||||
this.adjClose = adjClose;
|
||||
}
|
||||
|
||||
public Date getDate() {
|
||||
return date;
|
||||
}
|
||||
|
||||
public double getOpen() {
|
||||
return open;
|
||||
}
|
||||
|
||||
public double getClose() {
|
||||
return close;
|
||||
}
|
||||
|
||||
public double getHigh() {
|
||||
return high;
|
||||
}
|
||||
|
||||
public double getLow() {
|
||||
return low;
|
||||
}
|
||||
|
||||
public long getVolume() {
|
||||
return volume;
|
||||
}
|
||||
|
||||
public double getAdjClose() {
|
||||
return adjClose;
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,17 @@
|
||||
package net.woodyfolsom.cs6601.p3.ngram;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Calendar;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -14,7 +21,9 @@ import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import net.woodyfolsom.cs6601.p3.StockUtil;
|
||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
|
||||
|
||||
public class NGramModel {
|
||||
static final int MAX_N_GRAM_LENGTH = 3;
|
||||
@@ -25,6 +34,8 @@ public class NGramModel {
|
||||
private static final String UNK = "<unk>";
|
||||
|
||||
private Map<Integer, NGramDistribution> nGrams;
|
||||
private Map<Integer, Map<NGram,Double>> nGramPriceAvg;
|
||||
|
||||
private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1];
|
||||
|
||||
private Pattern wordPattern = Pattern.compile("\\w+");
|
||||
@@ -88,9 +99,13 @@ public class NGramModel {
|
||||
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||
nGrams.put(i, new NGramDistribution());
|
||||
}
|
||||
nGramPriceAvg = new HashMap<Integer, Map<NGram,Double>>();
|
||||
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||
nGramPriceAvg.put(i, new HashMap<NGram,Double>());
|
||||
}
|
||||
}
|
||||
|
||||
private void addNGram(int nGramLength, NGram nGram) {
|
||||
|
||||
private void addNGram(int nGramLength, NGram nGram, String stockName, Date date, Map<String, Map<Date,StockPrice>> stockTrends) {
|
||||
if (nGram.size() < nGramLength) {
|
||||
System.out.println("Cannot create " + nGramLength + "-gram from: "
|
||||
+ nGram);
|
||||
@@ -105,10 +120,31 @@ public class NGramModel {
|
||||
} else {
|
||||
nGramCounts.put(nGramCopy, 1);
|
||||
}
|
||||
|
||||
Map<NGram, Double> nGramPriceAvgs = nGramPriceAvg.get(nGramLength);
|
||||
|
||||
NGram nGramCopy2 = nGram.copy(nGramLength);
|
||||
|
||||
//TODO GET NEXT TRADING DAY'S DATE
|
||||
Date nextDay = StockUtil.getNextTradingDay(date);
|
||||
StockPrice stockPrice = stockTrends.get(stockName).get(nextDay);
|
||||
double percentChange;
|
||||
if (stockPrice == null) {
|
||||
percentChange = 0.0;
|
||||
} else {
|
||||
percentChange = StockUtil.getPercentChange(stockPrice);
|
||||
}
|
||||
|
||||
if (nGramPriceAvgs.containsKey(nGramCopy2)) {
|
||||
double totalPercentChange = nGramPriceAvgs.get(nGramCopy);
|
||||
nGramPriceAvgs.put(nGramCopy2, totalPercentChange + percentChange);
|
||||
} else {
|
||||
nGramPriceAvgs.put(nGramCopy2, percentChange);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an arbitrary String, replace punctutation with spaces, remove
|
||||
* Given an arbitrary String, replace punctuation with spaces, remove
|
||||
* non-alphanumeric characters, prepend with <START> token, append <END>
|
||||
* token.
|
||||
*
|
||||
@@ -193,12 +229,11 @@ public class NGramModel {
|
||||
+ " recognized n-grams in verification corpus: " + perplexity);
|
||||
}
|
||||
|
||||
private void generateModel(List<Headline> traininSet, boolean genRandom,
|
||||
boolean useUnk) throws FileNotFoundException, IOException {
|
||||
StringBuilder currentLine = new StringBuilder();
|
||||
List<String> fileByLines = new ArrayList<String>();
|
||||
private void generateModel(List<Headline> trainingSet, boolean genRandom,
|
||||
boolean useUnk, Map<String,Map<Date,StockPrice>> stockTrends) throws FileNotFoundException, IOException {
|
||||
//List<String> fileByLines = new ArrayList<String>();
|
||||
|
||||
for (Headline headline : traininSet) {
|
||||
for (Headline headline : trainingSet) {
|
||||
String headlineText = headline.getText();
|
||||
if (headlineText.length() == 0) {
|
||||
continue;
|
||||
@@ -206,6 +241,9 @@ public class NGramModel {
|
||||
String sanitizedLine = sanitize(headline.getText());
|
||||
// split on whitespace
|
||||
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
|
||||
|
||||
StringBuilder currentLine = new StringBuilder();
|
||||
|
||||
for (String token : tokens) {
|
||||
if (!isWord(token)) {
|
||||
continue;
|
||||
@@ -222,67 +260,67 @@ public class NGramModel {
|
||||
|
||||
if (END.equals(word)) {
|
||||
currentLine.append(word);
|
||||
fileByLines.add(currentLine.toString());
|
||||
currentLine = new StringBuilder();
|
||||
//fileByLines.add(currentLine.toString());
|
||||
//currentLine = new StringBuilder();
|
||||
} else {
|
||||
currentLine.append(word);
|
||||
currentLine.append(" ");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (String str : fileByLines) {
|
||||
String str = currentLine.toString();
|
||||
System.out.println(str);
|
||||
NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH);
|
||||
for (String token : str.split("\\s+")) {
|
||||
currentNgram.add(token);
|
||||
for (int i = 0; i <= currentNgram.size(); i++) {
|
||||
addNGram(currentNgram.size() - i,
|
||||
currentNgram.subNGram(i, currentNgram.size()));
|
||||
currentNgram.subNGram(i, currentNgram.size()), headline.getStock(), headline.getDate(), stockTrends);
|
||||
totalNGramCounts[currentNgram.size() - i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Most common words: ");
|
||||
|
||||
List<Entry<NGram, Integer>> unigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(1).entrySet());
|
||||
Collections.sort(unigrams, new NGramComparator());
|
||||
|
||||
List<Entry<NGram, Integer>> bigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(2).entrySet());
|
||||
Collections.sort(bigrams, new NGramComparator());
|
||||
|
||||
List<Entry<NGram, Integer>> trigrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(3).entrySet());
|
||||
Collections.sort(trigrams, new NGramComparator());
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ unigrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1])));
|
||||
}
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ bigrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2])));
|
||||
}
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ trigrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3])));
|
||||
DecimalFormat decFmt = new DecimalFormat("###0.0000");
|
||||
for (int modelIndex = 1; modelIndex <= 3; modelIndex++) {
|
||||
File file = new File(modelIndex + "grams.txt");
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file)));
|
||||
List<Entry<NGram, Integer>> ngrams = new ArrayList<Entry<NGram, Integer>>(
|
||||
nGrams.get(modelIndex).entrySet());
|
||||
Collections.sort(ngrams, new NGramComparator());
|
||||
System.out.println("Highest frequency " + modelIndex + "-grams:");
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
System.out
|
||||
.println(i
|
||||
+ ". "
|
||||
+ ngrams.get(i - 1).getKey()
|
||||
+ " : "
|
||||
+ (((double) (ngrams.get(i - 1).getValue()) / totalNGramCounts[1])));
|
||||
}
|
||||
Map<NGram,Double> pricesForModel = nGramPriceAvg.get(modelIndex);
|
||||
for (int nGramIndex = 1; nGramIndex <= ngrams.size(); nGramIndex++) {
|
||||
NGram key = ngrams.get(nGramIndex - 1).getKey();
|
||||
writer.write(key.toString());
|
||||
writer.write(",");
|
||||
int count = ngrams.get(nGramIndex - 1).getValue();
|
||||
writer.write(Integer.toString(count));
|
||||
writer.write(",");
|
||||
double avgPrice;
|
||||
try {
|
||||
avgPrice = pricesForModel.get(key);
|
||||
System.out.println("Avg price for " + modelIndex + "-gram " + key +": " + avgPrice);
|
||||
} catch (NullPointerException npe) {
|
||||
System.out.println("null avgPrice for " + modelIndex + "-gram " + key);
|
||||
avgPrice = 0.0;
|
||||
}
|
||||
writer.write(decFmt.format(avgPrice/(double)count));
|
||||
writer.write("\n");
|
||||
}
|
||||
try {
|
||||
writer.close();
|
||||
} catch (IOException ioe) {
|
||||
System.out.println(ioe.getMessage());
|
||||
}
|
||||
}
|
||||
if (genRandom) {
|
||||
for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) {
|
||||
@@ -333,11 +371,11 @@ public class NGramModel {
|
||||
}
|
||||
|
||||
public void reportModel(List<Headline> trainingSet,
|
||||
List<Headline> validationSet) {
|
||||
List<Headline> validationSet, Map<String,Map<Date,StockPrice>> stockTrends) {
|
||||
try {
|
||||
NGramModel ngm = new NGramModel();
|
||||
boolean doCalcPerplexity = true;
|
||||
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity);
|
||||
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity, stockTrends);
|
||||
if (doCalcPerplexity) {
|
||||
for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) {
|
||||
ngm.calcPerplexity(validationSet, i, true);
|
||||
|
||||
@@ -6,8 +6,11 @@ import java.util.List;
|
||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||
|
||||
public interface HeadlineService {
|
||||
boolean assignRandomDatasets(int training, int test, int validation);
|
||||
int getCount();
|
||||
int getCount(int dataset);
|
||||
int insertHeadline(Headline headline);
|
||||
int[] insertHeadlines(List<Headline> headline);
|
||||
List<Headline> getHeadlines(String stock, Date date);
|
||||
List<Headline> getHeadlines(String stock, Date startDate, Date endDate);
|
||||
List<Headline> getHeadlines(String stock, Date startDate, Date endDate, int dataset);
|
||||
}
|
||||
@@ -34,7 +34,22 @@ public class MySQLHeadlineServiceImpl implements HeadlineService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Headline> getHeadlines(String stock, Date startDate, Date endDate) {
|
||||
return headlineDao.select(stock, startDate, endDate);
|
||||
public List<Headline> getHeadlines(String stock, Date startDate, Date endDate, int dataset) {
|
||||
return headlineDao.select(stock, startDate, endDate, dataset);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean assignRandomDatasets(int training, int test, int validation) {
|
||||
return headlineDao.assignRandomDatasets(training, test, validation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCount() {
|
||||
return headlineDao.getCount();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCount(int dataset) {
|
||||
return headlineDao.getCount(dataset);
|
||||
}
|
||||
}
|
||||
@@ -89,7 +89,22 @@ public class YahooHeadlineServiceImpl implements HeadlineService {
|
||||
|
||||
@Override
|
||||
public List<Headline> getHeadlines(String stock, Date startDate,
|
||||
Date endDate) {
|
||||
Date endDate, int dataset) {
|
||||
throw new UnsupportedOperationException("This implementation does not support getting headlines for a date range.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean assignRandomDatasets(int training, int test, int validation) {
|
||||
throw new UnsupportedOperationException("This implementation does not support this method.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCount() {
|
||||
throw new UnsupportedOperationException("This implementation does not support this method");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCount(int dataset) {
|
||||
throw new UnsupportedOperationException("This implementation does not support this method");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user