Added script to pull historical stock data and resulting data files (1 per company). Added code to generate average price change per 1, 2 and 3-gram. Added code to output average price change per headline for VALIDATION dataset.

This commit is contained in:
Woody Folsom
2012-04-20 21:22:54 -04:00
parent eec32b19c1
commit 6e3680426e
65 changed files with 360216 additions and 102 deletions

View File

@@ -35,10 +35,10 @@ public class HeadlinePuller {
private static final int NO_ARGS = 5;
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 6;
@Autowired
HeadlineService mySQLHeadlineServiceImpl;
@Autowired
HeadlineService yahooHeadlineServiceImpl;
//@Autowired
//HeadlineService mySQLHeadlineServiceImpl;
//@Autowired
//HeadlineService yahooHeadlineServiceImpl;
private static void printUsage() {
System.out
@@ -115,10 +115,10 @@ public class HeadlinePuller {
for (calendar.setTime(startDate); (today = calendar.getTime())
.compareTo(endDate) <= 0; calendar
.add(Calendar.DATE, 1)) {
List<Headline> headlines = headlinePuller.pullHeadlines(
company.getStockSymbol(), today);
int[] updates = headlinePuller.mySQLHeadlineServiceImpl.insertHeadlines(headlines);
System.out.println(updates.length + " rows updated");
//List<Headline> headlines = headlinePuller.pullHeadlines(
// company.getStockSymbol(), today);
//int[] updates = headlinePuller.mySQLHeadlineServiceImpl.insertHeadlines(headlines);
//System.out.println(updates.length + " rows updated");
}
}
} catch (FileNotFoundException fnfe) {
@@ -132,12 +132,12 @@ public class HeadlinePuller {
}
}
private List<Headline> pullHeadlines(String stockSymbol, Date date) {
List<Headline> headlines = yahooHeadlineServiceImpl.getHeadlines(
stockSymbol, date);
System.out.println("Pulled " + headlines.size() + " headlines for " + stockSymbol + " on " + date);
return headlines;
}
//private List<Headline> pullHeadlines(String stockSymbol, Date date) {
//List<Headline> headlines = yahooHeadlineServiceImpl.getHeadlines(
// stockSymbol, date);
//System.out.println("Pulled " + headlines.size() + " headlines for " + stockSymbol + " on " + date);
//return headlines;
//}
private List<Company> getFortune50(File csvFile)
throws FileNotFoundException, IOException {

View File

@@ -45,39 +45,30 @@ public class ModelGenerator {
Date startDate = null;
Date endDate = null;
Date valStart = null;
Date valEnd = null;
try {
startDate = dateFmt.parse("2012-01-01");
endDate = dateFmt.parse("2012-03-31");
valStart = dateFmt.parse("2012-04-01");
valEnd = dateFmt.parse("2012-04-14");
endDate = dateFmt.parse("2012-04-14");
} catch (ParseException pe) {
System.exit(INVALID_DATE);
}
List<Headline> trainingSet = new ArrayList<Headline>();
//actually, this is the TEST dataset
List<Headline> testSet = new ArrayList<Headline>();
try {
List<Company> fortune50 = modelGenerator
.getFortune50(stockSymbolsCSV);
for (Company company : fortune50) {
System.out.println("Getting headlines for Fortune 50 company #"
+ company.getId() + " (" + company.getName() + ")...");
List<Headline> trainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate);
System.out.println("Pulled " + trainingSet.size() + " headlines for "
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
System.out.println("Pulled " + coTrainingSet.size() + " headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
List<Headline> validationSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), valStart, valEnd);
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
if (trainingSet.size() == 0) {
System.out.println("Training dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
continue;
}
if (validationSet.size() == 0) {
System.out.println("Validation dataset contains 0 headlines for " + company.getName() + ", skipping model generation.");
continue;
}
modelGenerator.ngramModel.reportModel(trainingSet, validationSet);
System.out.println("Finished " + company.getId() + " / 50");
trainingSet.addAll(coTrainingSet);
testSet.addAll(coTestSet);
}
} catch (FileNotFoundException fnfe) {
System.out.println("Stock symbol CSV file does not exist: "
@@ -88,6 +79,8 @@ public class ModelGenerator {
+ stockSymbolsCSV);
System.exit(IO_EXCEPTION);
}
//modelGenerator.ngramModel.reportModel(trainingSet, testSet);
}
private List<Company> getFortune50(File csvFile)

View File

@@ -0,0 +1,180 @@
package net.woodyfolsom.cs6601.p3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.FileSystemXmlApplicationContext;
import org.springframework.stereotype.Component;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
@Component
public class PricePoller {
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
private static final int INVALID_DATE = 1;
private static final int IO_EXCEPTION = 2;
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
@Autowired
HeadlineService mySQLHeadlineServiceImpl;
private NGramModel ngramModel = new NGramModel();
public static void main(String... args) {
ApplicationContext context = new FileSystemXmlApplicationContext(
new String[] { "AppContext.xml" });
PricePoller modelGenerator = context.getBean(PricePoller.class);
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
Date startDate = null;
Date endDate = null;
try {
startDate = dateFmt.parse("2012-01-01");
endDate = dateFmt.parse("2012-04-14");
} catch (ParseException pe) {
System.exit(INVALID_DATE);
}
List<Headline> trainingSet = new ArrayList<Headline>();
//actually, this is the TEST dataset
List<Headline> testSet = new ArrayList<Headline>();
Map<String,Map<Date,StockPrice>> stockTrends = new HashMap<String,Map<Date,StockPrice>>();
try {
List<Company> fortune50 = modelGenerator
.getFortune50(stockSymbolsCSV);
for (Company company : fortune50) {
stockTrends.put(company.getStockSymbol(), new HashMap<Date,StockPrice>());
System.out.println("Polling price data for " + company.getName());
File stockPriceFile = new File("data" + File.separator + company.getStockSymbol() + ".txt");
BufferedReader buf;
try {
buf = new BufferedReader(new InputStreamReader(new FileInputStream(stockPriceFile)));
} catch (FileNotFoundException fnfe) {
System.out.println("Unable to find historical stock data file for: " + company.getStockSymbol());
continue;
}
String line;
int linesRead = 0;
try {
while ((line = buf.readLine()) != null) {
linesRead++;
if (linesRead == 1) {
continue; // header line
}
String[] fields = line.trim().split(",");
Date date;
try {
date = dateFmt.parse(fields[0]);
} catch (ParseException pe) {
System.out.println("Error parsing date: " + fields[0]);
continue;
}
if (date.compareTo(endDate) > 0) {
continue;
}
if (date.compareTo(startDate) < 0) {
break;
}
double open;
double high;
double low;
double close;
long volume;
double adjClose;
try {
open = Double.parseDouble(fields[1]);
high = Double.parseDouble(fields[2]);
low = Double.parseDouble(fields[3]);
close = Double.parseDouble(fields[4]);
volume = Long.parseLong(fields[5]);
adjClose = Double.parseDouble(fields[6]);
} catch (NumberFormatException nfe) {
System.out.println(nfe.getMessage());
continue;
}
StockPrice stockPrice = new StockPrice(date,open,high,low,close,volume,adjClose);
stockTrends.get(company.getStockSymbol()).put(date,stockPrice);
}
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
continue;
}
try {
buf.close();
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
}
}
for (Company company : fortune50) {
System.out.println("Getting headlines for Fortune 50 company #"
+ company.getId() + " (" + company.getName() + ")...");
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
System.out.println("Pulled " + coTrainingSet.size() + " TRAINING headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
System.out.println("Pulled " + coTestSet.size() + " TEST headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
trainingSet.addAll(coTrainingSet);
testSet.addAll(coTestSet);
}
} catch (FileNotFoundException fnfe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
} catch (IOException ioe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(IO_EXCEPTION);
}
modelGenerator.ngramModel.reportModel(trainingSet, testSet, stockTrends);
}
private List<Company> getFortune50(File csvFile)
throws FileNotFoundException, IOException {
List<Company> fortune50 = new ArrayList<Company>();
FileInputStream fis = new FileInputStream(csvFile);
InputStreamReader reader = new InputStreamReader(fis);
BufferedReader buf = new BufferedReader(reader);
String csvline = null;
while ((csvline = buf.readLine()) != null) {
if (csvline.length() == 0) {
continue;
}
String[] fields = csvline.split(",");
if (fields.length != 3) {
throw new RuntimeException(
"Badly formatted csv file name (3 values expected): "
+ csvline);
}
int id = Integer.valueOf(fields[0]);
fortune50.add(new Company(id, fields[1], fields[2]));
}
return fortune50;
}
}

View File

@@ -0,0 +1,25 @@
package net.woodyfolsom.cs6601.p3;
import java.util.Calendar;
import java.util.Date;
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
public class StockUtil {
public static Date getNextTradingDay(Date date) {
Calendar cal = Calendar.getInstance();
cal.setTime(date);
do {
cal.add(Calendar.DATE, 1);
} while (cal.get(Calendar.DAY_OF_WEEK) == 1 || cal.get(Calendar.DAY_OF_WEEK) == 7);
return cal.getTime();
}
public static double getPercentChange(StockPrice stockPrice) {
double close = stockPrice.getClose();
double open = stockPrice.getOpen();
//If close is 2x open, pct change is 1.0;
//If close is 0.9 * open, pct change is -0.10;
return 100.0 * ((close / open) - 1.00);
}
}

View File

@@ -0,0 +1,256 @@
package net.woodyfolsom.cs6601.p3;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.text.DateFormat;
import java.text.DecimalFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.FileSystemXmlApplicationContext;
import org.springframework.stereotype.Component;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
import net.woodyfolsom.cs6601.p3.ngram.NGram;
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
@Component
public class ValidationSetCreator {
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
private static final int INVALID_DATE = 1;
private static final int IO_EXCEPTION = 2;
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
@Autowired
HeadlineService mySQLHeadlineServiceImpl;
public static void main(String... args) {
ApplicationContext context = new FileSystemXmlApplicationContext(
new String[] { "AppContext.xml" });
ValidationSetCreator modelGenerator = context
.getBean(ValidationSetCreator.class);
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
Date startDate = null;
Date endDate = null;
try {
startDate = dateFmt.parse("2012-01-01");
endDate = dateFmt.parse("2012-04-14");
} catch (ParseException pe) {
System.exit(INVALID_DATE);
}
Map<String, Map<Date, StockPrice>> stockTrends = new HashMap<String, Map<Date, StockPrice>>();
try {
List<Company> fortune50 = modelGenerator
.getFortune50(stockSymbolsCSV);
for (Company company : fortune50) {
stockTrends.put(company.getStockSymbol(),
new HashMap<Date, StockPrice>());
System.out.println("Polling price data for "
+ company.getName());
File stockPriceFile = new File("data" + File.separator
+ company.getStockSymbol() + ".txt");
BufferedReader buf;
try {
buf = new BufferedReader(new InputStreamReader(
new FileInputStream(stockPriceFile)));
} catch (FileNotFoundException fnfe) {
System.out
.println("Unable to find historical stock data file for: "
+ company.getStockSymbol());
continue;
}
String line;
int linesRead = 0;
try {
while ((line = buf.readLine()) != null) {
linesRead++;
if (linesRead == 1) {
continue; // header line
}
String[] fields = line.trim().split(",");
Date date;
try {
date = dateFmt.parse(fields[0]);
} catch (ParseException pe) {
System.out.println("Error parsing date: "
+ fields[0]);
continue;
}
if (date.compareTo(endDate) > 0) {
continue;
}
if (date.compareTo(startDate) < 0) {
break;
}
double open;
double high;
double low;
double close;
long volume;
double adjClose;
try {
open = Double.parseDouble(fields[1]);
high = Double.parseDouble(fields[2]);
low = Double.parseDouble(fields[3]);
close = Double.parseDouble(fields[4]);
volume = Long.parseLong(fields[5]);
adjClose = Double.parseDouble(fields[6]);
} catch (NumberFormatException nfe) {
System.out.println(nfe.getMessage());
continue;
}
StockPrice stockPrice = new StockPrice(date, open,
high, low, close, volume, adjClose);
stockTrends.get(company.getStockSymbol()).put(date,
stockPrice);
}
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
continue;
}
try {
buf.close();
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
}
}
List<Headline> valSet = new ArrayList<Headline>();
for (Company company : fortune50) {
System.out.println("Getting headlines for Fortune 50 company #"
+ company.getId() + " (" + company.getName() + ")...");
List<Headline> coValSet = modelGenerator.mySQLHeadlineServiceImpl
.getHeadlines(company.getStockSymbol(), startDate,
endDate, 3);
System.out.println("Pulled " + coValSet.size()
+ " VALIDATION headlines for "
+ company.getStockSymbol() + " from " + startDate
+ " to " + endDate);
valSet.addAll(coValSet);
}
File file = new File("validation.txt");
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(file)));
Map<String,Integer> headlineCount = new HashMap<String,Integer>();
Map<String,Double> totalPctChange = new HashMap<String,Double>();
for (Headline headline : valSet) {
String text = headline.getText();
Integer count = headlineCount.get(text);
Double pctChange = totalPctChange.get(text);
Date date = headline.getDate();
String stock = headline.getStock();
StockPrice stockPrice = stockTrends.get(stock).get(
StockUtil.getNextTradingDay(date));
double pctPriceChange;
if (stockPrice == null) {
pctPriceChange = 0.0;
} else {
pctPriceChange = StockUtil.getPercentChange(stockPrice);
}
if (count == null) {
headlineCount.put(text, 1);
totalPctChange.put(text, pctPriceChange);
} else {
headlineCount.put(text, count+1);
totalPctChange.put(text, pctChange + pctPriceChange);
}
}
Set<String> processedSet = new HashSet<String>();
DecimalFormat decFmt = new DecimalFormat("###0.0000");
for (Headline headline : valSet) {
String text = headline.getText();
if (processedSet.contains(text)) {
continue;
}
processedSet.add(text);
int id = headline.getId();
String stock = headline.getStock();
Date date = headline.getDate();
String dateFormatted = dateFmt.format(date);
double totalPriceChange = totalPctChange.get(text);
int totalCount = headlineCount.get(text);
StringBuilder sb = new StringBuilder();
sb.append(id);
sb.append(", ");
sb.append(stock);
sb.append(", ");
sb.append(dateFormatted);
sb.append(", ");
sb.append(decFmt.format(totalPriceChange/totalCount));
sb.append(", ");
text = text.replaceAll(
"[\'\";:,\\]\\[]", " ");
text = text.replaceAll(
"[^A-Za-z0-9 ]", "");
sb.append(text);
sb.append("\n");
writer.write(sb.toString());
}
writer.close();
} catch (FileNotFoundException fnfe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
} catch (IOException ioe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(IO_EXCEPTION);
}
}
private List<Company> getFortune50(File csvFile)
throws FileNotFoundException, IOException {
List<Company> fortune50 = new ArrayList<Company>();
FileInputStream fis = new FileInputStream(csvFile);
InputStreamReader reader = new InputStreamReader(fis);
BufferedReader buf = new BufferedReader(reader);
String csvline = null;
while ((csvline = buf.readLine()) != null) {
if (csvline.length() == 0) {
continue;
}
String[] fields = csvline.split(",");
if (fields.length != 3) {
throw new RuntimeException(
"Badly formatted csv file name (3 values expected): "
+ csvline);
}
int id = Integer.valueOf(fields[0]);
fortune50.add(new Company(id, fields[1], fields[2]));
}
return fortune50;
}
}

View File

@@ -6,12 +6,14 @@ import java.util.List;
import net.woodyfolsom.cs6601.p3.domain.Headline;
public interface HeadlineDao {
boolean assignRandomDatasets(int training, int test, int validation);
int getCount();
int getCount(int dataset);
int deleteById(int id);
int insert(Headline headline);
int[] insertBatch(List<Headline> headlines);
Headline select(int id);
List<Headline> select(String stock, Date date);
List<Headline> select(String stock, Date startDate, Date endDate);
List<Headline> select(String stock, Date startDate, Date endDate, int dataset);
}

View File

@@ -18,16 +18,36 @@ import net.woodyfolsom.cs6601.p3.domain.Headline;
@Repository
public class HeadlineDaoImpl implements HeadlineDao {
private static final String COUNT_ALL_QRY = "SELECT COUNT(1) FROM headlines";
private static final String COUNT_DATASET_QRY = "SELECT COUNT(1) FROM headlines where dataset = ?";
private static final String DELETE_BY_ID_STMT = "DELETE from headlines WHERE id = ?";
private static final String INSERT_STMT = "INSERT INTO headlines (text, date, stock, dataset) values (?, ?, ?, ?)";
private static final String SELECT_BY_ID_QRY = "SELECT * from headlines WHERE id = ?";
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ?";
private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ?";
private static final String SELECT_BY_STOCK_QRY = "SELECT * from headlines WHERE stock = ? AND date = ? AND dataset = 1";
private static final String SELECT_BY_DATE_RANGE_QRY = "SELECT * from headlines WHERE stock = ? AND date >= ? AND date <= ? AND dataset = ?";
private static final String ASSIGN_RANDOM_PCT_QRY = "update headlines set dataset = (select FLOOR(RAND() * (200 - 101) + 101))";
private static final String REMAP_TRAINING_QRY = "update headlines set dataset = 1 where dataset >= 101 and dataset <= (100 + ?)";
private static final String REMAP_TEST_QRY = "update headlines set dataset = 2 where dataset >= (100 + ?) and dataset <= (100 + ?)";
private static final String REMAP_VAL_QRY = "update headlines set dataset = 3 where dataset >= (100 + ?) and dataset <= 200";
private JdbcTemplate jdbcTemplate;
@Override
public boolean assignRandomDatasets(int training, int test, int validation) {
if (training + test + validation != 100) {
return false;
}
jdbcTemplate.update(ASSIGN_RANDOM_PCT_QRY);
jdbcTemplate.update(REMAP_TRAINING_QRY,training);
jdbcTemplate.update(REMAP_TEST_QRY,training,training+test);
jdbcTemplate.update(REMAP_VAL_QRY,training+test);
return true;
}
public int deleteById(int headlineId) {
return jdbcTemplate.update(DELETE_BY_ID_STMT,
new RequestMapper(), headlineId);
@@ -64,12 +84,12 @@ public class HeadlineDaoImpl implements HeadlineDao {
public List<Headline> select(String stock, Date date) {
return jdbcTemplate.query(SELECT_BY_STOCK_QRY,
new RequestMapper(), stock, date);
new RequestMapper(), stock, date, 1);
}
public List<Headline> select(String stock, Date startDate, Date endDate) {
public List<Headline> select(String stock, Date startDate, Date endDate, int dataset) {
return jdbcTemplate.query(SELECT_BY_DATE_RANGE_QRY,
new RequestMapper(), stock, startDate, endDate);
new RequestMapper(), stock, startDate, endDate, dataset);
}
@Autowired
@@ -82,6 +102,7 @@ public class HeadlineDaoImpl implements HeadlineDao {
@Override
public Headline mapRow(ResultSet rs, int arg1) throws SQLException {
Headline headline = new Headline();
headline.setId(rs.getInt("id"));
headline.setText(rs.getString("text"));
headline.setStock(rs.getString("stock"));
headline.setDate(rs.getDate("date"));
@@ -90,4 +111,14 @@ public class HeadlineDaoImpl implements HeadlineDao {
}
}
@Override
public int getCount() {
return jdbcTemplate.queryForInt(COUNT_ALL_QRY);
}
@Override
public int getCount(int dataset) {
return jdbcTemplate.queryForInt(COUNT_DATASET_QRY,dataset);
}
}

View File

@@ -0,0 +1,57 @@
package net.woodyfolsom.cs6601.p3.domain;
import java.util.Date;
public class StockPrice {
//Date,Open,High,Low,Close,Volume,Adj_Close
//2012-04-17,0.28,0.33,0.28,0.32,3408100,0.32
final Date date;
final double open;
final double high;
final double low;
final double close;
final double adjClose;
final long volume;
public StockPrice(Date date, double open, double high, double low, double close,
long volume, double adjClose) {
super();
this.date = date;
this.open = open;
this.high = high;
this.low = low;
this.close = close;
this.volume = volume;
this.adjClose = adjClose;
}
public Date getDate() {
return date;
}
public double getOpen() {
return open;
}
public double getClose() {
return close;
}
public double getHigh() {
return high;
}
public double getLow() {
return low;
}
public long getVolume() {
return volume;
}
public double getAdjClose() {
return adjClose;
}
}

View File

@@ -1,10 +1,17 @@
package net.woodyfolsom.cs6601.p3.ngram;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -14,7 +21,9 @@ import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import net.woodyfolsom.cs6601.p3.StockUtil;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
public class NGramModel {
static final int MAX_N_GRAM_LENGTH = 3;
@@ -25,6 +34,8 @@ public class NGramModel {
private static final String UNK = "<unk>";
private Map<Integer, NGramDistribution> nGrams;
private Map<Integer, Map<NGram,Double>> nGramPriceAvg;
private int[] totalNGramCounts = new int[MAX_N_GRAM_LENGTH + 1];
private Pattern wordPattern = Pattern.compile("\\w+");
@@ -88,9 +99,13 @@ public class NGramModel {
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
nGrams.put(i, new NGramDistribution());
}
nGramPriceAvg = new HashMap<Integer, Map<NGram,Double>>();
for (int i = 0; i <= MAX_N_GRAM_LENGTH; i++) {
nGramPriceAvg.put(i, new HashMap<NGram,Double>());
}
}
private void addNGram(int nGramLength, NGram nGram) {
private void addNGram(int nGramLength, NGram nGram, String stockName, Date date, Map<String, Map<Date,StockPrice>> stockTrends) {
if (nGram.size() < nGramLength) {
System.out.println("Cannot create " + nGramLength + "-gram from: "
+ nGram);
@@ -105,10 +120,31 @@ public class NGramModel {
} else {
nGramCounts.put(nGramCopy, 1);
}
Map<NGram, Double> nGramPriceAvgs = nGramPriceAvg.get(nGramLength);
NGram nGramCopy2 = nGram.copy(nGramLength);
//TODO GET NEXT TRADING DAY'S DATE
Date nextDay = StockUtil.getNextTradingDay(date);
StockPrice stockPrice = stockTrends.get(stockName).get(nextDay);
double percentChange;
if (stockPrice == null) {
percentChange = 0.0;
} else {
percentChange = StockUtil.getPercentChange(stockPrice);
}
if (nGramPriceAvgs.containsKey(nGramCopy2)) {
double totalPercentChange = nGramPriceAvgs.get(nGramCopy);
nGramPriceAvgs.put(nGramCopy2, totalPercentChange + percentChange);
} else {
nGramPriceAvgs.put(nGramCopy2, percentChange);
}
}
/**
* Given an arbitrary String, replace punctutation with spaces, remove
* Given an arbitrary String, replace punctuation with spaces, remove
* non-alphanumeric characters, prepend with <START> token, append <END>
* token.
*
@@ -193,12 +229,11 @@ public class NGramModel {
+ " recognized n-grams in verification corpus: " + perplexity);
}
private void generateModel(List<Headline> traininSet, boolean genRandom,
boolean useUnk) throws FileNotFoundException, IOException {
StringBuilder currentLine = new StringBuilder();
List<String> fileByLines = new ArrayList<String>();
private void generateModel(List<Headline> trainingSet, boolean genRandom,
boolean useUnk, Map<String,Map<Date,StockPrice>> stockTrends) throws FileNotFoundException, IOException {
//List<String> fileByLines = new ArrayList<String>();
for (Headline headline : traininSet) {
for (Headline headline : trainingSet) {
String headlineText = headline.getText();
if (headlineText.length() == 0) {
continue;
@@ -206,6 +241,9 @@ public class NGramModel {
String sanitizedLine = sanitize(headline.getText());
// split on whitespace
String[] tokens = sanitizedLine.toLowerCase().split("\\s+");
StringBuilder currentLine = new StringBuilder();
for (String token : tokens) {
if (!isWord(token)) {
continue;
@@ -222,67 +260,67 @@ public class NGramModel {
if (END.equals(word)) {
currentLine.append(word);
fileByLines.add(currentLine.toString());
currentLine = new StringBuilder();
//fileByLines.add(currentLine.toString());
//currentLine = new StringBuilder();
} else {
currentLine.append(word);
currentLine.append(" ");
}
}
}
for (String str : fileByLines) {
String str = currentLine.toString();
System.out.println(str);
NGram currentNgram = new NGram(MAX_N_GRAM_LENGTH);
for (String token : str.split("\\s+")) {
currentNgram.add(token);
for (int i = 0; i <= currentNgram.size(); i++) {
addNGram(currentNgram.size() - i,
currentNgram.subNGram(i, currentNgram.size()));
currentNgram.subNGram(i, currentNgram.size()), headline.getStock(), headline.getDate(), stockTrends);
totalNGramCounts[currentNgram.size() - i]++;
}
}
}
System.out.println("Most common words: ");
List<Entry<NGram, Integer>> unigrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(1).entrySet());
Collections.sort(unigrams, new NGramComparator());
List<Entry<NGram, Integer>> bigrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(2).entrySet());
Collections.sort(bigrams, new NGramComparator());
List<Entry<NGram, Integer>> trigrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(3).entrySet());
Collections.sort(trigrams, new NGramComparator());
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ unigrams.get(i - 1).getKey()
+ " : "
+ (((double) (unigrams.get(i - 1).getValue()) / totalNGramCounts[1])));
}
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ bigrams.get(i - 1).getKey()
+ " : "
+ (((double) (bigrams.get(i - 1).getValue()) / totalNGramCounts[2])));
}
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ trigrams.get(i - 1).getKey()
+ " : "
+ (((double) (trigrams.get(i - 1).getValue()) / totalNGramCounts[3])));
DecimalFormat decFmt = new DecimalFormat("###0.0000");
for (int modelIndex = 1; modelIndex <= 3; modelIndex++) {
File file = new File(modelIndex + "grams.txt");
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file)));
List<Entry<NGram, Integer>> ngrams = new ArrayList<Entry<NGram, Integer>>(
nGrams.get(modelIndex).entrySet());
Collections.sort(ngrams, new NGramComparator());
System.out.println("Highest frequency " + modelIndex + "-grams:");
for (int i = 1; i <= 10; i++) {
System.out
.println(i
+ ". "
+ ngrams.get(i - 1).getKey()
+ " : "
+ (((double) (ngrams.get(i - 1).getValue()) / totalNGramCounts[1])));
}
Map<NGram,Double> pricesForModel = nGramPriceAvg.get(modelIndex);
for (int nGramIndex = 1; nGramIndex <= ngrams.size(); nGramIndex++) {
NGram key = ngrams.get(nGramIndex - 1).getKey();
writer.write(key.toString());
writer.write(",");
int count = ngrams.get(nGramIndex - 1).getValue();
writer.write(Integer.toString(count));
writer.write(",");
double avgPrice;
try {
avgPrice = pricesForModel.get(key);
System.out.println("Avg price for " + modelIndex + "-gram " + key +": " + avgPrice);
} catch (NullPointerException npe) {
System.out.println("null avgPrice for " + modelIndex + "-gram " + key);
avgPrice = 0.0;
}
writer.write(decFmt.format(avgPrice/(double)count));
writer.write("\n");
}
try {
writer.close();
} catch (IOException ioe) {
System.out.println(ioe.getMessage());
}
}
if (genRandom) {
for (int nGramLength = 1; nGramLength <= MAX_N_GRAM_LENGTH; nGramLength++) {
@@ -333,11 +371,11 @@ public class NGramModel {
}
public void reportModel(List<Headline> trainingSet,
List<Headline> validationSet) {
List<Headline> validationSet, Map<String,Map<Date,StockPrice>> stockTrends) {
try {
NGramModel ngm = new NGramModel();
boolean doCalcPerplexity = true;
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity);
ngm.generateModel(trainingSet, !doCalcPerplexity, doCalcPerplexity, stockTrends);
if (doCalcPerplexity) {
for (int i = 1; i <= MAX_N_GRAM_LENGTH; i++) {
ngm.calcPerplexity(validationSet, i, true);

View File

@@ -6,8 +6,11 @@ import java.util.List;
import net.woodyfolsom.cs6601.p3.domain.Headline;
public interface HeadlineService {
boolean assignRandomDatasets(int training, int test, int validation);
int getCount();
int getCount(int dataset);
int insertHeadline(Headline headline);
int[] insertHeadlines(List<Headline> headline);
List<Headline> getHeadlines(String stock, Date date);
List<Headline> getHeadlines(String stock, Date startDate, Date endDate);
List<Headline> getHeadlines(String stock, Date startDate, Date endDate, int dataset);
}

View File

@@ -34,7 +34,22 @@ public class MySQLHeadlineServiceImpl implements HeadlineService {
}
@Override
public List<Headline> getHeadlines(String stock, Date startDate, Date endDate) {
return headlineDao.select(stock, startDate, endDate);
public List<Headline> getHeadlines(String stock, Date startDate, Date endDate, int dataset) {
return headlineDao.select(stock, startDate, endDate, dataset);
}
@Override
public boolean assignRandomDatasets(int training, int test, int validation) {
return headlineDao.assignRandomDatasets(training, test, validation);
}
@Override
public int getCount() {
return headlineDao.getCount();
}
@Override
public int getCount(int dataset) {
return headlineDao.getCount(dataset);
}
}

View File

@@ -89,7 +89,22 @@ public class YahooHeadlineServiceImpl implements HeadlineService {
@Override
public List<Headline> getHeadlines(String stock, Date startDate,
Date endDate) {
Date endDate, int dataset) {
throw new UnsupportedOperationException("This implementation does not support getting headlines for a date range.");
}
@Override
public boolean assignRandomDatasets(int training, int test, int validation) {
throw new UnsupportedOperationException("This implementation does not support this method.");
}
@Override
public int getCount() {
throw new UnsupportedOperationException("This implementation does not support this method");
}
@Override
public int getCount(int dataset) {
throw new UnsupportedOperationException("This implementation does not support this method");
}
}