Incremental update. Removing some experimental classes after commit.
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
PlayerOne=RANDOM
|
PlayerOne=ROOT_PAR
|
||||||
PlayerTwo=RANDOM
|
PlayerTwo=RANDOM
|
||||||
GUIDelay=1000 //1 second
|
GUIDelay=1000 //1 second
|
||||||
BoardSize=9
|
BoardSize=9
|
||||||
Komi=6.5
|
Komi=6.5
|
||||||
NumGames=1000 //Games for each color per player
|
NumGames=1 //Games for each color per player
|
||||||
TurnTime=1000 //seconds per player per turn
|
TurnTime=2000 //seconds per player per turn
|
||||||
SpectatorBoardShown=false;
|
SpectatorBoardShown=true
|
||||||
WhiteMoveLogged=false;
|
WhiteMoveLogged=false
|
||||||
BlackMoveLogged=false;
|
BlackMoveLogged=false
|
||||||
@@ -29,7 +29,7 @@ public class StandAloneGame {
|
|||||||
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE
|
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE
|
||||||
};
|
};
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws IOException {
|
||||||
try {
|
try {
|
||||||
GameSettings gameSettings = GameSettings
|
GameSettings gameSettings = GameSettings
|
||||||
.createGameSetings("data/gogame.cfg");
|
.createGameSetings("data/gogame.cfg");
|
||||||
@@ -42,6 +42,8 @@ public class StandAloneGame {
|
|||||||
gameSettings.getNumGames(), gameSettings.getTurnTime(),
|
gameSettings.getNumGames(), gameSettings.getTurnTime(),
|
||||||
gameSettings.isSpectatorBoardShown(),
|
gameSettings.isSpectatorBoardShown(),
|
||||||
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
|
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
|
||||||
|
System.out.println("Press <Enter> or CTRL-C to exit");
|
||||||
|
System.in.read(new byte[80]);
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
ioe.printStackTrace();
|
ioe.printStackTrace();
|
||||||
System.exit(EXIT_IO_EXCEPTION);
|
System.exit(EXIT_IO_EXCEPTION);
|
||||||
|
|||||||
@@ -3,12 +3,7 @@ package net.woodyfolsom.msproj.ann;
|
|||||||
import org.encog.ml.data.basic.BasicMLData;
|
import org.encog.ml.data.basic.BasicMLData;
|
||||||
|
|
||||||
public class DoublePair extends BasicMLData {
|
public class DoublePair extends BasicMLData {
|
||||||
// private final double x;
|
|
||||||
// private final double y;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
public DoublePair(double x, double y) {
|
public DoublePair(double x, double y) {
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
public class FusekiLearner {
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
package net.woodyfolsom.msproj.ann;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class copied verbatim from Encog framework due to dependency on Propagation
|
* Class copied verbatim from Encog framework due to dependency on Propagation
|
||||||
* implementation.
|
* implementation.
|
||||||
@@ -6,7 +7,7 @@ package net.woodyfolsom.msproj.ann;
|
|||||||
* Encog(tm) Core v3.2 - Java Version
|
* Encog(tm) Core v3.2 - Java Version
|
||||||
* http://www.heatonresearch.com/encog/
|
* http://www.heatonresearch.com/encog/
|
||||||
* http://code.google.com/p/encog-java/
|
* http://code.google.com/p/encog-java/
|
||||||
|
|
||||||
* Copyright 2008-2012 Heaton Research, Inc.
|
* Copyright 2008-2012 Heaton Research, Inc.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -26,12 +27,12 @@ package net.woodyfolsom.msproj.ann;
|
|||||||
* http://www.heatonresearch.com/copyright
|
* http://www.heatonresearch.com/copyright
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import org.encog.engine.network.activation.ActivationFunction;
|
import org.encog.engine.network.activation.ActivationFunction;
|
||||||
import org.encog.ml.data.MLDataPair;
|
import org.encog.ml.data.MLDataPair;
|
||||||
import org.encog.ml.data.MLDataSet;
|
|
||||||
import org.encog.ml.data.basic.BasicMLDataPair;
|
import org.encog.ml.data.basic.BasicMLDataPair;
|
||||||
import org.encog.neural.error.ErrorFunction;
|
import org.encog.neural.error.ErrorFunction;
|
||||||
import org.encog.neural.flat.FlatNetwork;
|
import org.encog.neural.flat.FlatNetwork;
|
||||||
@@ -42,7 +43,7 @@ public class GradientWorker implements EngineTask {
|
|||||||
|
|
||||||
private final FlatNetwork network;
|
private final FlatNetwork network;
|
||||||
private final ErrorCalculation errorCalculation = new ErrorCalculation();
|
private final ErrorCalculation errorCalculation = new ErrorCalculation();
|
||||||
private final double[] actual;
|
private final List<double[]> actuals;
|
||||||
private final double[] layerDelta;
|
private final double[] layerDelta;
|
||||||
private final int[] layerCounts;
|
private final int[] layerCounts;
|
||||||
private final int[] layerFeedCounts;
|
private final int[] layerFeedCounts;
|
||||||
@@ -52,30 +53,29 @@ public class GradientWorker implements EngineTask {
|
|||||||
private final double[] layerSums;
|
private final double[] layerSums;
|
||||||
private final double[] gradients;
|
private final double[] gradients;
|
||||||
private final double[] weights;
|
private final double[] weights;
|
||||||
private final MLDataPair pair;
|
private final MLDataPair pairPrototype;
|
||||||
private final Set<List<MLDataPair>> training;
|
private final Set<List<MLDataPair>> training;
|
||||||
private final int low;
|
//private final int low;
|
||||||
private final int high;
|
//private final int high;
|
||||||
private final TemporalDifferenceLearning owner;
|
private final TemporalDifferenceLearning owner;
|
||||||
private double[] flatSpot;
|
private double[] flatSpot;
|
||||||
private final ErrorFunction errorFunction;
|
private final ErrorFunction errorFunction;
|
||||||
|
|
||||||
public GradientWorker(final FlatNetwork theNetwork,
|
public GradientWorker(final FlatNetwork theNetwork,
|
||||||
final TemporalDifferenceLearning theOwner,
|
final TemporalDifferenceLearning theOwner,
|
||||||
final Set<List<MLDataPair>> theTraining, final int theLow,
|
final Set<List<MLDataPair>> theTraining, final int theLow,
|
||||||
final int theHigh, final double[] flatSpot,
|
final int theHigh, final double[] flatSpot, ErrorFunction ef) {
|
||||||
ErrorFunction ef) {
|
|
||||||
this.network = theNetwork;
|
this.network = theNetwork;
|
||||||
this.training = theTraining;
|
this.training = theTraining;
|
||||||
this.low = theLow;
|
//this.low = theLow;
|
||||||
this.high = theHigh;
|
//this.high = theHigh;
|
||||||
this.owner = theOwner;
|
this.owner = theOwner;
|
||||||
this.flatSpot = flatSpot;
|
this.flatSpot = flatSpot;
|
||||||
this.errorFunction = ef;
|
this.errorFunction = ef;
|
||||||
|
|
||||||
this.layerDelta = new double[network.getLayerOutput().length];
|
this.layerDelta = new double[network.getLayerOutput().length];
|
||||||
this.gradients = new double[network.getWeights().length];
|
this.gradients = new double[network.getWeights().length];
|
||||||
this.actual = new double[network.getOutputCount()];
|
this.actuals = new ArrayList<double[]>();
|
||||||
|
|
||||||
this.weights = network.getWeights();
|
this.weights = network.getWeights();
|
||||||
this.layerIndex = network.getLayerIndex();
|
this.layerIndex = network.getLayerIndex();
|
||||||
@@ -85,8 +85,8 @@ public class GradientWorker implements EngineTask {
|
|||||||
this.layerSums = network.getLayerSums();
|
this.layerSums = network.getLayerSums();
|
||||||
this.layerFeedCounts = network.getLayerFeedCounts();
|
this.layerFeedCounts = network.getLayerFeedCounts();
|
||||||
|
|
||||||
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
|
this.pairPrototype = BasicMLDataPair.createPair(
|
||||||
.getOutputCount());
|
network.getInputCount(), network.getOutputCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
public FlatNetwork getNetwork() {
|
public FlatNetwork getNetwork() {
|
||||||
@@ -97,22 +97,48 @@ public class GradientWorker implements EngineTask {
|
|||||||
return this.weights;
|
return this.weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void process(final double[] input, final double[] ideal, double s) {
|
private void process(List<MLDataPair> trainingSequence) {
|
||||||
this.network.compute(input, this.actual);
|
actuals.clear();
|
||||||
|
|
||||||
this.errorCalculation.updateError(this.actual, ideal, s);
|
for (int trainingIdx = 0; trainingIdx < trainingSequence.size(); trainingIdx++) {
|
||||||
this.errorFunction.calculateError(ideal, actual, this.layerDelta);
|
MLDataPair mlDataPair = trainingSequence.get(trainingIdx);
|
||||||
|
MLDataPair dataPairCopy = this.pairPrototype;
|
||||||
|
dataPairCopy.setInputArray(mlDataPair.getInputArray());
|
||||||
|
if (dataPairCopy.getIdealArray() != null) {
|
||||||
|
dataPairCopy.setIdealArray(mlDataPair.getIdealArray());
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < this.actual.length; i++) {
|
double[] input = dataPairCopy.getInputArray();
|
||||||
|
double[] ideal = dataPairCopy.getIdealArray();
|
||||||
|
double significance = dataPairCopy.getSignificance();
|
||||||
|
|
||||||
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
|
actuals.add(trainingIdx, new double[ideal.length]);
|
||||||
.derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0]))
|
this.network.compute(input, actuals.get(trainingIdx));
|
||||||
* (this.layerDelta[i] * s);
|
|
||||||
}
|
// For now, only calculate deltas for the final data pair
|
||||||
|
// For final TDL algorithm, deltas won't be used at all, instead the
|
||||||
|
// List of Actual vectors will.
|
||||||
|
if (trainingIdx < trainingSequence.size() - 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.errorCalculation.updateError(actuals.get(trainingIdx), ideal,
|
||||||
|
significance);
|
||||||
|
this.errorFunction.calculateError(ideal, actuals.get(trainingIdx),
|
||||||
|
this.layerDelta);
|
||||||
|
|
||||||
|
for (int i = 0; i < actuals.get(trainingIdx).length; i++) {
|
||||||
|
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
|
||||||
|
.derivativeFunction(this.layerSums[i],
|
||||||
|
this.layerOutput[i]) + this.flatSpot[0]))
|
||||||
|
* (this.layerDelta[i] * significance);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = this.network.getBeginTraining(); i < this.network
|
||||||
|
.getEndTraining(); i++) {
|
||||||
|
processLevel(i);
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = this.network.getBeginTraining(); i < this.network
|
|
||||||
.getEndTraining(); i++) {
|
|
||||||
processLevel(i);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +168,8 @@ public class GradientWorker implements EngineTask {
|
|||||||
}
|
}
|
||||||
|
|
||||||
this.layerDelta[yi] = sum
|
this.layerDelta[yi] = sum
|
||||||
* (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot);
|
* (activation.derivativeFunction(this.layerSums[yi],
|
||||||
|
this.layerOutput[yi]) + currentFlatSpot);
|
||||||
yi++;
|
yi++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -150,17 +177,11 @@ public class GradientWorker implements EngineTask {
|
|||||||
public final void run() {
|
public final void run() {
|
||||||
try {
|
try {
|
||||||
this.errorCalculation.reset();
|
this.errorCalculation.reset();
|
||||||
//for (int i = this.low; i <= this.high; i++) {
|
|
||||||
for (List<MLDataPair> trainingSequence : training) {
|
for (List<MLDataPair> trainingSequence : training) {
|
||||||
MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1);
|
process(trainingSequence);
|
||||||
this.pair.setInputArray(mldp.getInputArray());
|
|
||||||
if (this.pair.getIdealArray() != null) {
|
|
||||||
this.pair.setIdealArray(mldp.getIdealArray());
|
|
||||||
}
|
|
||||||
//this.training.getRecord(i, this.pair);
|
|
||||||
process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance());
|
|
||||||
}
|
}
|
||||||
//}
|
|
||||||
final double error = this.errorCalculation.calculate();
|
final double error = this.errorCalculation.calculate();
|
||||||
this.owner.report(this.gradients, error, null);
|
this.owner.report(this.gradients, error, null);
|
||||||
EngineArray.fill(this.gradients, 0);
|
EngineArray.fill(this.gradients, 0);
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
public class JosekiLearner {
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package net.woodyfolsom.msproj.ann;
|
|
||||||
|
|
||||||
public class ShapeLearner {
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -73,7 +73,7 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
|||||||
|
|
||||||
// BasicTraining
|
// BasicTraining
|
||||||
private final List<Strategy> strategies = new ArrayList<Strategy>();
|
private final List<Strategy> strategies = new ArrayList<Strategy>();
|
||||||
private Set<List<MLDataPair>> training;
|
//private Set<List<MLDataPair>> training;
|
||||||
private double error;
|
private double error;
|
||||||
private int iteration;
|
private int iteration;
|
||||||
private TrainingImplementationType implementationType;
|
private TrainingImplementationType implementationType;
|
||||||
@@ -104,7 +104,7 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
|||||||
initBasicTraining(TrainingImplementationType.Iterative);
|
initBasicTraining(TrainingImplementationType.Iterative);
|
||||||
this.network = network;
|
this.network = network;
|
||||||
this.currentFlatNetwork = network.getFlat();
|
this.currentFlatNetwork = network.getFlat();
|
||||||
setTraining(training);
|
//setTraining(training);
|
||||||
|
|
||||||
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
|
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
|
||||||
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
|
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
|
||||||
@@ -181,8 +181,10 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
|||||||
|
|
||||||
public double updateWeight(final double[] gradients,
|
public double updateWeight(final double[] gradients,
|
||||||
final double[] lastGradient, final int index) {
|
final double[] lastGradient, final int index) {
|
||||||
|
|
||||||
final double delta = (gradients[index] * this.learningRate)
|
final double delta = (gradients[index] * this.learningRate)
|
||||||
+ (this.lastDelta[index] * this.momentum);
|
+ (this.lastDelta[index] * this.momentum);
|
||||||
|
|
||||||
this.lastDelta[index] = delta;
|
this.lastDelta[index] = delta;
|
||||||
|
|
||||||
System.out.println("Updating weights for connection: " + index
|
System.out.println("Updating weights for connection: " + index
|
||||||
@@ -474,7 +476,8 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void setTraining(final Set<List<MLDataPair>> training) {
|
public void setTraining(final Set<List<MLDataPair>> training) {
|
||||||
this.training = training;
|
//this.training = training;
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainingImplementationType getImplementationType() {
|
public TrainingImplementationType getImplementationType() {
|
||||||
|
|||||||
@@ -61,9 +61,13 @@ public class WinFilter extends AbstractNeuralNetFilter implements
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(MLDataSet trainingData) {
|
public void learn(MLDataSet trainingData) {
|
||||||
throw new UnsupportedOperationException("This filter learns a Set<List<MLData>>, not an MLDataSet");
|
throw new UnsupportedOperationException("This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method is necessary because with temporal difference learning, some of the MLDataPairs are related by being a sequence
|
||||||
|
* of moves within a particular game.
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user