Removed dangerous code to populate the database - this must be retrieved from git in order to repopulate the headlines table.
(This should not need to happen). PricePoller and ValidationSetCreator generate the 1, 2, 3-grams.txt and validation.txt files, respectively. MySQLHeadlineDaoImplTest reshuffles the training, validation datasets in 60-40 ratio.
This commit is contained in:
@@ -1,164 +0,0 @@
|
|||||||
package net.woodyfolsom.cs6601.p3;
|
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStreamReader;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
import java.text.DateFormat;
|
|
||||||
import java.text.ParseException;
|
|
||||||
import java.text.SimpleDateFormat;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Calendar;
|
|
||||||
import java.util.Date;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.context.ApplicationContext;
|
|
||||||
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Company;
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
|
||||||
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
|
||||||
|
|
||||||
@Component
|
|
||||||
public class HeadlinePuller {
|
|
||||||
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
|
||||||
|
|
||||||
private static final int INVALID_END_DATE = 1;
|
|
||||||
private static final int INVALID_MODE = 2;
|
|
||||||
private static final int INVALID_START_DATE = 3;
|
|
||||||
private static final int IO_EXCEPTION = 4;
|
|
||||||
private static final int NO_ARGS = 5;
|
|
||||||
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 6;
|
|
||||||
|
|
||||||
//@Autowired
|
|
||||||
//HeadlineService mySQLHeadlineServiceImpl;
|
|
||||||
//@Autowired
|
|
||||||
//HeadlineService yahooHeadlineServiceImpl;
|
|
||||||
|
|
||||||
private static void printUsage() {
|
|
||||||
System.out
|
|
||||||
.println("Usage: java -jar cs6601p3.jar [insert|delete] mm/dd/yyyy-mm/dd/yyyy");
|
|
||||||
}
|
|
||||||
|
|
||||||
private enum MODE {
|
|
||||||
insert, invalid, delete
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String... args) {
|
|
||||||
MODE mode = MODE.invalid;
|
|
||||||
if (args.length != 2) {
|
|
||||||
printUsage();
|
|
||||||
System.exit(NO_ARGS);
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
mode = MODE.valueOf(args[0]);
|
|
||||||
} catch (Exception ex) {
|
|
||||||
System.out.println("Invalid mode: " + args[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mode == MODE.invalid) {
|
|
||||||
System.exit(INVALID_MODE);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mode == MODE.delete) {
|
|
||||||
System.out.println("Mode = delete. All data will be purged from HEADLINES table. Continue? [y/n]");
|
|
||||||
byte[] buf = new byte[10];
|
|
||||||
try {
|
|
||||||
int read = System.in.read(buf,0,10);
|
|
||||||
String conf = new String(buf,0,read,Charset.defaultCharset());
|
|
||||||
System.out.println("CONF = '" + conf +"'");
|
|
||||||
if (conf.charAt(0) == 'y') {
|
|
||||||
System.out.println("Delete mode confirmed. Continuing...");
|
|
||||||
System.exit(0);
|
|
||||||
} else {
|
|
||||||
System.out.println("Delete mode cancelled.");
|
|
||||||
System.exit(0);
|
|
||||||
}
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
System.exit(IO_EXCEPTION);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
String[] dateFields = args[1].split("-");
|
|
||||||
DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy");
|
|
||||||
Date startDate = null;
|
|
||||||
try {
|
|
||||||
startDate = dateFormat.parse(dateFields[0]);
|
|
||||||
} catch (ParseException pe) {
|
|
||||||
System.out.println("Invalid start date: " + dateFields[0]);
|
|
||||||
System.exit(INVALID_START_DATE);
|
|
||||||
}
|
|
||||||
Date endDate = null;
|
|
||||||
try {
|
|
||||||
endDate = dateFormat.parse(dateFields[1]);
|
|
||||||
} catch (ParseException pe) {
|
|
||||||
System.out.println("Invalid end date: " + dateFields[0]);
|
|
||||||
System.exit(INVALID_END_DATE);
|
|
||||||
}
|
|
||||||
|
|
||||||
ApplicationContext context = new FileSystemXmlApplicationContext(
|
|
||||||
new String[] { "AppContext.xml" });
|
|
||||||
HeadlinePuller headlinePuller = context.getBean(HeadlinePuller.class);
|
|
||||||
Calendar calendar = Calendar.getInstance();
|
|
||||||
try {
|
|
||||||
List<Company> fortune50 = headlinePuller
|
|
||||||
.getFortune50(stockSymbolsCSV);
|
|
||||||
for (Company company : fortune50) {
|
|
||||||
System.out.println("Getting headlines for Fortune 50 company #"
|
|
||||||
+ company.getId() + " (" + company.getName() + ")...");
|
|
||||||
Date today;
|
|
||||||
for (calendar.setTime(startDate); (today = calendar.getTime())
|
|
||||||
.compareTo(endDate) <= 0; calendar
|
|
||||||
.add(Calendar.DATE, 1)) {
|
|
||||||
//List<Headline> headlines = headlinePuller.pullHeadlines(
|
|
||||||
// company.getStockSymbol(), today);
|
|
||||||
//int[] updates = headlinePuller.mySQLHeadlineServiceImpl.insertHeadlines(headlines);
|
|
||||||
//System.out.println(updates.length + " rows updated");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (FileNotFoundException fnfe) {
|
|
||||||
System.out.println("Stock symbol CSV file does not exist: "
|
|
||||||
+ stockSymbolsCSV);
|
|
||||||
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
System.out.println("Stock symbol CSV file does not exist: "
|
|
||||||
+ stockSymbolsCSV);
|
|
||||||
System.exit(IO_EXCEPTION);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//private List<Headline> pullHeadlines(String stockSymbol, Date date) {
|
|
||||||
//List<Headline> headlines = yahooHeadlineServiceImpl.getHeadlines(
|
|
||||||
// stockSymbol, date);
|
|
||||||
//System.out.println("Pulled " + headlines.size() + " headlines for " + stockSymbol + " on " + date);
|
|
||||||
//return headlines;
|
|
||||||
//}
|
|
||||||
|
|
||||||
private List<Company> getFortune50(File csvFile)
|
|
||||||
throws FileNotFoundException, IOException {
|
|
||||||
List<Company> fortune50 = new ArrayList<Company>();
|
|
||||||
FileInputStream fis = new FileInputStream(csvFile);
|
|
||||||
InputStreamReader reader = new InputStreamReader(fis);
|
|
||||||
BufferedReader buf = new BufferedReader(reader);
|
|
||||||
String csvline = null;
|
|
||||||
while ((csvline = buf.readLine()) != null) {
|
|
||||||
if (csvline.length() == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String[] fields = csvline.split(",");
|
|
||||||
if (fields.length != 3) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Badly formatted csv file name (3 values expected): "
|
|
||||||
+ csvline);
|
|
||||||
}
|
|
||||||
int id = Integer.valueOf(fields[0]);
|
|
||||||
fortune50.add(new Company(id, fields[1], fields[2]));
|
|
||||||
}
|
|
||||||
return fortune50;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
package net.woodyfolsom.cs6601.p3;
|
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStreamReader;
|
|
||||||
import java.text.DateFormat;
|
|
||||||
import java.text.ParseException;
|
|
||||||
import java.text.SimpleDateFormat;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Date;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.context.ApplicationContext;
|
|
||||||
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Company;
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
|
||||||
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
|
|
||||||
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
|
||||||
|
|
||||||
@Component
|
|
||||||
public class ModelGenerator {
|
|
||||||
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
|
||||||
|
|
||||||
private static final int INVALID_DATE = 1;
|
|
||||||
private static final int IO_EXCEPTION = 2;
|
|
||||||
private static final int STOCK_SYMBOL_CSV_NOT_FOUND = 3;
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
HeadlineService mySQLHeadlineServiceImpl;
|
|
||||||
|
|
||||||
private NGramModel ngramModel = new NGramModel();
|
|
||||||
|
|
||||||
public static void main(String... args) {
|
|
||||||
|
|
||||||
ApplicationContext context = new FileSystemXmlApplicationContext(
|
|
||||||
new String[] { "AppContext.xml" });
|
|
||||||
ModelGenerator modelGenerator = context.getBean(ModelGenerator.class);
|
|
||||||
DateFormat dateFmt = new SimpleDateFormat("yyyy-MM-dd");
|
|
||||||
|
|
||||||
Date startDate = null;
|
|
||||||
Date endDate = null;
|
|
||||||
try {
|
|
||||||
startDate = dateFmt.parse("2012-01-01");
|
|
||||||
endDate = dateFmt.parse("2012-04-14");
|
|
||||||
} catch (ParseException pe) {
|
|
||||||
System.exit(INVALID_DATE);
|
|
||||||
}
|
|
||||||
|
|
||||||
List<Headline> trainingSet = new ArrayList<Headline>();
|
|
||||||
//actually, this is the TEST dataset
|
|
||||||
List<Headline> testSet = new ArrayList<Headline>();
|
|
||||||
|
|
||||||
try {
|
|
||||||
List<Company> fortune50 = modelGenerator
|
|
||||||
.getFortune50(stockSymbolsCSV);
|
|
||||||
for (Company company : fortune50) {
|
|
||||||
System.out.println("Getting headlines for Fortune 50 company #"
|
|
||||||
+ company.getId() + " (" + company.getName() + ")...");
|
|
||||||
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
|
|
||||||
System.out.println("Pulled " + coTrainingSet.size() + " headlines for "
|
|
||||||
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
|
||||||
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
|
|
||||||
|
|
||||||
trainingSet.addAll(coTrainingSet);
|
|
||||||
testSet.addAll(coTestSet);
|
|
||||||
}
|
|
||||||
} catch (FileNotFoundException fnfe) {
|
|
||||||
System.out.println("Stock symbol CSV file does not exist: "
|
|
||||||
+ stockSymbolsCSV);
|
|
||||||
System.exit(STOCK_SYMBOL_CSV_NOT_FOUND);
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
System.out.println("Stock symbol CSV file does not exist: "
|
|
||||||
+ stockSymbolsCSV);
|
|
||||||
System.exit(IO_EXCEPTION);
|
|
||||||
}
|
|
||||||
|
|
||||||
//modelGenerator.ngramModel.reportModel(trainingSet, testSet);
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<Company> getFortune50(File csvFile)
|
|
||||||
throws FileNotFoundException, IOException {
|
|
||||||
List<Company> fortune50 = new ArrayList<Company>();
|
|
||||||
FileInputStream fis = new FileInputStream(csvFile);
|
|
||||||
InputStreamReader reader = new InputStreamReader(fis);
|
|
||||||
BufferedReader buf = new BufferedReader(reader);
|
|
||||||
String csvline = null;
|
|
||||||
while ((csvline = buf.readLine()) != null) {
|
|
||||||
if (csvline.length() == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String[] fields = csvline.split(",");
|
|
||||||
if (fields.length != 3) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Badly formatted csv file name (3 values expected): "
|
|
||||||
+ csvline);
|
|
||||||
}
|
|
||||||
int id = Integer.valueOf(fields[0]);
|
|
||||||
fortune50.add(new Company(id, fields[1], fields[2]));
|
|
||||||
}
|
|
||||||
return fortune50;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -135,8 +135,8 @@ public class PricePoller {
|
|||||||
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
|
List<Headline> coTrainingSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 1);
|
||||||
System.out.println("Pulled " + coTrainingSet.size() + " TRAINING headlines for "
|
System.out.println("Pulled " + coTrainingSet.size() + " TRAINING headlines for "
|
||||||
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
||||||
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 2);
|
List<Headline> coTestSet = modelGenerator.mySQLHeadlineServiceImpl.getHeadlines(company.getStockSymbol(), startDate, endDate, 3);
|
||||||
System.out.println("Pulled " + coTestSet.size() + " TEST headlines for "
|
System.out.println("Pulled " + coTestSet.size() + " VALIDATION headlines for "
|
||||||
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
+ company.getStockSymbol() + " from " + startDate + " to " + endDate);
|
||||||
|
|
||||||
trainingSet.addAll(coTrainingSet);
|
trainingSet.addAll(coTrainingSet);
|
||||||
|
|||||||
@@ -21,18 +21,16 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
import net.woodyfolsom.cs6601.p3.domain.Company;
|
||||||
|
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||||
|
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
|
||||||
|
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
||||||
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.context.ApplicationContext;
|
import org.springframework.context.ApplicationContext;
|
||||||
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
import org.springframework.context.support.FileSystemXmlApplicationContext;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Company;
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.StockPrice;
|
|
||||||
import net.woodyfolsom.cs6601.p3.ngram.NGram;
|
|
||||||
import net.woodyfolsom.cs6601.p3.ngram.NGramModel;
|
|
||||||
import net.woodyfolsom.cs6601.p3.svc.HeadlineService;
|
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
public class ValidationSetCreator {
|
public class ValidationSetCreator {
|
||||||
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
private static final File stockSymbolsCSV = new File("stock_symbols.csv");
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import java.util.List;
|
|||||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||||
|
|
||||||
public interface HeadlineDao {
|
public interface HeadlineDao {
|
||||||
boolean assignRandomDatasets(int training, int test, int validation);
|
boolean assignRandomDatasets(int training/*, int test*/, int validation);
|
||||||
int getCount();
|
int getCount();
|
||||||
int getCount(int dataset);
|
int getCount(int dataset);
|
||||||
int deleteById(int id);
|
int deleteById(int id);
|
||||||
|
|||||||
@@ -31,20 +31,20 @@ public class HeadlineDaoImpl implements HeadlineDao {
|
|||||||
|
|
||||||
private static final String ASSIGN_RANDOM_PCT_QRY = "update headlines set dataset = (select FLOOR(RAND() * (200 - 101) + 101))";
|
private static final String ASSIGN_RANDOM_PCT_QRY = "update headlines set dataset = (select FLOOR(RAND() * (200 - 101) + 101))";
|
||||||
private static final String REMAP_TRAINING_QRY = "update headlines set dataset = 1 where dataset >= 101 and dataset <= (100 + ?)";
|
private static final String REMAP_TRAINING_QRY = "update headlines set dataset = 1 where dataset >= 101 and dataset <= (100 + ?)";
|
||||||
private static final String REMAP_TEST_QRY = "update headlines set dataset = 2 where dataset >= (100 + ?) and dataset <= (100 + ?)";
|
//private static final String REMAP_TEST_QRY = "update headlines set dataset = 2 where dataset >= (100 + ?) and dataset <= (100 + ?)";
|
||||||
private static final String REMAP_VAL_QRY = "update headlines set dataset = 3 where dataset >= (100 + ?) and dataset <= 200";
|
private static final String REMAP_VAL_QRY = "update headlines set dataset = 3 where dataset > (100 + ?) and dataset <= 200";
|
||||||
|
|
||||||
private JdbcTemplate jdbcTemplate;
|
private JdbcTemplate jdbcTemplate;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean assignRandomDatasets(int training, int test, int validation) {
|
public boolean assignRandomDatasets(int training/*, int test*/, int validation) {
|
||||||
if (training + test + validation != 100) {
|
if (training /*+ test*/ + validation != 100) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
jdbcTemplate.update(ASSIGN_RANDOM_PCT_QRY);
|
jdbcTemplate.update(ASSIGN_RANDOM_PCT_QRY);
|
||||||
jdbcTemplate.update(REMAP_TRAINING_QRY,training);
|
jdbcTemplate.update(REMAP_TRAINING_QRY,training);
|
||||||
jdbcTemplate.update(REMAP_TEST_QRY,training,training+test);
|
//jdbcTemplate.update(REMAP_TEST_QRY,training,training+test);
|
||||||
jdbcTemplate.update(REMAP_VAL_QRY,training+test);
|
jdbcTemplate.update(REMAP_VAL_QRY,training/*+test*/);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import java.util.List;
|
|||||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||||
|
|
||||||
public interface HeadlineService {
|
public interface HeadlineService {
|
||||||
boolean assignRandomDatasets(int training, int test, int validation);
|
boolean assignRandomDatasets(int training/*, int test*/, int validation);
|
||||||
int getCount();
|
int getCount();
|
||||||
int getCount(int dataset);
|
int getCount(int dataset);
|
||||||
int insertHeadline(Headline headline);
|
int insertHeadline(Headline headline);
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ public class MySQLHeadlineServiceImpl implements HeadlineService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean assignRandomDatasets(int training, int test, int validation) {
|
public boolean assignRandomDatasets(int training/*, int test*/, int validation) {
|
||||||
return headlineDao.assignRandomDatasets(training, test, validation);
|
return headlineDao.assignRandomDatasets(training/*, test*/, validation);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import java.io.InputStreamReader;
|
|||||||
import java.net.HttpURLConnection;
|
import java.net.HttpURLConnection;
|
||||||
import java.net.MalformedURLException;
|
import java.net.MalformedURLException;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.net.URLConnection;
|
|
||||||
import java.text.DateFormat;
|
import java.text.DateFormat;
|
||||||
import java.text.SimpleDateFormat;
|
import java.text.SimpleDateFormat;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -15,12 +14,10 @@ import java.util.List;
|
|||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
import net.woodyfolsom.cs6601.p3.dao.HeadlineDao;
|
|
||||||
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
import net.woodyfolsom.cs6601.p3.domain.Headline;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@@ -94,7 +91,7 @@ public class YahooHeadlineServiceImpl implements HeadlineService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean assignRandomDatasets(int training, int test, int validation) {
|
public boolean assignRandomDatasets(int training/*, int test*/, int validation) {
|
||||||
throw new UnsupportedOperationException("This implementation does not support this method.");
|
throw new UnsupportedOperationException("This implementation does not support this method.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package net.woodyfolsom.cs6601.p3;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
public class HeadlinePullerTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGetStartDate() {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -28,17 +28,17 @@ public class MySQLHeadlineDaoImplTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//Change this back to @Test to run it... but beware, it shuffles the datasets. Best done n times for n-fold cross validation.
|
//Change this back to @Test to run it... but beware, it shuffles the datasets. Best done n times for n-fold cross validation.
|
||||||
@Ignore
|
@Test
|
||||||
public void testAssignRandomDatasets() {
|
public void testAssignRandomDatasets() {
|
||||||
|
|
||||||
int trainingPct = 80;
|
int trainingPct = 60;
|
||||||
int testPct = 10;
|
//int testPct = 10;
|
||||||
int valPct = 10;
|
int valPct = 40;
|
||||||
|
|
||||||
//assignment fails if character is ommitted from valPct (80% 10% 1% by accident)
|
//assignment fails if character is ommitted from valPct (80% 10% 1% by accident)
|
||||||
assertFalse(headlineSvc.assignRandomDatasets(trainingPct,testPct,valPct/10));
|
assertFalse(headlineSvc.assignRandomDatasets(trainingPct/*,testPct*/,valPct/10));
|
||||||
//assignment succeeds if requested ratio is 8-1-1
|
//assignment succeeds if requested ratio is 8-1-1
|
||||||
assertTrue(headlineSvc.assignRandomDatasets(trainingPct,testPct,valPct));
|
assertTrue(headlineSvc.assignRandomDatasets(trainingPct/*,testPct*/,valPct));
|
||||||
|
|
||||||
int allCount = headlineSvc.getCount();
|
int allCount = headlineSvc.getCount();
|
||||||
int trainingCount = headlineSvc.getCount(1);
|
int trainingCount = headlineSvc.getCount(1);
|
||||||
@@ -48,7 +48,7 @@ public class MySQLHeadlineDaoImplTest {
|
|||||||
assertEquals(trainingCount + testCount + valCount, allCount);
|
assertEquals(trainingCount + testCount + valCount, allCount);
|
||||||
|
|
||||||
assertEquals((double)trainingCount/allCount,(double)trainingPct / 100.0,0.01);
|
assertEquals((double)trainingCount/allCount,(double)trainingPct / 100.0,0.01);
|
||||||
assertEquals((double)testCount/allCount,(double)testPct / 100.0,0.01);
|
//assertEquals((double)testCount/allCount,(double)testPct / 100.0,0.01);
|
||||||
assertEquals((double)valCount/allCount,(double)valPct / 100.0,0.01);
|
assertEquals((double)valCount/allCount,(double)valPct / 100.0,0.01);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user