Incremental update. Removing some experimental classes after commit.

This commit is contained in:
2012-11-22 16:07:03 -05:00
parent b723e2666e
commit 2e36b01363
9 changed files with 78 additions and 68 deletions

View File

@@ -1,10 +1,10 @@
PlayerOne=RANDOM
PlayerOne=ROOT_PAR
PlayerTwo=RANDOM
GUIDelay=1000 //1 second
BoardSize=9
Komi=6.5
NumGames=1000 //Games for each color per player
TurnTime=1000 //seconds per player per turn
SpectatorBoardShown=false;
WhiteMoveLogged=false;
BlackMoveLogged=false;
NumGames=1 //Games for each color per player
TurnTime=2000 //seconds per player per turn
SpectatorBoardShown=true
WhiteMoveLogged=false
BlackMoveLogged=false

View File

@@ -29,7 +29,7 @@ public class StandAloneGame {
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE
};
public static void main(String[] args) {
public static void main(String[] args) throws IOException {
try {
GameSettings gameSettings = GameSettings
.createGameSetings("data/gogame.cfg");
@@ -42,6 +42,8 @@ public class StandAloneGame {
gameSettings.getNumGames(), gameSettings.getTurnTime(),
gameSettings.isSpectatorBoardShown(),
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
System.out.println("Press <Enter> or CTRL-C to exit");
System.in.read(new byte[80]);
} catch (IOException ioe) {
ioe.printStackTrace();
System.exit(EXIT_IO_EXCEPTION);

View File

@@ -3,12 +3,7 @@ package net.woodyfolsom.msproj.ann;
import org.encog.ml.data.basic.BasicMLData;
public class DoublePair extends BasicMLData {
// private final double x;
// private final double y;
/**
*
*/
private static final long serialVersionUID = 1L;
public DoublePair(double x, double y) {

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann;
public class FusekiLearner {
}

View File

@@ -1,4 +1,5 @@
package net.woodyfolsom.msproj.ann;
/*
* Class copied verbatim from Encog framework due to dependency on Propagation
* implementation.
@@ -6,7 +7,7 @@ package net.woodyfolsom.msproj.ann;
* Encog(tm) Core v3.2 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2012 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,12 +27,12 @@ package net.woodyfolsom.msproj.ann;
* http://www.heatonresearch.com/copyright
*/
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
@@ -42,7 +43,7 @@ public class GradientWorker implements EngineTask {
private final FlatNetwork network;
private final ErrorCalculation errorCalculation = new ErrorCalculation();
private final double[] actual;
private final List<double[]> actuals;
private final double[] layerDelta;
private final int[] layerCounts;
private final int[] layerFeedCounts;
@@ -52,30 +53,29 @@ public class GradientWorker implements EngineTask {
private final double[] layerSums;
private final double[] gradients;
private final double[] weights;
private final MLDataPair pair;
private final MLDataPair pairPrototype;
private final Set<List<MLDataPair>> training;
private final int low;
private final int high;
//private final int low;
//private final int high;
private final TemporalDifferenceLearning owner;
private double[] flatSpot;
private final ErrorFunction errorFunction;
public GradientWorker(final FlatNetwork theNetwork,
final TemporalDifferenceLearning theOwner,
final Set<List<MLDataPair>> theTraining, final int theLow,
final int theHigh, final double[] flatSpot,
ErrorFunction ef) {
final Set<List<MLDataPair>> theTraining, final int theLow,
final int theHigh, final double[] flatSpot, ErrorFunction ef) {
this.network = theNetwork;
this.training = theTraining;
this.low = theLow;
this.high = theHigh;
//this.low = theLow;
//this.high = theHigh;
this.owner = theOwner;
this.flatSpot = flatSpot;
this.errorFunction = ef;
this.layerDelta = new double[network.getLayerOutput().length];
this.gradients = new double[network.getWeights().length];
this.actual = new double[network.getOutputCount()];
this.actuals = new ArrayList<double[]>();
this.weights = network.getWeights();
this.layerIndex = network.getLayerIndex();
@@ -85,8 +85,8 @@ public class GradientWorker implements EngineTask {
this.layerSums = network.getLayerSums();
this.layerFeedCounts = network.getLayerFeedCounts();
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
.getOutputCount());
this.pairPrototype = BasicMLDataPair.createPair(
network.getInputCount(), network.getOutputCount());
}
public FlatNetwork getNetwork() {
@@ -97,22 +97,48 @@ public class GradientWorker implements EngineTask {
return this.weights;
}
private void process(final double[] input, final double[] ideal, double s) {
this.network.compute(input, this.actual);
private void process(List<MLDataPair> trainingSequence) {
actuals.clear();
this.errorCalculation.updateError(this.actual, ideal, s);
this.errorFunction.calculateError(ideal, actual, this.layerDelta);
for (int trainingIdx = 0; trainingIdx < trainingSequence.size(); trainingIdx++) {
MLDataPair mlDataPair = trainingSequence.get(trainingIdx);
MLDataPair dataPairCopy = this.pairPrototype;
dataPairCopy.setInputArray(mlDataPair.getInputArray());
if (dataPairCopy.getIdealArray() != null) {
dataPairCopy.setIdealArray(mlDataPair.getIdealArray());
}
for (int i = 0; i < this.actual.length; i++) {
double[] input = dataPairCopy.getInputArray();
double[] ideal = dataPairCopy.getIdealArray();
double significance = dataPairCopy.getSignificance();
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
.derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0]))
* (this.layerDelta[i] * s);
}
actuals.add(trainingIdx, new double[ideal.length]);
this.network.compute(input, actuals.get(trainingIdx));
// For now, only calculate deltas for the final data pair
// For final TDL algorithm, deltas won't be used at all, instead the
// List of Actual vectors will.
if (trainingIdx < trainingSequence.size() - 1) {
continue;
}
this.errorCalculation.updateError(actuals.get(trainingIdx), ideal,
significance);
this.errorFunction.calculateError(ideal, actuals.get(trainingIdx),
this.layerDelta);
for (int i = 0; i < actuals.get(trainingIdx).length; i++) {
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
.derivativeFunction(this.layerSums[i],
this.layerOutput[i]) + this.flatSpot[0]))
* (this.layerDelta[i] * significance);
}
for (int i = this.network.getBeginTraining(); i < this.network
.getEndTraining(); i++) {
processLevel(i);
}
for (int i = this.network.getBeginTraining(); i < this.network
.getEndTraining(); i++) {
processLevel(i);
}
}
@@ -142,7 +168,8 @@ public class GradientWorker implements EngineTask {
}
this.layerDelta[yi] = sum
* (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot);
* (activation.derivativeFunction(this.layerSums[yi],
this.layerOutput[yi]) + currentFlatSpot);
yi++;
}
}
@@ -150,17 +177,11 @@ public class GradientWorker implements EngineTask {
public final void run() {
try {
this.errorCalculation.reset();
//for (int i = this.low; i <= this.high; i++) {
for (List<MLDataPair> trainingSequence : training) {
MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1);
this.pair.setInputArray(mldp.getInputArray());
if (this.pair.getIdealArray() != null) {
this.pair.setIdealArray(mldp.getIdealArray());
}
//this.training.getRecord(i, this.pair);
process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance());
process(trainingSequence);
}
//}
final double error = this.errorCalculation.calculate();
this.owner.report(this.gradients, error, null);
EngineArray.fill(this.gradients, 0);

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann;
public class JosekiLearner {
}

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann;
public class ShapeLearner {
}

View File

@@ -73,7 +73,7 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
// BasicTraining
private final List<Strategy> strategies = new ArrayList<Strategy>();
private Set<List<MLDataPair>> training;
//private Set<List<MLDataPair>> training;
private double error;
private int iteration;
private TrainingImplementationType implementationType;
@@ -104,7 +104,7 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
initBasicTraining(TrainingImplementationType.Iterative);
this.network = network;
this.currentFlatNetwork = network.getFlat();
setTraining(training);
//setTraining(training);
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
@@ -181,8 +181,10 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
public double updateWeight(final double[] gradients,
final double[] lastGradient, final int index) {
final double delta = (gradients[index] * this.learningRate)
+ (this.lastDelta[index] * this.momentum);
this.lastDelta[index] = delta;
System.out.println("Updating weights for connection: " + index
@@ -474,7 +476,8 @@ public class TemporalDifferenceLearning implements MLTrain, Momentum,
}
public void setTraining(final Set<List<MLDataPair>> training) {
this.training = training;
//this.training = training;
throw new UnsupportedOperationException();
}
public TrainingImplementationType getImplementationType() {

View File

@@ -61,9 +61,13 @@ public class WinFilter extends AbstractNeuralNetFilter implements
@Override
public void learn(MLDataSet trainingData) {
throw new UnsupportedOperationException("This filter learns a Set<List<MLData>>, not an MLDataSet");
throw new UnsupportedOperationException("This filter learns a Set<List<MLDataPair>>, not an MLDataSet");
}
/**
* Method is necessary because with temporal difference learning, some of the MLDataPairs are related by being a sequence
* of moves within a particular game.
*/
@Override
public void learn(Set<List<MLDataPair>> trainingSet) {