Compare commits
10 Commits
b723e2666e
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| f59de76865 | |||
| 606f069cff | |||
| 70c7d7d9b1 | |||
| 14bc769493 | |||
| 28dc44b61e | |||
| d24e7aee97 | |||
| 214bdcd032 | |||
| 790b5666a8 | |||
| 874847f41b | |||
| 2e36b01363 |
@@ -7,6 +7,5 @@
|
||||
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
|
||||
<classpathentry kind="lib" path="lib/kgsGtp.jar"/>
|
||||
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
|
||||
<classpathentry kind="lib" path="lib/encog-java-core.jar" sourcepath="lib/encog-java-core-sources.jar"/>
|
||||
<classpathentry kind="output" path="bin"/>
|
||||
</classpath>
|
||||
|
||||
0
GoGame.log
Normal file
0
GoGame.log
Normal file
20
build.xml
20
build.xml
@@ -3,6 +3,7 @@
|
||||
<description>Simple Framework for Testing Tree Search and Monte-Carlo Go</description>
|
||||
|
||||
<property name="src" location="src" />
|
||||
<property name="reports" location="reports" />
|
||||
<property name="build" location="build" />
|
||||
<property name="dist" location="dist" />
|
||||
<property name="test" location="test" />
|
||||
@@ -33,9 +34,16 @@
|
||||
</target>
|
||||
|
||||
<target name="copy-resources">
|
||||
<copy todir="${dist}/data">
|
||||
<copy todir="${dist}" file="connect.bat" />
|
||||
<copy todir="${dist}" file="rrt.bat" />
|
||||
<copy todir="${dist}" file="data/log4j.xml" />
|
||||
<copy todir="${dist}" file="data/kgsGtp.ini" />
|
||||
<copy todir="${dist}" file="data/gogame.cfg" />
|
||||
|
||||
<!--copy todir="${dist}/data">
|
||||
<fileset dir="data" />
|
||||
</copy>
|
||||
</copy-->
|
||||
|
||||
<copy todir="${build}/net/woodyfolsom/msproj/gui">
|
||||
<fileset dir="${src}/net/woodyfolsom/msproj/gui">
|
||||
<exclude name="**/*.java"/>
|
||||
@@ -58,6 +66,7 @@
|
||||
<!-- Delete the ${build} and ${dist} directory trees -->
|
||||
<delete dir="${build}" />
|
||||
<delete dir="${dist}" />
|
||||
<delete dir="${reports}" />
|
||||
</target>
|
||||
|
||||
<target name="dist" depends="compile,copy-resources,copy-libs" description="generate the distribution">
|
||||
@@ -83,16 +92,17 @@
|
||||
<target name="init">
|
||||
<!-- Create the build directory structure used by compile -->
|
||||
<mkdir dir="${build}" />
|
||||
<mkdir dir="${reports}" />
|
||||
</target>
|
||||
|
||||
<target name="test" depends="compile-test">
|
||||
<junit haltonfailure="true">
|
||||
<classpath refid="classpath.test" />
|
||||
<formatter type="brief" usefile="false" />
|
||||
<batchtest>
|
||||
<formatter type="xml" />
|
||||
<batchtest todir="${reports}">
|
||||
<fileset dir="${build}" includes="**/*Test.class" />
|
||||
</batchtest>
|
||||
</junit>
|
||||
</target>
|
||||
|
||||
</project>
|
||||
</project>
|
||||
|
||||
1
connect.bat
Normal file
1
connect.bat
Normal file
@@ -0,0 +1 @@
|
||||
java -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame
|
||||
@@ -1,10 +1,10 @@
|
||||
PlayerOne=RANDOM
|
||||
PlayerTwo=RANDOM
|
||||
PlayerOne=ROOT_PAR_AMAF //HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF, ROOT_PAR_AMAF
|
||||
PlayerTwo=Random
|
||||
GUIDelay=1000 //1 second
|
||||
BoardSize=9
|
||||
Komi=6.5
|
||||
NumGames=1000 //Games for each color per player
|
||||
TurnTime=1000 //seconds per player per turn
|
||||
SpectatorBoardShown=false;
|
||||
WhiteMoveLogged=false;
|
||||
BlackMoveLogged=false;
|
||||
BoardSize=13 //9, 13 or 19
|
||||
Komi=6.5 //suggested 6.5
|
||||
NumGames=1 //Games for each color per player
|
||||
TurnTime=6000 //seconds per player per turn
|
||||
SpectatorBoardShown=true //set to true for modes which otherwise wouldn't show GUI. false for HUMAN_GUI player.
|
||||
WhiteMoveLogged=false
|
||||
BlackMoveLogged=true
|
||||
@@ -1,12 +1,11 @@
|
||||
engine=java -cp GoGame.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame montecarlo
|
||||
name=whf4cs6999
|
||||
password=6id39p
|
||||
name=whf4human
|
||||
password=t3snxf
|
||||
room=whf4cs6999
|
||||
mode=custom
|
||||
mode=auto
|
||||
talk=I'm a Monte Carlo tree search bot.
|
||||
opponent=whf4human
|
||||
reconnect=t
|
||||
automatch.rank=25k
|
||||
rules=chinese
|
||||
rules.boardSize=9
|
||||
rules.time=0
|
||||
rules.time=0
|
||||
opponent=whf4cs6999
|
||||
3
gofree.txt
Normal file
3
gofree.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
UCT-RAVE vs GoFree
|
||||
level 1 (black) 2/2
|
||||
level 2 (black) 1/1
|
||||
11
kgsGtp.ini
Normal file
11
kgsGtp.ini
Normal file
@@ -0,0 +1,11 @@
|
||||
engine=java -cp GoGame.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame montecarlo
|
||||
name=whf4human
|
||||
password=t3snxf
|
||||
room=whf4cs6999
|
||||
mode=auto
|
||||
talk=I'm a Monte Carlo tree search bot.
|
||||
reconnect=t
|
||||
rules=chinese
|
||||
rules.boardSize=13
|
||||
rules.time=0
|
||||
opponent=whf4cs6999
|
||||
BIN
lib/activation-1.0.2.jar
Normal file
BIN
lib/activation-1.0.2.jar
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
lib/jaxb-api-2.3.1.jar
Normal file
BIN
lib/jaxb-api-2.3.1.jar
Normal file
Binary file not shown.
19
log4j.xml
Normal file
19
log4j.xml
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE log4j:configuration SYSTEM "log4j.dtd">
|
||||
|
||||
<log4j:configuration xmlns:log4j="http://jakarta.apache.org/log4j/" debug="false">
|
||||
|
||||
<appender name="fileAppender" class="org.apache.log4j.RollingFileAppender">
|
||||
<param name="Threshold" value="INFO" />
|
||||
<param name="File" value="GoGame.log"/>
|
||||
<layout class="org.apache.log4j.PatternLayout">
|
||||
<param name="ConversionPattern" value="%d %-5p [%c{1}] %m %n" />
|
||||
</layout>
|
||||
</appender>
|
||||
|
||||
<logger name="cs6601.p1" additivity="false" >
|
||||
<level value="INFO" />
|
||||
<appender-ref ref="fileAppender"/>
|
||||
</logger>
|
||||
|
||||
</log4j:configuration>
|
||||
42
pass.net
Normal file
42
pass.net
Normal file
@@ -0,0 +1,42 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<multiLayerPerceptron biased="true" name="PassFilter9">
|
||||
<activationFunction name="Sigmoid"/>
|
||||
<connections dest="3" src="0" weight="1.9322455294572656"/>
|
||||
<connections dest="3" src="1" weight="0.0859943747020325"/>
|
||||
<connections dest="3" src="2" weight="0.8394414489841715"/>
|
||||
<connections dest="4" src="0" weight="1.5831613048108952"/>
|
||||
<connections dest="4" src="1" weight="0.8667080746254153"/>
|
||||
<connections dest="4" src="2" weight="-3.204930958551688"/>
|
||||
<connections dest="5" src="0" weight="-1.4223906119706369"/>
|
||||
<connections dest="5" src="3" weight="-2.1292730695450857"/>
|
||||
<connections dest="5" src="4" weight="-2.5861434868493607"/>
|
||||
<neurons id="0">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="1">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="2">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="3">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="4">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="5">
|
||||
<activationFunction name="Sigmoid"/>
|
||||
</neurons>
|
||||
<layers>
|
||||
<neuronIds>1</neuronIds>
|
||||
<neuronIds>2</neuronIds>
|
||||
</layers>
|
||||
<layers>
|
||||
<neuronIds>3</neuronIds>
|
||||
<neuronIds>4</neuronIds>
|
||||
</layers>
|
||||
<layers>
|
||||
<neuronIds>5</neuronIds>
|
||||
</layers>
|
||||
</multiLayerPerceptron>
|
||||
1
rrt.bat
Normal file
1
rrt.bat
Normal file
@@ -0,0 +1 @@
|
||||
java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
85
rrt.rootpar-amaf.black.txt
Normal file
85
rrt.rootpar-amaf.black.txt
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+5.5
|
||||
RootParallelization (Black) vs Random (White) : B+5.5
|
||||
Game over. Result: B+30.5
|
||||
RootParallelization (Black) vs Random (White) : B+30.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization (Black) vs Random (White) : B+0.5
|
||||
Game over. Result: B+18.5
|
||||
RootParallelization (Black) vs Random (White) : B+18.5
|
||||
Game over. Result: B+18.5
|
||||
RootParallelization (Black) vs Random (White) : B+18.5
|
||||
Game over. Result: B+3.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+3.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+0.5
|
||||
Game over. Result: B+46.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+46.5
|
||||
Game over. Result: B+44.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+44.5
|
||||
Game over. Result: B+53.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+53.5
|
||||
Game over. Result: B+14.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+14.5
|
||||
Game over. Result: B+30.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+30.5
|
||||
Game over. Result: B+9.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+9.5
|
||||
Game over. Result: B+44.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+44.5
|
||||
Game over. Result: B+29.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+29.5
|
||||
Game over. Result: B+4.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : B+4.5
|
||||
Game over. Result: B+27.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : B+27.5
|
||||
Game over. Result: B+29.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : B+29.5
|
||||
Game over. Result: B+22.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : B+22.5
|
||||
Game over. Result: B+36.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : B+36.5
|
||||
Game over. Result: B+50.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : B+50.5
|
||||
Game over. Result: B+42.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : B+42.5
|
||||
Game over. Result: B+28.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : B+28.5
|
||||
Game over. Result: B+38.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : B+38.5
|
||||
Game over. Result: B+23.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : B+23.5
|
||||
Game over. Result: B+7.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : B+7.5
|
||||
Game over. Result: W+16.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : W+16.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : B+0.5
|
||||
Game over. Result: B+8.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : B+8.5
|
||||
Game over. Result: W+19.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : W+19.5
|
||||
Game over. Result: B+13.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+13.5
|
||||
Game over. Result: B+2.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+2.5
|
||||
Game over. Result: B+16.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+16.5
|
||||
Game over. Result: B+32.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+32.5
|
||||
Game over. Result: B+8.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+8.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
RootParallelization (Black) vs Random (White) : 100%
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : 100%
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : 100%
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : 100%
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : 100%
|
||||
RootParallelization (Black) vs RootParallelization (White) : 60%
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : 100%
|
||||
Tournament lasted 1597.948 seconds.
|
||||
84
rrt/rrt.alphabeta.black.txt
Normal file
84
rrt/rrt.alphabeta.black.txt
Normal file
@@ -0,0 +1,84 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+2.5
|
||||
Alpha-Beta (Black) vs Random (White) : B+2.5
|
||||
Game over. Result: W+3.5
|
||||
Alpha-Beta (Black) vs Random (White) : W+3.5
|
||||
Game over. Result: W+4.5
|
||||
Alpha-Beta (Black) vs Random (White) : W+4.5
|
||||
Game over. Result: W+1.5
|
||||
Alpha-Beta (Black) vs Random (White) : W+1.5
|
||||
Game over. Result: W+2.5
|
||||
Alpha-Beta (Black) vs Random (White) : W+2.5
|
||||
Game over. Result: B+40.5
|
||||
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
|
||||
Game over. Result: B+40.5
|
||||
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
|
||||
Game over. Result: B+40.5
|
||||
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
|
||||
Game over. Result: B+40.5
|
||||
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
|
||||
Game over. Result: B+40.5
|
||||
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
|
||||
Game over. Result: W+17.5
|
||||
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+17.5
|
||||
Game over. Result: W+40.5
|
||||
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+40.5
|
||||
Game over. Result: W+18.5
|
||||
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+18.5
|
||||
Game over. Result: W+30.5
|
||||
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+30.5
|
||||
Game over. Result: W+33.5
|
||||
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+33.5
|
||||
Game over. Result: W+32.5
|
||||
Alpha-Beta (Black) vs UCT-RAVE (White) : W+32.5
|
||||
Game over. Result: W+41.5
|
||||
Alpha-Beta (Black) vs UCT-RAVE (White) : W+41.5
|
||||
Game over. Result: W+36.5
|
||||
Alpha-Beta (Black) vs UCT-RAVE (White) : W+36.5
|
||||
Game over. Result: W+40.5
|
||||
Alpha-Beta (Black) vs UCT-RAVE (White) : W+40.5
|
||||
Game over. Result: W+34.5
|
||||
Alpha-Beta (Black) vs UCT-RAVE (White) : W+34.5
|
||||
Game over. Result: W+6.5
|
||||
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+6.5
|
||||
Game over. Result: W+23.5
|
||||
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+23.5
|
||||
Game over. Result: W+18.5
|
||||
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+18.5
|
||||
Game over. Result: W+33.5
|
||||
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+33.5
|
||||
Game over. Result: W+40.5
|
||||
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+40.5
|
||||
Game over. Result: W+1.5
|
||||
Alpha-Beta (Black) vs RootParallelization (White) : W+1.5
|
||||
Game over. Result: W+4.5
|
||||
Alpha-Beta (Black) vs RootParallelization (White) : W+4.5
|
||||
Game over. Result: W+0.5
|
||||
Alpha-Beta (Black) vs RootParallelization (White) : W+0.5
|
||||
Game over. Result: W+0.5
|
||||
Alpha-Beta (Black) vs RootParallelization (White) : W+0.5
|
||||
Game over. Result: W+35.5
|
||||
Alpha-Beta (Black) vs RootParallelization (White) : W+35.5
|
||||
Game over. Result: W+3.5
|
||||
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+3.5
|
||||
Game over. Result: W+0.5
|
||||
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+0.5
|
||||
Game over. Result: W+1.5
|
||||
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+1.5
|
||||
Game over. Result: W+4.5
|
||||
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+4.5
|
||||
Game over. Result: W+4.5
|
||||
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+4.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
Alpha-Beta (Black) vs Random (White) : 20%
|
||||
Alpha-Beta (Black) vs Alpha-Beta (White) : 100%
|
||||
Alpha-Beta (Black) vs MonteCarloUCT (White) : 00%
|
||||
Alpha-Beta (Black) vs UCT-RAVE (White) : 00%
|
||||
Alpha-Beta (Black) vs MonteCarloSMAF (White) : 00%
|
||||
Alpha-Beta (Black) vs RootParallelization (White) : 00%
|
||||
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : 00%
|
||||
85
rrt/rrt.amaf.black.txt
Normal file
85
rrt/rrt.amaf.black.txt
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+40.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+40.5
|
||||
Game over. Result: B+3.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+3.5
|
||||
Game over. Result: B+15.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+15.5
|
||||
Game over. Result: B+54.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+54.5
|
||||
Game over. Result: B+18.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+18.5
|
||||
Game over. Result: B+31.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+31.5
|
||||
Game over. Result: B+17.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+17.5
|
||||
Game over. Result: W+9.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : W+9.5
|
||||
Game over. Result: B+34.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+34.5
|
||||
Game over. Result: W+9.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : W+9.5
|
||||
Game over. Result: B+2.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+2.5
|
||||
Game over. Result: B+36.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+36.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+9.5
|
||||
Game over. Result: W+2.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : W+2.5
|
||||
Game over. Result: B+1.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+1.5
|
||||
Game over. Result: B+22.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+22.5
|
||||
Game over. Result: B+5.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+5.5
|
||||
Game over. Result: B+2.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+2.5
|
||||
Game over. Result: B+11.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+11.5
|
||||
Game over. Result: W+11.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : W+11.5
|
||||
Game over. Result: B+7.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+7.5
|
||||
Game over. Result: B+39.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+39.5
|
||||
Game over. Result: W+15.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+15.5
|
||||
Game over. Result: W+22.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+22.5
|
||||
Game over. Result: W+3.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+3.5
|
||||
Game over. Result: B+20.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+20.5
|
||||
Game over. Result: B+29.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+29.5
|
||||
Game over. Result: B+41.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+41.5
|
||||
Game over. Result: B+36.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+36.5
|
||||
Game over. Result: B+18.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+18.5
|
||||
Game over. Result: B+54.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+54.5
|
||||
Game over. Result: B+7.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+7.5
|
||||
Game over. Result: B+19.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+19.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+9.5
|
||||
Game over. Result: B+3.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+3.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
UCT-RAVE (Black) vs Random (White) : 100%
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : 60%
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : 80%
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : 80%
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : 40%
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : 100%
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : 100%
|
||||
Tournament lasted 1476.893 seconds.
|
||||
86
rrt/rrt.random.black.txt
Normal file
86
rrt/rrt.random.black.txt
Normal file
@@ -0,0 +1,86 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: W+6.5
|
||||
Random (Black) vs Random (White) : W+6.5
|
||||
Game over. Result: B+1.5
|
||||
Random (Black) vs Random (White) : B+1.5
|
||||
Game over. Result: B+7.5
|
||||
Random (Black) vs Random (White) : B+7.5
|
||||
Game over. Result: W+0.5
|
||||
Random (Black) vs Random (White) : W+0.5
|
||||
Game over. Result: B+1.5
|
||||
Random (Black) vs Random (White) : B+1.5
|
||||
Game over. Result: B+28.5
|
||||
Random (Black) vs Alpha-Beta (White) : B+28.5
|
||||
Game over. Result: B+1.5
|
||||
Random (Black) vs Alpha-Beta (White) : B+1.5
|
||||
Game over. Result: B+29.5
|
||||
Random (Black) vs Alpha-Beta (White) : B+29.5
|
||||
Game over. Result: B+47.5
|
||||
Random (Black) vs Alpha-Beta (White) : B+47.5
|
||||
Game over. Result: B+22.5
|
||||
Random (Black) vs Alpha-Beta (White) : B+22.5
|
||||
Game over. Result: W+22.5
|
||||
Random (Black) vs MonteCarloUCT (White) : W+22.5
|
||||
Game over. Result: W+6.5
|
||||
Random (Black) vs MonteCarloUCT (White) : W+6.5
|
||||
Game over. Result: W+5.5
|
||||
Random (Black) vs MonteCarloUCT (White) : W+5.5
|
||||
Game over. Result: W+12.5
|
||||
Random (Black) vs MonteCarloUCT (White) : W+12.5
|
||||
Game over. Result: W+35.5
|
||||
Random (Black) vs MonteCarloUCT (White) : W+35.5
|
||||
Game over. Result: W+14.5
|
||||
Random (Black) vs UCT-RAVE (White) : W+14.5
|
||||
Game over. Result: W+18.5
|
||||
Random (Black) vs UCT-RAVE (White) : W+18.5
|
||||
Game over. Result: W+3.5
|
||||
Random (Black) vs UCT-RAVE (White) : W+3.5
|
||||
Game over. Result: W+5.5
|
||||
Random (Black) vs UCT-RAVE (White) : W+5.5
|
||||
Game over. Result: W+32.5
|
||||
Random (Black) vs UCT-RAVE (White) : W+32.5
|
||||
Game over. Result: W+19.5
|
||||
Random (Black) vs MonteCarloSMAF (White) : W+19.5
|
||||
Game over. Result: W+26.5
|
||||
Random (Black) vs MonteCarloSMAF (White) : W+26.5
|
||||
Game over. Result: W+19.5
|
||||
Random (Black) vs MonteCarloSMAF (White) : W+19.5
|
||||
Game over. Result: W+8.5
|
||||
Random (Black) vs MonteCarloSMAF (White) : W+8.5
|
||||
Game over. Result: W+13.5
|
||||
Random (Black) vs MonteCarloSMAF (White) : W+13.5
|
||||
Game over. Result: W+9.5
|
||||
Random (Black) vs RootParallelization (White) : W+9.5
|
||||
Game over. Result: W+4.5
|
||||
Random (Black) vs RootParallelization (White) : W+4.5
|
||||
Game over. Result: W+8.5
|
||||
Random (Black) vs RootParallelization (White) : W+8.5
|
||||
Game over. Result: W+39.5
|
||||
Random (Black) vs RootParallelization (White) : W+39.5
|
||||
Game over. Result: W+0.5
|
||||
Random (Black) vs RootParallelization (White) : W+0.5
|
||||
Game over. Result: W+10.5
|
||||
Random (Black) vs RootParallelization-NeuralNet (White) : W+10.5
|
||||
Game over. Result: W+11.5
|
||||
Random (Black) vs RootParallelization-NeuralNet (White) : W+11.5
|
||||
Game over. Result: W+1.5
|
||||
Random (Black) vs RootParallelization-NeuralNet (White) : W+1.5
|
||||
Game over. Result: W+3.5
|
||||
Random (Black) vs RootParallelization-NeuralNet (White) : W+3.5
|
||||
Game over. Result: W+10.5
|
||||
Random (Black) vs RootParallelization-NeuralNet (White) : W+10.5
|
||||
Game over. Result: W+40.5
|
||||
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
Random (Black) vs Random (White) : 40%
|
||||
Random (Black) vs Alpha-Beta (White) : 100%
|
||||
Random (Black) vs MonteCarloUCT (White) : 00%
|
||||
Random (Black) vs UCT-RAVE (White) : 00%
|
||||
Random (Black) vs MonteCarloSMAF (White) : 00%
|
||||
Random (Black) vs RootParallelization (White) : 00%
|
||||
Random (Black) vs RootParallelization-NeuralNet (White) : 00%
|
||||
83
rrt/rrt.rave.black.txt
Normal file
83
rrt/rrt.rave.black.txt
Normal file
@@ -0,0 +1,83 @@
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+8.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+8.5
|
||||
Game over. Result: B+31.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+31.5
|
||||
Game over. Result: B+16.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+16.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+9.5
|
||||
Game over. Result: B+16.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+16.5
|
||||
Game over. Result: B+48.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+48.5
|
||||
Game over. Result: W+5.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : W+5.5
|
||||
Game over. Result: B+13.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+13.5
|
||||
Game over. Result: B+34.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+34.5
|
||||
Game over. Result: B+1.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+1.5
|
||||
Game over. Result: B+2.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+2.5
|
||||
Game over. Result: B+7.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+7.5
|
||||
Game over. Result: W+4.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : W+4.5
|
||||
Game over. Result: B+3.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+3.5
|
||||
Game over. Result: B+6.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+6.5
|
||||
Game over. Result: B+3.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+3.5
|
||||
Game over. Result: B+2.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+2.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+9.5
|
||||
Game over. Result: B+0.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+0.5
|
||||
Game over. Result: W+13.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : W+13.5
|
||||
Game over. Result: W+0.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
|
||||
Game over. Result: B+1.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+1.5
|
||||
Game over. Result: W+0.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+9.5
|
||||
Game over. Result: W+20.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+20.5
|
||||
Game over. Result: B+13.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+13.5
|
||||
Game over. Result: W+16.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : W+16.5
|
||||
Game over. Result: B+28.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+28.5
|
||||
Game over. Result: B+25.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
|
||||
Game over. Result: B+25.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
|
||||
Game over. Result: B+48.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+48.5
|
||||
Game over. Result: B+6.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+6.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+9.5
|
||||
Game over. Result: B+55.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+55.5
|
||||
Game over. Result: B+42.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+42.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
UCT-RAVE (Black) vs Random (White) : 100%
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : 80%
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : 80%
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : 80%
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : 40%
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : 80%
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : 100%
|
||||
Tournament lasted 1458.494 seconds.
|
||||
85
rrt/rrt.rootpar-nn.black.txt
Normal file
85
rrt/rrt.rootpar-nn.black.txt
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+6.5
|
||||
RootParallelization-NeuralNet (Black) vs Random (White) : B+6.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization-NeuralNet (Black) vs Random (White) : B+0.5
|
||||
Game over. Result: B+5.5
|
||||
RootParallelization-NeuralNet (Black) vs Random (White) : B+5.5
|
||||
Game over. Result: B+19.5
|
||||
RootParallelization-NeuralNet (Black) vs Random (White) : B+19.5
|
||||
Game over. Result: B+2.5
|
||||
RootParallelization-NeuralNet (Black) vs Random (White) : B+2.5
|
||||
Game over. Result: B+21.5
|
||||
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : B+21.5
|
||||
Game over. Result: W+12.5
|
||||
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : W+12.5
|
||||
Game over. Result: B+23.5
|
||||
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : B+23.5
|
||||
Game over. Result: B+23.5
|
||||
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : B+23.5
|
||||
Game over. Result: W+9.5
|
||||
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : W+9.5
|
||||
Game over. Result: B+29.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+29.5
|
||||
Game over. Result: B+9.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+9.5
|
||||
Game over. Result: W+50.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : W+50.5
|
||||
Game over. Result: B+9.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+9.5
|
||||
Game over. Result: B+7.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+7.5
|
||||
Game over. Result: W+12.5
|
||||
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+12.5
|
||||
Game over. Result: W+9.5
|
||||
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+9.5
|
||||
Game over. Result: W+29.5
|
||||
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+29.5
|
||||
Game over. Result: W+10.5
|
||||
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+10.5
|
||||
Game over. Result: W+27.5
|
||||
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+27.5
|
||||
Game over. Result: W+2.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+2.5
|
||||
Game over. Result: W+22.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+22.5
|
||||
Game over. Result: W+10.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+10.5
|
||||
Game over. Result: W+41.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+41.5
|
||||
Game over. Result: W+18.5
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+18.5
|
||||
Game over. Result: B+3.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : B+3.5
|
||||
Game over. Result: W+10.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+10.5
|
||||
Game over. Result: W+14.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+14.5
|
||||
Game over. Result: W+5.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+5.5
|
||||
Game over. Result: W+6.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+6.5
|
||||
Game over. Result: W+8.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : W+8.5
|
||||
Game over. Result: W+11.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : W+11.5
|
||||
Game over. Result: W+6.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : W+6.5
|
||||
Game over. Result: B+2.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : B+2.5
|
||||
Game over. Result: B+21.5
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : B+21.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
RootParallelization-NeuralNet (Black) vs Random (White) : 100%
|
||||
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : 60%
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : 80%
|
||||
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : 00%
|
||||
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : 00%
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : 20%
|
||||
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : 40%
|
||||
Tournament lasted 1400.277 seconds.
|
||||
85
rrt/rrt.rootpar.black.txt
Normal file
85
rrt/rrt.rootpar.black.txt
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+4.5
|
||||
RootParallelization (Black) vs Random (White) : B+4.5
|
||||
Game over. Result: B+4.5
|
||||
RootParallelization (Black) vs Random (White) : B+4.5
|
||||
Game over. Result: B+1.5
|
||||
RootParallelization (Black) vs Random (White) : B+1.5
|
||||
Game over. Result: B+1.5
|
||||
RootParallelization (Black) vs Random (White) : B+1.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization (Black) vs Random (White) : B+0.5
|
||||
Game over. Result: B+20.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+20.5
|
||||
Game over. Result: B+23.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+23.5
|
||||
Game over. Result: W+9.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : W+9.5
|
||||
Game over. Result: W+7.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : W+7.5
|
||||
Game over. Result: B+25.5
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : B+25.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+0.5
|
||||
Game over. Result: B+11.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+11.5
|
||||
Game over. Result: W+0.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : W+0.5
|
||||
Game over. Result: B+1.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+1.5
|
||||
Game over. Result: B+0.5
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : B+0.5
|
||||
Game over. Result: W+22.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : W+22.5
|
||||
Game over. Result: W+63.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : W+63.5
|
||||
Game over. Result: W+29.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : W+29.5
|
||||
Game over. Result: W+58.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : W+58.5
|
||||
Game over. Result: W+30.5
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : W+30.5
|
||||
Game over. Result: W+15.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : W+15.5
|
||||
Game over. Result: W+62.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : W+62.5
|
||||
Game over. Result: W+57.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : W+57.5
|
||||
Game over. Result: W+57.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : W+57.5
|
||||
Game over. Result: W+12.5
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : W+12.5
|
||||
Game over. Result: B+2.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : B+2.5
|
||||
Game over. Result: W+6.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : W+6.5
|
||||
Game over. Result: B+2.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : B+2.5
|
||||
Game over. Result: W+5.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : W+5.5
|
||||
Game over. Result: B+2.5
|
||||
RootParallelization (Black) vs RootParallelization (White) : B+2.5
|
||||
Game over. Result: W+8.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+8.5
|
||||
Game over. Result: W+6.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+6.5
|
||||
Game over. Result: W+6.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+6.5
|
||||
Game over. Result: B+3.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+3.5
|
||||
Game over. Result: W+13.5
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+13.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
RootParallelization (Black) vs Random (White) : 100%
|
||||
RootParallelization (Black) vs Alpha-Beta (White) : 60%
|
||||
RootParallelization (Black) vs MonteCarloUCT (White) : 80%
|
||||
RootParallelization (Black) vs UCT-RAVE (White) : 00%
|
||||
RootParallelization (Black) vs MonteCarloSMAF (White) : 00%
|
||||
RootParallelization (Black) vs RootParallelization (White) : 60%
|
||||
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : 20%
|
||||
Tournament lasted 1367.523 seconds.
|
||||
85
rrt/rrt.smaf.black.txt
Normal file
85
rrt/rrt.smaf.black.txt
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+8.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+8.5
|
||||
Game over. Result: B+31.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+31.5
|
||||
Game over. Result: B+16.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+16.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+9.5
|
||||
Game over. Result: B+16.5
|
||||
UCT-RAVE (Black) vs Random (White) : B+16.5
|
||||
Game over. Result: B+48.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+48.5
|
||||
Game over. Result: W+5.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : W+5.5
|
||||
Game over. Result: B+13.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+13.5
|
||||
Game over. Result: B+34.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+34.5
|
||||
Game over. Result: B+1.5
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : B+1.5
|
||||
Game over. Result: B+2.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+2.5
|
||||
Game over. Result: B+7.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+7.5
|
||||
Game over. Result: W+4.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : W+4.5
|
||||
Game over. Result: B+3.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+3.5
|
||||
Game over. Result: B+6.5
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+6.5
|
||||
Game over. Result: B+3.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+3.5
|
||||
Game over. Result: B+2.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+2.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+9.5
|
||||
Game over. Result: B+0.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : B+0.5
|
||||
Game over. Result: W+13.5
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : W+13.5
|
||||
Game over. Result: W+0.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
|
||||
Game over. Result: B+1.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+1.5
|
||||
Game over. Result: W+0.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+9.5
|
||||
Game over. Result: W+20.5
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+20.5
|
||||
Game over. Result: B+13.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+13.5
|
||||
Game over. Result: W+16.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : W+16.5
|
||||
Game over. Result: B+28.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+28.5
|
||||
Game over. Result: B+25.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
|
||||
Game over. Result: B+25.5
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
|
||||
Game over. Result: B+48.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+48.5
|
||||
Game over. Result: B+6.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+6.5
|
||||
Game over. Result: B+9.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+9.5
|
||||
Game over. Result: B+55.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+55.5
|
||||
Game over. Result: B+42.5
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+42.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
UCT-RAVE (Black) vs Random (White) : 100%
|
||||
UCT-RAVE (Black) vs Alpha-Beta (White) : 80%
|
||||
UCT-RAVE (Black) vs MonteCarloUCT (White) : 80%
|
||||
UCT-RAVE (Black) vs UCT-RAVE (White) : 80%
|
||||
UCT-RAVE (Black) vs MonteCarloSMAF (White) : 40%
|
||||
UCT-RAVE (Black) vs RootParallelization (White) : 80%
|
||||
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : 100%
|
||||
Tournament lasted 1458.494 seconds.
|
||||
85
rrt/rrt.uct.black
Normal file
85
rrt/rrt.uct.black
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
|
||||
Beginning round-robin tournament.
|
||||
Initializing policies...
|
||||
Game over. Result: B+0.5
|
||||
MonteCarloUCT (Black) vs Random (White) : B+0.5
|
||||
Game over. Result: B+9.5
|
||||
MonteCarloUCT (Black) vs Random (White) : B+9.5
|
||||
Game over. Result: B+24.5
|
||||
MonteCarloUCT (Black) vs Random (White) : B+24.5
|
||||
Game over. Result: B+10.5
|
||||
MonteCarloUCT (Black) vs Random (White) : B+10.5
|
||||
Game over. Result: B+4.5
|
||||
MonteCarloUCT (Black) vs Random (White) : B+4.5
|
||||
Game over. Result: B+15.5
|
||||
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+15.5
|
||||
Game over. Result: B+22.5
|
||||
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+22.5
|
||||
Game over. Result: B+32.5
|
||||
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+32.5
|
||||
Game over. Result: W+12.5
|
||||
MonteCarloUCT (Black) vs Alpha-Beta (White) : W+12.5
|
||||
Game over. Result: B+23.5
|
||||
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+23.5
|
||||
Game over. Result: B+0.5
|
||||
MonteCarloUCT (Black) vs MonteCarloUCT (White) : B+0.5
|
||||
Game over. Result: W+13.5
|
||||
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+13.5
|
||||
Game over. Result: W+11.5
|
||||
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+11.5
|
||||
Game over. Result: W+8.5
|
||||
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+8.5
|
||||
Game over. Result: W+9.5
|
||||
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+9.5
|
||||
Game over. Result: W+9.5
|
||||
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+9.5
|
||||
Game over. Result: W+16.5
|
||||
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+16.5
|
||||
Game over. Result: W+8.5
|
||||
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+8.5
|
||||
Game over. Result: W+11.5
|
||||
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+11.5
|
||||
Game over. Result: W+5.5
|
||||
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+5.5
|
||||
Game over. Result: W+8.5
|
||||
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+8.5
|
||||
Game over. Result: W+9.5
|
||||
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+9.5
|
||||
Game over. Result: W+15.5
|
||||
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+15.5
|
||||
Game over. Result: W+14.5
|
||||
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+14.5
|
||||
Game over. Result: W+13.5
|
||||
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+13.5
|
||||
Game over. Result: W+15.5
|
||||
MonteCarloUCT (Black) vs RootParallelization (White) : W+15.5
|
||||
Game over. Result: W+14.5
|
||||
MonteCarloUCT (Black) vs RootParallelization (White) : W+14.5
|
||||
Game over. Result: W+6.5
|
||||
MonteCarloUCT (Black) vs RootParallelization (White) : W+6.5
|
||||
Game over. Result: W+6.5
|
||||
MonteCarloUCT (Black) vs RootParallelization (White) : W+6.5
|
||||
Game over. Result: W+11.5
|
||||
MonteCarloUCT (Black) vs RootParallelization (White) : W+11.5
|
||||
Game over. Result: W+26.5
|
||||
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+26.5
|
||||
Game over. Result: W+11.5
|
||||
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+11.5
|
||||
Game over. Result: W+47.5
|
||||
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+47.5
|
||||
Game over. Result: W+13.5
|
||||
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+13.5
|
||||
Game over. Result: B+33.5
|
||||
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : B+33.5
|
||||
|
||||
Tournament Win Rates
|
||||
====================
|
||||
MonteCarloUCT (Black) vs Random (White) : 100%
|
||||
MonteCarloUCT (Black) vs Alpha-Beta (White) : 80%
|
||||
MonteCarloUCT (Black) vs MonteCarloUCT (White) : 20%
|
||||
MonteCarloUCT (Black) vs UCT-RAVE (White) : 00%
|
||||
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : 00%
|
||||
MonteCarloUCT (Black) vs RootParallelization (White) : 00%
|
||||
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : 20%
|
||||
Tournament lasted 1355.668 seconds.
|
||||
@@ -59,6 +59,14 @@ public class GameRecord {
|
||||
return gameStates.get(0).getGameConfig();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the game state for the most recent ply.
|
||||
* @return
|
||||
*/
|
||||
public GameState getGameState() {
|
||||
return gameStates.get(getNumTurns());
|
||||
}
|
||||
|
||||
public GameState getGameState(Integer turn) {
|
||||
return gameStates.get(turn);
|
||||
}
|
||||
|
||||
@@ -119,6 +119,14 @@ public class GameState {
|
||||
return whitePrisoners;
|
||||
}
|
||||
|
||||
public boolean isPrevPlyPass() {
|
||||
if (moveHistory.size() == 0) {
|
||||
return false;
|
||||
} else {
|
||||
return moveHistory.get(moveHistory.size()-1).isPass();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isSelfFill(Action action, Player player) {
|
||||
return gameBoard.isSelfFill(action, player);
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import net.woodyfolsom.msproj.policy.Minimax;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
|
||||
import net.woodyfolsom.msproj.policy.Policy;
|
||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||
import net.woodyfolsom.msproj.policy.RootParAMAF;
|
||||
|
||||
import org.apache.log4j.Logger;
|
||||
import org.apache.log4j.xml.DOMConfigurator;
|
||||
@@ -80,10 +81,11 @@ public class GoGame implements Runnable {
|
||||
public static void main(String[] args) throws IOException {
|
||||
configureLogging();
|
||||
if (args.length == 0) {
|
||||
Policy defaultMoveGenerator = new MonteCarloUCT(new RandomMovePolicy(), 5000L);
|
||||
LOGGER.info("No MoveGenerator specified. Using default: " + defaultMoveGenerator.toString());
|
||||
Policy policy = new RootParAMAF(4, 10000L);
|
||||
policy.setLogging(true);
|
||||
LOGGER.info("No MoveGenerator specified. Using default: " + policy.getName());
|
||||
|
||||
GoGame goGame = new GoGame(defaultMoveGenerator, PROPS_FILE);
|
||||
GoGame goGame = new GoGame(policy, PROPS_FILE);
|
||||
new Thread(goGame).start();
|
||||
|
||||
System.out.println("Creating GtpClient");
|
||||
@@ -111,7 +113,9 @@ public class GoGame implements Runnable {
|
||||
} else if ("alphabeta".equals(policyName)) {
|
||||
return new AlphaBeta();
|
||||
} else if ("montecarlo".equals(policyName)) {
|
||||
return new MonteCarloUCT(new RandomMovePolicy(), 5000L);
|
||||
return new MonteCarloUCT(new RandomMovePolicy(), 10000L);
|
||||
} else if ("root_par_amaf".equals(policyName)) {
|
||||
return new RootParAMAF(4, 10000L);
|
||||
} else {
|
||||
LOGGER.info("Unable to create Policy for unsupported name: " + policyName);
|
||||
System.exit(INVALID_MOVE_GENERATOR);
|
||||
|
||||
@@ -91,8 +91,7 @@ public class Referee {
|
||||
while (!gameRecord.isFinished()) {
|
||||
GameState gameState = gameRecord.getGameState(gameRecord
|
||||
.getNumTurns());
|
||||
// System.out.println(gameState);
|
||||
|
||||
|
||||
Player playerToMove = gameRecord.getPlayerToMove();
|
||||
Policy policy = getPolicy(playerToMove);
|
||||
Action action = policy.getAction(gameConfig, gameState,
|
||||
@@ -108,6 +107,11 @@ public class Referee {
|
||||
} else {
|
||||
System.out.println("Move rejected - try again.");
|
||||
}
|
||||
|
||||
if (policy.isLogging()) {
|
||||
System.out.println(gameState);
|
||||
}
|
||||
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
System.out
|
||||
|
||||
115
src/net/woodyfolsom/msproj/RoundRobin.java
Normal file
115
src/net/woodyfolsom/msproj/RoundRobin.java
Normal file
@@ -0,0 +1,115 @@
|
||||
package net.woodyfolsom.msproj;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.policy.AlphaBeta;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloSMAF;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
|
||||
import net.woodyfolsom.msproj.policy.NeuralNetPolicy;
|
||||
import net.woodyfolsom.msproj.policy.Policy;
|
||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||
import net.woodyfolsom.msproj.policy.RootParAMAF;
|
||||
import net.woodyfolsom.msproj.policy.RootParallelization;
|
||||
|
||||
public class RoundRobin {
|
||||
public static final int EXIT_USER_QUIT = 1;
|
||||
public static final int EXIT_NOMINAL = 0;
|
||||
public static final int EXIT_IO_EXCEPTION = -1;
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
System.out.println("Beginning round-robin tournament.");
|
||||
System.out.println("Initializing policies...");
|
||||
List<Policy> policies = new ArrayList<Policy>();
|
||||
|
||||
policies.add(new RandomMovePolicy());
|
||||
//policies.add(new Minimax(1));
|
||||
policies.add(new AlphaBeta(1));
|
||||
policies.add(new MonteCarloUCT(new RandomMovePolicy(), 500L));
|
||||
policies.add(new MonteCarloAMAF(new RandomMovePolicy(), 500L));
|
||||
policies.add(new MonteCarloSMAF(new RandomMovePolicy(), 500L, 4));
|
||||
policies.add(new RootParAMAF(4, 500L));
|
||||
policies.add(new RootParallelization(4, new NeuralNetPolicy(), 500L));
|
||||
|
||||
RoundRobin rr = new RoundRobin();
|
||||
|
||||
List<List<Double>> tourneyWinRates = new ArrayList<List<Double>>();
|
||||
|
||||
int gamesPerMatch = 5;
|
||||
|
||||
for (int i = 0; i < policies.size(); i++) {
|
||||
List<Double> roundWinRates = new ArrayList<Double>();
|
||||
if (i != 5) {
|
||||
tourneyWinRates.add(roundWinRates);
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < policies.size(); j++) {
|
||||
Policy policy1 = policies.get(i);
|
||||
policy1.setLogging(false);
|
||||
Policy policy2 = policies.get(j);
|
||||
policy2.setLogging(false);
|
||||
|
||||
List<GameResult> gameResults = rr.playGame(policy1, policy2, 9, 6.5, gamesPerMatch, false, false, false);
|
||||
|
||||
double wins = 0.0;
|
||||
double games = 0.0;
|
||||
for(GameResult gr : gameResults) {
|
||||
wins += gr.isWinner(Player.BLACK) ? 1.0 : 0.0;
|
||||
games += 1.0;
|
||||
}
|
||||
roundWinRates.add(100.0 * wins / games);
|
||||
}
|
||||
tourneyWinRates.add(roundWinRates);
|
||||
}
|
||||
|
||||
System.out.println("");
|
||||
System.out.println("Tournament Win Rates");
|
||||
System.out.println("====================");
|
||||
|
||||
DecimalFormat df = new DecimalFormat("00.#");
|
||||
for (int i = 0; i < policies.size(); i++) {
|
||||
for (int j = 0; j < policies.size(); j++) {
|
||||
if (i == 5)
|
||||
System.out.println(policies.get(i).getName() + " (Black) vs " + policies.get(j).getName() + " (White) : " + df.format(tourneyWinRates.get(i).get(j)) + "%");
|
||||
}
|
||||
}
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
System.out.println("Tournament lasted " + (endTime-startTime)/1000.0 + " seconds.");
|
||||
}
|
||||
|
||||
public List<GameResult> playGame(Policy player1Policy, Policy player2Policy, int size,
|
||||
double komi, int rounds, boolean showSpectatorBoard,
|
||||
boolean blackMoveLogged, boolean whiteMoveLogged) {
|
||||
|
||||
GameConfig gameConfig = new GameConfig(size);
|
||||
gameConfig.setKomi(komi);
|
||||
|
||||
Referee referee = new Referee();
|
||||
referee.setPolicy(Player.BLACK, player1Policy);
|
||||
referee.setPolicy(Player.WHITE, player2Policy);
|
||||
|
||||
List<GameResult> roundResults = new ArrayList<GameResult>();
|
||||
|
||||
boolean logGameRecords = false;
|
||||
|
||||
int gameNo = 1;
|
||||
|
||||
for (int round = 0; round < rounds; round++) {
|
||||
gameNo++;
|
||||
GameResult gameResult = referee.play(gameConfig, gameNo,
|
||||
showSpectatorBoard, logGameRecords);
|
||||
roundResults.add(gameResult);
|
||||
|
||||
System.out.println(player1Policy.getName() + " (Black) vs "
|
||||
+ player2Policy.getName() + " (White) : " + gameResult);
|
||||
roundResults.add(gameResult);
|
||||
}
|
||||
return roundResults;
|
||||
}
|
||||
}
|
||||
@@ -13,26 +13,28 @@ import net.woodyfolsom.msproj.gui.Goban;
|
||||
import net.woodyfolsom.msproj.policy.HumanGuiInput;
|
||||
import net.woodyfolsom.msproj.policy.HumanKeyboardInput;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloSMAF;
|
||||
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
|
||||
import net.woodyfolsom.msproj.policy.Policy;
|
||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||
import net.woodyfolsom.msproj.policy.RootParAMAF;
|
||||
import net.woodyfolsom.msproj.policy.RootParallelization;
|
||||
|
||||
public class StandAloneGame {
|
||||
public static final int EXIT_USER_QUIT = 1;
|
||||
public static final int EXIT_NOMINAL = 0;
|
||||
public static final int EXIT_IO_EXCEPTION = -1;
|
||||
|
||||
|
||||
private int gameNo = 0;
|
||||
|
||||
|
||||
enum PLAYER_TYPE {
|
||||
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE
|
||||
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF, ROOT_PAR_AMAF
|
||||
};
|
||||
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws IOException {
|
||||
try {
|
||||
GameSettings gameSettings = GameSettings
|
||||
.createGameSetings("data/gogame.cfg");
|
||||
.createGameSetings("gogame.cfg");
|
||||
System.out.println("Game Settings: " + gameSettings);
|
||||
System.out.println("Successfully parsed game settings.");
|
||||
new StandAloneGame().playGame(
|
||||
@@ -41,7 +43,10 @@ public class StandAloneGame {
|
||||
gameSettings.getBoardSize(), gameSettings.getKomi(),
|
||||
gameSettings.getNumGames(), gameSettings.getTurnTime(),
|
||||
gameSettings.isSpectatorBoardShown(),
|
||||
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged());
|
||||
gameSettings.isBlackMoveLogged(),
|
||||
gameSettings.isWhiteMoveLogged());
|
||||
System.out.println("Press <Enter> or CTRL-C to exit");
|
||||
System.in.read(new byte[80]);
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
System.exit(EXIT_IO_EXCEPTION);
|
||||
@@ -62,14 +67,19 @@ public class StandAloneGame {
|
||||
return PLAYER_TYPE.RANDOM;
|
||||
} else if ("RAVE".equalsIgnoreCase(playerTypeStr)) {
|
||||
return PLAYER_TYPE.RAVE;
|
||||
} else if ("SMAF".equalsIgnoreCase(playerTypeStr)) {
|
||||
return PLAYER_TYPE.SMAF;
|
||||
} else if ("ROOT_PAR_AMAF".equalsIgnoreCase(playerTypeStr)) {
|
||||
return PLAYER_TYPE.ROOT_PAR_AMAF;
|
||||
} else {
|
||||
throw new RuntimeException("Unknown player type: " + playerTypeStr);
|
||||
}
|
||||
}
|
||||
|
||||
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
|
||||
int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard,
|
||||
boolean blackMoveLogged, boolean whiteMoveLogged) {
|
||||
int size, double komi, int rounds, long turnLength,
|
||||
boolean showSpectatorBoard, boolean blackMoveLogged,
|
||||
boolean whiteMoveLogged) {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
@@ -77,28 +87,38 @@ public class StandAloneGame {
|
||||
gameConfig.setKomi(komi);
|
||||
|
||||
Referee referee = new Referee();
|
||||
referee.setPolicy(Player.BLACK,
|
||||
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
|
||||
referee.setPolicy(Player.WHITE,
|
||||
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
|
||||
referee.setPolicy(
|
||||
Player.BLACK,
|
||||
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength,
|
||||
blackMoveLogged));
|
||||
referee.setPolicy(
|
||||
Player.WHITE,
|
||||
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength,
|
||||
whiteMoveLogged));
|
||||
|
||||
List<GameResult> round1results = new ArrayList<GameResult>();
|
||||
|
||||
|
||||
boolean logGameRecords = rounds <= 50;
|
||||
for (int round = 0; round < rounds; round++) {
|
||||
gameNo++;
|
||||
round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
|
||||
round1results.add(referee.play(gameConfig, gameNo,
|
||||
showSpectatorBoard, logGameRecords));
|
||||
}
|
||||
|
||||
List<GameResult> round2results = new ArrayList<GameResult>();
|
||||
|
||||
referee.setPolicy(Player.BLACK,
|
||||
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged));
|
||||
referee.setPolicy(Player.WHITE,
|
||||
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged));
|
||||
referee.setPolicy(
|
||||
Player.BLACK,
|
||||
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength,
|
||||
blackMoveLogged));
|
||||
referee.setPolicy(
|
||||
Player.WHITE,
|
||||
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength,
|
||||
whiteMoveLogged));
|
||||
for (int round = 0; round < rounds; round++) {
|
||||
gameNo++;
|
||||
round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords));
|
||||
round2results.add(referee.play(gameConfig, gameNo,
|
||||
showSpectatorBoard, logGameRecords));
|
||||
}
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
@@ -111,13 +131,14 @@ public class StandAloneGame {
|
||||
|
||||
try {
|
||||
if (!logGameRecords) {
|
||||
System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
|
||||
System.out
|
||||
.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
|
||||
}
|
||||
|
||||
|
||||
logResults(writer, round1results, playerType1.toString(),
|
||||
playerType2.toString());
|
||||
playerType2.toString());
|
||||
logResults(writer, round2results, playerType2.toString(),
|
||||
playerType1.toString());
|
||||
playerType1.toString());
|
||||
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
|
||||
+ " seconds.");
|
||||
System.out.println("Game tournament saved as "
|
||||
@@ -155,25 +176,41 @@ public class StandAloneGame {
|
||||
|
||||
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
|
||||
Player player, long turnLength, boolean moveLogged) {
|
||||
|
||||
Policy policy;
|
||||
|
||||
switch (playerType) {
|
||||
case HUMAN:
|
||||
return new HumanKeyboardInput();
|
||||
policy = new HumanKeyboardInput();
|
||||
break;
|
||||
case HUMAN_GUI:
|
||||
return new HumanGuiInput(new Goban(gameConfig, player,""));
|
||||
policy = new HumanGuiInput(new Goban(gameConfig, player, ""));
|
||||
break;
|
||||
case ROOT_PAR:
|
||||
return new RootParallelization(4, turnLength);
|
||||
policy = new RootParallelization(4, turnLength);
|
||||
break;
|
||||
case ROOT_PAR_AMAF:
|
||||
policy = new RootParAMAF(4, turnLength);
|
||||
break;
|
||||
case UCT:
|
||||
return new MonteCarloUCT(new RandomMovePolicy(), turnLength);
|
||||
policy = new MonteCarloUCT(new RandomMovePolicy(), turnLength);
|
||||
break;
|
||||
case SMAF:
|
||||
policy = new MonteCarloSMAF(new RandomMovePolicy(), turnLength, 4);
|
||||
break;
|
||||
case RANDOM:
|
||||
RandomMovePolicy randomMovePolicy = new RandomMovePolicy();
|
||||
randomMovePolicy.setLogging(moveLogged);
|
||||
return randomMovePolicy;
|
||||
policy = new RandomMovePolicy();
|
||||
break;
|
||||
case RAVE:
|
||||
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
|
||||
policy = new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
|
||||
+ playerType);
|
||||
}
|
||||
|
||||
policy.setLogging(moveLogged);
|
||||
return policy;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,54 +1,110 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
import org.encog.neural.networks.PersistBasicNetwork;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
|
||||
protected BasicNetwork neuralNetwork;
|
||||
protected int actualTrainingEpochs = 0;
|
||||
protected int maxTrainingEpochs = 1000;
|
||||
private final FeedforwardNetwork neuralNetwork;
|
||||
private final TrainingMethod trainingMethod;
|
||||
|
||||
private double maxError;
|
||||
private int actualTrainingEpochs = 0;
|
||||
private int maxTrainingEpochs;
|
||||
|
||||
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
|
||||
this.neuralNetwork = neuralNetwork;
|
||||
this.trainingMethod = trainingMethod;
|
||||
this.maxError = maxError;
|
||||
this.maxTrainingEpochs = maxTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NNData compute(NNDataPair input) {
|
||||
return this.neuralNetwork.compute(input);
|
||||
}
|
||||
|
||||
public int getActualTrainingEpochs() {
|
||||
return actualTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
public int getMaxTrainingEpochs() {
|
||||
return maxTrainingEpochs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BasicNetwork getNeuralNetwork() {
|
||||
protected FeedforwardNetwork getNeuralNetwork() {
|
||||
return neuralNetwork;
|
||||
}
|
||||
|
||||
public void load(String filename) throws IOException {
|
||||
FileInputStream fis = new FileInputStream(new File(filename));
|
||||
neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis);
|
||||
fis.close();
|
||||
@Override
|
||||
public void learnPatterns(List<NNDataPair> trainingSet) {
|
||||
actualTrainingEpochs = 0;
|
||||
double error;
|
||||
neuralNetwork.initWeights();
|
||||
|
||||
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
|
||||
|
||||
if (error <= maxError) {
|
||||
System.out.println("Initial error: " + error);
|
||||
return;
|
||||
}
|
||||
|
||||
do {
|
||||
trainingMethod.iteratePatterns(neuralNetwork,trainingSet);
|
||||
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ error);
|
||||
actualTrainingEpochs++;
|
||||
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
neuralNetwork.reset();
|
||||
public void learnSequences(List<List<NNDataPair>> trainingSet) {
|
||||
actualTrainingEpochs = 0;
|
||||
double error;
|
||||
neuralNetwork.initWeights();
|
||||
|
||||
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||
|
||||
if (error <= maxError) {
|
||||
System.out.println("Initial error: " + error);
|
||||
return;
|
||||
}
|
||||
|
||||
do {
|
||||
trainingMethod.iterateSequences(neuralNetwork,trainingSet);
|
||||
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||
if (Double.isNaN(error)) {
|
||||
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
|
||||
}
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ error);
|
||||
actualTrainingEpochs++;
|
||||
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
|
||||
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset(int seed) {
|
||||
neuralNetwork.reset(seed);
|
||||
public boolean load(InputStream input) {
|
||||
return neuralNetwork.load(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean save(OutputStream output) {
|
||||
return neuralNetwork.save(output);
|
||||
}
|
||||
|
||||
public void save(String filename) throws IOException {
|
||||
FileOutputStream fos = new FileOutputStream(new File(filename));
|
||||
new PersistBasicNetwork().save(fos, getNeuralNetwork());
|
||||
fos.close();
|
||||
public void setMaxError(double maxError) {
|
||||
this.maxError = maxError;
|
||||
}
|
||||
|
||||
public void setMaxTrainingEpochs(int max) {
|
||||
this.maxTrainingEpochs = max;
|
||||
}
|
||||
}
|
||||
}
|
||||
132
src/net/woodyfolsom/msproj/ann/BackPropagation.java
Normal file
132
src/net/woodyfolsom/msproj/ann/BackPropagation.java
Normal file
@@ -0,0 +1,132 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.MSSE;
|
||||
|
||||
public class BackPropagation extends TrainingMethod {
|
||||
private final ErrorFunction errorFunction;
|
||||
private final double learningRate;
|
||||
private final double momentum;
|
||||
|
||||
public BackPropagation(double learningRate, double momentum) {
|
||||
this.errorFunction = MSSE.function;
|
||||
this.learningRate = learningRate;
|
||||
this.momentum = momentum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
System.out.println("Learningrate: " + learningRate);
|
||||
System.out.println("Momentum: " + momentum);
|
||||
|
||||
for (NNDataPair trainingPair : trainingSet) {
|
||||
zeroGradients(neuralNetwork);
|
||||
|
||||
System.out.println("Training with: " + trainingPair.getInput());
|
||||
|
||||
NNData ideal = trainingPair.getIdeal();
|
||||
NNData actual = neuralNetwork.compute(trainingPair);
|
||||
|
||||
System.out.println("Updating weights. Ideal Output: " + ideal);
|
||||
System.out.println("Actual Output: " + actual);
|
||||
|
||||
//backpropagate the gradients w.r.t. output error
|
||||
backPropagate(neuralNetwork, ideal);
|
||||
|
||||
updateWeights(neuralNetwork);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
int numDataPairs = trainingSet.size();
|
||||
int outputSize = neuralNetwork.getOutput().length;
|
||||
int totalOutputSize = outputSize * numDataPairs;
|
||||
|
||||
double[] actuals = new double[totalOutputSize];
|
||||
double[] ideals = new double[totalOutputSize];
|
||||
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
|
||||
NNDataPair nnDataPair = trainingSet.get(dataPair);
|
||||
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
|
||||
.getValues());
|
||||
double[] ideal = nnDataPair.getIdeal().getValues();
|
||||
int offset = dataPair * outputSize;
|
||||
|
||||
System.arraycopy(actual, 0, actuals, offset, outputSize);
|
||||
System.arraycopy(ideal, 0, ideals, offset, outputSize);
|
||||
}
|
||||
|
||||
double MSSE = errorFunction.compute(ideals, actuals);
|
||||
return MSSE;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected
|
||||
void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||
double[] idealValues = ideal.getValues();
|
||||
|
||||
for (int i = 0; i < idealValues.length; i++) {
|
||||
double input = outputNeurons[i].getInput();
|
||||
double derivative = outputNeurons[i].getActivationFunction()
|
||||
.derivative(input);
|
||||
outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput()));
|
||||
}
|
||||
// walking down the list of Neurons in reverse order, propagate the
|
||||
// error
|
||||
Neuron[] neurons = neuralNetwork.getNeurons();
|
||||
|
||||
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||
|
||||
Neuron neuron = neurons[n];
|
||||
double error = neuron.getGradient();
|
||||
|
||||
Connection[] connectionsFromN = neuralNetwork
|
||||
.getConnectionsFrom(neuron.getId());
|
||||
if (connectionsFromN.length > 0) {
|
||||
|
||||
double derivative = neuron.getActivationFunction().derivative(
|
||||
neuron.getInput());
|
||||
for (Connection connection : connectionsFromN) {
|
||||
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient();
|
||||
}
|
||||
}
|
||||
neuron.setGradient(error);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateWeights(FeedforwardNetwork neuralNetwork) {
|
||||
for (Connection connection : neuralNetwork.getConnections()) {
|
||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient();
|
||||
//TODO allow for momentum
|
||||
//double lastDelta = connection.getLastDelta();
|
||||
connection.addDelta(delta);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||
NNDataPair statePair, NNData nextReward) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
100
src/net/woodyfolsom/msproj/ann/Connection.java
Normal file
100
src/net/woodyfolsom/msproj/ann/Connection.java
Normal file
@@ -0,0 +1,100 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
|
||||
public class Connection {
|
||||
private int src;
|
||||
private int dest;
|
||||
private double weight;
|
||||
//private transient double lastDelta = 0.0;
|
||||
private transient double trace = 0.0;
|
||||
|
||||
public Connection() {
|
||||
//no-arg constructor for JAXB
|
||||
}
|
||||
|
||||
public Connection(int src, int dest, double weight) {
|
||||
this.src = src;
|
||||
this.dest = dest;
|
||||
this.weight = weight;
|
||||
}
|
||||
|
||||
public void addDelta(double delta) {
|
||||
this.trace = delta;
|
||||
this.weight += delta;
|
||||
//this.lastDelta = delta;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public int getDest() {
|
||||
return dest;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public int getSrc() {
|
||||
return src;
|
||||
}
|
||||
|
||||
public double getTrace() {
|
||||
return trace;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public double getWeight() {
|
||||
return weight;
|
||||
}
|
||||
|
||||
public void setDest(int dest) {
|
||||
this.dest = dest;
|
||||
}
|
||||
|
||||
public void setSrc(int src) {
|
||||
this.src = src;
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public void setTrace(double trace) {
|
||||
this.trace = trace;
|
||||
}
|
||||
|
||||
public void setWeight(double weight) {
|
||||
this.weight = weight;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + dest;
|
||||
result = prime * result + src;
|
||||
long temp;
|
||||
temp = Double.doubleToLongBits(weight);
|
||||
result = prime * result + (int) (temp ^ (temp >>> 32));
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj)
|
||||
return true;
|
||||
if (obj == null)
|
||||
return false;
|
||||
if (getClass() != obj.getClass())
|
||||
return false;
|
||||
Connection other = (Connection) obj;
|
||||
if (dest != other.dest)
|
||||
return false;
|
||||
if (src != other.src)
|
||||
return false;
|
||||
if (Double.doubleToLongBits(weight) != Double
|
||||
.doubleToLongBits(other.weight))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Connection(src: " + src + ",dest: " + dest + ", trace:" + trace +"), weight: " + weight;
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import org.encog.ml.data.basic.BasicMLData;
|
||||
|
||||
public class DoublePair extends BasicMLData {
|
||||
// private final double x;
|
||||
// private final double y;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public DoublePair(double x, double y) {
|
||||
super(new double[] { x, y });
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import org.encog.mathutil.error.ErrorCalculationMode;
|
||||
|
||||
/*
|
||||
Initial erison of this class was a verbatim copy from Encog framework.
|
||||
*/
|
||||
|
||||
public class ErrorCalculation {
|
||||
|
||||
private static ErrorCalculationMode mode = ErrorCalculationMode.MSE;
|
||||
|
||||
public static ErrorCalculationMode getMode() {
|
||||
return ErrorCalculation.mode;
|
||||
}
|
||||
|
||||
public static void setMode(final ErrorCalculationMode theMode) {
|
||||
ErrorCalculation.mode = theMode;
|
||||
}
|
||||
|
||||
private double globalError;
|
||||
|
||||
private int setSize;
|
||||
|
||||
public final double calculate() {
|
||||
if (this.setSize == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
switch (ErrorCalculation.getMode()) {
|
||||
case RMS:
|
||||
return calculateRMS();
|
||||
case MSE:
|
||||
return calculateMSE();
|
||||
case ESS:
|
||||
return calculateESS();
|
||||
default:
|
||||
return calculateMSE();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public final double calculateMSE() {
|
||||
if (this.setSize == 0) {
|
||||
return 0;
|
||||
}
|
||||
final double err = this.globalError / this.setSize;
|
||||
return err;
|
||||
|
||||
}
|
||||
|
||||
public final double calculateESS() {
|
||||
if (this.setSize == 0) {
|
||||
return 0;
|
||||
}
|
||||
final double err = this.globalError / 2;
|
||||
return err;
|
||||
|
||||
}
|
||||
|
||||
public final double calculateRMS() {
|
||||
if (this.setSize == 0) {
|
||||
return 0;
|
||||
}
|
||||
final double err = Math.sqrt(this.globalError / this.setSize);
|
||||
return err;
|
||||
}
|
||||
|
||||
public final void reset() {
|
||||
this.globalError = 0;
|
||||
this.setSize = 0;
|
||||
}
|
||||
|
||||
public final void updateError(final double actual, final double ideal) {
|
||||
|
||||
double delta = ideal - actual;
|
||||
|
||||
this.globalError += delta * delta;
|
||||
|
||||
this.setSize++;
|
||||
|
||||
}
|
||||
|
||||
public final void updateError(final double[] actual, final double[] ideal,
|
||||
final double significance) {
|
||||
for (int i = 0; i < actual.length; i++) {
|
||||
double delta = (ideal[i] - actual[i]) * significance;
|
||||
|
||||
this.globalError += delta * delta;
|
||||
}
|
||||
|
||||
this.setSize += ideal.length;
|
||||
}
|
||||
|
||||
}
|
||||
313
src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java
Normal file
313
src/net/woodyfolsom/msproj/ann/FeedforwardNetwork.java
Normal file
@@ -0,0 +1,313 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Linear;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
|
||||
public abstract class FeedforwardNetwork {
|
||||
private ActivationFunction activationFunction;
|
||||
private boolean biased;
|
||||
private List<Connection> connections;
|
||||
private List<Neuron> neurons;
|
||||
private String name;
|
||||
|
||||
private transient int biasNeuronId;
|
||||
private transient Map<Integer, List<Connection>> connectionsFrom;
|
||||
private transient Map<Integer, List<Connection>> connectionsTo;
|
||||
|
||||
public FeedforwardNetwork() {
|
||||
this(false);
|
||||
}
|
||||
|
||||
public FeedforwardNetwork(boolean biased) {
|
||||
//No-arg constructor for JAXB
|
||||
this.activationFunction = Sigmoid.function;
|
||||
this.connections = new ArrayList<Connection>();
|
||||
this.connectionsFrom = new HashMap<Integer,List<Connection>>();
|
||||
this.connectionsTo = new HashMap<Integer,List<Connection>>();
|
||||
this.neurons = new ArrayList<Neuron>();
|
||||
this.name = "UNDEFINED";
|
||||
this.biasNeuronId = -1;
|
||||
setBiased(biased);
|
||||
}
|
||||
|
||||
public void addConnection(Connection connection) {
|
||||
connections.add(connection);
|
||||
|
||||
int src = connection.getSrc();
|
||||
int dest = connection.getDest();
|
||||
|
||||
if (!connectionsFrom.containsKey(src)) {
|
||||
connectionsFrom.put(src, new ArrayList<Connection>());
|
||||
}
|
||||
|
||||
if (!connectionsTo.containsKey(dest)) {
|
||||
connectionsTo.put(dest, new ArrayList<Connection>());
|
||||
}
|
||||
|
||||
connectionsFrom.get(src).add(connection);
|
||||
connectionsTo.get(dest).add(connection);
|
||||
}
|
||||
|
||||
public NNData compute(NNDataPair nnDataPair) {
|
||||
NNData actual = new NNData(nnDataPair.getIdeal().getFields(),
|
||||
compute(nnDataPair.getInput().getValues()));
|
||||
return actual;
|
||||
}
|
||||
|
||||
public double[] compute(double[] input) {
|
||||
zeroInputs();
|
||||
setInput(input);
|
||||
feedforward();
|
||||
return getOutput();
|
||||
}
|
||||
|
||||
void createBiasConnection(int neuronId, double weight) {
|
||||
if (!biased) {
|
||||
throw new UnsupportedOperationException("Not a biased network");
|
||||
}
|
||||
addConnection(new Connection(biasNeuronId, neuronId, weight));
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a new neuron with a unique id to this FeedforwardNetwork.
|
||||
* @return
|
||||
*/
|
||||
Neuron createNeuron(boolean input, ActivationFunction afunc) {
|
||||
Neuron neuron;
|
||||
if (input) {
|
||||
neuron = new Neuron(Linear.function, neurons.size());
|
||||
} else {
|
||||
neuron = new Neuron(afunc, neurons.size());
|
||||
}
|
||||
neurons.add(neuron);
|
||||
return neuron;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj)
|
||||
return true;
|
||||
if (obj == null)
|
||||
return false;
|
||||
if (getClass() != obj.getClass())
|
||||
return false;
|
||||
FeedforwardNetwork other = (FeedforwardNetwork) obj;
|
||||
if (activationFunction == null) {
|
||||
if (other.activationFunction != null)
|
||||
return false;
|
||||
} else if (!activationFunction.equals(other.activationFunction))
|
||||
return false;
|
||||
if (connections == null) {
|
||||
if (other.connections != null)
|
||||
return false;
|
||||
} else if (!connections.equals(other.connections))
|
||||
return false;
|
||||
if (name == null) {
|
||||
if (other.name != null)
|
||||
return false;
|
||||
} else if (!name.equals(other.name))
|
||||
return false;
|
||||
if (neurons == null) {
|
||||
if (other.neurons != null)
|
||||
return false;
|
||||
} else if (!neurons.equals(other.neurons))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
protected void feedforward() {
|
||||
for (int i = 0; i < neurons.size(); i++) {
|
||||
Neuron src = neurons.get(i);
|
||||
for (Connection connection : getConnectionsFrom(src.getId())) {
|
||||
Neuron dest = getNeuron(connection.getDest());
|
||||
dest.addInput(src.getOutput() * connection.getWeight());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@XmlElement(type = Sigmoid.class)
|
||||
public ActivationFunction getActivationFunction() {
|
||||
return activationFunction;
|
||||
}
|
||||
|
||||
protected abstract double[] getOutput();
|
||||
protected abstract Neuron[] getOutputNeurons();
|
||||
|
||||
@XmlAttribute
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
protected Neuron getNeuron(int id) {
|
||||
return neurons.get(id);
|
||||
}
|
||||
|
||||
public Connection getConnection(int index) {
|
||||
return connections.get(index);
|
||||
}
|
||||
|
||||
@XmlElement
|
||||
protected Connection[] getConnections() {
|
||||
return connections.toArray(new Connection[connections.size()]);
|
||||
}
|
||||
|
||||
protected Connection[] getConnectionsFrom(int neuronId) {
|
||||
List<Connection> connList = connectionsFrom.get(neuronId);
|
||||
|
||||
if (connList == null) {
|
||||
return new Connection[0];
|
||||
} else {
|
||||
return connList.toArray(new Connection[connList.size()]);
|
||||
}
|
||||
}
|
||||
|
||||
protected Connection[] getConnectionsTo(int neuronId) {
|
||||
List<Connection> connList = connectionsTo.get(neuronId);
|
||||
|
||||
if (connList == null) {
|
||||
return new Connection[0];
|
||||
} else {
|
||||
return connList.toArray(new Connection[connList.size()]);
|
||||
}
|
||||
}
|
||||
|
||||
public double[] getGradients() {
|
||||
double[] gradients = new double[neurons.size()];
|
||||
for (int n = 0; n < gradients.length; n++) {
|
||||
gradients[n] = neurons.get(n).getGradient();
|
||||
}
|
||||
return gradients;
|
||||
}
|
||||
|
||||
public double[] getWeights() {
|
||||
double[] weights = new double[connections.size()];
|
||||
for (int i = 0; i < connections.size(); i++) {
|
||||
weights[i] = connections.get(i).getWeight();
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public boolean isBiased() {
|
||||
return biased;
|
||||
}
|
||||
|
||||
@XmlElement
|
||||
protected Neuron[] getNeurons() {
|
||||
return neurons.toArray(new Neuron[neurons.size()]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime
|
||||
* result
|
||||
+ ((activationFunction == null) ? 0 : activationFunction
|
||||
.hashCode());
|
||||
result = prime * result
|
||||
+ ((connections == null) ? 0 : connections.hashCode());
|
||||
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
||||
result = prime * result + ((neurons == null) ? 0 : neurons.hashCode());
|
||||
return result;
|
||||
}
|
||||
|
||||
public void initWeights() {
|
||||
for (Connection connection : connections) {
|
||||
connection.setWeight(1.0-Math.random()*2.0);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract boolean load(InputStream is);
|
||||
|
||||
public abstract boolean save(OutputStream os);
|
||||
|
||||
public void setActivationFunction(ActivationFunction activationFunction) {
|
||||
this.activationFunction = activationFunction;
|
||||
}
|
||||
|
||||
public void setBiased(boolean biased) {
|
||||
|
||||
if (this.biased == biased) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.biased = biased;
|
||||
|
||||
if (biased) {
|
||||
Neuron biasNeuron = createNeuron(true, activationFunction);
|
||||
biasNeuron.setInput(1.0);
|
||||
biasNeuronId = biasNeuron.getId();
|
||||
} else {
|
||||
//This is an inefficient but concise way to remove all connections involving the bias Neuron
|
||||
//from the global
|
||||
|
||||
//Remove all connections from biasId from this index
|
||||
List<Connection> connectionsFromBias = connectionsFrom.remove(biasNeuronId);
|
||||
|
||||
//Remove all connections to all nodes from biasId from this index
|
||||
for (Map.Entry<Integer,List<Connection>> mapEntry : connectionsTo.entrySet()) {
|
||||
mapEntry.getValue().removeAll(connectionsFromBias);
|
||||
}
|
||||
|
||||
//Finally, remove from the (serialized) list of non-indexed connections
|
||||
connections.remove(connectionsFromBias);
|
||||
|
||||
biasNeuronId = -1;
|
||||
}
|
||||
}
|
||||
|
||||
protected void setConnections(Connection[] connections) {
|
||||
this.connections.clear();
|
||||
this.connectionsFrom.clear();
|
||||
this.connectionsTo.clear();
|
||||
for (Connection connection : connections) {
|
||||
addConnection(connection);
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract void setInput(double[] input);
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
protected void setNeurons(Neuron[] neurons) {
|
||||
this.neurons.clear();
|
||||
for (Neuron neuron : neurons) {
|
||||
this.neurons.add(neuron);
|
||||
}
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public void setWeights(double[] weights) {
|
||||
if (weights.length != connections.size()) {
|
||||
throw new IllegalArgumentException("# of weights must == # of connections");
|
||||
}
|
||||
|
||||
for (int i = 0; i < connections.size(); i++) {
|
||||
connections.get(i).setWeight(weights[i]);
|
||||
}
|
||||
}
|
||||
|
||||
protected void zeroInputs() {
|
||||
for (Neuron neuron : neurons) {
|
||||
if (neuron.getId() != biasNeuronId){
|
||||
neuron.setInput(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
292
src/net/woodyfolsom/msproj/ann/FusekiFilterTrainer.java
Normal file
292
src/net/woodyfolsom/msproj/ann/FusekiFilterTrainer.java
Normal file
@@ -0,0 +1,292 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.Action;
|
||||
import net.woodyfolsom.msproj.GameConfig;
|
||||
import net.woodyfolsom.msproj.GameRecord;
|
||||
import net.woodyfolsom.msproj.GameResult;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.policy.NeuralNetPolicy;
|
||||
import net.woodyfolsom.msproj.policy.Policy;
|
||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||
|
||||
public class FusekiFilterTrainer { // implements epsilon-greedy trainer? online
|
||||
// version of NeuralNetFilter
|
||||
|
||||
private boolean training = true;
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
double alpha = 0.50;
|
||||
double lambda = 0.90;
|
||||
int maxGames = 1000;
|
||||
|
||||
new FusekiFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||
}
|
||||
|
||||
public void trainNetwork(double alpha, double lambda, int maxGames)
|
||||
throws IOException {
|
||||
|
||||
FeedforwardNetwork neuralNetwork;
|
||||
|
||||
GameConfig gameConfig = new GameConfig(9);
|
||||
|
||||
if (training) {
|
||||
neuralNetwork = new MultiLayerPerceptron(true, 81, 18, 1);
|
||||
neuralNetwork.setName("FusekiFilter" + gameConfig.getSize());
|
||||
neuralNetwork.initWeights();
|
||||
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
|
||||
|
||||
System.out.println("Playing untrained games.");
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
System.out.println("" + (i + 1) + ". "
|
||||
+ playOptimal(neuralNetwork, gameRecord).getResult());
|
||||
}
|
||||
|
||||
System.out.println("Learning from " + maxGames
|
||||
+ " games of random self-play");
|
||||
|
||||
int gamesPlayed = 0;
|
||||
List<GameResult> results = new ArrayList<GameResult>();
|
||||
do {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
playEpsilonGreedy(0.50, neuralNetwork, trainer, gameRecord);
|
||||
System.out.println("Winner: " + gameRecord.getResult());
|
||||
gamesPlayed++;
|
||||
results.add(gameRecord.getResult());
|
||||
} while (gamesPlayed < maxGames);
|
||||
|
||||
System.out.println("Results of every 10th training game:");
|
||||
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
if (i % 10 == 0) {
|
||||
System.out.println("" + (i + 1) + ". " + results.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Learned network after " + maxGames
|
||||
+ " training games.");
|
||||
} else {
|
||||
System.out.println("Loading TicTacToe network from file.");
|
||||
neuralNetwork = new MultiLayerPerceptron();
|
||||
FileInputStream fis = new FileInputStream(new File("pass.net"));
|
||||
if (!new MultiLayerPerceptron().load(fis)) {
|
||||
System.out.println("Error loading pass.net from file.");
|
||||
return;
|
||||
}
|
||||
fis.close();
|
||||
}
|
||||
|
||||
evalTestCases(gameConfig, neuralNetwork);
|
||||
|
||||
System.out.println("Playing optimal games.");
|
||||
List<GameResult> gameResults = new ArrayList<GameResult>();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
gameResults.add(playOptimal(neuralNetwork, gameRecord).getResult());
|
||||
}
|
||||
|
||||
boolean suboptimalPlay = false;
|
||||
System.out.println("Optimal game summary: ");
|
||||
for (int i = 0; i < gameResults.size(); i++) {
|
||||
GameResult result = gameResults.get(i);
|
||||
System.out.println("" + (i + 1) + ". " + result);
|
||||
}
|
||||
|
||||
File output = new File("pass.net");
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(output);
|
||||
|
||||
neuralNetwork.save(fos);
|
||||
|
||||
System.out.println("Playing optimal vs random games.");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
System.out.println(""
|
||||
+ (i + 1)
|
||||
+ ". "
|
||||
+ playOptimalVsRandom(neuralNetwork, gameRecord)
|
||||
.getResult());
|
||||
}
|
||||
|
||||
if (suboptimalPlay) {
|
||||
System.out.println("Suboptimal play detected!");
|
||||
}
|
||||
}
|
||||
|
||||
private double[] createBoard(GameConfig gameConfig, Action... actions) {
|
||||
GameRecord gameRec = new GameRecord(gameConfig);
|
||||
for (Action action : actions) {
|
||||
gameRec.play(gameRec.getPlayerToMove(), action);
|
||||
}
|
||||
return NNDataSetFactory.createDataPair(gameRec.getGameState(), FusekiFilterTrainer.class).getInput().getValues();
|
||||
}
|
||||
|
||||
private void evalTestCases(GameConfig gameConfig, FeedforwardNetwork neuralNetwork) {
|
||||
double[][] validationSet = new double[1][];
|
||||
|
||||
// start state: black has 0, white has 0 + komi, neither has passed
|
||||
validationSet[0] = createBoard(gameConfig, Action.getInstance("C3"));
|
||||
|
||||
String[] inputNames = NNDataSetFactory.getInputFields(FusekiFilterTrainer.class);
|
||||
String[] outputNames = NNDataSetFactory.getOutputFields(FusekiFilterTrainer.class);
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork,
|
||||
GameRecord gameRecord) {
|
||||
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
|
||||
neuralNetPolicy.setMoveFilter(neuralNetwork);
|
||||
|
||||
Policy randomPolicy = new RandomMovePolicy();
|
||||
|
||||
GameConfig gameConfig = gameRecord.getGameConfig();
|
||||
GameState gameState = gameRecord.getGameState();
|
||||
|
||||
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
|
||||
int turnNo = 0;
|
||||
do {
|
||||
Action action;
|
||||
GameState nextState;
|
||||
|
||||
Player playerToMove = gameState.getPlayerToMove();
|
||||
action = policies[turnNo % 2].getAction(gameConfig, gameState,
|
||||
playerToMove);
|
||||
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Illegal move: " + action);
|
||||
}
|
||||
|
||||
nextState = gameRecord.getGameState();
|
||||
|
||||
//System.out.println("Action " + action + " selected by policy "
|
||||
// + policies[turnNo % 2].getName());
|
||||
//System.out.println("Next board state: " + nextState);
|
||||
gameState = nextState;
|
||||
turnNo++;
|
||||
} while (!gameState.isTerminal());
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork,
|
||||
GameRecord gameRecord) {
|
||||
|
||||
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
|
||||
neuralNetPolicy.setMoveFilter(neuralNetwork);
|
||||
|
||||
if (gameRecord.getNumTurns() > 0) {
|
||||
throw new RuntimeException(
|
||||
"PlayOptimal requires a new GameRecord with no turns played.");
|
||||
}
|
||||
|
||||
GameState gameState;
|
||||
|
||||
do {
|
||||
Action action;
|
||||
GameState nextState;
|
||||
|
||||
Player playerToMove = gameRecord.getPlayerToMove();
|
||||
action = neuralNetPolicy.getAction(gameRecord.getGameConfig(),
|
||||
gameRecord.getGameState(), playerToMove);
|
||||
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Invalid move played: " + action);
|
||||
}
|
||||
nextState = gameRecord.getGameState();
|
||||
|
||||
//System.out.println("Action " + action + " selected by policy "
|
||||
// + neuralNetPolicy.getName());
|
||||
//System.out.println("Next board state: " + nextState);
|
||||
gameState = nextState;
|
||||
} while (!gameState.isTerminal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playEpsilonGreedy(double epsilon,
|
||||
FeedforwardNetwork neuralNetwork, TrainingMethod trainer,
|
||||
GameRecord gameRecord) {
|
||||
Policy randomPolicy = new RandomMovePolicy();
|
||||
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
|
||||
neuralNetPolicy.setMoveFilter(neuralNetwork);
|
||||
|
||||
if (gameRecord.getNumTurns() > 0) {
|
||||
throw new RuntimeException(
|
||||
"PlayOptimal requires a new GameRecord with no turns played.");
|
||||
}
|
||||
|
||||
GameState gameState = gameRecord.getGameState();
|
||||
NNDataPair statePair;
|
||||
|
||||
Policy selectedPolicy;
|
||||
trainer.zeroTraces(neuralNetwork);
|
||||
|
||||
do {
|
||||
Action action;
|
||||
GameState nextState;
|
||||
|
||||
Player playerToMove = gameRecord.getPlayerToMove();
|
||||
|
||||
if (Math.random() < epsilon) {
|
||||
selectedPolicy = randomPolicy;
|
||||
action = selectedPolicy
|
||||
.getAction(gameRecord.getGameConfig(),
|
||||
gameRecord.getGameState(),
|
||||
gameRecord.getPlayerToMove());
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Illegal move played: " + action);
|
||||
}
|
||||
nextState = gameRecord.getGameState();
|
||||
} else {
|
||||
selectedPolicy = neuralNetPolicy;
|
||||
action = selectedPolicy
|
||||
.getAction(gameRecord.getGameConfig(),
|
||||
gameRecord.getGameState(),
|
||||
gameRecord.getPlayerToMove());
|
||||
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Illegal move played: " + action);
|
||||
}
|
||||
nextState = gameRecord.getGameState();
|
||||
|
||||
statePair = NNDataSetFactory.createDataPair(gameState, FusekiFilterTrainer.class);
|
||||
NNDataPair nextStatePair = NNDataSetFactory
|
||||
.createDataPair(nextState, FusekiFilterTrainer.class);
|
||||
|
||||
trainer.iteratePattern(neuralNetwork, statePair,
|
||||
nextStatePair.getIdeal());
|
||||
}
|
||||
|
||||
gameState = nextState;
|
||||
} while (!gameState.isTerminal());
|
||||
|
||||
// finally, reinforce the actual reward
|
||||
statePair = NNDataSetFactory.createDataPair(gameState, FusekiFilterTrainer.class);
|
||||
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private void testNetwork(FeedforwardNetwork neuralNetwork,
|
||||
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||
validationSet[valIndex]), new NNData(outputNames,
|
||||
new double[] { 0.0 }));
|
||||
System.out.println(dp);
|
||||
System.out.println(" => ");
|
||||
System.out.println(neuralNetwork.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
|
||||
import org.encog.ml.data.basic.BasicMLData;
|
||||
|
||||
public class GameStateMLData extends BasicMLData {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private GameState gameState;
|
||||
|
||||
public GameStateMLData(double[] d, GameState gameState) {
|
||||
super(d);
|
||||
// TODO Auto-generated constructor stub
|
||||
this.gameState = gameState;
|
||||
}
|
||||
|
||||
public GameState getGameState() {
|
||||
return gameState;
|
||||
}
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import net.woodyfolsom.msproj.GameResult;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.basic.BasicMLData;
|
||||
import org.encog.ml.data.basic.BasicMLDataPair;
|
||||
import org.encog.util.kmeans.Centroid;
|
||||
|
||||
public class GameStateMLDataPair implements MLDataPair {
|
||||
//private final String[] inputs = { "BlackScore", "WhiteScore" };
|
||||
//private final String[] outputs = { "BlackWins", "WhiteWins" };
|
||||
|
||||
private BasicMLDataPair mlDataPairDelegate;
|
||||
private GameState gameState;
|
||||
|
||||
public GameStateMLDataPair(GameState gameState) {
|
||||
this.gameState = gameState;
|
||||
mlDataPairDelegate = new BasicMLDataPair(
|
||||
new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal()));
|
||||
}
|
||||
|
||||
public GameStateMLDataPair(GameStateMLDataPair that) {
|
||||
this.gameState = new GameState(that.gameState);
|
||||
mlDataPairDelegate = new BasicMLDataPair(
|
||||
that.mlDataPairDelegate.getInput(),
|
||||
that.mlDataPairDelegate.getIdeal());
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLDataPair clone() {
|
||||
return new GameStateMLDataPair(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Centroid<MLDataPair> createCentroid() {
|
||||
return mlDataPairDelegate.createCentroid();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a vector of normalized scores from GameState.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private double[] createInput() {
|
||||
|
||||
GameResult result = gameState.getResult();
|
||||
|
||||
double maxScore = gameState.getGameConfig().getSize()
|
||||
* gameState.getGameConfig().getSize();
|
||||
|
||||
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
|
||||
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
|
||||
|
||||
return new double[] { blackScore, whiteScore };
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a vector of values indicating strength of black/white win output
|
||||
* from network.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private double[] createIdeal() {
|
||||
GameResult result = gameState.getResult();
|
||||
|
||||
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
|
||||
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
|
||||
|
||||
return new double[] { blackWinner, whiteWinner };
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLData getIdeal() {
|
||||
return mlDataPairDelegate.getIdeal();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] getIdealArray() {
|
||||
return mlDataPairDelegate.getIdealArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLData getInput() {
|
||||
return mlDataPairDelegate.getInput();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] getInputArray() {
|
||||
return mlDataPairDelegate.getInputArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getSignificance() {
|
||||
return mlDataPairDelegate.getSignificance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupervised() {
|
||||
return mlDataPairDelegate.isSupervised();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIdealArray(double[] arg0) {
|
||||
mlDataPairDelegate.setIdealArray(arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInputArray(double[] arg0) {
|
||||
mlDataPairDelegate.setInputArray(arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSignificance(double arg0) {
|
||||
mlDataPairDelegate.setSignificance(arg0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
/*
|
||||
* Class copied verbatim from Encog framework due to dependency on Propagation
|
||||
* implementation.
|
||||
*
|
||||
* Encog(tm) Core v3.2 - Java Version
|
||||
* http://www.heatonresearch.com/encog/
|
||||
* http://code.google.com/p/encog-java/
|
||||
|
||||
* Copyright 2008-2012 Heaton Research, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* For more information on Heaton Research copyrights, licenses
|
||||
* and trademarks visit:
|
||||
* http://www.heatonresearch.com/copyright
|
||||
*/
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.engine.network.activation.ActivationFunction;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLDataPair;
|
||||
import org.encog.neural.error.ErrorFunction;
|
||||
import org.encog.neural.flat.FlatNetwork;
|
||||
import org.encog.util.EngineArray;
|
||||
import org.encog.util.concurrency.EngineTask;
|
||||
|
||||
public class GradientWorker implements EngineTask {
|
||||
|
||||
private final FlatNetwork network;
|
||||
private final ErrorCalculation errorCalculation = new ErrorCalculation();
|
||||
private final double[] actual;
|
||||
private final double[] layerDelta;
|
||||
private final int[] layerCounts;
|
||||
private final int[] layerFeedCounts;
|
||||
private final int[] layerIndex;
|
||||
private final int[] weightIndex;
|
||||
private final double[] layerOutput;
|
||||
private final double[] layerSums;
|
||||
private final double[] gradients;
|
||||
private final double[] weights;
|
||||
private final MLDataPair pair;
|
||||
private final Set<List<MLDataPair>> training;
|
||||
private final int low;
|
||||
private final int high;
|
||||
private final TemporalDifferenceLearning owner;
|
||||
private double[] flatSpot;
|
||||
private final ErrorFunction errorFunction;
|
||||
|
||||
public GradientWorker(final FlatNetwork theNetwork,
|
||||
final TemporalDifferenceLearning theOwner,
|
||||
final Set<List<MLDataPair>> theTraining, final int theLow,
|
||||
final int theHigh, final double[] flatSpot,
|
||||
ErrorFunction ef) {
|
||||
this.network = theNetwork;
|
||||
this.training = theTraining;
|
||||
this.low = theLow;
|
||||
this.high = theHigh;
|
||||
this.owner = theOwner;
|
||||
this.flatSpot = flatSpot;
|
||||
this.errorFunction = ef;
|
||||
|
||||
this.layerDelta = new double[network.getLayerOutput().length];
|
||||
this.gradients = new double[network.getWeights().length];
|
||||
this.actual = new double[network.getOutputCount()];
|
||||
|
||||
this.weights = network.getWeights();
|
||||
this.layerIndex = network.getLayerIndex();
|
||||
this.layerCounts = network.getLayerCounts();
|
||||
this.weightIndex = network.getWeightIndex();
|
||||
this.layerOutput = network.getLayerOutput();
|
||||
this.layerSums = network.getLayerSums();
|
||||
this.layerFeedCounts = network.getLayerFeedCounts();
|
||||
|
||||
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
|
||||
.getOutputCount());
|
||||
}
|
||||
|
||||
public FlatNetwork getNetwork() {
|
||||
return this.network;
|
||||
}
|
||||
|
||||
public double[] getWeights() {
|
||||
return this.weights;
|
||||
}
|
||||
|
||||
private void process(final double[] input, final double[] ideal, double s) {
|
||||
this.network.compute(input, this.actual);
|
||||
|
||||
this.errorCalculation.updateError(this.actual, ideal, s);
|
||||
this.errorFunction.calculateError(ideal, actual, this.layerDelta);
|
||||
|
||||
for (int i = 0; i < this.actual.length; i++) {
|
||||
|
||||
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
|
||||
.derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0]))
|
||||
* (this.layerDelta[i] * s);
|
||||
}
|
||||
|
||||
for (int i = this.network.getBeginTraining(); i < this.network
|
||||
.getEndTraining(); i++) {
|
||||
processLevel(i);
|
||||
}
|
||||
}
|
||||
|
||||
private void processLevel(final int currentLevel) {
|
||||
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
|
||||
final int toLayerIndex = this.layerIndex[currentLevel];
|
||||
final int fromLayerSize = this.layerCounts[currentLevel + 1];
|
||||
final int toLayerSize = this.layerFeedCounts[currentLevel];
|
||||
|
||||
final int index = this.weightIndex[currentLevel];
|
||||
final ActivationFunction activation = this.network
|
||||
.getActivationFunctions()[currentLevel];
|
||||
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
|
||||
|
||||
// handle weights
|
||||
int yi = fromLayerIndex;
|
||||
for (int y = 0; y < fromLayerSize; y++) {
|
||||
final double output = this.layerOutput[yi];
|
||||
double sum = 0;
|
||||
int xi = toLayerIndex;
|
||||
int wi = index + y;
|
||||
for (int x = 0; x < toLayerSize; x++) {
|
||||
this.gradients[wi] += output * this.layerDelta[xi];
|
||||
sum += this.weights[wi] * this.layerDelta[xi];
|
||||
wi += fromLayerSize;
|
||||
xi++;
|
||||
}
|
||||
|
||||
this.layerDelta[yi] = sum
|
||||
* (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot);
|
||||
yi++;
|
||||
}
|
||||
}
|
||||
|
||||
public final void run() {
|
||||
try {
|
||||
this.errorCalculation.reset();
|
||||
//for (int i = this.low; i <= this.high; i++) {
|
||||
for (List<MLDataPair> trainingSequence : training) {
|
||||
MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1);
|
||||
this.pair.setInputArray(mldp.getInputArray());
|
||||
if (this.pair.getIdealArray() != null) {
|
||||
this.pair.setIdealArray(mldp.getIdealArray());
|
||||
}
|
||||
//this.training.getRecord(i, this.pair);
|
||||
process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance());
|
||||
}
|
||||
//}
|
||||
final double error = this.errorCalculation.calculate();
|
||||
this.owner.report(this.gradients, error, null);
|
||||
EngineArray.fill(this.gradients, 0);
|
||||
} catch (final Throwable ex) {
|
||||
this.owner.report(null, 0, ex);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class JosekiLearner {
|
||||
|
||||
}
|
||||
64
src/net/woodyfolsom/msproj/ann/Layer.java
Normal file
64
src/net/woodyfolsom/msproj/ann/Layer.java
Normal file
@@ -0,0 +1,64 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
|
||||
public class Layer {
|
||||
private int[] neuronIds;
|
||||
|
||||
public Layer() {
|
||||
neuronIds = new int[0];
|
||||
}
|
||||
|
||||
public Layer(int numNeurons) {
|
||||
neuronIds = new int[numNeurons];
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return neuronIds.length;
|
||||
}
|
||||
|
||||
public int getNeuronId(int index) {
|
||||
return neuronIds[index];
|
||||
}
|
||||
|
||||
@XmlElement
|
||||
public int[] getNeuronIds() {
|
||||
int[] safeCopy = new int[neuronIds.length];
|
||||
System.arraycopy(neuronIds, 0, safeCopy, 0, neuronIds.length);
|
||||
return safeCopy;
|
||||
}
|
||||
|
||||
public void setNeuronId(int index, int id) {
|
||||
neuronIds[index] = id;
|
||||
}
|
||||
|
||||
public void setNeuronIds(int[] neuronIds) {
|
||||
this.neuronIds = new int[neuronIds.length];
|
||||
System.arraycopy(neuronIds, 0, this.neuronIds, 0, neuronIds.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + Arrays.hashCode(neuronIds);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj)
|
||||
return true;
|
||||
if (obj == null)
|
||||
return false;
|
||||
if (getClass() != obj.getClass())
|
||||
return false;
|
||||
Layer other = (Layer) obj;
|
||||
if (!Arrays.equals(neuronIds, other.neuronIds))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
157
src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java
Normal file
157
src/net/woodyfolsom/msproj/ann/MultiLayerPerceptron.java
Normal file
@@ -0,0 +1,157 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
import javax.xml.bind.JAXBContext;
|
||||
import javax.xml.bind.JAXBException;
|
||||
import javax.xml.bind.Marshaller;
|
||||
import javax.xml.bind.Unmarshaller;
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
import javax.xml.bind.annotation.XmlRootElement;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
@XmlRootElement
|
||||
public class MultiLayerPerceptron extends FeedforwardNetwork {
|
||||
private boolean biased;
|
||||
private Layer[] layers;
|
||||
|
||||
public MultiLayerPerceptron() {
|
||||
this(false, 1, 1);
|
||||
}
|
||||
|
||||
public MultiLayerPerceptron(boolean biased, int... layerSizes) {
|
||||
super(biased);
|
||||
|
||||
int numLayers = layerSizes.length;
|
||||
|
||||
if (numLayers < 2) {
|
||||
throw new IllegalArgumentException("# of layers must be >= 2");
|
||||
}
|
||||
|
||||
this.layers = new Layer[numLayers];
|
||||
|
||||
for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) {
|
||||
int layerSize = layerSizes[layerIndex];
|
||||
|
||||
if (layerSize < 1) {
|
||||
throw new IllegalArgumentException("Layer size must be >= 1");
|
||||
}
|
||||
|
||||
|
||||
Layer newLayer;
|
||||
if (layerIndex == numLayers - 1) {
|
||||
newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function);
|
||||
} else {
|
||||
newLayer = createNewLayer(layerIndex, layerSize, Tanh.function);
|
||||
}
|
||||
|
||||
if (layerIndex > 0) {
|
||||
Layer prevLayer = layers[layerIndex - 1];
|
||||
for (int j = 0; j < newLayer.size(); j++) {
|
||||
if (biased) {
|
||||
createBiasConnection(newLayer.getNeuronId(j),0.0);
|
||||
}
|
||||
for (int i = 0; i < prevLayer.size(); i++) {
|
||||
addConnection(new Connection(prevLayer.getNeuronId(i),
|
||||
newLayer.getNeuronId(j), 0.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) {
|
||||
Layer layer = new Layer(layerSize);
|
||||
layers[layerIndex] = layer;
|
||||
for (int n = 0; n < layerSize; n++) {
|
||||
Neuron neuron = createNeuron(layerIndex == 0, afunc);
|
||||
layer.setNeuronId(n, neuron.getId());
|
||||
}
|
||||
return layer;
|
||||
}
|
||||
|
||||
@XmlElement
|
||||
public Layer[] getLayers() {
|
||||
return layers;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double[] getOutput() {
|
||||
Layer outputLayer = layers[layers.length - 1];
|
||||
double output[] = new double[outputLayer.size()];
|
||||
for (int n = 0; n < outputLayer.size(); n++) {
|
||||
output[n] = getNeuron(outputLayer.getNeuronId(n)).getOutput();
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Neuron[] getOutputNeurons() {
|
||||
Layer outputLayer = layers[layers.length - 1];
|
||||
Neuron[] outputNeurons = new Neuron[outputLayer.size()];
|
||||
for (int i = 0; i < outputLayer.size(); i++) {
|
||||
outputNeurons[i] = getNeuron(outputLayer.getNeuronId(i));
|
||||
}
|
||||
return outputNeurons;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void setInput(double[] input) {
|
||||
Layer inputLayer = layers[0];
|
||||
for (int n = 0; n < inputLayer.size(); n++) {
|
||||
try {
|
||||
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
|
||||
} catch (NullPointerException npe) {
|
||||
npe.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public void setLayers(Layer[] layers) {
|
||||
this.layers = layers;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean load(InputStream is) {
|
||||
try {
|
||||
JAXBContext jc = JAXBContext
|
||||
.newInstance(MultiLayerPerceptron.class);
|
||||
|
||||
Unmarshaller u = jc.createUnmarshaller();
|
||||
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);
|
||||
|
||||
super.setActivationFunction(mlp.getActivationFunction());
|
||||
super.setConnections(mlp.getConnections());
|
||||
super.setNeurons(mlp.getNeurons());
|
||||
this.biased = mlp.biased;
|
||||
this.layers = mlp.layers;
|
||||
|
||||
return true;
|
||||
} catch (JAXBException je) {
|
||||
je.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean save(OutputStream os) {
|
||||
try {
|
||||
JAXBContext jc = JAXBContext
|
||||
.newInstance(MultiLayerPerceptron.class);
|
||||
|
||||
Marshaller m = jc.createMarshaller();
|
||||
m.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true);
|
||||
m.marshal(this, os);
|
||||
//m.marshal(this, System.out);
|
||||
return true;
|
||||
} catch (JAXBException je) {
|
||||
je.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
38
src/net/woodyfolsom/msproj/ann/NNData.java
Normal file
38
src/net/woodyfolsom/msproj/ann/NNData.java
Normal file
@@ -0,0 +1,38 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class NNData {
|
||||
private final double[] values;
|
||||
private final String[] fields;
|
||||
|
||||
public NNData(String[] fields, double[] values) {
|
||||
this.fields = fields;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
public NNData(NNData that) {
|
||||
this.fields = that.fields;
|
||||
this.values = that.values;
|
||||
}
|
||||
|
||||
public String[] getFields() {
|
||||
return fields;
|
||||
}
|
||||
|
||||
public double[] getValues() {
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder("[");
|
||||
|
||||
for (int i = 0; i < fields.length; i++) {
|
||||
if (i > 0) {
|
||||
sb.append(", " );
|
||||
}
|
||||
sb.append(fields[i] + "=" + values[i]);
|
||||
}
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
24
src/net/woodyfolsom/msproj/ann/NNDataPair.java
Normal file
24
src/net/woodyfolsom/msproj/ann/NNDataPair.java
Normal file
@@ -0,0 +1,24 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class NNDataPair {
|
||||
private final NNData input;
|
||||
private final NNData ideal;
|
||||
|
||||
public NNDataPair(NNData actual, NNData ideal) {
|
||||
this.input = actual;
|
||||
this.ideal = ideal;
|
||||
}
|
||||
|
||||
public NNData getInput() {
|
||||
return input;
|
||||
}
|
||||
|
||||
public NNData getIdeal() {
|
||||
return ideal;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return input.toString() + " => " + ideal.toString();
|
||||
}
|
||||
}
|
||||
@@ -1,31 +1,29 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
|
||||
public interface NeuralNetFilter {
|
||||
BasicNetwork getNeuralNetwork();
|
||||
int getActualTrainingEpochs();
|
||||
|
||||
public int getActualTrainingEpochs();
|
||||
public int getInputSize();
|
||||
public int getMaxTrainingEpochs();
|
||||
public int getOutputSize();
|
||||
int getInputSize();
|
||||
|
||||
int getMaxTrainingEpochs();
|
||||
|
||||
int getOutputSize();
|
||||
|
||||
boolean load(InputStream input);
|
||||
|
||||
boolean save(OutputStream output);
|
||||
|
||||
void setMaxTrainingEpochs(int max);
|
||||
|
||||
NNData compute(NNDataPair input);
|
||||
|
||||
//Due to Java type erasure, overloading a method
|
||||
//simply named 'learn' which takes Lists would be problematic
|
||||
|
||||
public double computeValue(MLData input);
|
||||
public double[] computeVector(MLData input);
|
||||
|
||||
public void learn(MLDataSet trainingSet);
|
||||
public void learn(Set<List<MLDataPair>> trainingSet);
|
||||
|
||||
public void load(String fileName) throws IOException;
|
||||
public void reset();
|
||||
public void reset(int seed);
|
||||
public void save(String fileName) throws IOException;
|
||||
public void setMaxTrainingEpochs(int max);
|
||||
void learnPatterns(List<NNDataPair> trainingSet);
|
||||
void learnSequences(List<List<NNDataPair>> trainingSet);
|
||||
}
|
||||
107
src/net/woodyfolsom/msproj/ann/Neuron.java
Normal file
107
src/net/woodyfolsom/msproj/ann/Neuron.java
Normal file
@@ -0,0 +1,107 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlElement;
|
||||
import javax.xml.bind.annotation.XmlElements;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Linear;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
public class Neuron {
|
||||
|
||||
private ActivationFunction activationFunction;
|
||||
private int id;
|
||||
private transient double input = 0.0;
|
||||
private transient double gradient = 0.0;
|
||||
|
||||
public Neuron() {
|
||||
//no-arg constructor for JAXB
|
||||
}
|
||||
|
||||
public Neuron(ActivationFunction activationFunction, int id) {
|
||||
this.activationFunction = activationFunction;
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public void addInput(double value) {
|
||||
input += value;
|
||||
}
|
||||
|
||||
|
||||
@XmlElements({
|
||||
@XmlElement(name="LinearActivationFunction",type=Linear.class),
|
||||
@XmlElement(name="SigmoidActivationFunction",type=Sigmoid.class),
|
||||
@XmlElement(name="TanhActivationFunction",type=Tanh.class)
|
||||
})
|
||||
public ActivationFunction getActivationFunction() {
|
||||
return activationFunction;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public double getGradient() {
|
||||
return gradient;
|
||||
}
|
||||
|
||||
@XmlTransient
|
||||
public double getInput() {
|
||||
return input;
|
||||
}
|
||||
|
||||
public double getOutput() {
|
||||
return activationFunction.calculate(input);
|
||||
}
|
||||
|
||||
public void setGradient(double value) {
|
||||
this.gradient = value;
|
||||
}
|
||||
|
||||
public void setInput(double input) {
|
||||
this.input = input;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime
|
||||
* result
|
||||
+ ((activationFunction == null) ? 0 : activationFunction
|
||||
.hashCode());
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj)
|
||||
return true;
|
||||
if (obj == null)
|
||||
return false;
|
||||
if (getClass() != obj.getClass())
|
||||
return false;
|
||||
Neuron other = (Neuron) obj;
|
||||
if (activationFunction == null) {
|
||||
if (other.activationFunction != null)
|
||||
return false;
|
||||
} else if (!activationFunction.equals(other.activationFunction))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
public void setActivationFunction(ActivationFunction activationFunction) {
|
||||
this.activationFunction = activationFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Neuron #" + id +", input: " + input + ", gradient: " + gradient;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class FusekiLearner {
|
||||
|
||||
}
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class ObjectiveFunction {
|
||||
|
||||
}
|
||||
298
src/net/woodyfolsom/msproj/ann/PassFilterTrainer.java
Normal file
298
src/net/woodyfolsom/msproj/ann/PassFilterTrainer.java
Normal file
@@ -0,0 +1,298 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.Action;
|
||||
import net.woodyfolsom.msproj.GameConfig;
|
||||
import net.woodyfolsom.msproj.GameRecord;
|
||||
import net.woodyfolsom.msproj.GameResult;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.policy.NeuralNetPolicy;
|
||||
import net.woodyfolsom.msproj.policy.Policy;
|
||||
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
|
||||
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||
|
||||
public class PassFilterTrainer { // implements epsilon-greedy trainer? online
|
||||
// version of NeuralNetFilter
|
||||
|
||||
private boolean training = true;
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
double alpha = 0.50;
|
||||
double lambda = 0.1;
|
||||
int maxGames = 1500;
|
||||
|
||||
new PassFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||
}
|
||||
|
||||
public void trainNetwork(double alpha, double lambda, int maxGames)
|
||||
throws IOException {
|
||||
|
||||
FeedforwardNetwork neuralNetwork;
|
||||
|
||||
GameConfig gameConfig = new GameConfig(9);
|
||||
|
||||
if (training) {
|
||||
neuralNetwork = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||
neuralNetwork.setName("PassFilter" + gameConfig.getSize());
|
||||
neuralNetwork.initWeights();
|
||||
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
|
||||
|
||||
System.out.println("Playing untrained games.");
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
System.out.println("" + (i + 1) + ". "
|
||||
+ playOptimal(neuralNetwork, gameRecord).getResult());
|
||||
}
|
||||
|
||||
System.out.println("Learning from " + maxGames
|
||||
+ " games of random self-play");
|
||||
|
||||
int gamesPlayed = 0;
|
||||
List<GameResult> results = new ArrayList<GameResult>();
|
||||
do {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
playEpsilonGreedy(0.5, neuralNetwork, trainer, gameRecord);
|
||||
System.out.println("Winner: " + gameRecord.getResult());
|
||||
gamesPlayed++;
|
||||
results.add(gameRecord.getResult());
|
||||
} while (gamesPlayed < maxGames);
|
||||
|
||||
System.out.println("Results of every 10th training game:");
|
||||
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
if (i % 10 == 0) {
|
||||
System.out.println("" + (i + 1) + ". " + results.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Learned network after " + maxGames
|
||||
+ " training games.");
|
||||
} else {
|
||||
System.out.println("Loading TicTacToe network from file.");
|
||||
neuralNetwork = new MultiLayerPerceptron();
|
||||
FileInputStream fis = new FileInputStream(new File("pass.net"));
|
||||
if (!new MultiLayerPerceptron().load(fis)) {
|
||||
System.out.println("Error loading pass.net from file.");
|
||||
return;
|
||||
}
|
||||
fis.close();
|
||||
}
|
||||
|
||||
evalTestCases(neuralNetwork);
|
||||
|
||||
System.out.println("Playing optimal games.");
|
||||
List<GameResult> gameResults = new ArrayList<GameResult>();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
gameResults.add(playOptimal(neuralNetwork, gameRecord).getResult());
|
||||
}
|
||||
|
||||
boolean suboptimalPlay = false;
|
||||
System.out.println("Optimal game summary: ");
|
||||
for (int i = 0; i < gameResults.size(); i++) {
|
||||
GameResult result = gameResults.get(i);
|
||||
System.out.println("" + (i + 1) + ". " + result);
|
||||
}
|
||||
|
||||
File output = new File("pass.net");
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(output);
|
||||
|
||||
neuralNetwork.save(fos);
|
||||
|
||||
System.out.println("Playing optimal vs random games.");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
GameRecord gameRecord = new GameRecord(gameConfig);
|
||||
System.out.println(""
|
||||
+ (i + 1)
|
||||
+ ". "
|
||||
+ playOptimalVsRandom(neuralNetwork, gameRecord)
|
||||
.getResult());
|
||||
}
|
||||
|
||||
if (suboptimalPlay) {
|
||||
System.out.println("Suboptimal play detected!");
|
||||
}
|
||||
}
|
||||
|
||||
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
|
||||
double[][] validationSet = new double[4][];
|
||||
|
||||
//losing, opponent did not pass
|
||||
//don't pass
|
||||
//(0.0 1.0 0.0) => 0.0
|
||||
validationSet[0] = new double[] { -1.0, -1.0 };
|
||||
|
||||
//winning, opponent did not pass
|
||||
//maybe pass?
|
||||
//(1.0 0.0 0.0) => ?
|
||||
validationSet[1] = new double[] { 1.0, -1.0 };
|
||||
|
||||
//winning, opponent passed
|
||||
//pass!
|
||||
//(1.0 0.0 1.0) => 1.0
|
||||
validationSet[2] = new double[] { 1.0, 1.0 };
|
||||
|
||||
//losing, opponent passed
|
||||
//don't pass!
|
||||
//(0.0 1.0 1.0) => 1.0
|
||||
validationSet[3] = new double[] { -1.0, 1.0 };
|
||||
|
||||
String[] inputNames = NNDataSetFactory.getInputFields(PassFilterTrainer.class);
|
||||
String[] outputNames = NNDataSetFactory.getOutputFields(PassFilterTrainer.class);
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork,
|
||||
GameRecord gameRecord) {
|
||||
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
|
||||
neuralNetPolicy.setPassFilter(neuralNetwork);
|
||||
Policy randomPolicy = new RandomMovePolicy();
|
||||
|
||||
GameConfig gameConfig = gameRecord.getGameConfig();
|
||||
GameState gameState = gameRecord.getGameState();
|
||||
|
||||
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
|
||||
int turnNo = 0;
|
||||
do {
|
||||
Action action;
|
||||
GameState nextState;
|
||||
|
||||
Player playerToMove = gameState.getPlayerToMove();
|
||||
action = policies[turnNo % 2].getAction(gameConfig, gameState,
|
||||
playerToMove);
|
||||
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Illegal move: " + action);
|
||||
}
|
||||
|
||||
nextState = gameRecord.getGameState();
|
||||
|
||||
//System.out.println("Action " + action + " selected by policy "
|
||||
// + policies[turnNo % 2].getName());
|
||||
//System.out.println("Next board state: " + nextState);
|
||||
gameState = nextState;
|
||||
turnNo++;
|
||||
} while (!gameState.isTerminal());
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork,
|
||||
GameRecord gameRecord) {
|
||||
|
||||
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
|
||||
neuralNetPolicy.setPassFilter(neuralNetwork);
|
||||
|
||||
if (gameRecord.getNumTurns() > 0) {
|
||||
throw new RuntimeException(
|
||||
"PlayOptimal requires a new GameRecord with no turns played.");
|
||||
}
|
||||
|
||||
GameState gameState;
|
||||
|
||||
do {
|
||||
Action action;
|
||||
GameState nextState;
|
||||
|
||||
Player playerToMove = gameRecord.getPlayerToMove();
|
||||
action = neuralNetPolicy.getAction(gameRecord.getGameConfig(),
|
||||
gameRecord.getGameState(), playerToMove);
|
||||
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Invalid move played: " + action);
|
||||
}
|
||||
nextState = gameRecord.getGameState();
|
||||
|
||||
//System.out.println("Action " + action + " selected by policy "
|
||||
// + neuralNetPolicy.getName());
|
||||
//System.out.println("Next board state: " + nextState);
|
||||
gameState = nextState;
|
||||
} while (!gameState.isTerminal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playEpsilonGreedy(double epsilon,
|
||||
FeedforwardNetwork neuralNetwork, TrainingMethod trainer,
|
||||
GameRecord gameRecord) {
|
||||
Policy randomPolicy = new RandomMovePolicy();
|
||||
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
|
||||
neuralNetPolicy.setPassFilter(neuralNetwork);
|
||||
|
||||
if (gameRecord.getNumTurns() > 0) {
|
||||
throw new RuntimeException(
|
||||
"PlayOptimal requires a new GameRecord with no turns played.");
|
||||
}
|
||||
|
||||
GameState gameState = gameRecord.getGameState();
|
||||
NNDataPair statePair;
|
||||
|
||||
Policy selectedPolicy;
|
||||
trainer.zeroTraces(neuralNetwork);
|
||||
|
||||
do {
|
||||
Action action;
|
||||
GameState nextState;
|
||||
|
||||
Player playerToMove = gameRecord.getPlayerToMove();
|
||||
|
||||
if (Math.random() < epsilon) {
|
||||
selectedPolicy = randomPolicy;
|
||||
action = selectedPolicy
|
||||
.getAction(gameRecord.getGameConfig(),
|
||||
gameRecord.getGameState(),
|
||||
gameRecord.getPlayerToMove());
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Illegal move played: " + action);
|
||||
}
|
||||
nextState = gameRecord.getGameState();
|
||||
} else {
|
||||
selectedPolicy = neuralNetPolicy;
|
||||
action = selectedPolicy
|
||||
.getAction(gameRecord.getGameConfig(),
|
||||
gameRecord.getGameState(),
|
||||
gameRecord.getPlayerToMove());
|
||||
|
||||
if (!gameRecord.play(playerToMove, action)) {
|
||||
throw new RuntimeException("Illegal move played: " + action);
|
||||
}
|
||||
nextState = gameRecord.getGameState();
|
||||
|
||||
statePair = NNDataSetFactory.createDataPair(gameState, PassFilterTrainer.class);
|
||||
NNDataPair nextStatePair = NNDataSetFactory
|
||||
.createDataPair(nextState, PassFilterTrainer.class);
|
||||
|
||||
trainer.iteratePattern(neuralNetwork, statePair,
|
||||
nextStatePair.getIdeal());
|
||||
}
|
||||
|
||||
gameState = nextState;
|
||||
} while (!gameState.isTerminal());
|
||||
|
||||
// finally, reinforce the actual reward
|
||||
statePair = NNDataSetFactory.createDataPair(gameState, PassFilterTrainer.class);
|
||||
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private void testNetwork(FeedforwardNetwork neuralNetwork,
|
||||
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||
validationSet[valIndex]), new NNData(outputNames,
|
||||
new double[] { 0.0 }));
|
||||
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
public class ShapeLearner {
|
||||
|
||||
}
|
||||
261
src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
Normal file
261
src/net/woodyfolsom/msproj/ann/TTTFilterTrainer.java
Normal file
@@ -0,0 +1,261 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Action;
|
||||
import net.woodyfolsom.msproj.tictactoe.GameRecord;
|
||||
import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT;
|
||||
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||
import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy;
|
||||
import net.woodyfolsom.msproj.tictactoe.Policy;
|
||||
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
|
||||
import net.woodyfolsom.msproj.tictactoe.State;
|
||||
|
||||
public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
|
||||
// version of NeuralNetFilter
|
||||
|
||||
private boolean training = true;
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
double alpha = 0.025;
|
||||
double lambda = .10;
|
||||
int maxGames = 100000;
|
||||
|
||||
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
|
||||
}
|
||||
|
||||
public void trainNetwork(double alpha, double lambda, int maxGames)
|
||||
throws IOException {
|
||||
|
||||
FeedforwardNetwork neuralNetwork;
|
||||
if (training) {
|
||||
neuralNetwork = new MultiLayerPerceptron(true, 9, 9, 1);
|
||||
neuralNetwork.setName("TicTacToe");
|
||||
neuralNetwork.initWeights();
|
||||
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
|
||||
|
||||
System.out.println("Playing untrained games.");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
System.out.println("" + (i + 1) + ". "
|
||||
+ playOptimal(neuralNetwork).getResult());
|
||||
}
|
||||
|
||||
System.out.println("Learning from " + maxGames
|
||||
+ " games of random self-play");
|
||||
|
||||
int gamesPlayed = 0;
|
||||
List<RESULT> results = new ArrayList<RESULT>();
|
||||
do {
|
||||
GameRecord gameRecord = playEpsilonGreedy(0.9, neuralNetwork,
|
||||
trainer);
|
||||
System.out.println("Winner: " + gameRecord.getResult());
|
||||
gamesPlayed++;
|
||||
results.add(gameRecord.getResult());
|
||||
} while (gamesPlayed < maxGames);
|
||||
|
||||
System.out.println("Results of every 10th training game:");
|
||||
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
if (i % 10 == 0) {
|
||||
System.out.println("" + (i + 1) + ". " + results.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Learned network after " + maxGames
|
||||
+ " training games.");
|
||||
} else {
|
||||
System.out.println("Loading TicTacToe network from file.");
|
||||
neuralNetwork = new MultiLayerPerceptron();
|
||||
FileInputStream fis = new FileInputStream(new File("ttt.net"));
|
||||
if (!new MultiLayerPerceptron().load(fis)) {
|
||||
System.out.println("Error loading ttt.net from file.");
|
||||
return;
|
||||
}
|
||||
fis.close();
|
||||
}
|
||||
|
||||
evalTestCases(neuralNetwork);
|
||||
|
||||
System.out.println("Playing optimal games.");
|
||||
List<RESULT> gameResults = new ArrayList<RESULT>();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
gameResults.add(playOptimal(neuralNetwork).getResult());
|
||||
}
|
||||
|
||||
boolean suboptimalPlay = false;
|
||||
System.out.println("Optimal game summary: ");
|
||||
for (int i = 0; i < gameResults.size(); i++) {
|
||||
RESULT result = gameResults.get(i);
|
||||
System.out.println("" + (i + 1) + ". " + result);
|
||||
if (result != RESULT.X_WINS) {
|
||||
suboptimalPlay = true;
|
||||
}
|
||||
}
|
||||
|
||||
File output = new File("ttt.net");
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(output);
|
||||
|
||||
neuralNetwork.save(fos);
|
||||
|
||||
System.out.println("Playing optimal vs random games.");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
System.out.println("" + (i + 1) + ". "
|
||||
+ playOptimalVsRandom(neuralNetwork).getResult());
|
||||
}
|
||||
|
||||
if (suboptimalPlay) {
|
||||
System.out.println("Suboptimal play detected!");
|
||||
}
|
||||
}
|
||||
|
||||
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
|
||||
double[][] validationSet = new double[8][];
|
||||
|
||||
// empty board
|
||||
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// center
|
||||
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// top edge
|
||||
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// left edge
|
||||
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// corner
|
||||
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0 };
|
||||
// win
|
||||
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
|
||||
-1.0, 0.0 };
|
||||
// loss
|
||||
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
|
||||
0.0, -1.0 };
|
||||
|
||||
// about to win
|
||||
validationSet[7] = new double[] { -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
-1.0, 0.0 };
|
||||
|
||||
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
|
||||
"12", "20", "21", "22" };
|
||||
String[] outputNames = new String[] { "values" };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||
Policy randomPolicy = new RandomPolicy();
|
||||
|
||||
State state = gameRecord.getState();
|
||||
|
||||
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
|
||||
int turnNo = 0;
|
||||
do {
|
||||
Action action;
|
||||
State nextState;
|
||||
|
||||
action = policies[turnNo % 2].getAction(gameRecord.getState());
|
||||
|
||||
nextState = gameRecord.apply(action);
|
||||
System.out.println("Action " + action + " selected by policy "
|
||||
+ policies[turnNo % 2].getName());
|
||||
System.out.println("Next board state: " + nextState);
|
||||
state = nextState;
|
||||
turnNo++;
|
||||
} while (!state.isTerminal());
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||
|
||||
State state = gameRecord.getState();
|
||||
|
||||
do {
|
||||
Action action;
|
||||
State nextState;
|
||||
|
||||
action = neuralNetPolicy.getAction(gameRecord.getState());
|
||||
|
||||
nextState = gameRecord.apply(action);
|
||||
System.out.println("Action " + action + " selected by policy "
|
||||
+ neuralNetPolicy.getName());
|
||||
System.out.println("Next board state: " + nextState);
|
||||
state = nextState;
|
||||
} while (!state.isTerminal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private GameRecord playEpsilonGreedy(double epsilon,
|
||||
FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
Policy randomPolicy = new RandomPolicy();
|
||||
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
|
||||
|
||||
// System.out.println("Playing epsilon-greedy game.");
|
||||
|
||||
State state = gameRecord.getState();
|
||||
NNDataPair statePair;
|
||||
|
||||
Policy selectedPolicy;
|
||||
trainer.zeroTraces(neuralNetwork);
|
||||
|
||||
do {
|
||||
Action action;
|
||||
State nextState;
|
||||
|
||||
if (Math.random() < epsilon) {
|
||||
selectedPolicy = randomPolicy;
|
||||
action = selectedPolicy.getAction(gameRecord.getState());
|
||||
nextState = gameRecord.apply(action);
|
||||
} else {
|
||||
selectedPolicy = neuralNetPolicy;
|
||||
action = selectedPolicy.getAction(gameRecord.getState());
|
||||
|
||||
nextState = gameRecord.apply(action);
|
||||
statePair = NNDataSetFactory.createDataPair(state);
|
||||
NNDataPair nextStatePair = NNDataSetFactory
|
||||
.createDataPair(nextState);
|
||||
trainer.iteratePattern(neuralNetwork, statePair,
|
||||
nextStatePair.getIdeal());
|
||||
}
|
||||
// System.out.println("Action " + action + " selected by policy " +
|
||||
// selectedPolicy.getName());
|
||||
|
||||
// System.out.println("Next board state: " + nextState);
|
||||
|
||||
state = nextState;
|
||||
} while (!state.isTerminal());
|
||||
|
||||
// finally, reinforce the actual reward
|
||||
statePair = NNDataSetFactory.createDataPair(state);
|
||||
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
|
||||
|
||||
return gameRecord;
|
||||
}
|
||||
|
||||
private void testNetwork(FeedforwardNetwork neuralNetwork,
|
||||
double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,
|
||||
validationSet[valIndex]), new NNData(outputNames,
|
||||
new double[] { 0.0 }));
|
||||
System.out.println(dp + " => " + neuralNetwork.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
123
src/net/woodyfolsom/msproj/ann/TemporalDifference.java
Normal file
123
src/net/woodyfolsom/msproj/ann/TemporalDifference.java
Normal file
@@ -0,0 +1,123 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TemporalDifference extends TrainingMethod {
|
||||
private final double gamma;
|
||||
private final double lambda;
|
||||
|
||||
public TemporalDifference(double alpha, double lambda) {
|
||||
this.gamma = alpha;
|
||||
this.lambda = lambda;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet) {
|
||||
int numDataPairs = trainingSet.size();
|
||||
int outputSize = neuralNetwork.getOutput().length;
|
||||
int totalOutputSize = outputSize * numDataPairs;
|
||||
|
||||
double[] actuals = new double[totalOutputSize];
|
||||
double[] ideals = new double[totalOutputSize];
|
||||
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
|
||||
NNDataPair nnDataPair = trainingSet.get(dataPair);
|
||||
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
|
||||
.getValues());
|
||||
double[] ideal = nnDataPair.getIdeal().getValues();
|
||||
int offset = dataPair * outputSize;
|
||||
|
||||
System.arraycopy(actual, 0, actuals, offset, outputSize);
|
||||
System.arraycopy(ideal, 0, ideals, offset, outputSize);
|
||||
}
|
||||
|
||||
double MSSE = errorFunction.compute(ideals, actuals);
|
||||
return MSSE;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
|
||||
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
|
||||
double[] idealValues = ideal.getValues();
|
||||
|
||||
for (int i = 0; i < idealValues.length; i++) {
|
||||
double input = outputNeurons[i].getInput();
|
||||
double derivative = outputNeurons[i].getActivationFunction()
|
||||
.derivative(input);
|
||||
outputNeurons[i].setGradient(outputNeurons[i].getGradient()
|
||||
+ derivative
|
||||
* (idealValues[i] - outputNeurons[i].getOutput()));
|
||||
}
|
||||
|
||||
// walking down the list of Neurons in reverse order, propagate the
|
||||
// error
|
||||
Neuron[] neurons = neuralNetwork.getNeurons();
|
||||
|
||||
for (int n = neurons.length - 1; n >= 0; n--) {
|
||||
|
||||
Neuron neuron = neurons[n];
|
||||
double error = neuron.getGradient();
|
||||
|
||||
Connection[] connectionsFromN = neuralNetwork
|
||||
.getConnectionsFrom(neuron.getId());
|
||||
if (connectionsFromN.length > 0) {
|
||||
|
||||
double derivative = neuron.getActivationFunction().derivative(
|
||||
neuron.getInput());
|
||||
for (Connection connection : connectionsFromN) {
|
||||
error += derivative
|
||||
* connection.getWeight()
|
||||
* neuralNetwork.getNeuron(connection.getDest())
|
||||
.getGradient();
|
||||
}
|
||||
}
|
||||
neuron.setGradient(error);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateWeights(FeedforwardNetwork neuralNetwork,
|
||||
double predictionError) {
|
||||
for (Connection connection : neuralNetwork.getConnections()) {
|
||||
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
|
||||
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
|
||||
double delta = gamma * srcNeuron.getOutput()
|
||||
* destNeuron.getGradient() + connection.getTrace() * lambda;
|
||||
connection.addDelta(delta);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||
NNDataPair statePair, NNData nextReward) {
|
||||
zeroGradients(neuralNetwork);
|
||||
|
||||
NNData ideal = nextReward;
|
||||
NNData actual = neuralNetwork.compute(statePair);
|
||||
|
||||
// backpropagate the gradients w.r.t. output error
|
||||
backPropagate(neuralNetwork, ideal);
|
||||
|
||||
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
|
||||
+ actual.getValues()[0] - nextReward.getValues()[0];
|
||||
|
||||
updateWeights(neuralNetwork, predictionError);
|
||||
}
|
||||
}
|
||||
@@ -1,484 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.EncogError;
|
||||
import org.encog.engine.network.activation.ActivationFunction;
|
||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||
import org.encog.mathutil.IntRange;
|
||||
import org.encog.ml.MLMethod;
|
||||
import org.encog.ml.TrainingImplementationType;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.train.MLTrain;
|
||||
import org.encog.ml.train.strategy.Strategy;
|
||||
import org.encog.ml.train.strategy.end.EndTrainingStrategy;
|
||||
import org.encog.neural.error.ErrorFunction;
|
||||
import org.encog.neural.error.LinearErrorFunction;
|
||||
import org.encog.neural.flat.FlatNetwork;
|
||||
import org.encog.neural.networks.ContainsFlat;
|
||||
import org.encog.neural.networks.training.LearningRate;
|
||||
import org.encog.neural.networks.training.Momentum;
|
||||
import org.encog.neural.networks.training.Train;
|
||||
import org.encog.neural.networks.training.TrainingError;
|
||||
import org.encog.neural.networks.training.propagation.TrainingContinuation;
|
||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
import org.encog.neural.networks.training.strategy.SmartLearningRate;
|
||||
import org.encog.neural.networks.training.strategy.SmartMomentum;
|
||||
import org.encog.util.EncogValidate;
|
||||
import org.encog.util.EngineArray;
|
||||
import org.encog.util.concurrency.DetermineWorkload;
|
||||
import org.encog.util.concurrency.EngineConcurrency;
|
||||
import org.encog.util.concurrency.MultiThreadable;
|
||||
import org.encog.util.concurrency.TaskGroup;
|
||||
import org.encog.util.logging.EncogLogging;
|
||||
|
||||
/**
|
||||
* This class started as a verbatim copy of BackPropagation from the open-source
|
||||
* Encog framework. It was merged with its super-classes to access protected
|
||||
* fields without resorting to reflection.
|
||||
*/
|
||||
public class TemporalDifferenceLearning implements MLTrain, Momentum,
|
||||
LearningRate, Train, MultiThreadable {
|
||||
// New fields for TD(lambda)
|
||||
private final double lambda;
|
||||
// end new fields
|
||||
|
||||
// BackProp
|
||||
public static final String LAST_DELTA = "LAST_DELTA";
|
||||
private double learningRate;
|
||||
private double momentum;
|
||||
private double[] lastDelta;
|
||||
// End BackProp
|
||||
|
||||
// Propagation
|
||||
private FlatNetwork currentFlatNetwork;
|
||||
private int numThreads;
|
||||
protected double[] gradients;
|
||||
private double[] lastGradient;
|
||||
protected ContainsFlat network;
|
||||
// private MLDataSet indexable;
|
||||
private Set<List<MLDataPair>> indexable;
|
||||
private GradientWorker[] workers;
|
||||
private double totalError;
|
||||
protected double lastError;
|
||||
private Throwable reportedException;
|
||||
private double[] flatSpot;
|
||||
private boolean shouldFixFlatSpot;
|
||||
private ErrorFunction ef = new LinearErrorFunction();
|
||||
// End Propagation
|
||||
|
||||
// BasicTraining
|
||||
private final List<Strategy> strategies = new ArrayList<Strategy>();
|
||||
private Set<List<MLDataPair>> training;
|
||||
private double error;
|
||||
private int iteration;
|
||||
private TrainingImplementationType implementationType;
|
||||
|
||||
// End BasicTraining
|
||||
|
||||
public TemporalDifferenceLearning(final ContainsFlat network,
|
||||
final Set<List<MLDataPair>> training, double lambda) {
|
||||
this(network, training, 0, 0, lambda);
|
||||
addStrategy(new SmartLearningRate());
|
||||
addStrategy(new SmartMomentum());
|
||||
}
|
||||
|
||||
public TemporalDifferenceLearning(final ContainsFlat network,
|
||||
Set<List<MLDataPair>> training, final double theLearnRate,
|
||||
final double theMomentum, double lambda) {
|
||||
initPropagation(network, training);
|
||||
// TODO consider how to re-implement validation
|
||||
// ValidateNetwork.validateMethodToData(network, training);
|
||||
this.momentum = theMomentum;
|
||||
this.learningRate = theLearnRate;
|
||||
this.lastDelta = new double[network.getFlat().getWeights().length];
|
||||
this.lambda = lambda;
|
||||
}
|
||||
|
||||
private void initPropagation(final ContainsFlat network,
|
||||
final Set<List<MLDataPair>> training) {
|
||||
initBasicTraining(TrainingImplementationType.Iterative);
|
||||
this.network = network;
|
||||
this.currentFlatNetwork = network.getFlat();
|
||||
setTraining(training);
|
||||
|
||||
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
|
||||
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
|
||||
|
||||
this.indexable = training;
|
||||
this.numThreads = 0;
|
||||
this.reportedException = null;
|
||||
this.shouldFixFlatSpot = true;
|
||||
}
|
||||
|
||||
private void initBasicTraining(TrainingImplementationType implementationType) {
|
||||
this.implementationType = implementationType;
|
||||
}
|
||||
|
||||
// Methods from BackPropagation
|
||||
@Override
|
||||
public boolean canContinue() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public double[] getLastDelta() {
|
||||
return this.lastDelta;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLearningRate() {
|
||||
return this.learningRate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMomentum() {
|
||||
return this.momentum;
|
||||
}
|
||||
|
||||
public boolean isValidResume(final TrainingContinuation state) {
|
||||
if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!state.getTrainingType().equals(getClass().getSimpleName())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA);
|
||||
return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TrainingContinuation pause() {
|
||||
final TrainingContinuation result = new TrainingContinuation();
|
||||
result.setTrainingType(this.getClass().getSimpleName());
|
||||
result.set(Backpropagation.LAST_DELTA, this.lastDelta);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void resume(final TrainingContinuation state) {
|
||||
if (!isValidResume(state)) {
|
||||
throw new TrainingError("Invalid training resume data length");
|
||||
}
|
||||
|
||||
this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLearningRate(final double rate) {
|
||||
this.learningRate = rate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setMomentum(final double m) {
|
||||
this.momentum = m;
|
||||
}
|
||||
|
||||
public double updateWeight(final double[] gradients,
|
||||
final double[] lastGradient, final int index) {
|
||||
final double delta = (gradients[index] * this.learningRate)
|
||||
+ (this.lastDelta[index] * this.momentum);
|
||||
this.lastDelta[index] = delta;
|
||||
|
||||
System.out.println("Updating weights for connection: " + index
|
||||
+ " with lambda: " + lambda);
|
||||
|
||||
return delta;
|
||||
}
|
||||
|
||||
public void initOthers() {
|
||||
}
|
||||
|
||||
// End methods from BackPropagation
|
||||
|
||||
// Methods from Propagation
|
||||
public void finishTraining() {
|
||||
basicFinishTraining();
|
||||
}
|
||||
|
||||
public FlatNetwork getCurrentFlatNetwork() {
|
||||
return this.currentFlatNetwork;
|
||||
}
|
||||
|
||||
public MLMethod getMethod() {
|
||||
return this.network;
|
||||
}
|
||||
|
||||
public void iteration() {
|
||||
iteration(1);
|
||||
}
|
||||
|
||||
public void rollIteration() {
|
||||
this.iteration++;
|
||||
}
|
||||
|
||||
public void iteration(final int count) {
|
||||
|
||||
try {
|
||||
for (int i = 0; i < count; i++) {
|
||||
|
||||
preIteration();
|
||||
|
||||
rollIteration();
|
||||
|
||||
calculateGradients();
|
||||
|
||||
if (this.currentFlatNetwork.isLimited()) {
|
||||
learnLimited();
|
||||
} else {
|
||||
learn();
|
||||
}
|
||||
|
||||
this.lastError = this.getError();
|
||||
|
||||
for (final GradientWorker worker : this.workers) {
|
||||
EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(),
|
||||
0, worker.getWeights(), 0,
|
||||
this.currentFlatNetwork.getWeights().length);
|
||||
}
|
||||
|
||||
if (this.currentFlatNetwork.getHasContext()) {
|
||||
copyContexts();
|
||||
}
|
||||
|
||||
if (this.reportedException != null) {
|
||||
throw (new EncogError(this.reportedException));
|
||||
}
|
||||
|
||||
postIteration();
|
||||
|
||||
EncogLogging.log(EncogLogging.LEVEL_INFO,
|
||||
"Training iteration done, error: " + getError());
|
||||
|
||||
}
|
||||
} catch (final ArrayIndexOutOfBoundsException ex) {
|
||||
EncogValidate.validateNetworkForTraining(this.network,
|
||||
getTraining());
|
||||
throw new EncogError(ex);
|
||||
}
|
||||
}
|
||||
|
||||
public void setThreadCount(final int numThreads) {
|
||||
this.numThreads = numThreads;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getThreadCount() {
|
||||
return this.numThreads;
|
||||
}
|
||||
|
||||
public void fixFlatSpot(boolean b) {
|
||||
this.shouldFixFlatSpot = b;
|
||||
}
|
||||
|
||||
public void setErrorFunction(ErrorFunction ef) {
|
||||
this.ef = ef;
|
||||
}
|
||||
|
||||
public void calculateGradients() {
|
||||
if (this.workers == null) {
|
||||
init();
|
||||
}
|
||||
|
||||
if (this.currentFlatNetwork.getHasContext()) {
|
||||
this.workers[0].getNetwork().clearContext();
|
||||
}
|
||||
|
||||
this.totalError = 0;
|
||||
|
||||
if (this.workers.length > 1) {
|
||||
|
||||
final TaskGroup group = EngineConcurrency.getInstance()
|
||||
.createTaskGroup();
|
||||
|
||||
for (final GradientWorker worker : this.workers) {
|
||||
EngineConcurrency.getInstance().processTask(worker, group);
|
||||
}
|
||||
|
||||
group.waitForComplete();
|
||||
} else {
|
||||
this.workers[0].run();
|
||||
}
|
||||
|
||||
this.setError(this.totalError / this.workers.length);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy the contexts to keep them consistent with multithreaded training.
|
||||
*/
|
||||
private void copyContexts() {
|
||||
|
||||
// copy the contexts(layer outputO from each group to the next group
|
||||
for (int i = 0; i < (this.workers.length - 1); i++) {
|
||||
final double[] src = this.workers[i].getNetwork().getLayerOutput();
|
||||
final double[] dst = this.workers[i + 1].getNetwork()
|
||||
.getLayerOutput();
|
||||
EngineArray.arrayCopy(src, dst);
|
||||
}
|
||||
|
||||
// copy the contexts from the final group to the real network
|
||||
EngineArray.arrayCopy(this.workers[this.workers.length - 1]
|
||||
.getNetwork().getLayerOutput(), this.currentFlatNetwork
|
||||
.getLayerOutput());
|
||||
}
|
||||
|
||||
private void init() {
|
||||
// fix flat spot, if needed
|
||||
this.flatSpot = new double[this.currentFlatNetwork
|
||||
.getActivationFunctions().length];
|
||||
|
||||
if (this.shouldFixFlatSpot) {
|
||||
for (int i = 0; i < this.currentFlatNetwork
|
||||
.getActivationFunctions().length; i++) {
|
||||
final ActivationFunction af = this.currentFlatNetwork
|
||||
.getActivationFunctions()[i];
|
||||
|
||||
if (af instanceof ActivationSigmoid) {
|
||||
this.flatSpot[i] = 0.1;
|
||||
} else {
|
||||
this.flatSpot[i] = 0.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
EngineArray.fill(this.flatSpot, 0.0);
|
||||
}
|
||||
|
||||
// setup workers
|
||||
final DetermineWorkload determine = new DetermineWorkload(
|
||||
this.numThreads, (int) this.indexable.size());
|
||||
// this.numThreads, (int) this.indexable.getRecordCount());
|
||||
|
||||
this.workers = new GradientWorker[determine.getThreadCount()];
|
||||
|
||||
int index = 0;
|
||||
|
||||
// handle CPU
|
||||
for (final IntRange r : determine.calculateWorkers()) {
|
||||
this.workers[index++] = new GradientWorker(
|
||||
this.currentFlatNetwork.clone(), this, new HashSet(
|
||||
this.indexable), r.getLow(), r.getHigh(),
|
||||
this.flatSpot, this.ef);
|
||||
}
|
||||
|
||||
initOthers();
|
||||
}
|
||||
|
||||
public void report(final double[] gradients, final double error,
|
||||
final Throwable ex) {
|
||||
synchronized (this) {
|
||||
if (ex == null) {
|
||||
|
||||
for (int i = 0; i < gradients.length; i++) {
|
||||
this.gradients[i] += gradients[i];
|
||||
}
|
||||
this.totalError += error;
|
||||
} else {
|
||||
this.reportedException = ex;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected void learn() {
|
||||
final double[] weights = this.currentFlatNetwork.getWeights();
|
||||
for (int i = 0; i < this.gradients.length; i++) {
|
||||
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
|
||||
this.gradients[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
protected void learnLimited() {
|
||||
final double limit = this.currentFlatNetwork.getConnectionLimit();
|
||||
final double[] weights = this.currentFlatNetwork.getWeights();
|
||||
for (int i = 0; i < this.gradients.length; i++) {
|
||||
if (Math.abs(weights[i]) < limit) {
|
||||
weights[i] = 0;
|
||||
} else {
|
||||
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
|
||||
}
|
||||
this.gradients[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public double[] getLastGradient() {
|
||||
return lastGradient;
|
||||
}
|
||||
|
||||
// End methods from Propagation
|
||||
|
||||
// Methods from BasicTraining/
|
||||
public void addStrategy(final Strategy strategy) {
|
||||
strategy.init(this);
|
||||
this.strategies.add(strategy);
|
||||
}
|
||||
|
||||
public void basicFinishTraining() {
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return this.error;
|
||||
}
|
||||
|
||||
public int getIteration() {
|
||||
return this.iteration;
|
||||
}
|
||||
|
||||
public List<Strategy> getStrategies() {
|
||||
return this.strategies;
|
||||
}
|
||||
|
||||
public MLDataSet getTraining() {
|
||||
throw new UnsupportedOperationException(
|
||||
"This learning method operates on Set<List<MLData>>, not MLDataSet");
|
||||
}
|
||||
|
||||
public boolean isTrainingDone() {
|
||||
for (Strategy strategy : this.strategies) {
|
||||
if (strategy instanceof EndTrainingStrategy) {
|
||||
EndTrainingStrategy end = (EndTrainingStrategy) strategy;
|
||||
if (end.shouldStop()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public void postIteration() {
|
||||
for (final Strategy strategy : this.strategies) {
|
||||
strategy.postIteration();
|
||||
}
|
||||
}
|
||||
|
||||
public void preIteration() {
|
||||
|
||||
this.iteration++;
|
||||
|
||||
for (final Strategy strategy : this.strategies) {
|
||||
strategy.preIteration();
|
||||
}
|
||||
}
|
||||
|
||||
public void setError(final double error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
public void setIteration(final int iteration) {
|
||||
this.iteration = iteration;
|
||||
}
|
||||
|
||||
public void setTraining(final Set<List<MLDataPair>> training) {
|
||||
this.training = training;
|
||||
}
|
||||
|
||||
public TrainingImplementationType getImplementationType() {
|
||||
return this.implementationType;
|
||||
}
|
||||
// End Methods from BasicTraining
|
||||
}
|
||||
43
src/net/woodyfolsom/msproj/ann/TrainingMethod.java
Normal file
43
src/net/woodyfolsom/msproj/ann/TrainingMethod.java
Normal file
@@ -0,0 +1,43 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.MSSE;
|
||||
|
||||
public abstract class TrainingMethod {
|
||||
protected final ErrorFunction errorFunction;
|
||||
|
||||
public TrainingMethod() {
|
||||
this.errorFunction = MSSE.function;
|
||||
}
|
||||
|
||||
protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork,
|
||||
NNDataPair statePair, NNData nextReward);
|
||||
|
||||
protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet);
|
||||
|
||||
protected abstract double computePatternError(FeedforwardNetwork neuralNetwork,
|
||||
List<NNDataPair> trainingSet);
|
||||
|
||||
protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet);
|
||||
|
||||
protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output);
|
||||
|
||||
protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork,
|
||||
List<List<NNDataPair>> trainingSet);
|
||||
|
||||
protected void zeroGradients(FeedforwardNetwork neuralNetwork) {
|
||||
for (Neuron neuron : neuralNetwork.getNeurons()) {
|
||||
neuron.setGradient(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
protected void zeroTraces(FeedforwardNetwork neuralNetwork) {
|
||||
for (Connection conn : neuralNetwork.getConnections()) {
|
||||
conn.setTrace(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
|
||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.train.MLTrain;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
import org.encog.neural.networks.layers.BasicLayer;
|
||||
|
||||
public class WinFilter extends AbstractNeuralNetFilter implements
|
||||
NeuralNetFilter {
|
||||
|
||||
public WinFilter() {
|
||||
// create a neural network, without using a factory
|
||||
BasicNetwork network = new BasicNetwork();
|
||||
network.addLayer(new BasicLayer(null, false, 2));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
|
||||
network.getStructure().finalizeStructure();
|
||||
network.reset();
|
||||
|
||||
this.neuralNetwork = network;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeValue(MLData input) {
|
||||
if (input instanceof GameStateMLData) {
|
||||
double[] idealVector = computeVector(input);
|
||||
GameState gameState = ((GameStateMLData) input).getGameState();
|
||||
Player playerToMove = gameState.getPlayerToMove();
|
||||
if (playerToMove == Player.BLACK) {
|
||||
return idealVector[0];
|
||||
} else if (playerToMove == Player.WHITE) {
|
||||
return idealVector[1];
|
||||
} else {
|
||||
throw new RuntimeException("Invalid GameState.playerToMove: "
|
||||
+ playerToMove);
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException(
|
||||
"This NeuralNetFilter only accepts GameStates as input.");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] computeVector(MLData input) {
|
||||
if (input instanceof GameStateMLData) {
|
||||
return neuralNetwork.compute(input).getData();
|
||||
} else {
|
||||
throw new UnsupportedOperationException(
|
||||
"This NeuralNetFilter only accepts GameStates as input.");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(MLDataSet trainingData) {
|
||||
throw new UnsupportedOperationException("This filter learns a Set<List<MLData>>, not an MLDataSet");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||
|
||||
// train the neural network
|
||||
final MLTrain train = new TemporalDifferenceLearning(neuralNetwork,
|
||||
trainingSet, 0.7, 0.8, 0.25);
|
||||
|
||||
actualTrainingEpochs = 0;
|
||||
|
||||
do {
|
||||
train.iteration();
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ train.getError());
|
||||
actualTrainingEpochs++;
|
||||
} while (train.getError() > 0.01
|
||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
neuralNetwork.reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset(int seed) {
|
||||
neuralNetwork.reset(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BasicNetwork getNeuralNetwork() {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
// TODO Auto-generated method stub
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputSize() {
|
||||
// TODO Auto-generated method stub
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,5 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.encog.engine.network.activation.ActivationSigmoid;
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||
import org.encog.ml.train.MLTrain;
|
||||
import org.encog.neural.networks.BasicNetwork;
|
||||
import org.encog.neural.networks.layers.BasicLayer;
|
||||
import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
|
||||
/**
|
||||
* Based on sample code from http://neuroph.sourceforge.net
|
||||
*
|
||||
@@ -21,63 +8,31 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
|
||||
*/
|
||||
public class XORFilter extends AbstractNeuralNetFilter implements
|
||||
NeuralNetFilter {
|
||||
|
||||
|
||||
private static final int INPUT_SIZE = 2;
|
||||
private static final int OUTPUT_SIZE = 1;
|
||||
|
||||
public XORFilter() {
|
||||
// create a neural network, without using a factory
|
||||
BasicNetwork network = new BasicNetwork();
|
||||
network.addLayer(new BasicLayer(null, false, 2));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
|
||||
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
|
||||
network.getStructure().finalizeStructure();
|
||||
network.reset();
|
||||
|
||||
this.neuralNetwork = network;
|
||||
this(0.8,0.7);
|
||||
}
|
||||
|
||||
public XORFilter(double learningRate, double momentum) {
|
||||
super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
|
||||
new BackPropagation(learningRate, momentum), 1000, 0.001);
|
||||
super.getNeuralNetwork().setName("XORFilter");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(MLDataSet trainingSet) {
|
||||
|
||||
// train the neural network
|
||||
final MLTrain train = new Backpropagation(neuralNetwork,
|
||||
trainingSet, 0.7, 0.8);
|
||||
|
||||
actualTrainingEpochs = 0;
|
||||
|
||||
do {
|
||||
train.iteration();
|
||||
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
|
||||
+ train.getError());
|
||||
actualTrainingEpochs++;
|
||||
} while (train.getError() > 0.01
|
||||
&& actualTrainingEpochs <= maxTrainingEpochs);
|
||||
public double compute(double x, double y) {
|
||||
return getNeuralNetwork().compute(new double[]{x,y})[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] computeVector(MLData mlData) {
|
||||
MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() },
|
||||
new double[][] { new double[getOutputSize()] });
|
||||
MLData output = neuralNetwork.compute(dataset.get(0).getInput());
|
||||
return output.getData();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int getInputSize() {
|
||||
return 2;
|
||||
return INPUT_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputSize() {
|
||||
// TODO Auto-generated method stub
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeValue(MLData input) {
|
||||
return computeVector(input)[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(Set<List<MLDataPair>> trainingSet) {
|
||||
throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set<List<MLData>>.");
|
||||
return OUTPUT_SIZE;
|
||||
}
|
||||
}
|
||||
53
src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java
Normal file
53
src/net/woodyfolsom/msproj/ann/math/ActivationFunction.java
Normal file
@@ -0,0 +1,53 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import javax.xml.bind.annotation.XmlAttribute;
|
||||
import javax.xml.bind.annotation.XmlTransient;
|
||||
|
||||
@XmlTransient
|
||||
public abstract class ActivationFunction {
|
||||
private String name;
|
||||
|
||||
public abstract double calculate(double arg);
|
||||
public abstract double derivative(double arg);
|
||||
|
||||
public ActivationFunction() {
|
||||
//no-arg constructor for JAXB
|
||||
}
|
||||
|
||||
public ActivationFunction(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@XmlAttribute
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
||||
return result;
|
||||
}
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj)
|
||||
return true;
|
||||
if (obj == null)
|
||||
return false;
|
||||
if (getClass() != obj.getClass())
|
||||
return false;
|
||||
ActivationFunction other = (ActivationFunction) obj;
|
||||
if (name == null) {
|
||||
if (other.name != null)
|
||||
return false;
|
||||
} else if (!name.equals(other.name))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
5
src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java
Normal file
5
src/net/woodyfolsom/msproj/ann/math/ErrorFunction.java
Normal file
@@ -0,0 +1,5 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public interface ErrorFunction {
|
||||
double compute(double[] ideal, double[] actual);
|
||||
}
|
||||
19
src/net/woodyfolsom/msproj/ann/math/Linear.java
Normal file
19
src/net/woodyfolsom/msproj/ann/math/Linear.java
Normal file
@@ -0,0 +1,19 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import javax.xml.bind.annotation.XmlRootElement;
|
||||
|
||||
public class Linear extends ActivationFunction{
|
||||
public static final Linear function = new Linear();
|
||||
|
||||
private Linear() {
|
||||
super("Linear");
|
||||
}
|
||||
|
||||
public double calculate(double arg) {
|
||||
return arg;
|
||||
}
|
||||
|
||||
public double derivative(double arg) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
22
src/net/woodyfolsom/msproj/ann/math/MSSE.java
Normal file
22
src/net/woodyfolsom/msproj/ann/math/MSSE.java
Normal file
@@ -0,0 +1,22 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
public class MSSE implements ErrorFunction{
|
||||
public static final ErrorFunction function = new MSSE();
|
||||
|
||||
public double compute(double[] ideal, double[] actual) {
|
||||
int idealSize = ideal.length;
|
||||
int actualSize = actual.length;
|
||||
|
||||
if (idealSize != actualSize) {
|
||||
throw new IllegalArgumentException("actualSize != idealSize");
|
||||
}
|
||||
|
||||
double SSE = 0.0;
|
||||
|
||||
for (int i = 0; i < idealSize; i++) {
|
||||
SSE += Math.pow(ideal[i] - actual[i], 2);
|
||||
}
|
||||
|
||||
return SSE / idealSize;
|
||||
}
|
||||
}
|
||||
20
src/net/woodyfolsom/msproj/ann/math/Sigmoid.java
Normal file
20
src/net/woodyfolsom/msproj/ann/math/Sigmoid.java
Normal file
@@ -0,0 +1,20 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import javax.xml.bind.annotation.XmlRootElement;
|
||||
|
||||
public class Sigmoid extends ActivationFunction{
|
||||
public static final Sigmoid function = new Sigmoid();
|
||||
|
||||
private Sigmoid() {
|
||||
super("Sigmoid");
|
||||
}
|
||||
|
||||
public double calculate(double arg) {
|
||||
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg));
|
||||
}
|
||||
|
||||
public double derivative(double arg) {
|
||||
double eX = Math.exp(arg);
|
||||
return eX / (Math.pow((1+eX), 2));
|
||||
}
|
||||
}
|
||||
23
src/net/woodyfolsom/msproj/ann/math/Tanh.java
Normal file
23
src/net/woodyfolsom/msproj/ann/math/Tanh.java
Normal file
@@ -0,0 +1,23 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import javax.xml.bind.annotation.XmlRootElement;
|
||||
|
||||
public class Tanh extends ActivationFunction{
|
||||
public static final Tanh function = new Tanh();
|
||||
|
||||
public Tanh() {
|
||||
super("Tanh");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double calculate(double arg) {
|
||||
return Math.tanh(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double derivative(double arg) {
|
||||
double tanh = Math.tanh(arg);
|
||||
return 1 - Math.pow(tanh, 2);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -16,6 +16,15 @@ public class AlphaBeta implements Policy {
|
||||
|
||||
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
||||
|
||||
private boolean logging = false;
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
private int lookAhead;
|
||||
private int numStateEvaluations = 0;
|
||||
|
||||
@@ -182,4 +191,9 @@ public class AlphaBeta implements Policy {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "Alpha-Beta";
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.gui.Goban;
|
||||
|
||||
public class HumanGuiInput implements Policy {
|
||||
private boolean logging;
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
private Goban goban;
|
||||
|
||||
public HumanGuiInput(Goban goban) {
|
||||
@@ -52,4 +61,9 @@ public class HumanGuiInput implements Policy {
|
||||
goban.setGameState(gameState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "HumanGUI";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
|
||||
public class HumanKeyboardInput implements Policy {
|
||||
private boolean logging = false;
|
||||
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
@@ -76,4 +85,9 @@ public class HumanKeyboardInput implements Policy {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "HumanKeyboard";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,6 +16,15 @@ public class Minimax implements Policy {
|
||||
|
||||
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
||||
|
||||
private boolean logging = false;
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
private int lookAhead;
|
||||
private int numStateEvaluations = 0;
|
||||
|
||||
@@ -152,7 +161,10 @@ public class Minimax implements Policy {
|
||||
|
||||
@Override
|
||||
public void setState(GameState gameState) {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "Minimax";
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,15 @@ import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||
public abstract class MonteCarlo implements Policy {
|
||||
protected static final int ROLLOUT_DEPTH_LIMIT = 250;
|
||||
|
||||
private boolean logging = false;
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
protected int numStateEvaluations = 0;
|
||||
protected Policy movePolicy;
|
||||
|
||||
|
||||
@@ -63,6 +63,45 @@ public class MonteCarloAMAF extends MonteCarloUCT {
|
||||
rootGameState, new AMAFProperties());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getBestAction(GameTreeNode<MonteCarloProperties> node) {
|
||||
Action bestAction = Action.NONE;
|
||||
double bestScore = Double.NEGATIVE_INFINITY;
|
||||
GameTreeNode<MonteCarloProperties> bestChild = null;
|
||||
|
||||
for (Action action : node.getActions()) {
|
||||
GameTreeNode<MonteCarloProperties> childNode = node
|
||||
.getChild(action);
|
||||
|
||||
AMAFProperties childProps = (AMAFProperties)childNode.getProperties();
|
||||
double childScore = childProps.getAmafWins() / (double)childProps.getAmafVisits();
|
||||
|
||||
if (childScore >= bestScore) {
|
||||
bestScore = childScore;
|
||||
bestAction = action;
|
||||
bestChild = childNode;
|
||||
}
|
||||
}
|
||||
|
||||
if (bestAction == Action.NONE) {
|
||||
System.out
|
||||
.println(getName() + " failed - no actions were found for the current game state (not even PASS).");
|
||||
} else {
|
||||
if (isLogging()) {
|
||||
System.out.println("Action " + bestAction + " selected for "
|
||||
+ node.getGameState().getPlayerToMove()
|
||||
+ " with simulated win ratio of "
|
||||
+ (bestScore * 100.0 + "%"));
|
||||
System.out.println("It was visited "
|
||||
+ bestChild.getProperties().getVisits() + " times out of "
|
||||
+ node.getProperties().getVisits() + " rollouts among "
|
||||
+ node.getNumChildren()
|
||||
+ " valid actions from the current state.");
|
||||
}
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) {
|
||||
//double nodeVisits = gameTreeNode.getParent().getProperties().getVisits();
|
||||
@@ -72,16 +111,8 @@ public class MonteCarloAMAF extends MonteCarloUCT {
|
||||
if (gameTreeNode.getGameState().isTerminal()) {
|
||||
nodeScore = 0.0;
|
||||
} else {
|
||||
/*
|
||||
MonteCarloProperties properties = gameTreeNode.getProperties();
|
||||
nodeScore = (double) (properties.getWins() / properties
|
||||
.getVisits())
|
||||
+ (TUNING_CONSTANT * Math.sqrt(Math.log(nodeVisits)
|
||||
/ gameTreeNode.getProperties().getVisits()));
|
||||
*
|
||||
*/
|
||||
AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties();
|
||||
nodeScore = (double) (properties.getAmafWins() / properties
|
||||
nodeScore = (properties.getAmafWins() / (double) properties
|
||||
.getAmafVisits())
|
||||
+ (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits)
|
||||
/ properties.getAmafVisits()));
|
||||
@@ -103,4 +134,9 @@ public class MonteCarloAMAF extends MonteCarloUCT {
|
||||
node.addChild(action, newChild);
|
||||
return newChildren;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "UCT-RAVE";
|
||||
}
|
||||
}
|
||||
63
src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java
Normal file
63
src/net/woodyfolsom/msproj/policy/MonteCarloSMAF.java
Normal file
@@ -0,0 +1,63 @@
|
||||
package net.woodyfolsom.msproj.policy;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.Action;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.tree.AMAFProperties;
|
||||
import net.woodyfolsom.msproj.tree.GameTreeNode;
|
||||
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||
|
||||
public class MonteCarloSMAF extends MonteCarloAMAF {
|
||||
private int horizon;
|
||||
|
||||
public MonteCarloSMAF(Policy movePolicy, long searchTimeLimit, int horizon) {
|
||||
super(movePolicy, searchTimeLimit);
|
||||
this.horizon = horizon;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(GameTreeNode<MonteCarloProperties> node, Rollout rollout) {
|
||||
GameTreeNode<MonteCarloProperties> currentNode = node;
|
||||
//List<Action> subTreeActions = new ArrayList<Action>(rollout.getPlayout());
|
||||
|
||||
List<Action> playout = rollout.getPlayout();
|
||||
int reward = rollout.getReward();
|
||||
while (currentNode != null) {
|
||||
AMAFProperties nodeProperties = (AMAFProperties)currentNode.getProperties();
|
||||
|
||||
//Always update props for the current node
|
||||
nodeProperties.setWins(nodeProperties.getWins() + reward);
|
||||
nodeProperties.setVisits(nodeProperties.getVisits() + 1);
|
||||
nodeProperties.setAmafWins(nodeProperties.getAmafWins() + reward);
|
||||
nodeProperties.setAmafVisits(nodeProperties.getAmafVisits() + 1);
|
||||
|
||||
GameTreeNode<MonteCarloProperties> parentNode = currentNode.getParent();
|
||||
if (parentNode != null) {
|
||||
Player playerToMove = parentNode.getGameState().getPlayerToMove();
|
||||
for (Action actionFromParent : parentNode.getActions()) {
|
||||
if (playout.subList(0, Math.max(horizon,playout.size())).contains(actionFromParent)) {
|
||||
GameTreeNode<MonteCarloProperties> subTreeChild = parentNode.getChild(actionFromParent);
|
||||
//Don't count AMAF properties for the current node twice
|
||||
if (subTreeChild == currentNode) {
|
||||
continue;
|
||||
}
|
||||
|
||||
AMAFProperties siblingProperties = (AMAFProperties)subTreeChild.getProperties();
|
||||
//Only update AMAF properties if the sibling is reached by the same action with the same player to move
|
||||
if (rollout.hasPlay(playerToMove,actionFromParent)) {
|
||||
siblingProperties.setAmafWins(siblingProperties.getAmafWins() + reward);
|
||||
siblingProperties.setAmafVisits(siblingProperties.getAmafVisits() + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentNode = currentNode.getParent();
|
||||
}
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return "MonteCarloSMAF";
|
||||
}
|
||||
}
|
||||
@@ -90,11 +90,8 @@ public class MonteCarloUCT extends MonteCarlo {
|
||||
GameTreeNode<MonteCarloProperties> childNode = node
|
||||
.getChild(action);
|
||||
|
||||
//MonteCarloProperties properties = childNode.getProperties();
|
||||
//double childScore = (double) properties.getWins()
|
||||
// / properties.getVisits();
|
||||
|
||||
double childScore = getNodeScore(childNode);
|
||||
MonteCarloProperties childProps = childNode.getProperties();
|
||||
double childScore = childProps.getWins() / (double)childProps.getVisits();
|
||||
|
||||
if (childScore >= bestScore) {
|
||||
bestScore = childScore;
|
||||
@@ -105,8 +102,9 @@ public class MonteCarloUCT extends MonteCarlo {
|
||||
|
||||
if (bestAction == Action.NONE) {
|
||||
System.out
|
||||
.println("MonteCarloUCT failed - no actions were found for the current game state (not even PASS).");
|
||||
.println(getName() + " failed - no actions were found for the current game state (not even PASS).");
|
||||
} else {
|
||||
if (isLogging()) {
|
||||
System.out.println("Action " + bestAction + " selected for "
|
||||
+ node.getGameState().getPlayerToMove()
|
||||
+ " with simulated win ratio of "
|
||||
@@ -116,6 +114,7 @@ public class MonteCarloUCT extends MonteCarlo {
|
||||
+ node.getProperties().getVisits() + " rollouts among "
|
||||
+ node.getNumChildren()
|
||||
+ " valid actions from the current state.");
|
||||
}
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
@@ -233,4 +232,9 @@ public class MonteCarloUCT extends MonteCarlo {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "MonteCarloUCT";
|
||||
}
|
||||
}
|
||||
144
src/net/woodyfolsom/msproj/policy/NeuralNetPolicy.java
Normal file
144
src/net/woodyfolsom/msproj/policy/NeuralNetPolicy.java
Normal file
@@ -0,0 +1,144 @@
|
||||
package net.woodyfolsom.msproj.policy;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.Action;
|
||||
import net.woodyfolsom.msproj.GameConfig;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.ann.PassFilterTrainer;
|
||||
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
|
||||
|
||||
public class NeuralNetPolicy implements Policy {
|
||||
|
||||
private FeedforwardNetwork moveFilter;
|
||||
private FeedforwardNetwork passFilter;
|
||||
|
||||
private ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
|
||||
private Policy randomMovePolicy = new RandomMovePolicy();
|
||||
|
||||
public NeuralNetPolicy() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Player player) {
|
||||
//If passFilter != null, check for a strong PASS signal.
|
||||
if (passFilter != null) {
|
||||
GameState stateAfterPass = new GameState(gameState);
|
||||
if (!stateAfterPass.playStone(player, Action.PASS)) {
|
||||
throw new RuntimeException("Pass should always be valid, but playStone(" + player +", Action.PASS) failed.");
|
||||
}
|
||||
|
||||
NNDataPair passData = NNDataSetFactory.createDataPair(gameState,PassFilterTrainer.class);
|
||||
double estimatedValue = passFilter.compute(passData).getValues()[0];
|
||||
|
||||
//if losing and opponent passed, never pass
|
||||
//if (passData.getInput().getValues()[0] == -1.0 && passData.getInput().getValues()[1] == 1.0) {
|
||||
// estimatedValue = 0.0;
|
||||
//}
|
||||
|
||||
if (player == Player.BLACK && 0.6 < estimatedValue) {
|
||||
//System.out.println("NeuralNetwork estimates value of PASS at > 0.95 (BLACK) for " + passData.getInput());
|
||||
return Action.PASS;
|
||||
}
|
||||
if (player == Player.WHITE && 0.4 > estimatedValue) {
|
||||
//System.out.println("NeuralNetwork estimates value of PASS at > 0.95 (BLACK) for " + passData.getInput());
|
||||
return Action.PASS;
|
||||
}
|
||||
}
|
||||
//If moveFilter != null, calculate action estimates and return the best one.
|
||||
|
||||
//max # valid moves is 19x19+2 (any coord plus pass, resign).
|
||||
List<Action> validMoves = validMoveGenerator.getActions(gameConfig, gameState, player, 363);
|
||||
|
||||
if (moveFilter != null) {
|
||||
if (player == Player.BLACK) {
|
||||
double bestValue = Double.NEGATIVE_INFINITY;
|
||||
Action bestAction = Action.NONE;
|
||||
for (Action actionToTry : validMoves) {
|
||||
GameState stateAfterAction = new GameState(gameState);
|
||||
if (!stateAfterAction.playStone(player, actionToTry)) {
|
||||
throw new RuntimeException("Invalid move: " + actionToTry);
|
||||
}
|
||||
NNDataPair passData = NNDataSetFactory.createDataPair(stateAfterAction,PassFilterTrainer.class);
|
||||
double estimatedValue = passFilter.compute(passData).getValues()[0];
|
||||
|
||||
if (estimatedValue > bestValue) {
|
||||
bestAction = actionToTry;
|
||||
}
|
||||
}
|
||||
if (bestValue > 0.95) {
|
||||
return bestAction;
|
||||
}
|
||||
} else if (player == Player.WHITE) {
|
||||
double bestValue = Double.POSITIVE_INFINITY;
|
||||
Action bestAction = Action.NONE;
|
||||
for (Action actionToTry : validMoves) {
|
||||
GameState stateAfterAction = new GameState(gameState);
|
||||
if (!stateAfterAction.playStone(player, actionToTry)) {
|
||||
throw new RuntimeException("Invalid move: " + actionToTry);
|
||||
}
|
||||
NNDataPair passData = NNDataSetFactory.createDataPair(stateAfterAction,PassFilterTrainer.class);
|
||||
double estimatedValue = passFilter.compute(passData).getValues()[0];
|
||||
|
||||
if (estimatedValue > bestValue) {
|
||||
bestAction = actionToTry;
|
||||
}
|
||||
}
|
||||
if (bestValue > 0.95) {
|
||||
return bestAction;
|
||||
}
|
||||
} else {
|
||||
throw new RuntimeException("Invalid player: " + player);
|
||||
}
|
||||
}
|
||||
|
||||
//If no moves make the cutoff, just return a random move.
|
||||
return randomMovePolicy.getAction(gameConfig, gameState, player);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Collection<Action> prohibitedActions, Player player) {
|
||||
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumStateEvaluations() {
|
||||
return randomMovePolicy.getNumStateEvaluations();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setState(GameState gameState) {
|
||||
randomMovePolicy.setState(gameState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLogging() {
|
||||
return randomMovePolicy.isLogging();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLogging(boolean logging) {
|
||||
randomMovePolicy.setLogging(logging);
|
||||
}
|
||||
|
||||
public void setMoveFilter(FeedforwardNetwork ffn) {
|
||||
this.moveFilter = ffn;
|
||||
}
|
||||
|
||||
public void setPassFilter(FeedforwardNetwork ffn) {
|
||||
this.passFilter = ffn;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "NeuralNet" + (passFilter != null ? "-" + passFilter.getName() : "")
|
||||
+ (passFilter != null ? "-" + passFilter.getName() : "");
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,19 @@ import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
|
||||
public interface Policy {
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Player player);
|
||||
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Collection<Action> prohibitedActions, Player player);
|
||||
|
||||
public int getNumStateEvaluations();
|
||||
String getName();
|
||||
|
||||
public void setState(GameState gameState);
|
||||
int getNumStateEvaluations();
|
||||
|
||||
void setState(GameState gameState);
|
||||
|
||||
boolean isLogging();
|
||||
|
||||
void setLogging(boolean logging);
|
||||
}
|
||||
14
src/net/woodyfolsom/msproj/policy/PolicyFactory.java
Normal file
14
src/net/woodyfolsom/msproj/policy/PolicyFactory.java
Normal file
@@ -0,0 +1,14 @@
|
||||
package net.woodyfolsom.msproj.policy;
|
||||
|
||||
public class PolicyFactory {
|
||||
|
||||
public static Policy createNew(Policy policyPrototype) {
|
||||
if (policyPrototype instanceof RandomMovePolicy) {
|
||||
return new RandomMovePolicy();
|
||||
} else if (policyPrototype instanceof NeuralNetPolicy) {
|
||||
return new NeuralNetPolicy();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Can only create new NeuralNetPolicy, not " + policyPrototype.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -110,6 +110,7 @@ public class RandomMovePolicy implements Policy, ActionGenerator {
|
||||
return randomAction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
@@ -122,4 +123,9 @@ public class RandomMovePolicy implements Policy, ActionGenerator {
|
||||
public void setState(GameState gameState) {
|
||||
// TODO Auto-generated method stub
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "Random";
|
||||
}
|
||||
}
|
||||
186
src/net/woodyfolsom/msproj/policy/RootParAMAF.java
Normal file
186
src/net/woodyfolsom/msproj/policy/RootParAMAF.java
Normal file
@@ -0,0 +1,186 @@
|
||||
package net.woodyfolsom.msproj.policy;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import net.woodyfolsom.msproj.Action;
|
||||
import net.woodyfolsom.msproj.GameConfig;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.tree.AMAFProperties;
|
||||
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||
|
||||
public class RootParAMAF implements Policy {
|
||||
private boolean logging = false;
|
||||
private int numTrees = 1;
|
||||
private Policy rolloutPolicy;
|
||||
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
private long timeLimit = 1000L;
|
||||
|
||||
public RootParAMAF(int numTrees, long timeLimit) {
|
||||
this.numTrees = numTrees;
|
||||
this.timeLimit = timeLimit;
|
||||
this.rolloutPolicy = new RandomMovePolicy();
|
||||
}
|
||||
|
||||
public RootParAMAF(int numTrees, Policy policyPrototype, long timeLimit) {
|
||||
this.numTrees = numTrees;
|
||||
this.timeLimit = timeLimit;
|
||||
this.rolloutPolicy = policyPrototype;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Player player) {
|
||||
Action bestAction = Action.NONE;
|
||||
|
||||
List<PolicyRunner> policyRunners = new ArrayList<PolicyRunner>();
|
||||
List<Thread> simulationThreads = new ArrayList<Thread>();
|
||||
|
||||
for (int i = 0; i < numTrees; i++) {
|
||||
|
||||
MonteCarlo policy = new MonteCarloAMAF(
|
||||
PolicyFactory.createNew(rolloutPolicy), timeLimit);
|
||||
|
||||
//policy.setLogging(true);
|
||||
|
||||
PolicyRunner policyRunner = new PolicyRunner(policy, gameConfig, gameState,
|
||||
player);
|
||||
policyRunners.add(policyRunner);
|
||||
|
||||
Thread simThread = new Thread(policyRunner);
|
||||
simulationThreads.add(simThread);
|
||||
}
|
||||
|
||||
for (Thread simThread : simulationThreads) {
|
||||
simThread.start();
|
||||
}
|
||||
|
||||
for (Thread simThread : simulationThreads) {
|
||||
try {
|
||||
simThread.join();
|
||||
} catch (InterruptedException ie) {
|
||||
System.out
|
||||
.println("Interrupted while waiting for Monte Carlo simulations to finish.");
|
||||
}
|
||||
}
|
||||
|
||||
Map<Action,Integer> totalReward = new HashMap<Action,Integer>();
|
||||
Map<Action,Integer> numSims = new HashMap<Action,Integer>();
|
||||
|
||||
for (PolicyRunner policyRunner : policyRunners) {
|
||||
Map<Action, MonteCarloProperties> qValues = policyRunner.getQvalues();
|
||||
for (Action action : qValues.keySet()) {
|
||||
if (totalReward.containsKey(action)) {
|
||||
totalReward.put(action, totalReward.get(action) + ((AMAFProperties)qValues.get(action)).getAmafWins());
|
||||
} else {
|
||||
totalReward.put(action, ((AMAFProperties)qValues.get(action)).getAmafWins());
|
||||
}
|
||||
if (numSims.containsKey(action)) {
|
||||
numSims.put(action, numSims.get(action) + ((AMAFProperties)qValues.get(action)).getAmafVisits());
|
||||
} else {
|
||||
numSims.put(action, ((AMAFProperties)qValues.get(action)).getAmafVisits());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double bestValue = 0.0;
|
||||
int totalRollouts = 0;
|
||||
int bestWins = 0;
|
||||
int bestSims = 0;
|
||||
|
||||
for (Action action : totalReward.keySet())
|
||||
{
|
||||
int totalWins = totalReward.get(action);
|
||||
int totalSims = numSims.get(action);
|
||||
|
||||
totalRollouts += totalSims;
|
||||
|
||||
double value = ((double)totalWins) / ((double)totalSims);
|
||||
|
||||
if (bestAction.isNone() || bestValue < value) {
|
||||
bestAction = action;
|
||||
bestValue = value;
|
||||
bestWins = totalWins;
|
||||
bestSims = totalSims;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(isLogging()) {
|
||||
System.out.println("Action " + bestAction + " selected for "
|
||||
+ player
|
||||
+ " with simulated win ratio of "
|
||||
+ (bestValue * 100.0 + "% among " + numTrees + " parallel simulations."));
|
||||
System.out.println("It won "
|
||||
+ bestWins + " out of " + bestSims
|
||||
+ " rollouts among " + totalRollouts
|
||||
+ " total rollouts (" + totalReward.size()
|
||||
+ " possible moves evaluated) from the current state.");
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Collection<Action> prohibitedActions, Player player) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Prohibited actions not supported by this class.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumStateEvaluations() {
|
||||
// TODO Auto-generated method stub
|
||||
return 0;
|
||||
}
|
||||
|
||||
class PolicyRunner implements Runnable {
|
||||
Map<Action,MonteCarloProperties> qValues;
|
||||
|
||||
private GameConfig gameConfig;
|
||||
private GameState gameState;
|
||||
private Player player;
|
||||
private MonteCarlo policy;
|
||||
|
||||
public PolicyRunner(MonteCarlo policy, GameConfig gameConfig,
|
||||
GameState gameState, Player player) {
|
||||
this.policy = policy;
|
||||
this.gameConfig = gameConfig;
|
||||
this.gameState = gameState;
|
||||
this.player = player;
|
||||
}
|
||||
|
||||
public Map<Action,MonteCarloProperties> getQvalues() {
|
||||
return qValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
qValues = policy.getQvalues(gameConfig, gameState, player);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setState(GameState gameState) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
if (rolloutPolicy.getName() == "Random") {
|
||||
return "RootParallelization";
|
||||
} else {
|
||||
return "RootParallelization-" + rolloutPolicy.getName();
|
||||
}
|
||||
}
|
||||
}
|
||||
186
src/net/woodyfolsom/msproj/policy/RootParSMAF.java
Normal file
186
src/net/woodyfolsom/msproj/policy/RootParSMAF.java
Normal file
@@ -0,0 +1,186 @@
|
||||
package net.woodyfolsom.msproj.policy;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import net.woodyfolsom.msproj.Action;
|
||||
import net.woodyfolsom.msproj.GameConfig;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.tree.AMAFProperties;
|
||||
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||
|
||||
public class RootParSMAF implements Policy {
|
||||
private boolean logging = false;
|
||||
private int numTrees = 1;
|
||||
private Policy rolloutPolicy;
|
||||
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
private long timeLimit = 1000L;
|
||||
|
||||
public RootParSMAF(int numTrees, long timeLimit) {
|
||||
this.numTrees = numTrees;
|
||||
this.timeLimit = timeLimit;
|
||||
this.rolloutPolicy = new RandomMovePolicy();
|
||||
}
|
||||
|
||||
public RootParSMAF(int numTrees, Policy policyPrototype, long timeLimit) {
|
||||
this.numTrees = numTrees;
|
||||
this.timeLimit = timeLimit;
|
||||
this.rolloutPolicy = policyPrototype;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Player player) {
|
||||
Action bestAction = Action.NONE;
|
||||
|
||||
List<PolicyRunner> policyRunners = new ArrayList<PolicyRunner>();
|
||||
List<Thread> simulationThreads = new ArrayList<Thread>();
|
||||
|
||||
for (int i = 0; i < numTrees; i++) {
|
||||
|
||||
MonteCarlo policy = new MonteCarloSMAF(
|
||||
PolicyFactory.createNew(rolloutPolicy), timeLimit, 4);
|
||||
|
||||
//policy.setLogging(true);
|
||||
|
||||
PolicyRunner policyRunner = new PolicyRunner(policy, gameConfig, gameState,
|
||||
player);
|
||||
policyRunners.add(policyRunner);
|
||||
|
||||
Thread simThread = new Thread(policyRunner);
|
||||
simulationThreads.add(simThread);
|
||||
}
|
||||
|
||||
for (Thread simThread : simulationThreads) {
|
||||
simThread.start();
|
||||
}
|
||||
|
||||
for (Thread simThread : simulationThreads) {
|
||||
try {
|
||||
simThread.join();
|
||||
} catch (InterruptedException ie) {
|
||||
System.out
|
||||
.println("Interrupted while waiting for Monte Carlo simulations to finish.");
|
||||
}
|
||||
}
|
||||
|
||||
Map<Action,Integer> totalReward = new HashMap<Action,Integer>();
|
||||
Map<Action,Integer> numSims = new HashMap<Action,Integer>();
|
||||
|
||||
for (PolicyRunner policyRunner : policyRunners) {
|
||||
Map<Action, MonteCarloProperties> qValues = policyRunner.getQvalues();
|
||||
for (Action action : qValues.keySet()) {
|
||||
if (totalReward.containsKey(action)) {
|
||||
totalReward.put(action, totalReward.get(action) + ((AMAFProperties)qValues.get(action)).getAmafWins());
|
||||
} else {
|
||||
totalReward.put(action, ((AMAFProperties)qValues.get(action)).getAmafWins());
|
||||
}
|
||||
if (numSims.containsKey(action)) {
|
||||
numSims.put(action, numSims.get(action) + ((AMAFProperties)qValues.get(action)).getAmafVisits());
|
||||
} else {
|
||||
numSims.put(action, ((AMAFProperties)qValues.get(action)).getAmafVisits());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double bestValue = 0.0;
|
||||
int totalRollouts = 0;
|
||||
int bestWins = 0;
|
||||
int bestSims = 0;
|
||||
|
||||
for (Action action : totalReward.keySet())
|
||||
{
|
||||
int totalWins = totalReward.get(action);
|
||||
int totalSims = numSims.get(action);
|
||||
|
||||
totalRollouts += totalSims;
|
||||
|
||||
double value = ((double)totalWins) / ((double)totalSims);
|
||||
|
||||
if (bestAction.isNone() || bestValue < value) {
|
||||
bestAction = action;
|
||||
bestValue = value;
|
||||
bestWins = totalWins;
|
||||
bestSims = totalSims;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(isLogging()) {
|
||||
System.out.println("Action " + bestAction + " selected for "
|
||||
+ player
|
||||
+ " with simulated win ratio of "
|
||||
+ (bestValue * 100.0 + "% among " + numTrees + " parallel simulations."));
|
||||
System.out.println("It won "
|
||||
+ bestWins + " out of " + bestSims
|
||||
+ " rollouts among " + totalRollouts
|
||||
+ " total rollouts (" + totalReward.size()
|
||||
+ " possible moves evaluated) from the current state.");
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Collection<Action> prohibitedActions, Player player) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Prohibited actions not supported by this class.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumStateEvaluations() {
|
||||
// TODO Auto-generated method stub
|
||||
return 0;
|
||||
}
|
||||
|
||||
class PolicyRunner implements Runnable {
|
||||
Map<Action,MonteCarloProperties> qValues;
|
||||
|
||||
private GameConfig gameConfig;
|
||||
private GameState gameState;
|
||||
private Player player;
|
||||
private MonteCarlo policy;
|
||||
|
||||
public PolicyRunner(MonteCarlo policy, GameConfig gameConfig,
|
||||
GameState gameState, Player player) {
|
||||
this.policy = policy;
|
||||
this.gameConfig = gameConfig;
|
||||
this.gameState = gameState;
|
||||
this.player = player;
|
||||
}
|
||||
|
||||
public Map<Action,MonteCarloProperties> getQvalues() {
|
||||
return qValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
qValues = policy.getQvalues(gameConfig, gameState, player);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setState(GameState gameState) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
if (rolloutPolicy.getName() == "Random") {
|
||||
return "RootParallelization";
|
||||
} else {
|
||||
return "RootParallelization-" + rolloutPolicy.getName();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,14 +13,32 @@ import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
|
||||
|
||||
public class RootParallelization implements Policy {
|
||||
private boolean logging = false;
|
||||
private int numTrees = 1;
|
||||
private Policy rolloutPolicy;
|
||||
|
||||
public boolean isLogging() {
|
||||
return logging;
|
||||
}
|
||||
|
||||
public void setLogging(boolean logging) {
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
private long timeLimit = 1000L;
|
||||
|
||||
public RootParallelization(int numTrees, long timeLimit) {
|
||||
this.numTrees = numTrees;
|
||||
this.timeLimit = timeLimit;
|
||||
this.rolloutPolicy = new RandomMovePolicy();
|
||||
}
|
||||
|
||||
public RootParallelization(int numTrees, Policy policyPrototype, long timeLimit) {
|
||||
this.numTrees = numTrees;
|
||||
this.timeLimit = timeLimit;
|
||||
this.rolloutPolicy = policyPrototype;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(GameConfig gameConfig, GameState gameState,
|
||||
Player player) {
|
||||
@@ -31,7 +49,7 @@ public class RootParallelization implements Policy {
|
||||
|
||||
for (int i = 0; i < numTrees; i++) {
|
||||
PolicyRunner policyRunner = new PolicyRunner(new MonteCarloUCT(
|
||||
new RandomMovePolicy(), timeLimit), gameConfig, gameState,
|
||||
PolicyFactory.createNew(rolloutPolicy), timeLimit), gameConfig, gameState,
|
||||
player);
|
||||
policyRunners.add(policyRunner);
|
||||
|
||||
@@ -94,6 +112,7 @@ public class RootParallelization implements Policy {
|
||||
|
||||
}
|
||||
|
||||
if(isLogging()) {
|
||||
System.out.println("Action " + bestAction + " selected for "
|
||||
+ player
|
||||
+ " with simulated win ratio of "
|
||||
@@ -103,7 +122,7 @@ public class RootParallelization implements Policy {
|
||||
+ " rollouts among " + totalRollouts
|
||||
+ " total rollouts (" + totalReward.size()
|
||||
+ " possible moves evaluated) from the current state.");
|
||||
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
|
||||
@@ -147,8 +166,15 @@ public class RootParallelization implements Policy {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setState(GameState gameState) {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
public void setState(GameState gameState) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
if (rolloutPolicy.getName() == "Random") {
|
||||
return "RootParallelization";
|
||||
} else {
|
||||
return "RootParallelization-" + rolloutPolicy.getName();
|
||||
}
|
||||
}
|
||||
}
|
||||
54
src/net/woodyfolsom/msproj/tictactoe/Action.java
Normal file
54
src/net/woodyfolsom/msproj/tictactoe/Action.java
Normal file
@@ -0,0 +1,54 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class Action {
|
||||
public static final Action NONE = new Action(PLAYER.NONE, -1, -1);
|
||||
|
||||
private Game.PLAYER player;
|
||||
private int row;
|
||||
private int column;
|
||||
|
||||
public static Action getInstance(PLAYER player, int row, int column) {
|
||||
return new Action(player,row,column);
|
||||
}
|
||||
|
||||
private Action(PLAYER player, int row, int column) {
|
||||
this.player = player;
|
||||
this.row = row;
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
public Game.PLAYER getPlayer() {
|
||||
return player;
|
||||
}
|
||||
|
||||
public int getColumn() {
|
||||
return column;
|
||||
}
|
||||
|
||||
public int getRow() {
|
||||
return row;
|
||||
}
|
||||
|
||||
public boolean isNone() {
|
||||
return this == Action.NONE;
|
||||
}
|
||||
|
||||
public void setPlayer(Game.PLAYER player) {
|
||||
this.player = player;
|
||||
}
|
||||
|
||||
public void setRow(int row) {
|
||||
this.row = row;
|
||||
}
|
||||
|
||||
public void setColumn(int column) {
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return player + "(" + row + ", " + column + ")";
|
||||
}
|
||||
}
|
||||
5
src/net/woodyfolsom/msproj/tictactoe/Game.java
Normal file
5
src/net/woodyfolsom/msproj/tictactoe/Game.java
Normal file
@@ -0,0 +1,5 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
public class Game {
|
||||
public enum PLAYER {X,O,NONE}
|
||||
}
|
||||
63
src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
Normal file
63
src/net/woodyfolsom/msproj/tictactoe/GameRecord.java
Normal file
@@ -0,0 +1,63 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class GameRecord {
|
||||
public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS}
|
||||
|
||||
private List<Action> actions = new ArrayList<Action>();
|
||||
private List<State> states = new ArrayList<State>();
|
||||
|
||||
private RESULT result = RESULT.IN_PROGRESS;
|
||||
|
||||
public GameRecord() {
|
||||
actions.add(Action.NONE);
|
||||
states.add(new State());
|
||||
}
|
||||
|
||||
public void addState(State state) {
|
||||
states.add(state);
|
||||
}
|
||||
|
||||
public State apply(Action action) {
|
||||
State nextState = getState().apply(action);
|
||||
if (nextState.isValid()) {
|
||||
states.add(nextState);
|
||||
actions.add(action);
|
||||
}
|
||||
|
||||
if (nextState.isTerminal()) {
|
||||
if (nextState.isWinner(PLAYER.X)) {
|
||||
result = RESULT.X_WINS;
|
||||
} else if (nextState.isWinner(PLAYER.O)) {
|
||||
result = RESULT.O_WINS;
|
||||
} else {
|
||||
result = RESULT.TIE_GAME;
|
||||
}
|
||||
}
|
||||
return nextState;
|
||||
}
|
||||
|
||||
public int getNumStates() {
|
||||
return states.size();
|
||||
}
|
||||
|
||||
public RESULT getResult() {
|
||||
return result;
|
||||
}
|
||||
|
||||
public void setResult(RESULT result) {
|
||||
this.result = result;
|
||||
}
|
||||
|
||||
public State getState() {
|
||||
return states.get(states.size()-1);
|
||||
}
|
||||
|
||||
public State getState(int index) {
|
||||
return states.get(index);
|
||||
}
|
||||
}
|
||||
20
src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
Normal file
20
src/net/woodyfolsom/msproj/tictactoe/MoveGenerator.java
Normal file
@@ -0,0 +1,20 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class MoveGenerator {
|
||||
public List<Action> getValidActions(State state) {
|
||||
PLAYER playerToMove = state.getPlayerToMove();
|
||||
List<Action> validActions = new ArrayList<Action>();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
if (state.isEmpty(i,j))
|
||||
validActions.add(Action.getInstance(playerToMove, i, j));
|
||||
}
|
||||
}
|
||||
return validActions;
|
||||
}
|
||||
}
|
||||
211
src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
Normal file
211
src/net/woodyfolsom/msproj/tictactoe/NNDataSetFactory.java
Normal file
@@ -0,0 +1,211 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.GameBoard;
|
||||
import net.woodyfolsom.msproj.GameResult;
|
||||
import net.woodyfolsom.msproj.GameState;
|
||||
import net.woodyfolsom.msproj.Player;
|
||||
import net.woodyfolsom.msproj.ann.FusekiFilterTrainer;
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.ann.PassFilterTrainer;
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class NNDataSetFactory {
|
||||
public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"};
|
||||
public static final String[] TTT_OUTPUT_FIELDS = {"VALUE"};
|
||||
|
||||
public static final String[] PASS_INPUT_FIELDS = {"WINNING","PREV_PLY_PASS"};
|
||||
public static final String[] PASS_OUTPUT_FIELDS = {"SHOULD_PASS"};
|
||||
|
||||
public static final String[] FUSEKI_INPUT_FIELDS = {
|
||||
"00","11","22","33","44","55","66","77","88",
|
||||
"10","11","22","33","44","55","66","77","88",
|
||||
"20","11","22","33","44","55","66","77","88",
|
||||
"30","11","22","33","44","55","66","77","88",
|
||||
"40","11","22","33","44","55","66","77","88",
|
||||
"50","11","22","33","44","55","66","77","88",
|
||||
"60","11","22","33","44","55","66","77","88",
|
||||
"70","11","22","33","44","55","66","77","88",
|
||||
"70","11","22","33","44","55","66","77","88"};
|
||||
public static final String[] FUSEKI_OUTPUT_FIELDS = {"VALUE"};
|
||||
|
||||
public static List<List<NNDataPair>> createDataSet(List<GameRecord> tttGames) {
|
||||
|
||||
List<List<NNDataPair>> nnDataSet = new ArrayList<List<NNDataPair>>();
|
||||
|
||||
for (GameRecord tttGame : tttGames) {
|
||||
List<NNDataPair> gameData = createDataPairList(tttGame);
|
||||
|
||||
|
||||
nnDataSet.add(gameData);
|
||||
}
|
||||
|
||||
return nnDataSet;
|
||||
}
|
||||
|
||||
public static String[] getInputFields(Object clazz) {
|
||||
if (clazz == PassFilterTrainer.class) {
|
||||
return PASS_INPUT_FIELDS;
|
||||
} else if (clazz == FusekiFilterTrainer.class) {
|
||||
return FUSEKI_INPUT_FIELDS;
|
||||
} else {
|
||||
throw new RuntimeException("Don't know how to return inputFields for NeuralNetwork Trainer of type: " + clazz.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
public static String[] getOutputFields(Object clazz) {
|
||||
if (clazz == PassFilterTrainer.class) {
|
||||
return PASS_OUTPUT_FIELDS;
|
||||
} else if (clazz == FusekiFilterTrainer.class) {
|
||||
return FUSEKI_OUTPUT_FIELDS;
|
||||
} else {
|
||||
throw new RuntimeException("Don't know how to return inputFields for NeuralNetwork Trainer of type: " + clazz.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
public static List<NNDataPair> createDataPairList(GameRecord gameRecord) {
|
||||
List<NNDataPair> gameData = new ArrayList<NNDataPair>();
|
||||
|
||||
for (int i = 0; i < gameRecord.getNumStates(); i++) {
|
||||
gameData.add(createDataPair(gameRecord.getState(i)));
|
||||
}
|
||||
|
||||
return gameData;
|
||||
}
|
||||
|
||||
public static NNDataPair createDataPair(GameState goState, Object clazz) {
|
||||
if (clazz == PassFilterTrainer.class) {
|
||||
return createPassFilterDataPair(goState);
|
||||
} else if (clazz == FusekiFilterTrainer.class) {
|
||||
return createFusekiFilterDataPair(goState);
|
||||
} else {
|
||||
throw new RuntimeException("Don't know how to create DataPair for NeuralNetwork Trainer of type: " + clazz.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
private static NNDataPair createFusekiFilterDataPair(GameState goState) {
|
||||
double value;
|
||||
|
||||
if (goState.isTerminal()) {
|
||||
if (goState.getResult().isWinner(Player.BLACK)) {
|
||||
value = 1.0; // win for black
|
||||
} else if (goState.getResult().isWinner(Player.WHITE)) {
|
||||
value = 0.0; // loss for black
|
||||
//value = -1.0;
|
||||
} else {// tie
|
||||
value = 0.5;
|
||||
//value = 0.0; //tie
|
||||
}
|
||||
} else {
|
||||
value = 0.0;
|
||||
}
|
||||
|
||||
int size = goState.getGameConfig().getSize();
|
||||
double[] inputValues = new double[size * size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
//col then row
|
||||
char symbol = goState.getGameBoard().getSymbolAt(j, i);
|
||||
switch (symbol) {
|
||||
case GameBoard.EMPTY_INTERSECTION : inputValues[i*size+j] = 0.0;
|
||||
break;
|
||||
case GameBoard.BLACK_STONE : inputValues[i*size+j] = 1.0;
|
||||
break;
|
||||
case GameBoard.WHITE_STONE : inputValues[i*size+j] = -1.0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new NNDataPair(new NNData(FUSEKI_INPUT_FIELDS,inputValues),new NNData(FUSEKI_OUTPUT_FIELDS,new double[]{value}));
|
||||
}
|
||||
|
||||
private static NNDataPair createPassFilterDataPair(GameState goState) {
|
||||
double value;
|
||||
|
||||
GameResult result = goState.getResult();
|
||||
if (goState.isTerminal()) {
|
||||
if (result.isWinner(Player.BLACK)) {
|
||||
value = 1.0; // win for black
|
||||
} else if (result.isWinner(Player.WHITE)) {
|
||||
value = 0.0; // loss for black
|
||||
} else {// tie
|
||||
value = 0.5;
|
||||
//value = 0.0; //tie
|
||||
}
|
||||
} else {
|
||||
value = 0.0;
|
||||
}
|
||||
|
||||
double[] inputValues = new double[4];
|
||||
inputValues[0] = result.isWinner(goState.getPlayerToMove()) ? 1.0 : -1.0;
|
||||
//inputValues[1] = result.isWinner(goState.getPlayerToMove()) ? -1.0 : 1.0;
|
||||
inputValues[1] = goState.isPrevPlyPass() ? 1.0 : 0.0;
|
||||
|
||||
return new NNDataPair(new NNData(PASS_INPUT_FIELDS,inputValues),new NNData(PASS_OUTPUT_FIELDS,new double[]{value}));
|
||||
}
|
||||
|
||||
/*
|
||||
private static double getNormalizedScore(GameState goState, Player player) {
|
||||
GameResult gameResult = goState.getResult();
|
||||
GameConfig gameConfig = goState.getGameConfig();
|
||||
|
||||
double maxPoints = Math.pow(gameConfig.getSize(),2);
|
||||
double komi = gameConfig.getKomi();
|
||||
|
||||
if (player == Player.BLACK) {
|
||||
return gameResult.getBlackScore() / maxPoints;
|
||||
} else if (player == Player.WHITE) {
|
||||
return gameResult.getWhiteScore() / (maxPoints + komi);
|
||||
} else {
|
||||
throw new RuntimeException("Invalid player");
|
||||
}
|
||||
}*/
|
||||
|
||||
public static NNDataPair createDataPair(State tttState) {
|
||||
double value;
|
||||
if (tttState.isTerminal()) {
|
||||
if (tttState.isWinner(PLAYER.X)) {
|
||||
value = 1.0; // win for black
|
||||
} else if (tttState.isWinner(PLAYER.O)) {
|
||||
value = 0.0; // loss for black
|
||||
//value = -1.0;
|
||||
} else {
|
||||
value = 0.5;
|
||||
//value = 0.0; //tie
|
||||
}
|
||||
} else {
|
||||
value = 0.0;
|
||||
}
|
||||
|
||||
double[] inputValues = new double[9];
|
||||
char[] boardCopy = tttState.getBoard();
|
||||
inputValues[0] = getTicTacToeInput(boardCopy, 0, 0);
|
||||
inputValues[1] = getTicTacToeInput(boardCopy, 0, 1);
|
||||
inputValues[2] = getTicTacToeInput(boardCopy, 0, 2);
|
||||
inputValues[3] = getTicTacToeInput(boardCopy, 1, 0);
|
||||
inputValues[4] = getTicTacToeInput(boardCopy, 1, 1);
|
||||
inputValues[5] = getTicTacToeInput(boardCopy, 1, 2);
|
||||
inputValues[6] = getTicTacToeInput(boardCopy, 2, 0);
|
||||
inputValues[7] = getTicTacToeInput(boardCopy, 2, 1);
|
||||
inputValues[8] = getTicTacToeInput(boardCopy, 2, 2);
|
||||
|
||||
return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value}));
|
||||
}
|
||||
|
||||
private static double getTicTacToeInput(char[] board, int row, int column) {
|
||||
switch (board[row*3+column]) {
|
||||
case 'X' :
|
||||
return 1.0;
|
||||
case 'O' :
|
||||
return -1.0;
|
||||
case '.' :
|
||||
return 0.0;
|
||||
default:
|
||||
throw new RuntimeException("Invalid board symbol at " + row +", " + column);
|
||||
}
|
||||
}
|
||||
}
|
||||
68
src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
Normal file
68
src/net/woodyfolsom/msproj/tictactoe/NeuralNetPolicy.java
Normal file
@@ -0,0 +1,68 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class NeuralNetPolicy extends Policy {
|
||||
private FeedforwardNetwork neuralNet;
|
||||
private MoveGenerator moveGenerator = new MoveGenerator();
|
||||
|
||||
public NeuralNetPolicy(FeedforwardNetwork neuralNet) {
|
||||
super("NeuralNet-" + neuralNet.getName());
|
||||
this.neuralNet = neuralNet;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(State state) {
|
||||
List<Action> validMoves = moveGenerator.getValidActions(state);
|
||||
Map<Action, Double> scores = new HashMap<Action, Double>();
|
||||
|
||||
for (Action action : validMoves) {
|
||||
State nextState = state.apply(action);
|
||||
//NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
|
||||
NNDataPair dataPair = NNDataSetFactory.createDataPair(nextState);
|
||||
//estimated reward for X
|
||||
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
|
||||
}
|
||||
|
||||
PLAYER playerToMove = state.getPlayerToMove();
|
||||
|
||||
if (playerToMove == PLAYER.X) {
|
||||
return returnMaxAction(scores);
|
||||
} else if (playerToMove == PLAYER.O) {
|
||||
return returnMinAction(scores);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove);
|
||||
}
|
||||
//return validMoves.get((int)(Math.random() * validMoves.size()));
|
||||
}
|
||||
|
||||
private Action returnMaxAction(Map<Action,Double> scores) {
|
||||
Action bestAction = null;
|
||||
Double bestScore = Double.NEGATIVE_INFINITY;
|
||||
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
|
||||
if (entry.getValue() > bestScore) {
|
||||
bestScore = entry.getValue();
|
||||
bestAction = entry.getKey();
|
||||
}
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
|
||||
private Action returnMinAction(Map<Action,Double> scores) {
|
||||
Action bestAction = null;
|
||||
Double bestScore = Double.POSITIVE_INFINITY;
|
||||
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
|
||||
if (entry.getValue() < bestScore) {
|
||||
bestScore = entry.getValue();
|
||||
bestAction = entry.getKey();
|
||||
}
|
||||
}
|
||||
return bestAction;
|
||||
}
|
||||
}
|
||||
15
src/net/woodyfolsom/msproj/tictactoe/Policy.java
Normal file
15
src/net/woodyfolsom/msproj/tictactoe/Policy.java
Normal file
@@ -0,0 +1,15 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
public abstract class Policy {
|
||||
private String name;
|
||||
|
||||
protected Policy(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public abstract Action getAction(State state);
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
18
src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
Normal file
18
src/net/woodyfolsom/msproj/tictactoe/RandomPolicy.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class RandomPolicy extends Policy {
|
||||
private MoveGenerator moveGenerator = new MoveGenerator();
|
||||
|
||||
public RandomPolicy() {
|
||||
super("Random");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Action getAction(State state) {
|
||||
List<Action> validMoves = moveGenerator.getValidActions(state);
|
||||
return validMoves.get((int)(Math.random() * validMoves.size()));
|
||||
}
|
||||
|
||||
}
|
||||
43
src/net/woodyfolsom/msproj/tictactoe/Referee.java
Normal file
43
src/net/woodyfolsom/msproj/tictactoe/Referee.java
Normal file
@@ -0,0 +1,43 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class Referee {
|
||||
|
||||
public static void main(String[] args) {
|
||||
new Referee().play(50);
|
||||
}
|
||||
|
||||
public List<GameRecord> play(int nGames) {
|
||||
Policy policy = new RandomPolicy();
|
||||
|
||||
List<GameRecord> tournament = new ArrayList<GameRecord>();
|
||||
|
||||
for (int i = 0; i < nGames; i++) {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
|
||||
System.out.println("Playing game #" +(i+1));
|
||||
|
||||
State state;
|
||||
do {
|
||||
Action action = policy.getAction(gameRecord.getState());
|
||||
System.out.println("Action " + action + " selected by policy " + policy.getName());
|
||||
state = gameRecord.apply(action);
|
||||
System.out.println("Next board state:");
|
||||
System.out.println(gameRecord.getState());
|
||||
} while (!state.isTerminal());
|
||||
System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult());
|
||||
tournament.add(gameRecord);
|
||||
}
|
||||
|
||||
System.out.println("Played " + tournament.size() + " random games.");
|
||||
System.out.println("Results:");
|
||||
for (int i = 0; i < tournament.size(); i++) {
|
||||
GameRecord gameRecord = tournament.get(i);
|
||||
System.out.println((i+1) + ". " + gameRecord.getResult());
|
||||
}
|
||||
|
||||
return tournament;
|
||||
}
|
||||
}
|
||||
116
src/net/woodyfolsom/msproj/tictactoe/State.java
Normal file
116
src/net/woodyfolsom/msproj/tictactoe/State.java
Normal file
@@ -0,0 +1,116 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class State {
|
||||
public static final State INVALID = new State();
|
||||
public static char EMPTY_SQUARE = '.';
|
||||
|
||||
private char[] board;
|
||||
private PLAYER playerToMove;
|
||||
|
||||
public State() {
|
||||
playerToMove = Game.PLAYER.X;
|
||||
board = new char[9];
|
||||
Arrays.fill(board,'.');
|
||||
}
|
||||
|
||||
private State(State that) {
|
||||
this.board = Arrays.copyOf(that.board, that.board.length);
|
||||
this.playerToMove = that.playerToMove;
|
||||
}
|
||||
|
||||
public State apply(Action action) {
|
||||
if (action.getPlayer() != playerToMove) {
|
||||
System.out.println("It is not " + action.getPlayer() +"'s turn.");
|
||||
return State.INVALID;
|
||||
}
|
||||
State nextState = new State(this);
|
||||
|
||||
int row = action.getRow();
|
||||
int column = action.getColumn();
|
||||
int dest = row * 3 + column;
|
||||
|
||||
if (board[dest] != EMPTY_SQUARE) {
|
||||
System.out.println("Invalid move " + action + ", coordinate not empty.");
|
||||
return State.INVALID;
|
||||
}
|
||||
switch (playerToMove) {
|
||||
case X : nextState.board[dest] = 'X';
|
||||
break;
|
||||
case O : nextState.board[dest] = 'O';
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Invalid playerToMove");
|
||||
}
|
||||
|
||||
if (playerToMove == PLAYER.X) {
|
||||
nextState.playerToMove = PLAYER.O;
|
||||
} else {
|
||||
nextState.playerToMove = PLAYER.X;
|
||||
}
|
||||
return nextState;
|
||||
}
|
||||
|
||||
public char[] getBoard() {
|
||||
return Arrays.copyOf(board, board.length);
|
||||
}
|
||||
|
||||
public PLAYER getPlayerToMove() {
|
||||
return playerToMove;
|
||||
}
|
||||
|
||||
public boolean isEmpty(int row, int column) {
|
||||
return board[row*3+column] == EMPTY_SQUARE;
|
||||
}
|
||||
|
||||
public boolean isFull(char mark1, char mark2, char mark3) {
|
||||
return mark1 != '.' && mark2 != '.' && mark3 != '.';
|
||||
}
|
||||
|
||||
public boolean isWinner(PLAYER player) {
|
||||
return isWin(player,board[0],board[1],board[2]) ||
|
||||
isWin(player,board[3],board[4],board[5]) ||
|
||||
isWin(player,board[6],board[7],board[8]) ||
|
||||
isWin(player,board[0],board[3],board[6]) ||
|
||||
isWin(player,board[1],board[4],board[7]) ||
|
||||
isWin(player,board[2],board[5],board[8]) ||
|
||||
isWin(player,board[0],board[4],board[8]) ||
|
||||
isWin(player,board[2],board[4],board[6]);
|
||||
}
|
||||
|
||||
public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) {
|
||||
if (isFull(mark1,mark2,mark3)) {
|
||||
switch (player) {
|
||||
case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X';
|
||||
case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O';
|
||||
default :
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isTerminal() {
|
||||
return isWinner(PLAYER.X) || isWinner(PLAYER.O) ||
|
||||
(isFull(board[0],board[1], board[2]) &&
|
||||
isFull(board[3],board[4], board[5]) &&
|
||||
isFull(board[6],board[7], board[8]));
|
||||
}
|
||||
|
||||
public boolean isValid() {
|
||||
return this != INVALID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n");
|
||||
sb.append(""+board[0] + board[1] + board[2] + "\n");
|
||||
sb.append(""+board[3] + board[4] + board[5] + "\n");
|
||||
sb.append(""+board[6] + board[7] + board[8] + "\n");
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import javax.xml.bind.JAXBException;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.Connection;
|
||||
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
|
||||
import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MultiLayerPerceptronTest {
|
||||
static final File TEST_FILE = new File("data/test/mlp.net");
|
||||
static final double EPS = 0.001;
|
||||
|
||||
@BeforeClass
|
||||
public static void setUp() {
|
||||
if (TEST_FILE.exists()) {
|
||||
TEST_FILE.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void tearDown() {
|
||||
if (TEST_FILE.exists()) {
|
||||
TEST_FILE.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConstructor() {
|
||||
new MultiLayerPerceptron(true, 2, 4, 1);
|
||||
new MultiLayerPerceptron(false, 2, 1);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testConstructorTooFewLayers() {
|
||||
new MultiLayerPerceptron(true, 2);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testConstructorTooFewNeurons() {
|
||||
new MultiLayerPerceptron(true, 2, 4, 0, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistence() throws JAXBException, IOException {
|
||||
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
|
||||
FileOutputStream fos = new FileOutputStream(TEST_FILE);
|
||||
assertTrue(mlp.save(fos));
|
||||
fos.close();
|
||||
FileInputStream fis = new FileInputStream(TEST_FILE);
|
||||
FeedforwardNetwork mlp2 = new MultiLayerPerceptron();
|
||||
assertTrue(mlp2.load(fis));
|
||||
assertEquals(mlp, mlp2);
|
||||
fis.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompute() {
|
||||
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
|
||||
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
|
||||
NNData actualOutput = mlp.compute(actual);
|
||||
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testXORnetwork() {
|
||||
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
|
||||
mlp.setWeights(new double[] {
|
||||
0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
|
||||
-0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
|
||||
-0.993423, 0.164732, 0.752621}); //output
|
||||
|
||||
for (Connection connection : mlp.getConnections()) {
|
||||
System.out.println(connection);
|
||||
}
|
||||
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.263932}));
|
||||
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
|
||||
NNData actualOutput = mlp.compute(actual);
|
||||
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileFilter;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import net.woodyfolsom.msproj.GameRecord;
|
||||
import net.woodyfolsom.msproj.Referee;
|
||||
|
||||
import org.antlr.runtime.RecognitionException;
|
||||
import org.encog.ml.data.MLData;
|
||||
import org.encog.ml.data.MLDataPair;
|
||||
import org.junit.Test;
|
||||
|
||||
public class WinFilterTest {
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException, RecognitionException {
|
||||
File[] sgfFiles = new File("data/games/random_vs_random")
|
||||
.listFiles(new FileFilter() {
|
||||
@Override
|
||||
public boolean accept(File pathname) {
|
||||
return pathname.getName().endsWith(".sgf");
|
||||
}
|
||||
});
|
||||
|
||||
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
|
||||
|
||||
for (File file : sgfFiles) {
|
||||
FileInputStream fis = new FileInputStream(file);
|
||||
GameRecord gameRecord = Referee.replay(fis);
|
||||
|
||||
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
|
||||
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
|
||||
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
|
||||
}
|
||||
|
||||
trainingData.add(gameData);
|
||||
|
||||
fis.close();
|
||||
}
|
||||
|
||||
WinFilter winFilter = new WinFilter();
|
||||
|
||||
winFilter.learn(trainingData);
|
||||
|
||||
for (List<MLDataPair> trainingSequence : trainingData) {
|
||||
//for (MLDataPair mlDataPair : trainingSequence) {
|
||||
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
|
||||
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
|
||||
continue;
|
||||
}
|
||||
MLData input = trainingSequence.get(stateIndex).getInput();
|
||||
|
||||
System.out.println("Turn " + stateIndex + ": " + input + " => "
|
||||
+ winFilter.computeValue(input));
|
||||
}
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,19 @@
|
||||
package net.woodyfolsom.msproj.ann;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.NNData;
|
||||
import net.woodyfolsom.msproj.ann.NNDataPair;
|
||||
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
|
||||
import net.woodyfolsom.msproj.ann.XORFilter;
|
||||
|
||||
import org.encog.ml.data.MLDataSet;
|
||||
import org.encog.ml.data.basic.BasicMLDataSet;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
@@ -29,10 +38,51 @@ public class XORFilterTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter();
|
||||
public void testLearn() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 1;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
double[][] trainingOutput = new double[4 * size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
|
||||
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
|
||||
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
|
||||
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
|
||||
trainingOutput[i * 4 + 0] = new double[] { 0 };
|
||||
trainingOutput[i * 4 + 1] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 2] = new double[] { 1 };
|
||||
trainingOutput[i * 4 + 3] = new double[] { 0 };
|
||||
}
|
||||
|
||||
// create training data
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(20000);
|
||||
nnLearner.learnPatterns(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
validationSet[1] = new double[] { 0, 1 };
|
||||
validationSet[2] = new double[] { 1, 0 };
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network):");
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLearnSaveLoad() throws IOException {
|
||||
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
|
||||
|
||||
// create training set (logical XOR function)
|
||||
int size = 1;
|
||||
double[][] trainingInput = new double[4 * size][];
|
||||
@@ -49,10 +99,17 @@ public class XORFilterTest {
|
||||
}
|
||||
|
||||
// create training data
|
||||
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput);
|
||||
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
|
||||
String[] inputNames = new String[] {"x","y"};
|
||||
String[] outputNames = new String[] {"XOR"};
|
||||
for (int i = 0; i < 4*size; i++) {
|
||||
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
|
||||
}
|
||||
|
||||
nnLearner.setMaxTrainingEpochs(10000);
|
||||
nnLearner.learnPatterns(trainingSet);
|
||||
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
|
||||
|
||||
nnLearner.learn(trainingSet);
|
||||
|
||||
double[][] validationSet = new double[4][2];
|
||||
|
||||
validationSet[0] = new double[] { 0, 0 };
|
||||
@@ -61,19 +118,24 @@ public class XORFilterTest {
|
||||
validationSet[3] = new double[] { 1, 1 };
|
||||
|
||||
System.out.println("Output from eval set (learned network, pre-serialization):");
|
||||
testNetwork(nnLearner, validationSet);
|
||||
|
||||
nnLearner.save(FILENAME);
|
||||
nnLearner.load(FILENAME);
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
|
||||
FileOutputStream fos = new FileOutputStream(FILENAME);
|
||||
assertTrue(nnLearner.save(fos));
|
||||
fos.close();
|
||||
|
||||
FileInputStream fis = new FileInputStream(FILENAME);
|
||||
assertTrue(nnLearner.load(fis));
|
||||
fis.close();
|
||||
|
||||
System.out.println("Output from eval set (learned network, post-serialization):");
|
||||
testNetwork(nnLearner, validationSet);
|
||||
testNetwork(nnLearner, validationSet, inputNames, outputNames);
|
||||
}
|
||||
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) {
|
||||
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
|
||||
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
|
||||
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]);
|
||||
System.out.println(dp + " => " + nnLearner.computeValue(dp));
|
||||
NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
|
||||
System.out.println(dp + " => " + nnLearner.compute(dp));
|
||||
}
|
||||
}
|
||||
}
|
||||
29
test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java
Normal file
29
test/net/woodyfolsom/msproj/ann/math/SigmoidTest.java
Normal file
@@ -0,0 +1,29 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Sigmoid;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class SigmoidTest {
|
||||
static double EPS = 0.001;
|
||||
|
||||
@Test
|
||||
public void testCalculate() {
|
||||
|
||||
ActivationFunction sigmoid = Sigmoid.function;
|
||||
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
|
||||
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
|
||||
assertTrue(sigmoid.calculate(-9000.0) < EPS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDerivative() {
|
||||
ActivationFunction sigmoid = new Tanh();
|
||||
assertEquals(1.0,sigmoid.derivative(0.0), EPS);
|
||||
}
|
||||
}
|
||||
28
test/net/woodyfolsom/msproj/ann/math/TanhTest.java
Normal file
28
test/net/woodyfolsom/msproj/ann/math/TanhTest.java
Normal file
@@ -0,0 +1,28 @@
|
||||
package net.woodyfolsom.msproj.ann.math;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
|
||||
import net.woodyfolsom.msproj.ann.math.Tanh;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class TanhTest {
|
||||
static double EPS = 0.001;
|
||||
|
||||
@Test
|
||||
public void testCalculate() {
|
||||
|
||||
ActivationFunction tanh = new Tanh();
|
||||
assertEquals(0.0,tanh.calculate(0.0),EPS);
|
||||
assertTrue(tanh.calculate(100.0) > 0.5 - EPS);
|
||||
assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDerivative() {
|
||||
ActivationFunction tanh = new Tanh();
|
||||
assertEquals(1.0,tanh.derivative(0.0), EPS);
|
||||
}
|
||||
}
|
||||
73
test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
Normal file
73
test/net/woodyfolsom/msproj/tictactoe/GameRecordTest.java
Normal file
@@ -0,0 +1,73 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
|
||||
|
||||
public class GameRecordTest {
|
||||
|
||||
@Test
|
||||
public void testGetResultXwins() {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
|
||||
State finalState = gameRecord.getState();
|
||||
System.out.println("Final state:");
|
||||
System.out.println(finalState);
|
||||
assertTrue(finalState.isValid());
|
||||
assertTrue(finalState.isTerminal());
|
||||
assertTrue(finalState.isWinner(PLAYER.X));
|
||||
assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetResultOwins() {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
|
||||
|
||||
State finalState = gameRecord.getState();
|
||||
System.out.println("Final state:");
|
||||
System.out.println(finalState);
|
||||
assertTrue(finalState.isValid());
|
||||
assertTrue(finalState.isTerminal());
|
||||
assertTrue(finalState.isWinner(PLAYER.O));
|
||||
assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetResultTieGame() {
|
||||
GameRecord gameRecord = new GameRecord();
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
|
||||
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
|
||||
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
|
||||
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
|
||||
|
||||
State finalState = gameRecord.getState();
|
||||
System.out.println("Final state:");
|
||||
System.out.println(finalState);
|
||||
assertTrue(finalState.isValid());
|
||||
assertTrue(finalState.isTerminal());
|
||||
assertFalse(finalState.isWinner(PLAYER.X));
|
||||
assertFalse(finalState.isWinner(PLAYER.O));
|
||||
assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
|
||||
}
|
||||
}
|
||||
12
test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
Normal file
12
test/net/woodyfolsom/msproj/tictactoe/RefereeTest.java
Normal file
@@ -0,0 +1,12 @@
|
||||
package net.woodyfolsom.msproj.tictactoe;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class RefereeTest {
|
||||
|
||||
@Test
|
||||
public void testPlay100Games() {
|
||||
new Referee().play(100);
|
||||
}
|
||||
|
||||
}
|
||||
189
ttt.net
Normal file
189
ttt.net
Normal file
@@ -0,0 +1,189 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<multiLayerPerceptron biased="true" name="TicTacToe">
|
||||
<activationFunction name="Sigmoid"/>
|
||||
<connections dest="10" src="0" weight="-1.139430876029846"/>
|
||||
<connections dest="10" src="1" weight="3.091584814276022"/>
|
||||
<connections dest="10" src="2" weight="-0.2551933137016801"/>
|
||||
<connections dest="10" src="3" weight="-0.7615637398659946"/>
|
||||
<connections dest="10" src="4" weight="-0.6548680752915276"/>
|
||||
<connections dest="10" src="5" weight="0.08510244492139961"/>
|
||||
<connections dest="10" src="6" weight="-0.8138255062915528"/>
|
||||
<connections dest="10" src="7" weight="-0.2048642154445006"/>
|
||||
<connections dest="10" src="8" weight="0.4118548734860931"/>
|
||||
<connections dest="10" src="9" weight="-0.3418643333437593"/>
|
||||
<connections dest="11" src="0" weight="0.985258631016843"/>
|
||||
<connections dest="11" src="1" weight="0.5585206829273895"/>
|
||||
<connections dest="11" src="2" weight="-0.00710478214319128"/>
|
||||
<connections dest="11" src="3" weight="0.4458768855799938"/>
|
||||
<connections dest="11" src="4" weight="0.699630908274699"/>
|
||||
<connections dest="11" src="5" weight="-0.291692394014361"/>
|
||||
<connections dest="11" src="6" weight="-0.3968126140382831"/>
|
||||
<connections dest="11" src="7" weight="1.5110166318959362"/>
|
||||
<connections dest="11" src="8" weight="-0.18225892993024753"/>
|
||||
<connections dest="11" src="9" weight="-0.7602259764999595"/>
|
||||
<connections dest="12" src="0" weight="1.7429897430035988"/>
|
||||
<connections dest="12" src="1" weight="-0.28322509888402325"/>
|
||||
<connections dest="12" src="2" weight="0.5040019819001578"/>
|
||||
<connections dest="12" src="3" weight="0.9359376456777513"/>
|
||||
<connections dest="12" src="4" weight="-0.15284920922664844"/>
|
||||
<connections dest="12" src="5" weight="-3.4788220747438667"/>
|
||||
<connections dest="12" src="6" weight="0.7547569837163356"/>
|
||||
<connections dest="12" src="7" weight="-0.32085504506413287"/>
|
||||
<connections dest="12" src="8" weight="0.518047606917643"/>
|
||||
<connections dest="12" src="9" weight="0.8207811849267582"/>
|
||||
<connections dest="13" src="0" weight="0.0768646322526947"/>
|
||||
<connections dest="13" src="1" weight="0.675759240542896"/>
|
||||
<connections dest="13" src="2" weight="-1.04758516390251"/>
|
||||
<connections dest="13" src="3" weight="-1.1207097351854434"/>
|
||||
<connections dest="13" src="4" weight="-1.0663558249243994"/>
|
||||
<connections dest="13" src="5" weight="0.40669746747595964"/>
|
||||
<connections dest="13" src="6" weight="-0.8040553688830026"/>
|
||||
<connections dest="13" src="7" weight="-0.810063503984392"/>
|
||||
<connections dest="13" src="8" weight="-0.63726821466013"/>
|
||||
<connections dest="13" src="9" weight="-0.062253116036353605"/>
|
||||
<connections dest="14" src="0" weight="-0.2569035720861068"/>
|
||||
<connections dest="14" src="1" weight="-0.23868649547740917"/>
|
||||
<connections dest="14" src="2" weight="0.3319329593778122"/>
|
||||
<connections dest="14" src="3" weight="0.22285129465763973"/>
|
||||
<connections dest="14" src="4" weight="-1.1932177045246797"/>
|
||||
<connections dest="14" src="5" weight="-0.8246033698516325"/>
|
||||
<connections dest="14" src="6" weight="-1.1522063213004192"/>
|
||||
<connections dest="14" src="7" weight="-0.08295162498206299"/>
|
||||
<connections dest="14" src="8" weight="0.45121422208738693"/>
|
||||
<connections dest="14" src="9" weight="0.1344210997671879"/>
|
||||
<connections dest="15" src="0" weight="-0.19080274015172097"/>
|
||||
<connections dest="15" src="1" weight="-0.08751712180997395"/>
|
||||
<connections dest="15" src="2" weight="0.6338301857587448"/>
|
||||
<connections dest="15" src="3" weight="-0.9971509770232028"/>
|
||||
<connections dest="15" src="4" weight="0.37406630555233944"/>
|
||||
<connections dest="15" src="5" weight="1.7040252761510988"/>
|
||||
<connections dest="15" src="6" weight="0.43507827352032896"/>
|
||||
<connections dest="15" src="7" weight="-1.030255483779959"/>
|
||||
<connections dest="15" src="8" weight="0.6425158958005772"/>
|
||||
<connections dest="15" src="9" weight="-0.2768699175127161"/>
|
||||
<connections dest="16" src="0" weight="0.383162191474126"/>
|
||||
<connections dest="16" src="1" weight="-0.316758353560207"/>
|
||||
<connections dest="16" src="2" weight="-0.40398044046890863"/>
|
||||
<connections dest="16" src="3" weight="-0.4103150897933657"/>
|
||||
<connections dest="16" src="4" weight="-0.2110512886314012"/>
|
||||
<connections dest="16" src="5" weight="-0.7537325411227123"/>
|
||||
<connections dest="16" src="6" weight="-0.277432410233177"/>
|
||||
<connections dest="16" src="7" weight="0.6523906983042057"/>
|
||||
<connections dest="16" src="8" weight="0.8237246362854799"/>
|
||||
<connections dest="16" src="9" weight="0.6450796646675565"/>
|
||||
<connections dest="17" src="0" weight="-1.3222355951033131"/>
|
||||
<connections dest="17" src="1" weight="-0.6775300244272042"/>
|
||||
<connections dest="17" src="2" weight="-0.9101223420136262"/>
|
||||
<connections dest="17" src="3" weight="-0.8913218057082705"/>
|
||||
<connections dest="17" src="4" weight="-0.3228919507773142"/>
|
||||
<connections dest="17" src="5" weight="0.6156397776974011"/>
|
||||
<connections dest="17" src="6" weight="-0.6008468597628974"/>
|
||||
<connections dest="17" src="7" weight="-0.3094929421425772"/>
|
||||
<connections dest="17" src="8" weight="1.4800051973199828"/>
|
||||
<connections dest="17" src="9" weight="-0.26820420703433634"/>
|
||||
<connections dest="18" src="0" weight="-0.4724752146627139"/>
|
||||
<connections dest="18" src="1" weight="-0.17278878268217254"/>
|
||||
<connections dest="18" src="2" weight="-0.3213530770778259"/>
|
||||
<connections dest="18" src="3" weight="-0.4343270409319928"/>
|
||||
<connections dest="18" src="4" weight="0.5864291732809569"/>
|
||||
<connections dest="18" src="5" weight="0.4944358358169582"/>
|
||||
<connections dest="18" src="6" weight="0.8432289820265341"/>
|
||||
<connections dest="18" src="7" weight="0.7294985221790254"/>
|
||||
<connections dest="18" src="8" weight="0.19741936496860893"/>
|
||||
<connections dest="18" src="9" weight="1.0649680979002503"/>
|
||||
<connections dest="19" src="0" weight="2.6635011409263543"/>
|
||||
<connections dest="19" src="10" weight="2.185963006026185"/>
|
||||
<connections dest="19" src="11" weight="-1.401987872790659"/>
|
||||
<connections dest="19" src="12" weight="-2.572264670917092"/>
|
||||
<connections dest="19" src="13" weight="-2.719351802228293"/>
|
||||
<connections dest="19" src="14" weight="2.14428000554082"/>
|
||||
<connections dest="19" src="15" weight="-2.5948425406968325"/>
|
||||
<connections dest="19" src="16" weight="-2.593589676600079"/>
|
||||
<connections dest="19" src="17" weight="1.2492257986319857"/>
|
||||
<connections dest="19" src="18" weight="2.7912530986331845"/>
|
||||
<neurons id="0">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="1">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="2">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="3">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="4">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="5">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="6">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="7">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="8">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="9">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
|
||||
</neurons>
|
||||
<neurons id="10">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="11">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="12">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="13">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="14">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="15">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="16">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="17">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="18">
|
||||
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
|
||||
</neurons>
|
||||
<neurons id="19">
|
||||
<activationFunction name="Sigmoid"/>
|
||||
</neurons>
|
||||
<layers>
|
||||
<neuronIds>1</neuronIds>
|
||||
<neuronIds>2</neuronIds>
|
||||
<neuronIds>3</neuronIds>
|
||||
<neuronIds>4</neuronIds>
|
||||
<neuronIds>5</neuronIds>
|
||||
<neuronIds>6</neuronIds>
|
||||
<neuronIds>7</neuronIds>
|
||||
<neuronIds>8</neuronIds>
|
||||
<neuronIds>9</neuronIds>
|
||||
</layers>
|
||||
<layers>
|
||||
<neuronIds>10</neuronIds>
|
||||
<neuronIds>11</neuronIds>
|
||||
<neuronIds>12</neuronIds>
|
||||
<neuronIds>13</neuronIds>
|
||||
<neuronIds>14</neuronIds>
|
||||
<neuronIds>15</neuronIds>
|
||||
<neuronIds>16</neuronIds>
|
||||
<neuronIds>17</neuronIds>
|
||||
<neuronIds>18</neuronIds>
|
||||
</layers>
|
||||
<layers>
|
||||
<neuronIds>19</neuronIds>
|
||||
</layers>
|
||||
</multiLayerPerceptron>
|
||||
Reference in New Issue
Block a user