Starting my own ANN implementation.
This commit is contained in:
@@ -5,6 +5,7 @@ import java.io.FileInputStream;
|
|||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.encog.ml.data.MLData;
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
import org.encog.neural.networks.PersistBasicNetwork;
|
import org.encog.neural.networks.PersistBasicNetwork;
|
||||||
|
|
||||||
@@ -13,6 +14,11 @@ public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
|||||||
protected int actualTrainingEpochs = 0;
|
protected int actualTrainingEpochs = 0;
|
||||||
protected int maxTrainingEpochs = 1000;
|
protected int maxTrainingEpochs = 1000;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLData compute(MLData input) {
|
||||||
|
return this.neuralNetwork.compute(input);
|
||||||
|
}
|
||||||
|
|
||||||
public int getActualTrainingEpochs() {
|
public int getActualTrainingEpochs() {
|
||||||
return actualTrainingEpochs;
|
return actualTrainingEpochs;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import org.encog.mathutil.error.ErrorCalculationMode;
|
|
||||||
|
|
||||||
/*
|
|
||||||
Initial erison of this class was a verbatim copy from Encog framework.
|
|
||||||
*/
|
|
||||||
|
|
||||||
public class ErrorCalculation {
|
|
||||||
|
|
||||||
private static ErrorCalculationMode mode = ErrorCalculationMode.MSE;
|
|
||||||
|
|
||||||
public static ErrorCalculationMode getMode() {
|
|
||||||
return ErrorCalculation.mode;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setMode(final ErrorCalculationMode theMode) {
|
|
||||||
ErrorCalculation.mode = theMode;
|
|
||||||
}
|
|
||||||
|
|
||||||
private double globalError;
|
|
||||||
|
|
||||||
private int setSize;
|
|
||||||
|
|
||||||
public final double calculate() {
|
|
||||||
if (this.setSize == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (ErrorCalculation.getMode()) {
|
|
||||||
case RMS:
|
|
||||||
return calculateRMS();
|
|
||||||
case MSE:
|
|
||||||
return calculateMSE();
|
|
||||||
case ESS:
|
|
||||||
return calculateESS();
|
|
||||||
default:
|
|
||||||
return calculateMSE();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public final double calculateMSE() {
|
|
||||||
if (this.setSize == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
final double err = this.globalError / this.setSize;
|
|
||||||
return err;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public final double calculateESS() {
|
|
||||||
if (this.setSize == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
final double err = this.globalError / 2;
|
|
||||||
return err;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public final double calculateRMS() {
|
|
||||||
if (this.setSize == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
final double err = Math.sqrt(this.globalError / this.setSize);
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final void reset() {
|
|
||||||
this.globalError = 0;
|
|
||||||
this.setSize = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final void updateError(final double actual, final double ideal) {
|
|
||||||
|
|
||||||
double delta = ideal - actual;
|
|
||||||
|
|
||||||
this.globalError += delta * delta;
|
|
||||||
|
|
||||||
this.setSize++;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public final void updateError(final double[] actual, final double[] ideal,
|
|
||||||
final double significance) {
|
|
||||||
for (int i = 0; i < actual.length; i++) {
|
|
||||||
double delta = (ideal[i] - actual[i]) * significance;
|
|
||||||
|
|
||||||
this.globalError += delta * delta;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize += ideal.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.GameState;
|
|
||||||
|
|
||||||
import org.encog.ml.data.basic.BasicMLData;
|
|
||||||
|
|
||||||
public class GameStateMLData extends BasicMLData {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
|
|
||||||
private GameState gameState;
|
|
||||||
|
|
||||||
public GameStateMLData(double[] d, GameState gameState) {
|
|
||||||
super(d);
|
|
||||||
// TODO Auto-generated constructor stub
|
|
||||||
this.gameState = gameState;
|
|
||||||
}
|
|
||||||
|
|
||||||
public GameState getGameState() {
|
|
||||||
return gameState;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,16 +11,12 @@ import org.encog.ml.data.basic.BasicMLDataPair;
|
|||||||
import org.encog.util.kmeans.Centroid;
|
import org.encog.util.kmeans.Centroid;
|
||||||
|
|
||||||
public class GameStateMLDataPair implements MLDataPair {
|
public class GameStateMLDataPair implements MLDataPair {
|
||||||
//private final String[] inputs = { "BlackScore", "WhiteScore" };
|
|
||||||
//private final String[] outputs = { "BlackWins", "WhiteWins" };
|
|
||||||
|
|
||||||
private BasicMLDataPair mlDataPairDelegate;
|
private BasicMLDataPair mlDataPairDelegate;
|
||||||
private GameState gameState;
|
private GameState gameState;
|
||||||
|
|
||||||
public GameStateMLDataPair(GameState gameState) {
|
public GameStateMLDataPair(GameState gameState) {
|
||||||
this.gameState = gameState;
|
this.gameState = gameState;
|
||||||
mlDataPairDelegate = new BasicMLDataPair(
|
mlDataPairDelegate = new BasicMLDataPair(new BasicMLData(createInput()), new BasicMLData(createIdeal()));
|
||||||
new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public GameStateMLDataPair(GameStateMLDataPair that) {
|
public GameStateMLDataPair(GameStateMLDataPair that) {
|
||||||
@@ -118,4 +114,4 @@ public class GameStateMLDataPair implements MLDataPair {
|
|||||||
mlDataPairDelegate.setSignificance(arg0);
|
mlDataPairDelegate.setSignificance(arg0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class copied verbatim from Encog framework due to dependency on Propagation
|
|
||||||
* implementation.
|
|
||||||
*
|
|
||||||
* Encog(tm) Core v3.2 - Java Version
|
|
||||||
* http://www.heatonresearch.com/encog/
|
|
||||||
* http://code.google.com/p/encog-java/
|
|
||||||
|
|
||||||
* Copyright 2008-2012 Heaton Research, Inc.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
* For more information on Heaton Research copyrights, licenses
|
|
||||||
* and trademarks visit:
|
|
||||||
* http://www.heatonresearch.com/copyright
|
|
||||||
*/
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import org.encog.engine.network.activation.ActivationFunction;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.encog.ml.data.basic.BasicMLDataPair;
|
|
||||||
import org.encog.neural.error.ErrorFunction;
|
|
||||||
import org.encog.neural.flat.FlatNetwork;
|
|
||||||
import org.encog.util.EngineArray;
|
|
||||||
import org.encog.util.concurrency.EngineTask;
|
|
||||||
|
|
||||||
public class GradientWorker implements EngineTask {
|
|
||||||
|
|
||||||
private final FlatNetwork network;
|
|
||||||
private final ErrorCalculation errorCalculation = new ErrorCalculation();
|
|
||||||
private final List<double[]> actuals;
|
|
||||||
private final double[] layerDelta;
|
|
||||||
private final int[] layerCounts;
|
|
||||||
private final int[] layerFeedCounts;
|
|
||||||
private final int[] layerIndex;
|
|
||||||
private final int[] weightIndex;
|
|
||||||
private final double[] layerOutput;
|
|
||||||
private final double[] layerSums;
|
|
||||||
private final double[] gradients;
|
|
||||||
private final double[] weights;
|
|
||||||
private final MLDataPair pairPrototype;
|
|
||||||
private final Set<List<MLDataPair>> training;
|
|
||||||
//private final int low;
|
|
||||||
//private final int high;
|
|
||||||
private final TemporalDifferenceLearning owner;
|
|
||||||
private double[] flatSpot;
|
|
||||||
private final ErrorFunction errorFunction;
|
|
||||||
|
|
||||||
public GradientWorker(final FlatNetwork theNetwork,
|
|
||||||
final TemporalDifferenceLearning theOwner,
|
|
||||||
final Set<List<MLDataPair>> theTraining, final int theLow,
|
|
||||||
final int theHigh, final double[] flatSpot, ErrorFunction ef) {
|
|
||||||
this.network = theNetwork;
|
|
||||||
this.training = theTraining;
|
|
||||||
//this.low = theLow;
|
|
||||||
//this.high = theHigh;
|
|
||||||
this.owner = theOwner;
|
|
||||||
this.flatSpot = flatSpot;
|
|
||||||
this.errorFunction = ef;
|
|
||||||
|
|
||||||
this.layerDelta = new double[network.getLayerOutput().length];
|
|
||||||
this.gradients = new double[network.getWeights().length];
|
|
||||||
this.actuals = new ArrayList<double[]>();
|
|
||||||
|
|
||||||
this.weights = network.getWeights();
|
|
||||||
this.layerIndex = network.getLayerIndex();
|
|
||||||
this.layerCounts = network.getLayerCounts();
|
|
||||||
this.weightIndex = network.getWeightIndex();
|
|
||||||
this.layerOutput = network.getLayerOutput();
|
|
||||||
this.layerSums = network.getLayerSums();
|
|
||||||
this.layerFeedCounts = network.getLayerFeedCounts();
|
|
||||||
|
|
||||||
this.pairPrototype = BasicMLDataPair.createPair(
|
|
||||||
network.getInputCount(), network.getOutputCount());
|
|
||||||
}
|
|
||||||
|
|
||||||
public FlatNetwork getNetwork() {
|
|
||||||
return this.network;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] getWeights() {
|
|
||||||
return this.weights;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void process(List<MLDataPair> trainingSequence) {
|
|
||||||
actuals.clear();
|
|
||||||
|
|
||||||
for (int trainingIdx = 0; trainingIdx < trainingSequence.size(); trainingIdx++) {
|
|
||||||
MLDataPair mlDataPair = trainingSequence.get(trainingIdx);
|
|
||||||
MLDataPair dataPairCopy = this.pairPrototype;
|
|
||||||
dataPairCopy.setInputArray(mlDataPair.getInputArray());
|
|
||||||
if (dataPairCopy.getIdealArray() != null) {
|
|
||||||
dataPairCopy.setIdealArray(mlDataPair.getIdealArray());
|
|
||||||
}
|
|
||||||
|
|
||||||
double[] input = dataPairCopy.getInputArray();
|
|
||||||
double[] ideal = dataPairCopy.getIdealArray();
|
|
||||||
double significance = dataPairCopy.getSignificance();
|
|
||||||
|
|
||||||
actuals.add(trainingIdx, new double[ideal.length]);
|
|
||||||
this.network.compute(input, actuals.get(trainingIdx));
|
|
||||||
|
|
||||||
// For now, only calculate deltas for the final data pair
|
|
||||||
// For final TDL algorithm, deltas won't be used at all, instead the
|
|
||||||
// List of Actual vectors will.
|
|
||||||
if (trainingIdx < trainingSequence.size() - 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.errorCalculation.updateError(actuals.get(trainingIdx), ideal,
|
|
||||||
significance);
|
|
||||||
this.errorFunction.calculateError(ideal, actuals.get(trainingIdx),
|
|
||||||
this.layerDelta);
|
|
||||||
|
|
||||||
for (int i = 0; i < actuals.get(trainingIdx).length; i++) {
|
|
||||||
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
|
|
||||||
.derivativeFunction(this.layerSums[i],
|
|
||||||
this.layerOutput[i]) + this.flatSpot[0]))
|
|
||||||
* (this.layerDelta[i] * significance);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = this.network.getBeginTraining(); i < this.network
|
|
||||||
.getEndTraining(); i++) {
|
|
||||||
processLevel(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void processLevel(final int currentLevel) {
|
|
||||||
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
|
|
||||||
final int toLayerIndex = this.layerIndex[currentLevel];
|
|
||||||
final int fromLayerSize = this.layerCounts[currentLevel + 1];
|
|
||||||
final int toLayerSize = this.layerFeedCounts[currentLevel];
|
|
||||||
|
|
||||||
final int index = this.weightIndex[currentLevel];
|
|
||||||
final ActivationFunction activation = this.network
|
|
||||||
.getActivationFunctions()[currentLevel];
|
|
||||||
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
|
|
||||||
|
|
||||||
// handle weights
|
|
||||||
int yi = fromLayerIndex;
|
|
||||||
for (int y = 0; y < fromLayerSize; y++) {
|
|
||||||
final double output = this.layerOutput[yi];
|
|
||||||
double sum = 0;
|
|
||||||
int xi = toLayerIndex;
|
|
||||||
int wi = index + y;
|
|
||||||
for (int x = 0; x < toLayerSize; x++) {
|
|
||||||
this.gradients[wi] += output * this.layerDelta[xi];
|
|
||||||
sum += this.weights[wi] * this.layerDelta[xi];
|
|
||||||
wi += fromLayerSize;
|
|
||||||
xi++;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.layerDelta[yi] = sum
|
|
||||||
* (activation.derivativeFunction(this.layerSums[yi],
|
|
||||||
this.layerOutput[yi]) + currentFlatSpot);
|
|
||||||
yi++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public final void run() {
|
|
||||||
try {
|
|
||||||
this.errorCalculation.reset();
|
|
||||||
|
|
||||||
for (List<MLDataPair> trainingSequence : training) {
|
|
||||||
process(trainingSequence);
|
|
||||||
}
|
|
||||||
|
|
||||||
final double error = this.errorCalculation.calculate();
|
|
||||||
this.owner.report(this.gradients, error, null);
|
|
||||||
EngineArray.fill(this.gradients, 0);
|
|
||||||
} catch (final Throwable ex) {
|
|
||||||
this.owner.report(null, 0, ex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -11,21 +11,20 @@ import org.encog.neural.networks.BasicNetwork;
|
|||||||
|
|
||||||
public interface NeuralNetFilter {
|
public interface NeuralNetFilter {
|
||||||
BasicNetwork getNeuralNetwork();
|
BasicNetwork getNeuralNetwork();
|
||||||
|
|
||||||
|
int getActualTrainingEpochs();
|
||||||
|
int getInputSize();
|
||||||
|
int getMaxTrainingEpochs();
|
||||||
|
int getOutputSize();
|
||||||
|
|
||||||
|
void learn(MLDataSet trainingSet);
|
||||||
|
void learn(Set<List<MLDataPair>> trainingSet);
|
||||||
|
|
||||||
|
void load(String fileName) throws IOException;
|
||||||
|
void reset();
|
||||||
|
void reset(int seed);
|
||||||
|
void save(String fileName) throws IOException;
|
||||||
|
void setMaxTrainingEpochs(int max);
|
||||||
|
|
||||||
public int getActualTrainingEpochs();
|
MLData compute(MLData input);
|
||||||
public int getInputSize();
|
|
||||||
public int getMaxTrainingEpochs();
|
|
||||||
public int getOutputSize();
|
|
||||||
|
|
||||||
public double computeValue(MLData input);
|
|
||||||
public double[] computeVector(MLData input);
|
|
||||||
|
|
||||||
public void learn(MLDataSet trainingSet);
|
|
||||||
public void learn(Set<List<MLDataPair>> trainingSet);
|
|
||||||
|
|
||||||
public void load(String fileName) throws IOException;
|
|
||||||
public void reset();
|
|
||||||
public void reset(int seed);
|
|
||||||
public void save(String fileName) throws IOException;
|
|
||||||
public void setMaxTrainingEpochs(int max);
|
|
||||||
}
|
}
|
||||||
30
src/net/woodyfolsom/msproj/ann/TemporalDifference.java
Normal file
30
src/net/woodyfolsom/msproj/ann/TemporalDifference.java
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.neural.networks.ContainsFlat;
|
||||||
|
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||||
|
|
||||||
|
public class TemporalDifference extends Backpropagation {
|
||||||
|
private final double lambda;
|
||||||
|
|
||||||
|
public TemporalDifference(ContainsFlat network, MLDataSet training,
|
||||||
|
double theLearnRate, double theMomentum, double lambda) {
|
||||||
|
super(network, training, theLearnRate, theMomentum);
|
||||||
|
this.lambda = lambda;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getLamdba() {
|
||||||
|
return lambda;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double updateWeight(final double[] gradients,
|
||||||
|
final double[] lastGradient, final int index) {
|
||||||
|
double alpha = this.getLearningRate();
|
||||||
|
|
||||||
|
//TODO fill in weight update for TD(lambda)
|
||||||
|
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,487 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import org.encog.EncogError;
|
|
||||||
import org.encog.engine.network.activation.ActivationFunction;
|
|
||||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
|
||||||
import org.encog.mathutil.IntRange;
|
|
||||||
import org.encog.ml.MLMethod;
|
|
||||||
import org.encog.ml.TrainingImplementationType;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
|
||||||
import org.encog.ml.data.MLDataSet;
|
|
||||||
import org.encog.ml.train.MLTrain;
|
|
||||||
import org.encog.ml.train.strategy.Strategy;
|
|
||||||
import org.encog.ml.train.strategy.end.EndTrainingStrategy;
|
|
||||||
import org.encog.neural.error.ErrorFunction;
|
|
||||||
import org.encog.neural.error.LinearErrorFunction;
|
|
||||||
import org.encog.neural.flat.FlatNetwork;
|
|
||||||
import org.encog.neural.networks.ContainsFlat;
|
|
||||||
import org.encog.neural.networks.training.LearningRate;
|
|
||||||
import org.encog.neural.networks.training.Momentum;
|
|
||||||
import org.encog.neural.networks.training.Train;
|
|
||||||
import org.encog.neural.networks.training.TrainingError;
|
|
||||||
import org.encog.neural.networks.training.propagation.TrainingContinuation;
|
|
||||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
|
||||||
import org.encog.neural.networks.training.strategy.SmartLearningRate;
|
|
||||||
import org.encog.neural.networks.training.strategy.SmartMomentum;
|
|
||||||
import org.encog.util.EncogValidate;
|
|
||||||
import org.encog.util.EngineArray;
|
|
||||||
import org.encog.util.concurrency.DetermineWorkload;
|
|
||||||
import org.encog.util.concurrency.EngineConcurrency;
|
|
||||||
import org.encog.util.concurrency.MultiThreadable;
|
|
||||||
import org.encog.util.concurrency.TaskGroup;
|
|
||||||
import org.encog.util.logging.EncogLogging;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This class started as a verbatim copy of BackPropagation from the open-source
|
|
||||||
* Encog framework. It was merged with its super-classes to access protected
|
|
||||||
* fields without resorting to reflection.
|
|
||||||
*/
|
|
||||||
public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
|
||||||
LearningRate, Train, MultiThreadable {
|
|
||||||
// New fields for TD(lambda)
|
|
||||||
private final double lambda;
|
|
||||||
// end new fields
|
|
||||||
|
|
||||||
// BackProp
|
|
||||||
public static final String LAST_DELTA = "LAST_DELTA";
|
|
||||||
private double learningRate;
|
|
||||||
private double momentum;
|
|
||||||
private double[] lastDelta;
|
|
||||||
// End BackProp
|
|
||||||
|
|
||||||
// Propagation
|
|
||||||
private FlatNetwork currentFlatNetwork;
|
|
||||||
private int numThreads;
|
|
||||||
protected double[] gradients;
|
|
||||||
private double[] lastGradient;
|
|
||||||
protected ContainsFlat network;
|
|
||||||
// private MLDataSet indexable;
|
|
||||||
private Set<List<MLDataPair>> indexable;
|
|
||||||
private GradientWorker[] workers;
|
|
||||||
private double totalError;
|
|
||||||
protected double lastError;
|
|
||||||
private Throwable reportedException;
|
|
||||||
private double[] flatSpot;
|
|
||||||
private boolean shouldFixFlatSpot;
|
|
||||||
private ErrorFunction ef = new LinearErrorFunction();
|
|
||||||
// End Propagation
|
|
||||||
|
|
||||||
// BasicTraining
|
|
||||||
private final List<Strategy> strategies = new ArrayList<Strategy>();
|
|
||||||
//private Set<List<MLDataPair>> training;
|
|
||||||
private double error;
|
|
||||||
private int iteration;
|
|
||||||
private TrainingImplementationType implementationType;
|
|
||||||
|
|
||||||
// End BasicTraining
|
|
||||||
|
|
||||||
public TemporalDifferenceLearning(final ContainsFlat network,
|
|
||||||
final Set<List<MLDataPair>> training, double lambda) {
|
|
||||||
this(network, training, 0, 0, lambda);
|
|
||||||
addStrategy(new SmartLearningRate());
|
|
||||||
addStrategy(new SmartMomentum());
|
|
||||||
}
|
|
||||||
|
|
||||||
public TemporalDifferenceLearning(final ContainsFlat network,
|
|
||||||
Set<List<MLDataPair>> training, final double theLearnRate,
|
|
||||||
final double theMomentum, double lambda) {
|
|
||||||
initPropagation(network, training);
|
|
||||||
// TODO consider how to re-implement validation
|
|
||||||
// ValidateNetwork.validateMethodToData(network, training);
|
|
||||||
this.momentum = theMomentum;
|
|
||||||
this.learningRate = theLearnRate;
|
|
||||||
this.lastDelta = new double[network.getFlat().getWeights().length];
|
|
||||||
this.lambda = lambda;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void initPropagation(final ContainsFlat network,
|
|
||||||
final Set<List<MLDataPair>> training) {
|
|
||||||
initBasicTraining(TrainingImplementationType.Iterative);
|
|
||||||
this.network = network;
|
|
||||||
this.currentFlatNetwork = network.getFlat();
|
|
||||||
//setTraining(training);
|
|
||||||
|
|
||||||
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
|
|
||||||
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
|
|
||||||
|
|
||||||
this.indexable = training;
|
|
||||||
this.numThreads = 0;
|
|
||||||
this.reportedException = null;
|
|
||||||
this.shouldFixFlatSpot = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void initBasicTraining(TrainingImplementationType implementationType) {
|
|
||||||
this.implementationType = implementationType;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Methods from BackPropagation
|
|
||||||
@Override
|
|
||||||
public boolean canContinue() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] getLastDelta() {
|
|
||||||
return this.lastDelta;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getLearningRate() {
|
|
||||||
return this.learningRate;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getMomentum() {
|
|
||||||
return this.momentum;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isValidResume(final TrainingContinuation state) {
|
|
||||||
if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!state.getTrainingType().equals(getClass().getSimpleName())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA);
|
|
||||||
return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public TrainingContinuation pause() {
|
|
||||||
final TrainingContinuation result = new TrainingContinuation();
|
|
||||||
result.setTrainingType(this.getClass().getSimpleName());
|
|
||||||
result.set(Backpropagation.LAST_DELTA, this.lastDelta);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resume(final TrainingContinuation state) {
|
|
||||||
if (!isValidResume(state)) {
|
|
||||||
throw new TrainingError("Invalid training resume data length");
|
|
||||||
}
|
|
||||||
|
|
||||||
this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setLearningRate(final double rate) {
|
|
||||||
this.learningRate = rate;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setMomentum(final double m) {
|
|
||||||
this.momentum = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double updateWeight(final double[] gradients,
|
|
||||||
final double[] lastGradient, final int index) {
|
|
||||||
|
|
||||||
final double delta = (gradients[index] * this.learningRate)
|
|
||||||
+ (this.lastDelta[index] * this.momentum);
|
|
||||||
|
|
||||||
this.lastDelta[index] = delta;
|
|
||||||
|
|
||||||
System.out.println("Updating weights for connection: " + index
|
|
||||||
+ " with lambda: " + lambda);
|
|
||||||
|
|
||||||
return delta;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void initOthers() {
|
|
||||||
}
|
|
||||||
|
|
||||||
// End methods from BackPropagation
|
|
||||||
|
|
||||||
// Methods from Propagation
|
|
||||||
public void finishTraining() {
|
|
||||||
basicFinishTraining();
|
|
||||||
}
|
|
||||||
|
|
||||||
public FlatNetwork getCurrentFlatNetwork() {
|
|
||||||
return this.currentFlatNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MLMethod getMethod() {
|
|
||||||
return this.network;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void iteration() {
|
|
||||||
iteration(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void rollIteration() {
|
|
||||||
this.iteration++;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void iteration(final int count) {
|
|
||||||
|
|
||||||
try {
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
|
|
||||||
preIteration();
|
|
||||||
|
|
||||||
rollIteration();
|
|
||||||
|
|
||||||
calculateGradients();
|
|
||||||
|
|
||||||
if (this.currentFlatNetwork.isLimited()) {
|
|
||||||
learnLimited();
|
|
||||||
} else {
|
|
||||||
learn();
|
|
||||||
}
|
|
||||||
|
|
||||||
this.lastError = this.getError();
|
|
||||||
|
|
||||||
for (final GradientWorker worker : this.workers) {
|
|
||||||
EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(),
|
|
||||||
0, worker.getWeights(), 0,
|
|
||||||
this.currentFlatNetwork.getWeights().length);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.currentFlatNetwork.getHasContext()) {
|
|
||||||
copyContexts();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.reportedException != null) {
|
|
||||||
throw (new EncogError(this.reportedException));
|
|
||||||
}
|
|
||||||
|
|
||||||
postIteration();
|
|
||||||
|
|
||||||
EncogLogging.log(EncogLogging.LEVEL_INFO,
|
|
||||||
"Training iteration done, error: " + getError());
|
|
||||||
|
|
||||||
}
|
|
||||||
} catch (final ArrayIndexOutOfBoundsException ex) {
|
|
||||||
EncogValidate.validateNetworkForTraining(this.network,
|
|
||||||
getTraining());
|
|
||||||
throw new EncogError(ex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setThreadCount(final int numThreads) {
|
|
||||||
this.numThreads = numThreads;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getThreadCount() {
|
|
||||||
return this.numThreads;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void fixFlatSpot(boolean b) {
|
|
||||||
this.shouldFixFlatSpot = b;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setErrorFunction(ErrorFunction ef) {
|
|
||||||
this.ef = ef;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void calculateGradients() {
|
|
||||||
if (this.workers == null) {
|
|
||||||
init();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.currentFlatNetwork.getHasContext()) {
|
|
||||||
this.workers[0].getNetwork().clearContext();
|
|
||||||
}
|
|
||||||
|
|
||||||
this.totalError = 0;
|
|
||||||
|
|
||||||
if (this.workers.length > 1) {
|
|
||||||
|
|
||||||
final TaskGroup group = EngineConcurrency.getInstance()
|
|
||||||
.createTaskGroup();
|
|
||||||
|
|
||||||
for (final GradientWorker worker : this.workers) {
|
|
||||||
EngineConcurrency.getInstance().processTask(worker, group);
|
|
||||||
}
|
|
||||||
|
|
||||||
group.waitForComplete();
|
|
||||||
} else {
|
|
||||||
this.workers[0].run();
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setError(this.totalError / this.workers.length);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Copy the contexts to keep them consistent with multithreaded training.
|
|
||||||
*/
|
|
||||||
private void copyContexts() {
|
|
||||||
|
|
||||||
// copy the contexts(layer outputO from each group to the next group
|
|
||||||
for (int i = 0; i < (this.workers.length - 1); i++) {
|
|
||||||
final double[] src = this.workers[i].getNetwork().getLayerOutput();
|
|
||||||
final double[] dst = this.workers[i + 1].getNetwork()
|
|
||||||
.getLayerOutput();
|
|
||||||
EngineArray.arrayCopy(src, dst);
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy the contexts from the final group to the real network
|
|
||||||
EngineArray.arrayCopy(this.workers[this.workers.length - 1]
|
|
||||||
.getNetwork().getLayerOutput(), this.currentFlatNetwork
|
|
||||||
.getLayerOutput());
|
|
||||||
}
|
|
||||||
|
|
||||||
private void init() {
|
|
||||||
// fix flat spot, if needed
|
|
||||||
this.flatSpot = new double[this.currentFlatNetwork
|
|
||||||
.getActivationFunctions().length];
|
|
||||||
|
|
||||||
if (this.shouldFixFlatSpot) {
|
|
||||||
for (int i = 0; i < this.currentFlatNetwork
|
|
||||||
.getActivationFunctions().length; i++) {
|
|
||||||
final ActivationFunction af = this.currentFlatNetwork
|
|
||||||
.getActivationFunctions()[i];
|
|
||||||
|
|
||||||
if (af instanceof ActivationSigmoid) {
|
|
||||||
this.flatSpot[i] = 0.1;
|
|
||||||
} else {
|
|
||||||
this.flatSpot[i] = 0.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
EngineArray.fill(this.flatSpot, 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup workers
|
|
||||||
final DetermineWorkload determine = new DetermineWorkload(
|
|
||||||
this.numThreads, (int) this.indexable.size());
|
|
||||||
// this.numThreads, (int) this.indexable.getRecordCount());
|
|
||||||
|
|
||||||
this.workers = new GradientWorker[determine.getThreadCount()];
|
|
||||||
|
|
||||||
int index = 0;
|
|
||||||
|
|
||||||
// handle CPU
|
|
||||||
for (final IntRange r : determine.calculateWorkers()) {
|
|
||||||
this.workers[index++] = new GradientWorker(
|
|
||||||
this.currentFlatNetwork.clone(), this, new HashSet(
|
|
||||||
this.indexable), r.getLow(), r.getHigh(),
|
|
||||||
this.flatSpot, this.ef);
|
|
||||||
}
|
|
||||||
|
|
||||||
initOthers();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void report(final double[] gradients, final double error,
|
|
||||||
final Throwable ex) {
|
|
||||||
synchronized (this) {
|
|
||||||
if (ex == null) {
|
|
||||||
|
|
||||||
for (int i = 0; i < gradients.length; i++) {
|
|
||||||
this.gradients[i] += gradients[i];
|
|
||||||
}
|
|
||||||
this.totalError += error;
|
|
||||||
} else {
|
|
||||||
this.reportedException = ex;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void learn() {
|
|
||||||
final double[] weights = this.currentFlatNetwork.getWeights();
|
|
||||||
for (int i = 0; i < this.gradients.length; i++) {
|
|
||||||
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
|
|
||||||
this.gradients[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void learnLimited() {
|
|
||||||
final double limit = this.currentFlatNetwork.getConnectionLimit();
|
|
||||||
final double[] weights = this.currentFlatNetwork.getWeights();
|
|
||||||
for (int i = 0; i < this.gradients.length; i++) {
|
|
||||||
if (Math.abs(weights[i]) < limit) {
|
|
||||||
weights[i] = 0;
|
|
||||||
} else {
|
|
||||||
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
|
|
||||||
}
|
|
||||||
this.gradients[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] getLastGradient() {
|
|
||||||
return lastGradient;
|
|
||||||
}
|
|
||||||
|
|
||||||
// End methods from Propagation
|
|
||||||
|
|
||||||
// Methods from BasicTraining/
|
|
||||||
public void addStrategy(final Strategy strategy) {
|
|
||||||
strategy.init(this);
|
|
||||||
this.strategies.add(strategy);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void basicFinishTraining() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getError() {
|
|
||||||
return this.error;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getIteration() {
|
|
||||||
return this.iteration;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Strategy> getStrategies() {
|
|
||||||
return this.strategies;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MLDataSet getTraining() {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"This learning method operates on Set<List<MLData>>, not MLDataSet");
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isTrainingDone() {
|
|
||||||
for (Strategy strategy : this.strategies) {
|
|
||||||
if (strategy instanceof EndTrainingStrategy) {
|
|
||||||
EndTrainingStrategy end = (EndTrainingStrategy) strategy;
|
|
||||||
if (end.shouldStop()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void postIteration() {
|
|
||||||
for (final Strategy strategy : this.strategies) {
|
|
||||||
strategy.postIteration();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void preIteration() {
|
|
||||||
|
|
||||||
this.iteration++;
|
|
||||||
|
|
||||||
for (final Strategy strategy : this.strategies) {
|
|
||||||
strategy.preIteration();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setError(final double error) {
|
|
||||||
this.error = error;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIteration(final int iteration) {
|
|
||||||
this.iteration = iteration;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTraining(final Set<List<MLDataPair>> training) {
|
|
||||||
//this.training = training;
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
public TrainingImplementationType getImplementationType() {
|
|
||||||
return this.implementationType;
|
|
||||||
}
|
|
||||||
// End Methods from BasicTraining
|
|
||||||
}
|
|
||||||
@@ -3,16 +3,14 @@ package net.woodyfolsom.msproj.ann;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import net.woodyfolsom.msproj.GameState;
|
|
||||||
import net.woodyfolsom.msproj.Player;
|
|
||||||
|
|
||||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||||
import org.encog.ml.data.MLData;
|
|
||||||
import org.encog.ml.data.MLDataPair;
|
import org.encog.ml.data.MLDataPair;
|
||||||
import org.encog.ml.data.MLDataSet;
|
import org.encog.ml.data.MLDataSet;
|
||||||
|
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||||
import org.encog.ml.train.MLTrain;
|
import org.encog.ml.train.MLTrain;
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
import org.encog.neural.networks.layers.BasicLayer;
|
import org.encog.neural.networks.layers.BasicLayer;
|
||||||
|
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||||
|
|
||||||
public class WinFilter extends AbstractNeuralNetFilter implements
|
public class WinFilter extends AbstractNeuralNetFilter implements
|
||||||
NeuralNetFilter {
|
NeuralNetFilter {
|
||||||
@@ -29,55 +27,46 @@ public class WinFilter extends AbstractNeuralNetFilter implements
|
|||||||
this.neuralNetwork = network;
|
this.neuralNetwork = network;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public double computeValue(MLData input) {
|
|
||||||
if (input instanceof GameStateMLData) {
|
|
||||||
double[] idealVector = computeVector(input);
|
|
||||||
GameState gameState = ((GameStateMLData) input).getGameState();
|
|
||||||
Player playerToMove = gameState.getPlayerToMove();
|
|
||||||
if (playerToMove == Player.BLACK) {
|
|
||||||
return idealVector[0];
|
|
||||||
} else if (playerToMove == Player.WHITE) {
|
|
||||||
return idealVector[1];
|
|
||||||
} else {
|
|
||||||
throw new RuntimeException("Invalid GameState.playerToMove: "
|
|
||||||
+ playerToMove);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"This NeuralNetFilter only accepts GameStates as input.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] computeVector(MLData input) {
|
|
||||||
if (input instanceof GameStateMLData) {
|
|
||||||
return neuralNetwork.compute(input).getData();
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"This NeuralNetFilter only accepts GameStates as input.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(MLDataSet trainingData) {
|
public void learn(MLDataSet trainingData) {
|
||||||
throw new UnsupportedOperationException("This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
|
throw new UnsupportedOperationException(
|
||||||
|
"This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Method is necessary because with temporal difference learning, some of the MLDataPairs are related by being a sequence
|
* Method is necessary because with temporal difference learning, some of
|
||||||
* of moves within a particular game.
|
* the MLDataPairs are related by being a sequence of moves within a
|
||||||
|
* particular game.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||||
|
MLDataSet mlDataset = new BasicMLDataSet();
|
||||||
|
|
||||||
|
for (List<MLDataPair> gameRecord : trainingSet) {
|
||||||
|
for (int t = 0; t < gameRecord.size() - 1; t++) {
|
||||||
|
mlDataset.add(gameRecord.get(t).getInput(), this.neuralNetwork.compute(gameRecord.get(t)
|
||||||
|
.getInput()));
|
||||||
|
}
|
||||||
|
mlDataset.add(gameRecord.get(gameRecord.size() - 1));
|
||||||
|
}
|
||||||
|
|
||||||
// train the neural network
|
// train the neural network
|
||||||
final MLTrain train = new TemporalDifferenceLearning(neuralNetwork,
|
final MLTrain train = new TemporalDifference(neuralNetwork, mlDataset, 0.7, 0.8, 0.25);
|
||||||
trainingSet, 0.7, 0.8, 0.25);
|
//final MLTrain train = new Backpropagation(neuralNetwork, mlDataset, 0.7, 0.8);
|
||||||
|
|
||||||
actualTrainingEpochs = 0;
|
actualTrainingEpochs = 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
|
if (actualTrainingEpochs > 0) {
|
||||||
|
int gameStateIndex = 0;
|
||||||
|
for (List<MLDataPair> gameRecord : trainingSet) {
|
||||||
|
for (int t = 0; t < gameRecord.size() - 1; t++) {
|
||||||
|
MLDataPair oldDataPair = mlDataset.get(gameStateIndex);
|
||||||
|
this.neuralNetwork.compute(oldDataPair.getInput());
|
||||||
|
gameStateIndex++;
|
||||||
|
}
|
||||||
|
gameStateIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
train.iteration();
|
train.iteration();
|
||||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
+ train.getError());
|
+ train.getError());
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import org.encog.engine.network.activation.ActivationSigmoid;
|
|||||||
import org.encog.ml.data.MLData;
|
import org.encog.ml.data.MLData;
|
||||||
import org.encog.ml.data.MLDataPair;
|
import org.encog.ml.data.MLDataPair;
|
||||||
import org.encog.ml.data.MLDataSet;
|
import org.encog.ml.data.MLDataSet;
|
||||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
import org.encog.ml.data.basic.BasicMLData;
|
||||||
import org.encog.ml.train.MLTrain;
|
import org.encog.ml.train.MLTrain;
|
||||||
import org.encog.neural.networks.BasicNetwork;
|
import org.encog.neural.networks.BasicNetwork;
|
||||||
import org.encog.neural.networks.layers.BasicLayer;
|
import org.encog.neural.networks.layers.BasicLayer;
|
||||||
@@ -21,7 +21,7 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
|||||||
*/
|
*/
|
||||||
public class XORFilter extends AbstractNeuralNetFilter implements
|
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||||
NeuralNetFilter {
|
NeuralNetFilter {
|
||||||
|
|
||||||
public XORFilter() {
|
public XORFilter() {
|
||||||
// create a neural network, without using a factory
|
// create a neural network, without using a factory
|
||||||
BasicNetwork network = new BasicNetwork();
|
BasicNetwork network = new BasicNetwork();
|
||||||
@@ -34,32 +34,10 @@ public class XORFilter extends AbstractNeuralNetFilter implements
|
|||||||
this.neuralNetwork = network;
|
this.neuralNetwork = network;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public double compute(double x, double y) {
|
||||||
public void learn(MLDataSet trainingSet) {
|
return compute(new BasicMLData(new double[]{x,y})).getData(0);
|
||||||
|
|
||||||
// train the neural network
|
|
||||||
final MLTrain train = new Backpropagation(neuralNetwork,
|
|
||||||
trainingSet, 0.7, 0.8);
|
|
||||||
|
|
||||||
actualTrainingEpochs = 0;
|
|
||||||
|
|
||||||
do {
|
|
||||||
train.iteration();
|
|
||||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
|
||||||
+ train.getError());
|
|
||||||
actualTrainingEpochs++;
|
|
||||||
} while (train.getError() > 0.01
|
|
||||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] computeVector(MLData mlData) {
|
|
||||||
MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() },
|
|
||||||
new double[][] { new double[getOutputSize()] });
|
|
||||||
MLData output = neuralNetwork.compute(dataset.get(0).getInput());
|
|
||||||
return output.getData();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getInputSize() {
|
public int getInputSize() {
|
||||||
return 2;
|
return 2;
|
||||||
@@ -72,12 +50,26 @@ public class XORFilter extends AbstractNeuralNetFilter implements
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double computeValue(MLData input) {
|
public void learn(MLDataSet trainingSet) {
|
||||||
return computeVector(input)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// train the neural network
|
||||||
|
final MLTrain train = new Backpropagation(neuralNetwork, trainingSet,
|
||||||
|
0.7, 0.8);
|
||||||
|
|
||||||
|
actualTrainingEpochs = 0;
|
||||||
|
|
||||||
|
do {
|
||||||
|
train.iteration();
|
||||||
|
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||||
|
+ train.getError());
|
||||||
|
actualTrainingEpochs++;
|
||||||
|
} while (train.getError() > 0.01
|
||||||
|
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||||
throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set<List<MLData>>.");
|
throw new UnsupportedOperationException(
|
||||||
|
"This Filter learns an MLDataSet, not a Set<List<MLData>>.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
5
src/net/woodyfolsom/msproj/ann2/ActivationFunction.java
Normal file
5
src/net/woodyfolsom/msproj/ann2/ActivationFunction.java
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
public interface ActivationFunction {
|
||||||
|
double calculate(double arg);
|
||||||
|
}
|
||||||
53
src/net/woodyfolsom/msproj/ann2/Layer.java
Normal file
53
src/net/woodyfolsom/msproj/ann2/Layer.java
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
public class Layer {
|
||||||
|
private Neuron[] neurons;
|
||||||
|
|
||||||
|
public Layer() {
|
||||||
|
//default constructor for JAXB
|
||||||
|
}
|
||||||
|
|
||||||
|
public Layer(int numNeurons, int numWeights, ActivationFunction activationFunction) {
|
||||||
|
neurons = new Neuron[numNeurons];
|
||||||
|
for (int neuronIndex = 0; neuronIndex < numNeurons; neuronIndex++) {
|
||||||
|
neurons[neuronIndex] = new Neuron(activationFunction, numWeights);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
return neurons.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime * result + Arrays.hashCode(neurons);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
Layer other = (Layer) obj;
|
||||||
|
if (!Arrays.equals(neurons, other.neurons))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Neuron[] getNeurons() {
|
||||||
|
return neurons;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNeurons(Neuron[] neurons) {
|
||||||
|
this.neurons = neurons;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
175
src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java
Normal file
175
src/net/woodyfolsom/msproj/ann2/MultiLayerPerceptron.java
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import javax.xml.bind.JAXBContext;
|
||||||
|
import javax.xml.bind.JAXBException;
|
||||||
|
import javax.xml.bind.Marshaller;
|
||||||
|
import javax.xml.bind.Unmarshaller;
|
||||||
|
import javax.xml.bind.annotation.XmlAttribute;
|
||||||
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
|
import javax.xml.bind.annotation.XmlRootElement;
|
||||||
|
|
||||||
|
@XmlRootElement
|
||||||
|
public class MultiLayerPerceptron extends NeuralNetwork {
|
||||||
|
private ActivationFunction activationFunction;
|
||||||
|
private boolean biased;
|
||||||
|
private Layer[] layers;
|
||||||
|
|
||||||
|
public MultiLayerPerceptron() {
|
||||||
|
this(false, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiLayerPerceptron(boolean biased, int... layerSizes) {
|
||||||
|
int numLayers = layerSizes.length;
|
||||||
|
|
||||||
|
if (numLayers < 2) {
|
||||||
|
throw new IllegalArgumentException("# of layers must be >= 2");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.activationFunction = Sigmoid.function;
|
||||||
|
this.biased = biased;
|
||||||
|
this.layers = new Layer[numLayers];
|
||||||
|
|
||||||
|
int numWeights;
|
||||||
|
|
||||||
|
for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) {
|
||||||
|
int layerSize = layerSizes[layerIndex];
|
||||||
|
|
||||||
|
if (layerSize < 1) {
|
||||||
|
throw new IllegalArgumentException("Layer size must be >= 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layerIndex == 0) {
|
||||||
|
numWeights = 0;
|
||||||
|
if (biased) {
|
||||||
|
layerSize++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
numWeights = layers[layerIndex - 1].size();
|
||||||
|
}
|
||||||
|
|
||||||
|
layers[layerIndex] = new Layer(layerSize, numWeights,
|
||||||
|
activationFunction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement(type=Sigmoid.class)
|
||||||
|
public ActivationFunction getActivationFunction() {
|
||||||
|
return activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement
|
||||||
|
public Layer[] getLayers() {
|
||||||
|
return layers;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected double[] getOutput() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Neuron[] getNeurons() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlAttribute
|
||||||
|
public boolean isBiased() {
|
||||||
|
return biased;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setActivationFunction(ActivationFunction activationFunction) {
|
||||||
|
this.activationFunction = activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void setInput(double[] input) {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBiased(boolean biased) {
|
||||||
|
this.biased = biased;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLayers(Layer[] layers) {
|
||||||
|
this.layers = layers;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean load(InputStream is) {
|
||||||
|
try {
|
||||||
|
JAXBContext jc = JAXBContext
|
||||||
|
.newInstance(MultiLayerPerceptron.class);
|
||||||
|
|
||||||
|
// unmarshal from foo.xml
|
||||||
|
Unmarshaller u = jc.createUnmarshaller();
|
||||||
|
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);
|
||||||
|
|
||||||
|
this.activationFunction = mlp.activationFunction;
|
||||||
|
this.biased = mlp.biased;
|
||||||
|
this.layers = mlp.layers;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
} catch (JAXBException je) {
|
||||||
|
je.printStackTrace();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean save(OutputStream os) {
|
||||||
|
try {
|
||||||
|
JAXBContext jc = JAXBContext
|
||||||
|
.newInstance(MultiLayerPerceptron.class);
|
||||||
|
|
||||||
|
Marshaller m = jc.createMarshaller();
|
||||||
|
m.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true);
|
||||||
|
m.marshal(this, os);
|
||||||
|
m.marshal(this, System.out);
|
||||||
|
return true;
|
||||||
|
} catch (JAXBException je) {
|
||||||
|
je.printStackTrace();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime
|
||||||
|
* result
|
||||||
|
+ ((activationFunction == null) ? 0 : activationFunction
|
||||||
|
.hashCode());
|
||||||
|
result = prime * result + (biased ? 1231 : 1237);
|
||||||
|
result = prime * result + Arrays.hashCode(layers);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
MultiLayerPerceptron other = (MultiLayerPerceptron) obj;
|
||||||
|
if (activationFunction == null) {
|
||||||
|
if (other.activationFunction != null)
|
||||||
|
return false;
|
||||||
|
} else if (!activationFunction.equals(other.activationFunction))
|
||||||
|
return false;
|
||||||
|
if (biased != other.biased)
|
||||||
|
return false;
|
||||||
|
if (!Arrays.equals(layers, other.layers))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
29
src/net/woodyfolsom/msproj/ann2/NNData.java
Normal file
29
src/net/woodyfolsom/msproj/ann2/NNData.java
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
public class NNData {
|
||||||
|
private final double[] values;
|
||||||
|
private final String[] fields;
|
||||||
|
|
||||||
|
public NNData(String[] fields, double[] values) {
|
||||||
|
this.fields = fields;
|
||||||
|
this.values = values;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getValues() {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
StringBuilder sb = new StringBuilder("[");
|
||||||
|
|
||||||
|
for (int i = 0; i < fields.length; i++) {
|
||||||
|
if (i > 0) {
|
||||||
|
sb.append(", " );
|
||||||
|
}
|
||||||
|
sb.append(fields[i] + "=" + values[i]);
|
||||||
|
}
|
||||||
|
sb.append("]");
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
19
src/net/woodyfolsom/msproj/ann2/NNDataPair.java
Normal file
19
src/net/woodyfolsom/msproj/ann2/NNDataPair.java
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
public class NNDataPair {
|
||||||
|
private final NNData actual;
|
||||||
|
private final NNData ideal;
|
||||||
|
|
||||||
|
public NNDataPair(NNData actual, NNData ideal) {
|
||||||
|
this.actual = actual;
|
||||||
|
this.ideal = ideal;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NNData getActual() {
|
||||||
|
return actual;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NNData getIdeal() {
|
||||||
|
return ideal;
|
||||||
|
}
|
||||||
|
}
|
||||||
53
src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java
Normal file
53
src/net/woodyfolsom/msproj/ann2/NeuralNetwork.java
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
|
||||||
|
import javax.xml.bind.JAXBException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A NeuralNetwork is simply an ordered set of Neurons.
|
||||||
|
*
|
||||||
|
* Functions which rely on knowledge of input neurons, output neurons and layers
|
||||||
|
* are delegated to MultiLayerPerception.
|
||||||
|
*
|
||||||
|
* The primary function implemented in this abstract class is feedfoward.
|
||||||
|
* This function depends only on getNeurons() returning Neurons in feedforward order
|
||||||
|
* and the returned Neurons must have the correct number of weights for the NeuralNetwork
|
||||||
|
* configuration.
|
||||||
|
*
|
||||||
|
* @author Woody
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public abstract class NeuralNetwork {
|
||||||
|
public NeuralNetwork() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] calculate(double[] input) {
|
||||||
|
zeroInputs();
|
||||||
|
setInput(input);
|
||||||
|
feedforward();
|
||||||
|
return getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void feedforward() {
|
||||||
|
Neuron[] neurons = getNeurons();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract double[] getOutput();
|
||||||
|
|
||||||
|
protected abstract Neuron[] getNeurons();
|
||||||
|
|
||||||
|
public abstract boolean load(InputStream is);
|
||||||
|
public abstract boolean save(OutputStream os);
|
||||||
|
|
||||||
|
protected abstract void setInput(double[] input);
|
||||||
|
|
||||||
|
protected void zeroInputs() {
|
||||||
|
for (Neuron neuron : getNeurons()) {
|
||||||
|
neuron.setInput(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
92
src/net/woodyfolsom/msproj/ann2/Neuron.java
Normal file
92
src/net/woodyfolsom/msproj/ann2/Neuron.java
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import javax.xml.bind.Unmarshaller;
|
||||||
|
import javax.xml.bind.annotation.XmlElement;
|
||||||
|
import javax.xml.bind.annotation.XmlTransient;
|
||||||
|
|
||||||
|
public class Neuron {
|
||||||
|
private ActivationFunction activationFunction;
|
||||||
|
private double[] weights;
|
||||||
|
|
||||||
|
private transient double input = 0.0;
|
||||||
|
|
||||||
|
public Neuron() {
|
||||||
|
//no-arg constructor for JAXB
|
||||||
|
}
|
||||||
|
|
||||||
|
public Neuron(ActivationFunction activationFunction, int numWeights) {
|
||||||
|
this.activationFunction = activationFunction;
|
||||||
|
this.weights = new double[numWeights];
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement(type=Sigmoid.class)
|
||||||
|
public ActivationFunction getActivationFunction() {
|
||||||
|
return activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
void afterUnmarshal(Unmarshaller aUnmarshaller, Object aParent)
|
||||||
|
{
|
||||||
|
if (weights == null) {
|
||||||
|
weights = new double[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlTransient
|
||||||
|
public double getInput() {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getOutput() {
|
||||||
|
return activationFunction.calculate(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
@XmlElement
|
||||||
|
public double[] getWeights() {
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setInput(double input) {
|
||||||
|
this.input = input;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime
|
||||||
|
* result
|
||||||
|
+ ((activationFunction == null) ? 0 : activationFunction
|
||||||
|
.hashCode());
|
||||||
|
result = prime * result + Arrays.hashCode(weights);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
Neuron other = (Neuron) obj;
|
||||||
|
if (activationFunction == null) {
|
||||||
|
if (other.activationFunction != null)
|
||||||
|
return false;
|
||||||
|
} else if (!activationFunction.equals(other.activationFunction))
|
||||||
|
return false;
|
||||||
|
if (!Arrays.equals(weights, other.weights))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setActivationFunction(ActivationFunction activationFunction) {
|
||||||
|
this.activationFunction = activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWeights(double[] weights) {
|
||||||
|
this.weights = weights;
|
||||||
|
}
|
||||||
|
}
|
||||||
5
src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java
Normal file
5
src/net/woodyfolsom/msproj/ann2/ObjectiveFunction.java
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
public class ObjectiveFunction {
|
||||||
|
|
||||||
|
}
|
||||||
48
src/net/woodyfolsom/msproj/ann2/Sigmoid.java
Normal file
48
src/net/woodyfolsom/msproj/ann2/Sigmoid.java
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
public class Sigmoid implements ActivationFunction{
|
||||||
|
public static final Sigmoid function = new Sigmoid();
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private Sigmoid() {
|
||||||
|
this.name = "Sigmoid";
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculate(double arg) {
|
||||||
|
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg));
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setName(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
Sigmoid other = (Sigmoid) obj;
|
||||||
|
if (name == null) {
|
||||||
|
if (other.name != null)
|
||||||
|
return false;
|
||||||
|
} else if (!name.equals(other.name))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
10
src/net/woodyfolsom/msproj/ann2/Tanh.java
Normal file
10
src/net/woodyfolsom/msproj/ann2/Tanh.java
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
public class Tanh implements ActivationFunction{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double calculate(double arg) {
|
||||||
|
return Math.tanh(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -50,7 +50,6 @@ public class WinFilterTest {
|
|||||||
winFilter.learn(trainingData);
|
winFilter.learn(trainingData);
|
||||||
|
|
||||||
for (List<MLDataPair> trainingSequence : trainingData) {
|
for (List<MLDataPair> trainingSequence : trainingData) {
|
||||||
//for (MLDataPair mlDataPair : trainingSequence) {
|
|
||||||
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
||||||
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
||||||
continue;
|
continue;
|
||||||
@@ -58,9 +57,8 @@ public class WinFilterTest {
|
|||||||
MLData input = trainingSequence.get(stateIndex).getInput();
|
MLData input = trainingSequence.get(stateIndex).getInput();
|
||||||
|
|
||||||
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
||||||
+ winFilter.computeValue(input));
|
+ winFilter.compute(input));
|
||||||
}
|
}
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ public class XORFilterTest {
|
|||||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
||||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||||
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
||||||
System.out.println(dp + " => " + nnLearner.computeValue(dp));
|
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import javax.xml.bind.JAXBException;
|
||||||
|
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class MultiLayerPerceptronTest {
|
||||||
|
static final File TEST_FILE = new File("data/test/mlp.net");
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void setUp() {
|
||||||
|
if (TEST_FILE.exists()) {
|
||||||
|
TEST_FILE.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void tearDown() {
|
||||||
|
if (TEST_FILE.exists()) {
|
||||||
|
TEST_FILE.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConstructor() {
|
||||||
|
new MultiLayerPerceptron(true, 2, 4, 1);
|
||||||
|
new MultiLayerPerceptron(false, 2, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void testConstructorTooFewLayers() {
|
||||||
|
new MultiLayerPerceptron(true, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void testConstructorTooFewNeurons() {
|
||||||
|
new MultiLayerPerceptron(true, 2, 4, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPersistence() throws JAXBException, IOException {
|
||||||
|
NeuralNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
||||||
|
FileOutputStream fos = new FileOutputStream(TEST_FILE);
|
||||||
|
assertTrue(mlp.save(fos));
|
||||||
|
fos.close();
|
||||||
|
FileInputStream fis = new FileInputStream(TEST_FILE);
|
||||||
|
NeuralNetwork mlp2 = new MultiLayerPerceptron();
|
||||||
|
assertTrue(mlp2.load(fis));
|
||||||
|
assertEquals(mlp, mlp2);
|
||||||
|
fis.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
18
test/net/woodyfolsom/msproj/ann2/SigmoidTest.java
Normal file
18
test/net/woodyfolsom/msproj/ann2/SigmoidTest.java
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class SigmoidTest {
|
||||||
|
@Test
|
||||||
|
public void testCalculate() {
|
||||||
|
double EPS = 0.001;
|
||||||
|
|
||||||
|
ActivationFunction sigmoid = Sigmoid.function;
|
||||||
|
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
|
||||||
|
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
|
||||||
|
assertTrue(sigmoid.calculate(-9000.0) < EPS);
|
||||||
|
}
|
||||||
|
}
|
||||||
18
test/net/woodyfolsom/msproj/ann2/TanhTest.java
Normal file
18
test/net/woodyfolsom/msproj/ann2/TanhTest.java
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package net.woodyfolsom.msproj.ann2;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class TanhTest {
|
||||||
|
@Test
|
||||||
|
public void testCalculate() {
|
||||||
|
double EPS = 0.001;
|
||||||
|
|
||||||
|
ActivationFunction sigmoid = new Tanh();
|
||||||
|
assertEquals(0.0,sigmoid.calculate(0.0),EPS);
|
||||||
|
assertTrue(sigmoid.calculate(100.0) > 0.5 - EPS);
|
||||||
|
assertTrue(sigmoid.calculate(-9000.0) < -0.5+EPS);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user