diff --git a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java index 698197f..865d6e9 100644 --- a/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java +++ b/src/net/woodyfolsom/msproj/ann/AbstractNeuralNetFilter.java @@ -5,6 +5,7 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import org.encog.ml.data.MLData; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.PersistBasicNetwork; @@ -13,6 +14,11 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { protected int actualTrainingEpochs = 0; protected int maxTrainingEpochs = 1000; + @Override + public MLData compute(MLData input) { + return this.neuralNetwork.compute(input); + } + public int getActualTrainingEpochs() { return actualTrainingEpochs; } diff --git a/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java b/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java deleted file mode 100644 index 7823a8f..0000000 --- a/src/net/woodyfolsom/msproj/ann/ErrorCalculation.java +++ /dev/null @@ -1,95 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import org.encog.mathutil.error.ErrorCalculationMode; - -/* - Initial erison of this class was a verbatim copy from Encog framework. - */ - -public class ErrorCalculation { - - private static ErrorCalculationMode mode = ErrorCalculationMode.MSE; - - public static ErrorCalculationMode getMode() { - return ErrorCalculation.mode; - } - - public static void setMode(final ErrorCalculationMode theMode) { - ErrorCalculation.mode = theMode; - } - - private double globalError; - - private int setSize; - - public final double calculate() { - if (this.setSize == 0) { - return 0; - } - - switch (ErrorCalculation.getMode()) { - case RMS: - return calculateRMS(); - case MSE: - return calculateMSE(); - case ESS: - return calculateESS(); - default: - return calculateMSE(); - } - - } - - public final double calculateMSE() { - if (this.setSize == 0) { - return 0; - } - final double err = this.globalError / this.setSize; - return err; - - } - - public final double calculateESS() { - if (this.setSize == 0) { - return 0; - } - final double err = this.globalError / 2; - return err; - - } - - public final double calculateRMS() { - if (this.setSize == 0) { - return 0; - } - final double err = Math.sqrt(this.globalError / this.setSize); - return err; - } - - public final void reset() { - this.globalError = 0; - this.setSize = 0; - } - - public final void updateError(final double actual, final double ideal) { - - double delta = ideal - actual; - - this.globalError += delta * delta; - - this.setSize++; - - } - - public final void updateError(final double[] actual, final double[] ideal, - final double significance) { - for (int i = 0; i < actual.length; i++) { - double delta = (ideal[i] - actual[i]) * significance; - - this.globalError += delta * delta; - } - - this.setSize += ideal.length; - } - -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLData.java b/src/net/woodyfolsom/msproj/ann/GameStateMLData.java deleted file mode 100644 index 47d588c..0000000 --- a/src/net/woodyfolsom/msproj/ann/GameStateMLData.java +++ /dev/null @@ -1,25 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import net.woodyfolsom.msproj.GameState; - -import org.encog.ml.data.basic.BasicMLData; - -public class GameStateMLData extends BasicMLData { - - /** - * - */ - private static final long serialVersionUID = 1L; - - private GameState gameState; - - public GameStateMLData(double[] d, GameState gameState) { - super(d); - // TODO Auto-generated constructor stub - this.gameState = gameState; - } - - public GameState getGameState() { - return gameState; - } -} diff --git a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java index 8f41126..d7c1316 100644 --- a/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java +++ b/src/net/woodyfolsom/msproj/ann/GameStateMLDataPair.java @@ -11,16 +11,12 @@ import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.util.kmeans.Centroid; public class GameStateMLDataPair implements MLDataPair { - //private final String[] inputs = { "BlackScore", "WhiteScore" }; - //private final String[] outputs = { "BlackWins", "WhiteWins" }; - private BasicMLDataPair mlDataPairDelegate; private GameState gameState; public GameStateMLDataPair(GameState gameState) { this.gameState = gameState; - mlDataPairDelegate = new BasicMLDataPair( - new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal())); + mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal())); } public GameStateMLDataPair(GameStateMLDataPair that) { @@ -118,4 +114,4 @@ public class GameStateMLDataPair implements MLDataPair { mlDataPairDelegate.setSignificance(arg0); } -} +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/GradientWorker.java b/src/net/woodyfolsom/msproj/ann/GradientWorker.java deleted file mode 100644 index fbb6ec6..0000000 --- a/src/net/woodyfolsom/msproj/ann/GradientWorker.java +++ /dev/null @@ -1,193 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -/* - * Class copied verbatim from Encog framework due to dependency on Propagation - * implementation. - * - * Encog(tm) Core v3.2 - Java Version - * http://www.heatonresearch.com/encog/ - * http://code.google.com/p/encog-java/ - - * Copyright 2008-2012 Heaton Research, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * For more information on Heaton Research copyrights, licenses - * and trademarks visit: - * http://www.heatonresearch.com/copyright - */ - -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - -import org.encog.engine.network.activation.ActivationFunction; -import org.encog.ml.data.MLDataPair; -import org.encog.ml.data.basic.BasicMLDataPair; -import org.encog.neural.error.ErrorFunction; -import org.encog.neural.flat.FlatNetwork; -import org.encog.util.EngineArray; -import org.encog.util.concurrency.EngineTask; - -public class GradientWorker implements EngineTask { - - private final FlatNetwork network; - private final ErrorCalculation errorCalculation = new ErrorCalculation(); - private final List actuals; - private final double[] layerDelta; - private final int[] layerCounts; - private final int[] layerFeedCounts; - private final int[] layerIndex; - private final int[] weightIndex; - private final double[] layerOutput; - private final double[] layerSums; - private final double[] gradients; - private final double[] weights; - private final MLDataPair pairPrototype; - private final Set> training; - //private final int low; - //private final int high; - private final TemporalDifferenceLearning owner; - private double[] flatSpot; - private final ErrorFunction errorFunction; - - public GradientWorker(final FlatNetwork theNetwork, - final TemporalDifferenceLearning theOwner, - final Set> theTraining, final int theLow, - final int theHigh, final double[] flatSpot, ErrorFunction ef) { - this.network = theNetwork; - this.training = theTraining; - //this.low = theLow; - //this.high = theHigh; - this.owner = theOwner; - this.flatSpot = flatSpot; - this.errorFunction = ef; - - this.layerDelta = new double[network.getLayerOutput().length]; - this.gradients = new double[network.getWeights().length]; - this.actuals = new ArrayList(); - - this.weights = network.getWeights(); - this.layerIndex = network.getLayerIndex(); - this.layerCounts = network.getLayerCounts(); - this.weightIndex = network.getWeightIndex(); - this.layerOutput = network.getLayerOutput(); - this.layerSums = network.getLayerSums(); - this.layerFeedCounts = network.getLayerFeedCounts(); - - this.pairPrototype = BasicMLDataPair.createPair( - network.getInputCount(), network.getOutputCount()); - } - - public FlatNetwork getNetwork() { - return this.network; - } - - public double[] getWeights() { - return this.weights; - } - - private void process(List trainingSequence) { - actuals.clear(); - - for (int trainingIdx = 0; trainingIdx < trainingSequence.size(); trainingIdx++) { - MLDataPair mlDataPair = trainingSequence.get(trainingIdx); - MLDataPair dataPairCopy = this.pairPrototype; - dataPairCopy.setInputArray(mlDataPair.getInputArray()); - if (dataPairCopy.getIdealArray() != null) { - dataPairCopy.setIdealArray(mlDataPair.getIdealArray()); - } - - double[] input = dataPairCopy.getInputArray(); - double[] ideal = dataPairCopy.getIdealArray(); - double significance = dataPairCopy.getSignificance(); - - actuals.add(trainingIdx, new double[ideal.length]); - this.network.compute(input, actuals.get(trainingIdx)); - - // For now, only calculate deltas for the final data pair - // For final TDL algorithm, deltas won't be used at all, instead the - // List of Actual vectors will. - if (trainingIdx < trainingSequence.size() - 1) { - continue; - } - - this.errorCalculation.updateError(actuals.get(trainingIdx), ideal, - significance); - this.errorFunction.calculateError(ideal, actuals.get(trainingIdx), - this.layerDelta); - - for (int i = 0; i < actuals.get(trainingIdx).length; i++) { - this.layerDelta[i] = ((this.network.getActivationFunctions()[0] - .derivativeFunction(this.layerSums[i], - this.layerOutput[i]) + this.flatSpot[0])) - * (this.layerDelta[i] * significance); - } - - for (int i = this.network.getBeginTraining(); i < this.network - .getEndTraining(); i++) { - processLevel(i); - } - - } - } - - private void processLevel(final int currentLevel) { - final int fromLayerIndex = this.layerIndex[currentLevel + 1]; - final int toLayerIndex = this.layerIndex[currentLevel]; - final int fromLayerSize = this.layerCounts[currentLevel + 1]; - final int toLayerSize = this.layerFeedCounts[currentLevel]; - - final int index = this.weightIndex[currentLevel]; - final ActivationFunction activation = this.network - .getActivationFunctions()[currentLevel]; - final double currentFlatSpot = this.flatSpot[currentLevel + 1]; - - // handle weights - int yi = fromLayerIndex; - for (int y = 0; y < fromLayerSize; y++) { - final double output = this.layerOutput[yi]; - double sum = 0; - int xi = toLayerIndex; - int wi = index + y; - for (int x = 0; x < toLayerSize; x++) { - this.gradients[wi] += output * this.layerDelta[xi]; - sum += this.weights[wi] * this.layerDelta[xi]; - wi += fromLayerSize; - xi++; - } - - this.layerDelta[yi] = sum - * (activation.derivativeFunction(this.layerSums[yi], - this.layerOutput[yi]) + currentFlatSpot); - yi++; - } - } - - public final void run() { - try { - this.errorCalculation.reset(); - - for (List trainingSequence : training) { - process(trainingSequence); - } - - final double error = this.errorCalculation.calculate(); - this.owner.report(this.gradients, error, null); - EngineArray.fill(this.gradients, 0); - } catch (final Throwable ex) { - this.owner.report(null, 0, ex); - } - } - -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java index 64d40c7..73ce084 100644 --- a/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java +++ b/src/net/woodyfolsom/msproj/ann/NeuralNetFilter.java @@ -11,21 +11,20 @@ import org.encog.neural.networks.BasicNetwork; public interface NeuralNetFilter { BasicNetwork getNeuralNetwork(); + + int getActualTrainingEpochs(); + int getInputSize(); + int getMaxTrainingEpochs(); + int getOutputSize(); + + void learn(MLDataSet trainingSet); + void learn(Set> trainingSet); + + void load(String fileName) throws IOException; + void reset(); + void reset(int seed); + void save(String fileName) throws IOException; + void setMaxTrainingEpochs(int max); - public int getActualTrainingEpochs(); - public int getInputSize(); - public int getMaxTrainingEpochs(); - public int getOutputSize(); - - public double computeValue(MLData input); - public double[] computeVector(MLData input); - - public void learn(MLDataSet trainingSet); - public void learn(Set> trainingSet); - - public void load(String fileName) throws IOException; - public void reset(); - public void reset(int seed); - public void save(String fileName) throws IOException; - public void setMaxTrainingEpochs(int max); + MLData compute(MLData input); } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifference.java b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java new file mode 100644 index 0000000..991608b --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann/TemporalDifference.java @@ -0,0 +1,30 @@ +package net.woodyfolsom.msproj.ann; + +import org.encog.ml.data.MLDataSet; +import org.encog.neural.networks.ContainsFlat; +import org.encog.neural.networks.training.propagation.back.Backpropagation; + +public class TemporalDifference extends Backpropagation { + private final double lambda; + + public TemporalDifference(ContainsFlat network, MLDataSet training, + double theLearnRate, double theMomentum, double lambda) { + super(network, training, theLearnRate, theMomentum); + this.lambda = lambda; + } + + public double getLamdba() { + return lambda; + } + + @Override + public double updateWeight(final double[] gradients, + final double[] lastGradient, final int index) { + double alpha = this.getLearningRate(); + + //TODO fill in weight update for TD(lambda) + + return 0.0; + } + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java b/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java deleted file mode 100644 index ea03bbe..0000000 --- a/src/net/woodyfolsom/msproj/ann/TemporalDifferenceLearning.java +++ /dev/null @@ -1,487 +0,0 @@ -package net.woodyfolsom.msproj.ann; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.encog.EncogError; -import org.encog.engine.network.activation.ActivationFunction; -import org.encog.engine.network.activation.ActivationSigmoid; -import org.encog.mathutil.IntRange; -import org.encog.ml.MLMethod; -import org.encog.ml.TrainingImplementationType; -import org.encog.ml.data.MLDataPair; -import org.encog.ml.data.MLDataSet; -import org.encog.ml.train.MLTrain; -import org.encog.ml.train.strategy.Strategy; -import org.encog.ml.train.strategy.end.EndTrainingStrategy; -import org.encog.neural.error.ErrorFunction; -import org.encog.neural.error.LinearErrorFunction; -import org.encog.neural.flat.FlatNetwork; -import org.encog.neural.networks.ContainsFlat; -import org.encog.neural.networks.training.LearningRate; -import org.encog.neural.networks.training.Momentum; -import org.encog.neural.networks.training.Train; -import org.encog.neural.networks.training.TrainingError; -import org.encog.neural.networks.training.propagation.TrainingContinuation; -import org.encog.neural.networks.training.propagation.back.Backpropagation; -import org.encog.neural.networks.training.strategy.SmartLearningRate; -import org.encog.neural.networks.training.strategy.SmartMomentum; -import org.encog.util.EncogValidate; -import org.encog.util.EngineArray; -import org.encog.util.concurrency.DetermineWorkload; -import org.encog.util.concurrency.EngineConcurrency; -import org.encog.util.concurrency.MultiThreadable; -import org.encog.util.concurrency.TaskGroup; -import org.encog.util.logging.EncogLogging; - -/** - * This class started as a verbatim copy of BackPropagation from the open-source - * Encog framework. It was merged with its super-classes to access protected - * fields without resorting to reflection. - */ -public class TemporalDifferenceLearning implements MLTrain, Momentum, - LearningRate, Train, MultiThreadable { - // New fields for TD(lambda) - private final double lambda; - // end new fields - - // BackProp - public static final String LAST_DELTA = "LAST_DELTA"; - private double learningRate; - private double momentum; - private double[] lastDelta; - // End BackProp - - // Propagation - private FlatNetwork currentFlatNetwork; - private int numThreads; - protected double[] gradients; - private double[] lastGradient; - protected ContainsFlat network; - // private MLDataSet indexable; - private Set> indexable; - private GradientWorker[] workers; - private double totalError; - protected double lastError; - private Throwable reportedException; - private double[] flatSpot; - private boolean shouldFixFlatSpot; - private ErrorFunction ef = new LinearErrorFunction(); - // End Propagation - - // BasicTraining - private final List strategies = new ArrayList(); - //private Set> training; - private double error; - private int iteration; - private TrainingImplementationType implementationType; - - // End BasicTraining - - public TemporalDifferenceLearning(final ContainsFlat network, - final Set> training, double lambda) { - this(network, training, 0, 0, lambda); - addStrategy(new SmartLearningRate()); - addStrategy(new SmartMomentum()); - } - - public TemporalDifferenceLearning(final ContainsFlat network, - Set> training, final double theLearnRate, - final double theMomentum, double lambda) { - initPropagation(network, training); - // TODO consider how to re-implement validation - // ValidateNetwork.validateMethodToData(network, training); - this.momentum = theMomentum; - this.learningRate = theLearnRate; - this.lastDelta = new double[network.getFlat().getWeights().length]; - this.lambda = lambda; - } - - private void initPropagation(final ContainsFlat network, - final Set> training) { - initBasicTraining(TrainingImplementationType.Iterative); - this.network = network; - this.currentFlatNetwork = network.getFlat(); - //setTraining(training); - - this.gradients = new double[this.currentFlatNetwork.getWeights().length]; - this.lastGradient = new double[this.currentFlatNetwork.getWeights().length]; - - this.indexable = training; - this.numThreads = 0; - this.reportedException = null; - this.shouldFixFlatSpot = true; - } - - private void initBasicTraining(TrainingImplementationType implementationType) { - this.implementationType = implementationType; - } - - // Methods from BackPropagation - @Override - public boolean canContinue() { - return false; - } - - public double[] getLastDelta() { - return this.lastDelta; - } - - @Override - public double getLearningRate() { - return this.learningRate; - } - - @Override - public double getMomentum() { - return this.momentum; - } - - public boolean isValidResume(final TrainingContinuation state) { - if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) { - return false; - } - - if (!state.getTrainingType().equals(getClass().getSimpleName())) { - return false; - } - - final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA); - return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length; - } - - @Override - public TrainingContinuation pause() { - final TrainingContinuation result = new TrainingContinuation(); - result.setTrainingType(this.getClass().getSimpleName()); - result.set(Backpropagation.LAST_DELTA, this.lastDelta); - return result; - } - - @Override - public void resume(final TrainingContinuation state) { - if (!isValidResume(state)) { - throw new TrainingError("Invalid training resume data length"); - } - - this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA)); - } - - @Override - public void setLearningRate(final double rate) { - this.learningRate = rate; - } - - @Override - public void setMomentum(final double m) { - this.momentum = m; - } - - public double updateWeight(final double[] gradients, - final double[] lastGradient, final int index) { - - final double delta = (gradients[index] * this.learningRate) - + (this.lastDelta[index] * this.momentum); - - this.lastDelta[index] = delta; - - System.out.println("Updating weights for connection: " + index - + " with lambda: " + lambda); - - return delta; - } - - public void initOthers() { - } - - // End methods from BackPropagation - - // Methods from Propagation - public void finishTraining() { - basicFinishTraining(); - } - - public FlatNetwork getCurrentFlatNetwork() { - return this.currentFlatNetwork; - } - - public MLMethod getMethod() { - return this.network; - } - - public void iteration() { - iteration(1); - } - - public void rollIteration() { - this.iteration++; - } - - public void iteration(final int count) { - - try { - for (int i = 0; i < count; i++) { - - preIteration(); - - rollIteration(); - - calculateGradients(); - - if (this.currentFlatNetwork.isLimited()) { - learnLimited(); - } else { - learn(); - } - - this.lastError = this.getError(); - - for (final GradientWorker worker : this.workers) { - EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(), - 0, worker.getWeights(), 0, - this.currentFlatNetwork.getWeights().length); - } - - if (this.currentFlatNetwork.getHasContext()) { - copyContexts(); - } - - if (this.reportedException != null) { - throw (new EncogError(this.reportedException)); - } - - postIteration(); - - EncogLogging.log(EncogLogging.LEVEL_INFO, - "Training iteration done, error: " + getError()); - - } - } catch (final ArrayIndexOutOfBoundsException ex) { - EncogValidate.validateNetworkForTraining(this.network, - getTraining()); - throw new EncogError(ex); - } - } - - public void setThreadCount(final int numThreads) { - this.numThreads = numThreads; - } - - @Override - public int getThreadCount() { - return this.numThreads; - } - - public void fixFlatSpot(boolean b) { - this.shouldFixFlatSpot = b; - } - - public void setErrorFunction(ErrorFunction ef) { - this.ef = ef; - } - - public void calculateGradients() { - if (this.workers == null) { - init(); - } - - if (this.currentFlatNetwork.getHasContext()) { - this.workers[0].getNetwork().clearContext(); - } - - this.totalError = 0; - - if (this.workers.length > 1) { - - final TaskGroup group = EngineConcurrency.getInstance() - .createTaskGroup(); - - for (final GradientWorker worker : this.workers) { - EngineConcurrency.getInstance().processTask(worker, group); - } - - group.waitForComplete(); - } else { - this.workers[0].run(); - } - - this.setError(this.totalError / this.workers.length); - - } - - /** - * Copy the contexts to keep them consistent with multithreaded training. - */ - private void copyContexts() { - - // copy the contexts(layer outputO from each group to the next group - for (int i = 0; i < (this.workers.length - 1); i++) { - final double[] src = this.workers[i].getNetwork().getLayerOutput(); - final double[] dst = this.workers[i + 1].getNetwork() - .getLayerOutput(); - EngineArray.arrayCopy(src, dst); - } - - // copy the contexts from the final group to the real network - EngineArray.arrayCopy(this.workers[this.workers.length - 1] - .getNetwork().getLayerOutput(), this.currentFlatNetwork - .getLayerOutput()); - } - - private void init() { - // fix flat spot, if needed - this.flatSpot = new double[this.currentFlatNetwork - .getActivationFunctions().length]; - - if (this.shouldFixFlatSpot) { - for (int i = 0; i < this.currentFlatNetwork - .getActivationFunctions().length; i++) { - final ActivationFunction af = this.currentFlatNetwork - .getActivationFunctions()[i]; - - if (af instanceof ActivationSigmoid) { - this.flatSpot[i] = 0.1; - } else { - this.flatSpot[i] = 0.0; - } - } - } else { - EngineArray.fill(this.flatSpot, 0.0); - } - - // setup workers - final DetermineWorkload determine = new DetermineWorkload( - this.numThreads, (int) this.indexable.size()); - // this.numThreads, (int) this.indexable.getRecordCount()); - - this.workers = new GradientWorker[determine.getThreadCount()]; - - int index = 0; - - // handle CPU - for (final IntRange r : determine.calculateWorkers()) { - this.workers[index++] = new GradientWorker( - this.currentFlatNetwork.clone(), this, new HashSet( - this.indexable), r.getLow(), r.getHigh(), - this.flatSpot, this.ef); - } - - initOthers(); - } - - public void report(final double[] gradients, final double error, - final Throwable ex) { - synchronized (this) { - if (ex == null) { - - for (int i = 0; i < gradients.length; i++) { - this.gradients[i] += gradients[i]; - } - this.totalError += error; - } else { - this.reportedException = ex; - } - } - } - - protected void learn() { - final double[] weights = this.currentFlatNetwork.getWeights(); - for (int i = 0; i < this.gradients.length; i++) { - weights[i] += updateWeight(this.gradients, this.lastGradient, i); - this.gradients[i] = 0; - } - } - - protected void learnLimited() { - final double limit = this.currentFlatNetwork.getConnectionLimit(); - final double[] weights = this.currentFlatNetwork.getWeights(); - for (int i = 0; i < this.gradients.length; i++) { - if (Math.abs(weights[i]) < limit) { - weights[i] = 0; - } else { - weights[i] += updateWeight(this.gradients, this.lastGradient, i); - } - this.gradients[i] = 0; - } - } - - public double[] getLastGradient() { - return lastGradient; - } - - // End methods from Propagation - - // Methods from BasicTraining/ - public void addStrategy(final Strategy strategy) { - strategy.init(this); - this.strategies.add(strategy); - } - - public void basicFinishTraining() { - } - - public double getError() { - return this.error; - } - - public int getIteration() { - return this.iteration; - } - - public List getStrategies() { - return this.strategies; - } - - public MLDataSet getTraining() { - throw new UnsupportedOperationException( - "This learning method operates on Set>, not MLDataSet"); - } - - public boolean isTrainingDone() { - for (Strategy strategy : this.strategies) { - if (strategy instanceof EndTrainingStrategy) { - EndTrainingStrategy end = (EndTrainingStrategy) strategy; - if (end.shouldStop()) { - return true; - } - } - } - - return false; - } - - public void postIteration() { - for (final Strategy strategy : this.strategies) { - strategy.postIteration(); - } - } - - public void preIteration() { - - this.iteration++; - - for (final Strategy strategy : this.strategies) { - strategy.preIteration(); - } - } - - public void setError(final double error) { - this.error = error; - } - - public void setIteration(final int iteration) { - this.iteration = iteration; - } - - public void setTraining(final Set> training) { - //this.training = training; - throw new UnsupportedOperationException(); - } - - public TrainingImplementationType getImplementationType() { - return this.implementationType; - } - // End Methods from BasicTraining -} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann/WinFilter.java b/src/net/woodyfolsom/msproj/ann/WinFilter.java index d52083e..c042691 100644 --- a/src/net/woodyfolsom/msproj/ann/WinFilter.java +++ b/src/net/woodyfolsom/msproj/ann/WinFilter.java @@ -3,16 +3,14 @@ package net.woodyfolsom.msproj.ann; import java.util.List; import java.util.Set; -import net.woodyfolsom.msproj.GameState; -import net.woodyfolsom.msproj.Player; - import org.encog.engine.network.activation.ActivationSigmoid; -import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; +import org.encog.ml.data.basic.BasicMLDataSet; import org.encog.ml.train.MLTrain; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.BasicLayer; +import org.encog.neural.networks.training.propagation.back.Backpropagation; public class WinFilter extends AbstractNeuralNetFilter implements NeuralNetFilter { @@ -29,55 +27,46 @@ public class WinFilter extends AbstractNeuralNetFilter implements this.neuralNetwork = network; } - @Override - public double computeValue(MLData input) { - if (input instanceof GameStateMLData) { - double[] idealVector = computeVector(input); - GameState gameState = ((GameStateMLData) input).getGameState(); - Player playerToMove = gameState.getPlayerToMove(); - if (playerToMove == Player.BLACK) { - return idealVector[0]; - } else if (playerToMove == Player.WHITE) { - return idealVector[1]; - } else { - throw new RuntimeException("Invalid GameState.playerToMove: " - + playerToMove); - } - } else { - throw new UnsupportedOperationException( - "This NeuralNetFilter only accepts GameStates as input."); - } - } - - @Override - public double[] computeVector(MLData input) { - if (input instanceof GameStateMLData) { - return neuralNetwork.compute(input).getData(); - } else { - throw new UnsupportedOperationException( - "This NeuralNetFilter only accepts GameStates as input."); - } - } - @Override public void learn(MLDataSet trainingData) { - throw new UnsupportedOperationException("This filter learns a Set>, not an MLDataSet"); + throw new UnsupportedOperationException( + "This filter learns a Set>, not an MLDataSet"); } - + /** - * Method is necessary because with temporal difference learning, some of the MLDataPairs are related by being a sequence - * of moves within a particular game. + * Method is necessary because with temporal difference learning, some of + * the MLDataPairs are related by being a sequence of moves within a + * particular game. */ @Override public void learn(Set> trainingSet) { + MLDataSet mlDataset = new BasicMLDataSet(); + + for (List gameRecord : trainingSet) { + for (int t = 0; t < gameRecord.size() - 1; t++) { + mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t) + .getInput())); + } + mlDataset.add(gameRecord.get(gameRecord.size() - 1)); + } // train the neural network - final MLTrain train = new TemporalDifferenceLearning(neuralNetwork, - trainingSet, 0.7, 0.8, 0.25); - + final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25); + //final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8); actualTrainingEpochs = 0; do { + if (actualTrainingEpochs > 0) { + int gameStateIndex = 0; + for (List gameRecord : trainingSet) { + for (int t = 0; t < gameRecord.size() - 1; t++) { + MLDataPair oldDataPair = mlDataset.get(gameStateIndex); + this.neuralNetwork.compute(oldDataPair.getInput()); + gameStateIndex++; + } + gameStateIndex++; + } + } train.iteration(); System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + train.getError()); diff --git a/src/net/woodyfolsom/msproj/ann/XORFilter.java b/src/net/woodyfolsom/msproj/ann/XORFilter.java index e6be979..1021dc5 100644 --- a/src/net/woodyfolsom/msproj/ann/XORFilter.java +++ b/src/net/woodyfolsom/msproj/ann/XORFilter.java @@ -7,7 +7,7 @@ import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; -import org.encog.ml.data.basic.BasicMLDataSet; +import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.train.MLTrain; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.BasicLayer; @@ -21,7 +21,7 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation; */ public class XORFilter extends AbstractNeuralNetFilter implements NeuralNetFilter { - + public XORFilter() { // create a neural network, without using a factory BasicNetwork network = new BasicNetwork(); @@ -34,32 +34,10 @@ public class XORFilter extends AbstractNeuralNetFilter implements this.neuralNetwork = network; } - @Override - public void learn(MLDataSet trainingSet) { - - // train the neural network - final MLTrain train = new Backpropagation(neuralNetwork, - trainingSet, 0.7, 0.8); - - actualTrainingEpochs = 0; - - do { - train.iteration(); - System.out.println("Epoch #" + actualTrainingEpochs + " Error:" - + train.getError()); - actualTrainingEpochs++; - } while (train.getError() > 0.01 - && actualTrainingEpochs <= maxTrainingEpochs); + public double compute(double x, double y) { + return compute(new BasicMLData(new double[]{x,y})).getData(0); } - - @Override - public double[] computeVector(MLData mlData) { - MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() }, - new double[][] { new double[getOutputSize()] }); - MLData output = neuralNetwork.compute(dataset.get(0).getInput()); - return output.getData(); - } - + @Override public int getInputSize() { return 2; @@ -72,12 +50,26 @@ public class XORFilter extends AbstractNeuralNetFilter implements } @Override - public double computeValue(MLData input) { - return computeVector(input)[0]; - } + public void learn(MLDataSet trainingSet) { + // train the neural network + final MLTrain train = new Backpropagation(neuralNetwork, trainingSet, + 0.7, 0.8); + + actualTrainingEpochs = 0; + + do { + train.iteration(); + System.out.println("Epoch #" + actualTrainingEpochs + " Error:" + + train.getError()); + actualTrainingEpochs++; + } while (train.getError() > 0.01 + && actualTrainingEpochs <= maxTrainingEpochs); + } + @Override public void learn(Set> trainingSet) { - throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set>."); + throw new UnsupportedOperationException( + "This Filter learns an MLDataSet, not a Set>."); } } \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/ActivationFunction.java b/src/net/woodyfolsom/msproj/ann2/ActivationFunction.java new file mode 100644 index 0000000..ffe4454 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/ActivationFunction.java @@ -0,0 +1,5 @@ +package net.woodyfolsom.msproj.ann2; + +public interface ActivationFunction { + double calculate(double arg); +} diff --git a/src/net/woodyfolsom/msproj/ann2/Layer.java b/src/net/woodyfolsom/msproj/ann2/Layer.java new file mode 100644 index 0000000..3bd0830 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/Layer.java @@ -0,0 +1,53 @@ +package net.woodyfolsom.msproj.ann2; + +import java.util.Arrays; + +public class Layer { + private Neuron[] neurons; + + public Layer() { + //default constructor for JAXB + } + + public Layer(int numNeurons, int numWeights, ActivationFunction activationFunction) { + neurons = new Neuron[numNeurons]; + for (int neuronIndex = 0; neuronIndex < numNeurons; neuronIndex++) { + neurons[neuronIndex] = new Neuron(activationFunction, numWeights); + } + } + + public int size() { + return neurons.length; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + Arrays.hashCode(neurons); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Layer other = (Layer) obj; + if (!Arrays.equals(neurons, other.neurons)) + return false; + return true; + } + + public Neuron[] getNeurons() { + return neurons; + } + + public void setNeurons(Neuron[] neurons) { + this.neurons = neurons; + } + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java b/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java new file mode 100644 index 0000000..93ae15f --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java @@ -0,0 +1,175 @@ +package net.woodyfolsom.msproj.ann2; + +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; + +import javax.xml.bind.JAXBContext; +import javax.xml.bind.JAXBException; +import javax.xml.bind.Marshaller; +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +@XmlRootElement +public class MultiLayerPerceptron extends NeuralNetwork { + private ActivationFunction activationFunction; + private boolean biased; + private Layer[] layers; + + public MultiLayerPerceptron() { + this(false, 1, 1); + } + + public MultiLayerPerceptron(boolean biased, int... layerSizes) { + int numLayers = layerSizes.length; + + if (numLayers < 2) { + throw new IllegalArgumentException("# of layers must be >= 2"); + } + + this.activationFunction = Sigmoid.function; + this.biased = biased; + this.layers = new Layer[numLayers]; + + int numWeights; + + for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) { + int layerSize = layerSizes[layerIndex]; + + if (layerSize < 1) { + throw new IllegalArgumentException("Layer size must be >= 1"); + } + + if (layerIndex == 0) { + numWeights = 0; + if (biased) { + layerSize++; + } + } else { + numWeights = layers[layerIndex - 1].size(); + } + + layers[layerIndex] = new Layer(layerSize, numWeights, + activationFunction); + } + } + + @XmlElement(type=Sigmoid.class) + public ActivationFunction getActivationFunction() { + return activationFunction; + } + + @XmlElement + public Layer[] getLayers() { + return layers; + } + + @Override + protected double[] getOutput() { + // TODO Auto-generated method stub + return null; + } + + @Override + protected Neuron[] getNeurons() { + // TODO Auto-generated method stub + return null; + } + + @XmlAttribute + public boolean isBiased() { + return biased; + } + + public void setActivationFunction(ActivationFunction activationFunction) { + this.activationFunction = activationFunction; + } + + @Override + protected void setInput(double[] input) { + // TODO Auto-generated method stub + + } + + public void setBiased(boolean biased) { + this.biased = biased; + } + + public void setLayers(Layer[] layers) { + this.layers = layers; + } + + @Override + public boolean load(InputStream is) { + try { + JAXBContext jc = JAXBContext + .newInstance(MultiLayerPerceptron.class); + + // unmarshal from foo.xml + Unmarshaller u = jc.createUnmarshaller(); + MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is); + + this.activationFunction = mlp.activationFunction; + this.biased = mlp.biased; + this.layers = mlp.layers; + + return true; + } catch (JAXBException je) { + je.printStackTrace(); + return false; + } + } + + @Override + public boolean save(OutputStream os) { + try { + JAXBContext jc = JAXBContext + .newInstance(MultiLayerPerceptron.class); + + Marshaller m = jc.createMarshaller(); + m.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true); + m.marshal(this, os); + m.marshal(this, System.out); + return true; + } catch (JAXBException je) { + je.printStackTrace(); + return false; + } + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime + * result + + ((activationFunction == null) ? 0 : activationFunction + .hashCode()); + result = prime * result + (biased ? 1231 : 1237); + result = prime * result + Arrays.hashCode(layers); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + MultiLayerPerceptron other = (MultiLayerPerceptron) obj; + if (activationFunction == null) { + if (other.activationFunction != null) + return false; + } else if (!activationFunction.equals(other.activationFunction)) + return false; + if (biased != other.biased) + return false; + if (!Arrays.equals(layers, other.layers)) + return false; + return true; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/NNData.java b/src/net/woodyfolsom/msproj/ann2/NNData.java new file mode 100644 index 0000000..b182765 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/NNData.java @@ -0,0 +1,29 @@ +package net.woodyfolsom.msproj.ann2; + +public class NNData { + private final double[] values; + private final String[] fields; + + public NNData(String[] fields, double[] values) { + this.fields = fields; + this.values = values; + } + + public double[] getValues() { + return values; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("["); + + for (int i = 0; i < fields.length; i++) { + if (i > 0) { + sb.append(", " ); + } + sb.append(fields[i] + "=" + values[i]); + } + sb.append("]"); + return sb.toString(); + } +} diff --git a/src/net/woodyfolsom/msproj/ann2/NNDataPair.java b/src/net/woodyfolsom/msproj/ann2/NNDataPair.java new file mode 100644 index 0000000..b6f76ea --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/NNDataPair.java @@ -0,0 +1,19 @@ +package net.woodyfolsom.msproj.ann2; + +public class NNDataPair { + private final NNData actual; + private final NNData ideal; + + public NNDataPair(NNData actual, NNData ideal) { + this.actual = actual; + this.ideal = ideal; + } + + public NNData getActual() { + return actual; + } + + public NNData getIdeal() { + return ideal; + } +} diff --git a/src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java b/src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java new file mode 100644 index 0000000..9679d6f --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java @@ -0,0 +1,53 @@ +package net.woodyfolsom.msproj.ann2; + +import java.io.InputStream; +import java.io.OutputStream; + +import javax.xml.bind.JAXBException; + +/** + * A NeuralNetwork is simply an ordered set of Neurons. + * + * Functions which rely on knowledge of input neurons, output neurons and layers + * are delegated to MultiLayerPerception. + * + * The primary function implemented in this abstract class is feedfoward. + * This function depends only on getNeurons() returning Neurons in feedforward order + * and the returned Neurons must have the correct number of weights for the NeuralNetwork + * configuration. + * + * @author Woody + * + */ +public abstract class NeuralNetwork { + public NeuralNetwork() { + } + + public double[] calculate(double[] input) { + zeroInputs(); + setInput(input); + feedforward(); + return getOutput(); + } + + protected void feedforward() { + Neuron[] neurons = getNeurons(); + } + + protected abstract double[] getOutput(); + + protected abstract Neuron[] getNeurons(); + + public abstract boolean load(InputStream is); + public abstract boolean save(OutputStream os); + + protected abstract void setInput(double[] input); + + protected void zeroInputs() { + for (Neuron neuron : getNeurons()) { + neuron.setInput(0.0); + } + } + + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Neuron.java b/src/net/woodyfolsom/msproj/ann2/Neuron.java new file mode 100644 index 0000000..f32ea0a --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/Neuron.java @@ -0,0 +1,92 @@ +package net.woodyfolsom.msproj.ann2; + +import java.util.Arrays; + +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlTransient; + +public class Neuron { + private ActivationFunction activationFunction; + private double[] weights; + + private transient double input = 0.0; + + public Neuron() { + //no-arg constructor for JAXB + } + + public Neuron(ActivationFunction activationFunction, int numWeights) { + this.activationFunction = activationFunction; + this.weights = new double[numWeights]; + } + + @XmlElement(type=Sigmoid.class) + public ActivationFunction getActivationFunction() { + return activationFunction; + } + + void afterUnmarshal(Unmarshaller aUnmarshaller, Object aParent) + { + if (weights == null) { + weights = new double[0]; + } + } + + @XmlTransient + public double getInput() { + return input; + } + + public double getOutput() { + return activationFunction.calculate(input); + } + + @XmlElement + public double[] getWeights() { + return weights; + } + + public void setInput(double input) { + this.input = input; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime + * result + + ((activationFunction == null) ? 0 : activationFunction + .hashCode()); + result = prime * result + Arrays.hashCode(weights); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Neuron other = (Neuron) obj; + if (activationFunction == null) { + if (other.activationFunction != null) + return false; + } else if (!activationFunction.equals(other.activationFunction)) + return false; + if (!Arrays.equals(weights, other.weights)) + return false; + return true; + } + + public void setActivationFunction(ActivationFunction activationFunction) { + this.activationFunction = activationFunction; + } + + public void setWeights(double[] weights) { + this.weights = weights; + } +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java b/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java new file mode 100644 index 0000000..4eac691 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java @@ -0,0 +1,5 @@ +package net.woodyfolsom.msproj.ann2; + +public class ObjectiveFunction { + +} diff --git a/src/net/woodyfolsom/msproj/ann2/Sigmoid.java b/src/net/woodyfolsom/msproj/ann2/Sigmoid.java new file mode 100644 index 0000000..8629496 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/Sigmoid.java @@ -0,0 +1,48 @@ +package net.woodyfolsom.msproj.ann2; + +public class Sigmoid implements ActivationFunction{ + public static final Sigmoid function = new Sigmoid(); + private String name; + + private Sigmoid() { + this.name = "Sigmoid"; + } + + public double calculate(double arg) { + return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg)); + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((name == null) ? 0 : name.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Sigmoid other = (Sigmoid) obj; + if (name == null) { + if (other.name != null) + return false; + } else if (!name.equals(other.name)) + return false; + return true; + } + +} \ No newline at end of file diff --git a/src/net/woodyfolsom/msproj/ann2/Tanh.java b/src/net/woodyfolsom/msproj/ann2/Tanh.java new file mode 100644 index 0000000..7277d67 --- /dev/null +++ b/src/net/woodyfolsom/msproj/ann2/Tanh.java @@ -0,0 +1,10 @@ +package net.woodyfolsom.msproj.ann2; + +public class Tanh implements ActivationFunction{ + + @Override + public double calculate(double arg) { + return Math.tanh(arg); + } + +} diff --git a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java index 52eaf09..bcb0e18 100644 --- a/test/net/woodyfolsom/msproj/ann/WinFilterTest.java +++ b/test/net/woodyfolsom/msproj/ann/WinFilterTest.java @@ -50,7 +50,6 @@ public class WinFilterTest { winFilter.learn(trainingData); for (List trainingSequence : trainingData) { - //for (MLDataPair mlDataPair : trainingSequence) { for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) { if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) { continue; @@ -58,9 +57,8 @@ public class WinFilterTest { MLData input = trainingSequence.get(stateIndex).getInput(); System.out.println("Turn " + stateIndex + ": " + input + " => " - + winFilter.computeValue(input)); + + winFilter.compute(input)); } - //} } } } diff --git a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java index e0c2977..8a39c15 100644 --- a/test/net/woodyfolsom/msproj/ann/XORFilterTest.java +++ b/test/net/woodyfolsom/msproj/ann/XORFilterTest.java @@ -73,7 +73,7 @@ public class XORFilterTest { private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]); - System.out.println(dp + " => " + nnLearner.computeValue(dp)); + System.out.println(dp + " => " + nnLearner.compute(dp)); } } } \ No newline at end of file diff --git a/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java b/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java new file mode 100644 index 0000000..344f026 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann2/MultiLayerPerceptronTest.java @@ -0,0 +1,62 @@ +package net.woodyfolsom.msproj.ann2; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import javax.xml.bind.JAXBException; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class MultiLayerPerceptronTest { + static final File TEST_FILE = new File("data/test/mlp.net"); + + @BeforeClass + public static void setUp() { + if (TEST_FILE.exists()) { + TEST_FILE.delete(); + } + } + + @AfterClass + public static void tearDown() { + if (TEST_FILE.exists()) { + TEST_FILE.delete(); + } + } + + @Test + public void testConstructor() { + new MultiLayerPerceptron(true, 2, 4, 1); + new MultiLayerPerceptron(false, 2, 1); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorTooFewLayers() { + new MultiLayerPerceptron(true, 2); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorTooFewNeurons() { + new MultiLayerPerceptron(true, 2, 4, 0, 1); + } + + @Test + public void testPersistence() throws JAXBException, IOException { + NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1); + FileOutputStream fos = new FileOutputStream(TEST_FILE); + assertTrue(mlp.save(fos)); + fos.close(); + FileInputStream fis = new FileInputStream(TEST_FILE); + NeuralNetwork mlp2 = new MultiLayerPerceptron(); + assertTrue(mlp2.load(fis)); + assertEquals(mlp, mlp2); + fis.close(); + } +} diff --git a/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java b/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java new file mode 100644 index 0000000..a8e3350 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann2/SigmoidTest.java @@ -0,0 +1,18 @@ +package net.woodyfolsom.msproj.ann2; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class SigmoidTest { + @Test + public void testCalculate() { + double EPS = 0.001; + + ActivationFunction sigmoid = Sigmoid.function; + assertEquals(0.5,sigmoid.calculate(0.0),EPS); + assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS); + assertTrue(sigmoid.calculate(-9000.0) < EPS); + } +} diff --git a/test/net/woodyfolsom/msproj/ann2/TanhTest.java b/test/net/woodyfolsom/msproj/ann2/TanhTest.java new file mode 100644 index 0000000..abd3874 --- /dev/null +++ b/test/net/woodyfolsom/msproj/ann2/TanhTest.java @@ -0,0 +1,18 @@ +package net.woodyfolsom.msproj.ann2; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class TanhTest { + @Test + public void testCalculate() { + double EPS = 0.001; + + ActivationFunction sigmoid = new Tanh(); + assertEquals(0.0,sigmoid.calculate(0.0),EPS); + assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS); + assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS); + } +}