Removed dangerous code to populate the database - this must be retrieved from git in order to repopulate the headlines table.

(This should not need to happen).
PricePoller and ValidationSetCreator generate the 1, 2, 3-grams.txt and validation.txt files, respectively.
MySQLHeadlineDaoImplTest reshuffles the training, validation datasets in 60-40 ratio.
This commit is contained in:
Woody Folsom
2012-04-22 21:24:01 -04:00
parent 6e3680426e
commit 5270359b10
11 changed files with 25 additions and 313 deletions

View File

@@ -1,164 +0,0 @@
package net.woodyfolsom.cs6601.p3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.FileSystemXmlApplicationContext;
import org.springframework.stereotype.Component;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
@Component
public class HeadlinePuller {
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
private static final int INVALID_END_DATE = 1;
private static final int INVALID_MODE = 2;
private static final int INVALID_START_DATE = 3;
private static final int IO_EXCEPTION = 4;
private static final int NO_ARGS = 5;
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 6;
//@Autowired
//HeadlineService mySQLHeadlineServiceImpl;
//@Autowired
//HeadlineService yahooHeadlineServiceImpl;
private static void printUsage() {
System.out
.println("Usage: java -jar cs6601p3.jar [insert|delete] mm/dd/yyyy-mm/dd/yyyy");
}
private enum MODE {
insert, invalid, delete
}
public static void main(String... args) {
MODE mode = MODE.invalid;
if (args.length != 2) {
printUsage();
System.exit(NO_ARGS);
} else {
try {
mode = MODE.valueOf(args[0]);
} catch (Exception ex) {
System.out.println("Invalid mode: " + args[0]);
}
}
if (mode == MODE.invalid) {
System.exit(INVALID_MODE);
}
if (mode == MODE.delete) {
System.out.println("Mode = delete. All data will be purged from HEADLINES table. Continue? [y/n]");
byte[] buf = new byte[10];
try {
int read = System.in.read(buf,0,10);
String conf = new String(buf,0,read,Charset.defaultCharset());
System.out.println("CONF = '" + conf +"'");
if (conf.charAt(0) == 'y') {
System.out.println("Delete mode confirmed. Continuing...");
System.exit(0);
} else {
System.out.println("Delete mode cancelled.");
System.exit(0);
}
} catch (IOException ioe) {
System.exit(IO_EXCEPTION);
}
}
String[] dateFields = args[1].split("-");
DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy");
Date startDate = null;
try {
startDate = dateFormat.parse(dateFields[0]);
} catch (ParseException pe) {
System.out.println("Invalid start date: " + dateFields[0]);
System.exit(INVALID_START_DATE);
}
Date endDate = null;
try {
endDate = dateFormat.parse(dateFields[1]);
} catch (ParseException pe) {
System.out.println("Invalid end date: " + dateFields[0]);
System.exit(INVALID_END_DATE);
}
ApplicationContext context = new FileSystemXmlApplicationContext(
new String[] { "AppContext.xml" });
HeadlinePuller headlinePuller = context.getBean(HeadlinePuller.class);
Calendar calendar = Calendar.getInstance();
try {
List<Company> fortune50 = headlinePuller
.getFortune50(stockSymbolsCSV);
for (Company company : fortune50) {
System.out.println("Getting headlines for Fortune 50 company #"
+ company.getId() + " (" + company.getName() + ")...");
Date today;
for (calendar.setTime(startDate); (today = calendar.getTime())
.compareTo(endDate) <= 0; calendar
.add(Calendar.DATE, 1)) {
//List<Headline> headlines = headlinePuller.pullHeadlines(
// company.getStockSymbol(), today);
//int[] updates = headlinePuller.mySQLHeadlineServiceImpl.insertHeadlines(headlines);
//System.out.println(updates.length + " rows updated");
}
}
} catch (FileNotFoundException fnfe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
} catch (IOException ioe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(IO_EXCEPTION);
}
}
//private List<Headline> pullHeadlines(String stockSymbol, Date date) {
//List<Headline> headlines = yahooHeadlineServiceImpl.getHeadlines(
// stockSymbol, date);
//System.out.println("Pulled " + headlines.size() + " headlines for " + stockSymbol + " on " + date);
//return headlines;
//}
private List<Company> getFortune50(File csvFile)
throws FileNotFoundException, IOException {
List<Company> fortune50 = new ArrayList<Company>();
FileInputStream fis = new FileInputStream(csvFile);
InputStreamReader reader = new InputStreamReader(fis);
BufferedReader buf = new BufferedReader(reader);
String csvline = null;
while ((csvline = buf.readLine()) != null) {
if (csvline.length() == 0) {
continue;
}
String[] fields = csvline.split(",");
if (fields.length != 3) {
throw new RuntimeException(
"Badly formatted csv file name (3 values expected): "
+ csvline);
}
int id = Integer.valueOf(fields[0]);
fortune50.add(new Company(id, fields[1], fields[2]));
}
return fortune50;
}
}

View File

@@ -1,108 +0,0 @@
package net.woodyfolsom.cs6601.p3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.FileSystemXmlApplicationContext;
import org.springframework.stereotype.Component;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
@Component
public class ModelGenerator {
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
private static final int INVALID_DATE = 1;
private static final int IO_EXCEPTION = 2;
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
@Autowired
HeadlineService mySQLHeadlineServiceImpl;
private NGramModel ngramModel = new NGramModel();
public static void main(String... args) {
ApplicationContext context = new FileSystemXmlApplicationContext(
new String[] { "AppContext.xml" });
ModelGenerator modelGenerator = context.getBean(ModelGenerator.class);
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
Date startDate = null;
Date endDate = null;
try {
startDate = dateFmt.parse("2012-01-01");
endDate = dateFmt.parse("2012-04-14");
} catch (ParseException pe) {
System.exit(INVALID_DATE);
}
List<Headline> trainingSet = new ArrayList<Headline>();
//actually, this is the TEST dataset
List<Headline> testSet = new ArrayList<Headline>();
try {
List<Company> fortune50 = modelGenerator
.getFortune50(stockSymbolsCSV);
for (Company company : fortune50) {
System.out.println("Getting headlines for Fortune 50 company #"
+ company.getId() + " (" + company.getName() + ")...");
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
System.out.println("Pulled " + coTrainingSet.size() + " headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
trainingSet.addAll(coTrainingSet);
testSet.addAll(coTestSet);
}
} catch (FileNotFoundException fnfe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
} catch (IOException ioe) {
System.out.println("Stock symbol CSV file does not exist: "
+ stockSymbolsCSV);
System.exit(IO_EXCEPTION);
}
//modelGenerator.ngramModel.reportModel(trainingSet, testSet);
}
private List<Company> getFortune50(File csvFile)
throws FileNotFoundException, IOException {
List<Company> fortune50 = new ArrayList<Company>();
FileInputStream fis = new FileInputStream(csvFile);
InputStreamReader reader = new InputStreamReader(fis);
BufferedReader buf = new BufferedReader(reader);
String csvline = null;
while ((csvline = buf.readLine()) != null) {
if (csvline.length() == 0) {
continue;
}
String[] fields = csvline.split(",");
if (fields.length != 3) {
throw new RuntimeException(
"Badly formatted csv file name (3 values expected): "
+ csvline);
}
int id = Integer.valueOf(fields[0]);
fortune50.add(new Company(id, fields[1], fields[2]));
}
return fortune50;
}
}

View File

@@ -135,8 +135,8 @@ public class PricePoller {
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
System.out.println("Pulled " + coTrainingSet.size() + " TRAINING headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
System.out.println("Pulled " + coTestSet.size() + " TEST headlines for "
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 3);
System.out.println("Pulled " + coTestSet.size() + " VALIDATION headlines for "
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
trainingSet.addAll(coTrainingSet);

View File

@@ -21,18 +21,16 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.FileSystemXmlApplicationContext;
import org.springframework.stereotype.Component;
import net.woodyfolsom.cs6601.p3.domain.Company;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
import net.woodyfolsom.cs6601.p3.ngram.NGram;
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
@Component
public class ValidationSetCreator {
private static final File stockSymbolsCSV = new File("stock_symbols.csv");

View File

@@ -6,7 +6,7 @@ import java.util.List;
import net.woodyfolsom.cs6601.p3.domain.Headline;
public interface HeadlineDao {
boolean assignRandomDatasets(int training, int test, int validation);
boolean assignRandomDatasets(int training/*, int test*/, int validation);
int getCount();
int getCount(int dataset);
int deleteById(int id);

View File

@@ -31,20 +31,20 @@ public class HeadlineDaoImpl implements HeadlineDao {
private static final String ASSIGN_RANDOM_PCT_QRY = "update headlines set dataset = (select FLOOR(RAND() * (200 - 101) + 101))";
private static final String REMAP_TRAINING_QRY = "update headlines set dataset = 1 where dataset >= 101 and dataset <= (100 + ?)";
private static final String REMAP_TEST_QRY = "update headlines set dataset = 2 where dataset >= (100 + ?) and dataset <= (100 + ?)";
private static final String REMAP_VAL_QRY = "update headlines set dataset = 3 where dataset >= (100 + ?) and dataset <= 200";
//private static final String REMAP_TEST_QRY = "update headlines set dataset = 2 where dataset >= (100 + ?) and dataset <= (100 + ?)";
private static final String REMAP_VAL_QRY = "update headlines set dataset = 3 where dataset > (100 + ?) and dataset <= 200";
private JdbcTemplate jdbcTemplate;
@Override
public boolean assignRandomDatasets(int training, int test, int validation) {
if (training + test + validation != 100) {
public boolean assignRandomDatasets(int training/*, int test*/, int validation) {
if (training /*+ test*/ + validation != 100) {
return false;
}
jdbcTemplate.update(ASSIGN_RANDOM_PCT_QRY);
jdbcTemplate.update(REMAP_TRAINING_QRY,training);
jdbcTemplate.update(REMAP_TEST_QRY,training,training+test);
jdbcTemplate.update(REMAP_VAL_QRY,training+test);
//jdbcTemplate.update(REMAP_TEST_QRY,training,training+test);
jdbcTemplate.update(REMAP_VAL_QRY,training/*+test*/);
return true;
}

View File

@@ -6,7 +6,7 @@ import java.util.List;
import net.woodyfolsom.cs6601.p3.domain.Headline;
public interface HeadlineService {
boolean assignRandomDatasets(int training, int test, int validation);
boolean assignRandomDatasets(int training/*, int test*/, int validation);
int getCount();
int getCount(int dataset);
int insertHeadline(Headline headline);

View File

@@ -39,8 +39,8 @@ public class MySQLHeadlineServiceImpl implements HeadlineService {
}
@Override
public boolean assignRandomDatasets(int training, int test, int validation) {
return headlineDao.assignRandomDatasets(training, test, validation);
public boolean assignRandomDatasets(int training/*, int test*/, int validation) {
return headlineDao.assignRandomDatasets(training/*, test*/, validation);
}
@Override

View File

@@ -6,7 +6,6 @@ import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
@@ -15,12 +14,10 @@ import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import net.woodyfolsom.cs6601.p3.dao.HeadlineDao;
import net.woodyfolsom.cs6601.p3.domain.Headline;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@@ -94,7 +91,7 @@ public class YahooHeadlineServiceImpl implements HeadlineService {
}
@Override
public boolean assignRandomDatasets(int training, int test, int validation) {
public boolean assignRandomDatasets(int training/*, int test*/, int validation) {
throw new UnsupportedOperationException("This implementation does not support this method.");
}