Compare commits

...

10 Commits

96 changed files with 5177 additions and 1317 deletions

View File

@@ -7,6 +7,5 @@
<classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/> <classpathentry kind="lib" path="lib/log4j-1.2.16.jar"/>
<classpathentry kind="lib" path="lib/kgsGtp.jar"/> <classpathentry kind="lib" path="lib/kgsGtp.jar"/>
<classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/> <classpathentry kind="lib" path="lib/antlrworks-1.4.3.jar"/>
<classpathentry kind="lib" path="lib/encog-java-core.jar" sourcepath="lib/encog-java-core-sources.jar"/>
<classpathentry kind="output" path="bin"/> <classpathentry kind="output" path="bin"/>
</classpath> </classpath>

0
GoGame.log Normal file
View File

View File

@@ -3,6 +3,7 @@
<description>Simple Framework for Testing Tree Search and Monte-Carlo Go</description> <description>Simple Framework for Testing Tree Search and Monte-Carlo Go</description>
<property name="src" location="src" /> <property name="src" location="src" />
<property name="reports" location="reports" />
<property name="build" location="build" /> <property name="build" location="build" />
<property name="dist" location="dist" /> <property name="dist" location="dist" />
<property name="test" location="test" /> <property name="test" location="test" />
@@ -33,9 +34,16 @@
</target> </target>
<target name="copy-resources"> <target name="copy-resources">
<copy todir="${dist}/data"> <copy todir="${dist}" file="connect.bat" />
<copy todir="${dist}" file="rrt.bat" />
<copy todir="${dist}" file="data/log4j.xml" />
<copy todir="${dist}" file="data/kgsGtp.ini" />
<copy todir="${dist}" file="data/gogame.cfg" />
<!--copy todir="${dist}/data">
<fileset dir="data" /> <fileset dir="data" />
</copy> </copy-->
<copy todir="${build}/net/woodyfolsom/msproj/gui"> <copy todir="${build}/net/woodyfolsom/msproj/gui">
<fileset dir="${src}/net/woodyfolsom/msproj/gui"> <fileset dir="${src}/net/woodyfolsom/msproj/gui">
<exclude name="**/*.java"/> <exclude name="**/*.java"/>
@@ -58,6 +66,7 @@
<!-- Delete the ${build} and ${dist} directory trees --> <!-- Delete the ${build} and ${dist} directory trees -->
<delete dir="${build}" /> <delete dir="${build}" />
<delete dir="${dist}" /> <delete dir="${dist}" />
<delete dir="${reports}" />
</target> </target>
<target name="dist" depends="compile,copy-resources,copy-libs" description="generate the distribution"> <target name="dist" depends="compile,copy-resources,copy-libs" description="generate the distribution">
@@ -83,13 +92,14 @@
<target name="init"> <target name="init">
<!-- Create the build directory structure used by compile --> <!-- Create the build directory structure used by compile -->
<mkdir dir="${build}" /> <mkdir dir="${build}" />
<mkdir dir="${reports}" />
</target> </target>
<target name="test" depends="compile-test"> <target name="test" depends="compile-test">
<junit haltonfailure="true"> <junit haltonfailure="true">
<classpath refid="classpath.test" /> <classpath refid="classpath.test" />
<formatter type="brief" usefile="false" /> <formatter type="xml" />
<batchtest> <batchtest todir="${reports}">
<fileset dir="${build}" includes="**/*Test.class" /> <fileset dir="${build}" includes="**/*Test.class" />
</batchtest> </batchtest>
</junit> </junit>

1
connect.bat Normal file
View File

@@ -0,0 +1 @@
java -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame

View File

@@ -1,10 +1,10 @@
PlayerOne=RANDOM PlayerOne=ROOT_PAR_AMAF //HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF, ROOT_PAR_AMAF
PlayerTwo=RANDOM PlayerTwo=Random
GUIDelay=1000 //1 second GUIDelay=1000 //1 second
BoardSize=9 BoardSize=13 //9, 13 or 19
Komi=6.5 Komi=6.5 //suggested 6.5
NumGames=1000 //Games for each color per player NumGames=1 //Games for each color per player
TurnTime=1000 //seconds per player per turn TurnTime=6000 //seconds per player per turn
SpectatorBoardShown=false; SpectatorBoardShown=true //set to true for modes which otherwise wouldn't show GUI. false for HUMAN_GUI player.
WhiteMoveLogged=false; WhiteMoveLogged=false
BlackMoveLogged=false; BlackMoveLogged=true

View File

@@ -1,12 +1,11 @@
engine=java -cp GoGame.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame montecarlo engine=java -cp GoGame.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame montecarlo
name=whf4cs6999 name=whf4human
password=6id39p password=t3snxf
room=whf4cs6999 room=whf4cs6999
mode=custom mode=auto
talk=I'm a Monte Carlo tree search bot. talk=I'm a Monte Carlo tree search bot.
opponent=whf4human
reconnect=t reconnect=t
automatch.rank=25k
rules=chinese rules=chinese
rules.boardSize=9 rules.boardSize=9
rules.time=0 rules.time=0
opponent=whf4cs6999

3
gofree.txt Normal file
View File

@@ -0,0 +1,3 @@
UCT-RAVE vs GoFree
level 1 (black) 2/2
level 2 (black) 1/1

11
kgsGtp.ini Normal file
View File

@@ -0,0 +1,11 @@
engine=java -cp GoGame.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.GoGame montecarlo
name=whf4human
password=t3snxf
room=whf4cs6999
mode=auto
talk=I'm a Monte Carlo tree search bot.
reconnect=t
rules=chinese
rules.boardSize=13
rules.time=0
opponent=whf4cs6999

BIN
lib/activation-1.0.2.jar Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
lib/jaxb-api-2.3.1.jar Normal file

Binary file not shown.

19
log4j.xml Normal file
View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE log4j:configuration SYSTEM "log4j.dtd">
<log4j:configuration xmlns:log4j="http://jakarta.apache.org/log4j/" debug="false">
<appender name="fileAppender" class="org.apache.log4j.RollingFileAppender">
<param name="Threshold" value="INFO" />
<param name="File" value="GoGame.log"/>
<layout class="org.apache.log4j.PatternLayout">
<param name="ConversionPattern" value="%d %-5p [%c{1}] %m %n" />
</layout>
</appender>
<logger name="cs6601.p1" additivity="false" >
<level value="INFO" />
<appender-ref ref="fileAppender"/>
</logger>
</log4j:configuration>

42
pass.net Normal file
View File

@@ -0,0 +1,42 @@
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<multiLayerPerceptron biased="true" name="PassFilter9">
<activationFunction name="Sigmoid"/>
<connections dest="3" src="0" weight="1.9322455294572656"/>
<connections dest="3" src="1" weight="0.0859943747020325"/>
<connections dest="3" src="2" weight="0.8394414489841715"/>
<connections dest="4" src="0" weight="1.5831613048108952"/>
<connections dest="4" src="1" weight="0.8667080746254153"/>
<connections dest="4" src="2" weight="-3.204930958551688"/>
<connections dest="5" src="0" weight="-1.4223906119706369"/>
<connections dest="5" src="3" weight="-2.1292730695450857"/>
<connections dest="5" src="4" weight="-2.5861434868493607"/>
<neurons id="0">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="1">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="2">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="3">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="4">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="5">
<activationFunction name="Sigmoid"/>
</neurons>
<layers>
<neuronIds>1</neuronIds>
<neuronIds>2</neuronIds>
</layers>
<layers>
<neuronIds>3</neuronIds>
<neuronIds>4</neuronIds>
</layers>
<layers>
<neuronIds>5</neuronIds>
</layers>
</multiLayerPerceptron>

1
rrt.bat Normal file
View File

@@ -0,0 +1 @@
java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin

View File

@@ -0,0 +1,85 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+5.5
RootParallelization (Black) vs Random (White) : B+5.5
Game over. Result: B+30.5
RootParallelization (Black) vs Random (White) : B+30.5
Game over. Result: B+0.5
RootParallelization (Black) vs Random (White) : B+0.5
Game over. Result: B+18.5
RootParallelization (Black) vs Random (White) : B+18.5
Game over. Result: B+18.5
RootParallelization (Black) vs Random (White) : B+18.5
Game over. Result: B+3.5
RootParallelization (Black) vs Alpha-Beta (White) : B+3.5
Game over. Result: B+0.5
RootParallelization (Black) vs Alpha-Beta (White) : B+0.5
Game over. Result: B+46.5
RootParallelization (Black) vs Alpha-Beta (White) : B+46.5
Game over. Result: B+44.5
RootParallelization (Black) vs Alpha-Beta (White) : B+44.5
Game over. Result: B+53.5
RootParallelization (Black) vs Alpha-Beta (White) : B+53.5
Game over. Result: B+14.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+14.5
Game over. Result: B+30.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+30.5
Game over. Result: B+9.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+9.5
Game over. Result: B+44.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+44.5
Game over. Result: B+29.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+29.5
Game over. Result: B+4.5
RootParallelization (Black) vs UCT-RAVE (White) : B+4.5
Game over. Result: B+27.5
RootParallelization (Black) vs UCT-RAVE (White) : B+27.5
Game over. Result: B+29.5
RootParallelization (Black) vs UCT-RAVE (White) : B+29.5
Game over. Result: B+22.5
RootParallelization (Black) vs UCT-RAVE (White) : B+22.5
Game over. Result: B+36.5
RootParallelization (Black) vs UCT-RAVE (White) : B+36.5
Game over. Result: B+50.5
RootParallelization (Black) vs MonteCarloSMAF (White) : B+50.5
Game over. Result: B+42.5
RootParallelization (Black) vs MonteCarloSMAF (White) : B+42.5
Game over. Result: B+28.5
RootParallelization (Black) vs MonteCarloSMAF (White) : B+28.5
Game over. Result: B+38.5
RootParallelization (Black) vs MonteCarloSMAF (White) : B+38.5
Game over. Result: B+23.5
RootParallelization (Black) vs MonteCarloSMAF (White) : B+23.5
Game over. Result: B+7.5
RootParallelization (Black) vs RootParallelization (White) : B+7.5
Game over. Result: W+16.5
RootParallelization (Black) vs RootParallelization (White) : W+16.5
Game over. Result: B+0.5
RootParallelization (Black) vs RootParallelization (White) : B+0.5
Game over. Result: B+8.5
RootParallelization (Black) vs RootParallelization (White) : B+8.5
Game over. Result: W+19.5
RootParallelization (Black) vs RootParallelization (White) : W+19.5
Game over. Result: B+13.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+13.5
Game over. Result: B+2.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+2.5
Game over. Result: B+16.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+16.5
Game over. Result: B+32.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+32.5
Game over. Result: B+8.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+8.5
Tournament Win Rates
====================
RootParallelization (Black) vs Random (White) : 100%
RootParallelization (Black) vs Alpha-Beta (White) : 100%
RootParallelization (Black) vs MonteCarloUCT (White) : 100%
RootParallelization (Black) vs UCT-RAVE (White) : 100%
RootParallelization (Black) vs MonteCarloSMAF (White) : 100%
RootParallelization (Black) vs RootParallelization (White) : 60%
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : 100%
Tournament lasted 1597.948 seconds.

View File

@@ -0,0 +1,84 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+2.5
Alpha-Beta (Black) vs Random (White) : B+2.5
Game over. Result: W+3.5
Alpha-Beta (Black) vs Random (White) : W+3.5
Game over. Result: W+4.5
Alpha-Beta (Black) vs Random (White) : W+4.5
Game over. Result: W+1.5
Alpha-Beta (Black) vs Random (White) : W+1.5
Game over. Result: W+2.5
Alpha-Beta (Black) vs Random (White) : W+2.5
Game over. Result: B+40.5
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
Game over. Result: B+40.5
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
Game over. Result: B+40.5
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
Game over. Result: B+40.5
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
Game over. Result: B+40.5
Alpha-Beta (Black) vs Alpha-Beta (White) : B+40.5
Game over. Result: W+17.5
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+17.5
Game over. Result: W+40.5
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+40.5
Game over. Result: W+18.5
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+18.5
Game over. Result: W+30.5
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+30.5
Game over. Result: W+33.5
Alpha-Beta (Black) vs MonteCarloUCT (White) : W+33.5
Game over. Result: W+32.5
Alpha-Beta (Black) vs UCT-RAVE (White) : W+32.5
Game over. Result: W+41.5
Alpha-Beta (Black) vs UCT-RAVE (White) : W+41.5
Game over. Result: W+36.5
Alpha-Beta (Black) vs UCT-RAVE (White) : W+36.5
Game over. Result: W+40.5
Alpha-Beta (Black) vs UCT-RAVE (White) : W+40.5
Game over. Result: W+34.5
Alpha-Beta (Black) vs UCT-RAVE (White) : W+34.5
Game over. Result: W+6.5
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+6.5
Game over. Result: W+23.5
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+23.5
Game over. Result: W+18.5
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+18.5
Game over. Result: W+33.5
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+33.5
Game over. Result: W+40.5
Alpha-Beta (Black) vs MonteCarloSMAF (White) : W+40.5
Game over. Result: W+1.5
Alpha-Beta (Black) vs RootParallelization (White) : W+1.5
Game over. Result: W+4.5
Alpha-Beta (Black) vs RootParallelization (White) : W+4.5
Game over. Result: W+0.5
Alpha-Beta (Black) vs RootParallelization (White) : W+0.5
Game over. Result: W+0.5
Alpha-Beta (Black) vs RootParallelization (White) : W+0.5
Game over. Result: W+35.5
Alpha-Beta (Black) vs RootParallelization (White) : W+35.5
Game over. Result: W+3.5
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+3.5
Game over. Result: W+0.5
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+0.5
Game over. Result: W+1.5
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+1.5
Game over. Result: W+4.5
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+4.5
Game over. Result: W+4.5
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : W+4.5
Tournament Win Rates
====================
Alpha-Beta (Black) vs Random (White) : 20%
Alpha-Beta (Black) vs Alpha-Beta (White) : 100%
Alpha-Beta (Black) vs MonteCarloUCT (White) : 00%
Alpha-Beta (Black) vs UCT-RAVE (White) : 00%
Alpha-Beta (Black) vs MonteCarloSMAF (White) : 00%
Alpha-Beta (Black) vs RootParallelization (White) : 00%
Alpha-Beta (Black) vs RootParallelization-NeuralNet (White) : 00%

85
rrt/rrt.amaf.black.txt Normal file
View File

@@ -0,0 +1,85 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+40.5
UCT-RAVE (Black) vs Random (White) : B+40.5
Game over. Result: B+3.5
UCT-RAVE (Black) vs Random (White) : B+3.5
Game over. Result: B+15.5
UCT-RAVE (Black) vs Random (White) : B+15.5
Game over. Result: B+54.5
UCT-RAVE (Black) vs Random (White) : B+54.5
Game over. Result: B+18.5
UCT-RAVE (Black) vs Random (White) : B+18.5
Game over. Result: B+31.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+31.5
Game over. Result: B+17.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+17.5
Game over. Result: W+9.5
UCT-RAVE (Black) vs Alpha-Beta (White) : W+9.5
Game over. Result: B+34.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+34.5
Game over. Result: W+9.5
UCT-RAVE (Black) vs Alpha-Beta (White) : W+9.5
Game over. Result: B+2.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+2.5
Game over. Result: B+36.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+36.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+9.5
Game over. Result: W+2.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : W+2.5
Game over. Result: B+1.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+1.5
Game over. Result: B+22.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+22.5
Game over. Result: B+5.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+5.5
Game over. Result: B+2.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+2.5
Game over. Result: B+11.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+11.5
Game over. Result: W+11.5
UCT-RAVE (Black) vs UCT-RAVE (White) : W+11.5
Game over. Result: B+7.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+7.5
Game over. Result: B+39.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+39.5
Game over. Result: W+15.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+15.5
Game over. Result: W+22.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+22.5
Game over. Result: W+3.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+3.5
Game over. Result: B+20.5
UCT-RAVE (Black) vs RootParallelization (White) : B+20.5
Game over. Result: B+29.5
UCT-RAVE (Black) vs RootParallelization (White) : B+29.5
Game over. Result: B+41.5
UCT-RAVE (Black) vs RootParallelization (White) : B+41.5
Game over. Result: B+36.5
UCT-RAVE (Black) vs RootParallelization (White) : B+36.5
Game over. Result: B+18.5
UCT-RAVE (Black) vs RootParallelization (White) : B+18.5
Game over. Result: B+54.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+54.5
Game over. Result: B+7.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+7.5
Game over. Result: B+19.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+19.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+9.5
Game over. Result: B+3.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+3.5
Tournament Win Rates
====================
UCT-RAVE (Black) vs Random (White) : 100%
UCT-RAVE (Black) vs Alpha-Beta (White) : 60%
UCT-RAVE (Black) vs MonteCarloUCT (White) : 80%
UCT-RAVE (Black) vs UCT-RAVE (White) : 80%
UCT-RAVE (Black) vs MonteCarloSMAF (White) : 40%
UCT-RAVE (Black) vs RootParallelization (White) : 100%
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : 100%
Tournament lasted 1476.893 seconds.

86
rrt/rrt.random.black.txt Normal file
View File

@@ -0,0 +1,86 @@
C:\workspace\msproj\dist>java -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: W+6.5
Random (Black) vs Random (White) : W+6.5
Game over. Result: B+1.5
Random (Black) vs Random (White) : B+1.5
Game over. Result: B+7.5
Random (Black) vs Random (White) : B+7.5
Game over. Result: W+0.5
Random (Black) vs Random (White) : W+0.5
Game over. Result: B+1.5
Random (Black) vs Random (White) : B+1.5
Game over. Result: B+28.5
Random (Black) vs Alpha-Beta (White) : B+28.5
Game over. Result: B+1.5
Random (Black) vs Alpha-Beta (White) : B+1.5
Game over. Result: B+29.5
Random (Black) vs Alpha-Beta (White) : B+29.5
Game over. Result: B+47.5
Random (Black) vs Alpha-Beta (White) : B+47.5
Game over. Result: B+22.5
Random (Black) vs Alpha-Beta (White) : B+22.5
Game over. Result: W+22.5
Random (Black) vs MonteCarloUCT (White) : W+22.5
Game over. Result: W+6.5
Random (Black) vs MonteCarloUCT (White) : W+6.5
Game over. Result: W+5.5
Random (Black) vs MonteCarloUCT (White) : W+5.5
Game over. Result: W+12.5
Random (Black) vs MonteCarloUCT (White) : W+12.5
Game over. Result: W+35.5
Random (Black) vs MonteCarloUCT (White) : W+35.5
Game over. Result: W+14.5
Random (Black) vs UCT-RAVE (White) : W+14.5
Game over. Result: W+18.5
Random (Black) vs UCT-RAVE (White) : W+18.5
Game over. Result: W+3.5
Random (Black) vs UCT-RAVE (White) : W+3.5
Game over. Result: W+5.5
Random (Black) vs UCT-RAVE (White) : W+5.5
Game over. Result: W+32.5
Random (Black) vs UCT-RAVE (White) : W+32.5
Game over. Result: W+19.5
Random (Black) vs MonteCarloSMAF (White) : W+19.5
Game over. Result: W+26.5
Random (Black) vs MonteCarloSMAF (White) : W+26.5
Game over. Result: W+19.5
Random (Black) vs MonteCarloSMAF (White) : W+19.5
Game over. Result: W+8.5
Random (Black) vs MonteCarloSMAF (White) : W+8.5
Game over. Result: W+13.5
Random (Black) vs MonteCarloSMAF (White) : W+13.5
Game over. Result: W+9.5
Random (Black) vs RootParallelization (White) : W+9.5
Game over. Result: W+4.5
Random (Black) vs RootParallelization (White) : W+4.5
Game over. Result: W+8.5
Random (Black) vs RootParallelization (White) : W+8.5
Game over. Result: W+39.5
Random (Black) vs RootParallelization (White) : W+39.5
Game over. Result: W+0.5
Random (Black) vs RootParallelization (White) : W+0.5
Game over. Result: W+10.5
Random (Black) vs RootParallelization-NeuralNet (White) : W+10.5
Game over. Result: W+11.5
Random (Black) vs RootParallelization-NeuralNet (White) : W+11.5
Game over. Result: W+1.5
Random (Black) vs RootParallelization-NeuralNet (White) : W+1.5
Game over. Result: W+3.5
Random (Black) vs RootParallelization-NeuralNet (White) : W+3.5
Game over. Result: W+10.5
Random (Black) vs RootParallelization-NeuralNet (White) : W+10.5
Game over. Result: W+40.5
Tournament Win Rates
====================
Random (Black) vs Random (White) : 40%
Random (Black) vs Alpha-Beta (White) : 100%
Random (Black) vs MonteCarloUCT (White) : 00%
Random (Black) vs UCT-RAVE (White) : 00%
Random (Black) vs MonteCarloSMAF (White) : 00%
Random (Black) vs RootParallelization (White) : 00%
Random (Black) vs RootParallelization-NeuralNet (White) : 00%

83
rrt/rrt.rave.black.txt Normal file
View File

@@ -0,0 +1,83 @@
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+8.5
UCT-RAVE (Black) vs Random (White) : B+8.5
Game over. Result: B+31.5
UCT-RAVE (Black) vs Random (White) : B+31.5
Game over. Result: B+16.5
UCT-RAVE (Black) vs Random (White) : B+16.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs Random (White) : B+9.5
Game over. Result: B+16.5
UCT-RAVE (Black) vs Random (White) : B+16.5
Game over. Result: B+48.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+48.5
Game over. Result: W+5.5
UCT-RAVE (Black) vs Alpha-Beta (White) : W+5.5
Game over. Result: B+13.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+13.5
Game over. Result: B+34.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+34.5
Game over. Result: B+1.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+1.5
Game over. Result: B+2.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+2.5
Game over. Result: B+7.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+7.5
Game over. Result: W+4.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : W+4.5
Game over. Result: B+3.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+3.5
Game over. Result: B+6.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+6.5
Game over. Result: B+3.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+3.5
Game over. Result: B+2.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+2.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+9.5
Game over. Result: B+0.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+0.5
Game over. Result: W+13.5
UCT-RAVE (Black) vs UCT-RAVE (White) : W+13.5
Game over. Result: W+0.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
Game over. Result: B+1.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+1.5
Game over. Result: W+0.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+9.5
Game over. Result: W+20.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+20.5
Game over. Result: B+13.5
UCT-RAVE (Black) vs RootParallelization (White) : B+13.5
Game over. Result: W+16.5
UCT-RAVE (Black) vs RootParallelization (White) : W+16.5
Game over. Result: B+28.5
UCT-RAVE (Black) vs RootParallelization (White) : B+28.5
Game over. Result: B+25.5
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
Game over. Result: B+25.5
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
Game over. Result: B+48.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+48.5
Game over. Result: B+6.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+6.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+9.5
Game over. Result: B+55.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+55.5
Game over. Result: B+42.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+42.5
Tournament Win Rates
====================
UCT-RAVE (Black) vs Random (White) : 100%
UCT-RAVE (Black) vs Alpha-Beta (White) : 80%
UCT-RAVE (Black) vs MonteCarloUCT (White) : 80%
UCT-RAVE (Black) vs UCT-RAVE (White) : 80%
UCT-RAVE (Black) vs MonteCarloSMAF (White) : 40%
UCT-RAVE (Black) vs RootParallelization (White) : 80%
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : 100%
Tournament lasted 1458.494 seconds.

View File

@@ -0,0 +1,85 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+6.5
RootParallelization-NeuralNet (Black) vs Random (White) : B+6.5
Game over. Result: B+0.5
RootParallelization-NeuralNet (Black) vs Random (White) : B+0.5
Game over. Result: B+5.5
RootParallelization-NeuralNet (Black) vs Random (White) : B+5.5
Game over. Result: B+19.5
RootParallelization-NeuralNet (Black) vs Random (White) : B+19.5
Game over. Result: B+2.5
RootParallelization-NeuralNet (Black) vs Random (White) : B+2.5
Game over. Result: B+21.5
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : B+21.5
Game over. Result: W+12.5
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : W+12.5
Game over. Result: B+23.5
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : B+23.5
Game over. Result: B+23.5
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : B+23.5
Game over. Result: W+9.5
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : W+9.5
Game over. Result: B+29.5
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+29.5
Game over. Result: B+9.5
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+9.5
Game over. Result: W+50.5
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : W+50.5
Game over. Result: B+9.5
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+9.5
Game over. Result: B+7.5
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : B+7.5
Game over. Result: W+12.5
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+12.5
Game over. Result: W+9.5
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+9.5
Game over. Result: W+29.5
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+29.5
Game over. Result: W+10.5
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+10.5
Game over. Result: W+27.5
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : W+27.5
Game over. Result: W+2.5
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+2.5
Game over. Result: W+22.5
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+22.5
Game over. Result: W+10.5
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+10.5
Game over. Result: W+41.5
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+41.5
Game over. Result: W+18.5
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : W+18.5
Game over. Result: B+3.5
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : B+3.5
Game over. Result: W+10.5
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+10.5
Game over. Result: W+14.5
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+14.5
Game over. Result: W+5.5
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+5.5
Game over. Result: W+6.5
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : W+6.5
Game over. Result: W+8.5
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : W+8.5
Game over. Result: W+11.5
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : W+11.5
Game over. Result: W+6.5
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : W+6.5
Game over. Result: B+2.5
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : B+2.5
Game over. Result: B+21.5
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : B+21.5
Tournament Win Rates
====================
RootParallelization-NeuralNet (Black) vs Random (White) : 100%
RootParallelization-NeuralNet (Black) vs Alpha-Beta (White) : 60%
RootParallelization-NeuralNet (Black) vs MonteCarloUCT (White) : 80%
RootParallelization-NeuralNet (Black) vs UCT-RAVE (White) : 00%
RootParallelization-NeuralNet (Black) vs MonteCarloSMAF (White) : 00%
RootParallelization-NeuralNet (Black) vs RootParallelization (White) : 20%
RootParallelization-NeuralNet (Black) vs RootParallelization-NeuralNet (White) : 40%
Tournament lasted 1400.277 seconds.

85
rrt/rrt.rootpar.black.txt Normal file
View File

@@ -0,0 +1,85 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+4.5
RootParallelization (Black) vs Random (White) : B+4.5
Game over. Result: B+4.5
RootParallelization (Black) vs Random (White) : B+4.5
Game over. Result: B+1.5
RootParallelization (Black) vs Random (White) : B+1.5
Game over. Result: B+1.5
RootParallelization (Black) vs Random (White) : B+1.5
Game over. Result: B+0.5
RootParallelization (Black) vs Random (White) : B+0.5
Game over. Result: B+20.5
RootParallelization (Black) vs Alpha-Beta (White) : B+20.5
Game over. Result: B+23.5
RootParallelization (Black) vs Alpha-Beta (White) : B+23.5
Game over. Result: W+9.5
RootParallelization (Black) vs Alpha-Beta (White) : W+9.5
Game over. Result: W+7.5
RootParallelization (Black) vs Alpha-Beta (White) : W+7.5
Game over. Result: B+25.5
RootParallelization (Black) vs Alpha-Beta (White) : B+25.5
Game over. Result: B+0.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+0.5
Game over. Result: B+11.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+11.5
Game over. Result: W+0.5
RootParallelization (Black) vs MonteCarloUCT (White) : W+0.5
Game over. Result: B+1.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+1.5
Game over. Result: B+0.5
RootParallelization (Black) vs MonteCarloUCT (White) : B+0.5
Game over. Result: W+22.5
RootParallelization (Black) vs UCT-RAVE (White) : W+22.5
Game over. Result: W+63.5
RootParallelization (Black) vs UCT-RAVE (White) : W+63.5
Game over. Result: W+29.5
RootParallelization (Black) vs UCT-RAVE (White) : W+29.5
Game over. Result: W+58.5
RootParallelization (Black) vs UCT-RAVE (White) : W+58.5
Game over. Result: W+30.5
RootParallelization (Black) vs UCT-RAVE (White) : W+30.5
Game over. Result: W+15.5
RootParallelization (Black) vs MonteCarloSMAF (White) : W+15.5
Game over. Result: W+62.5
RootParallelization (Black) vs MonteCarloSMAF (White) : W+62.5
Game over. Result: W+57.5
RootParallelization (Black) vs MonteCarloSMAF (White) : W+57.5
Game over. Result: W+57.5
RootParallelization (Black) vs MonteCarloSMAF (White) : W+57.5
Game over. Result: W+12.5
RootParallelization (Black) vs MonteCarloSMAF (White) : W+12.5
Game over. Result: B+2.5
RootParallelization (Black) vs RootParallelization (White) : B+2.5
Game over. Result: W+6.5
RootParallelization (Black) vs RootParallelization (White) : W+6.5
Game over. Result: B+2.5
RootParallelization (Black) vs RootParallelization (White) : B+2.5
Game over. Result: W+5.5
RootParallelization (Black) vs RootParallelization (White) : W+5.5
Game over. Result: B+2.5
RootParallelization (Black) vs RootParallelization (White) : B+2.5
Game over. Result: W+8.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+8.5
Game over. Result: W+6.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+6.5
Game over. Result: W+6.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+6.5
Game over. Result: B+3.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : B+3.5
Game over. Result: W+13.5
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : W+13.5
Tournament Win Rates
====================
RootParallelization (Black) vs Random (White) : 100%
RootParallelization (Black) vs Alpha-Beta (White) : 60%
RootParallelization (Black) vs MonteCarloUCT (White) : 80%
RootParallelization (Black) vs UCT-RAVE (White) : 00%
RootParallelization (Black) vs MonteCarloSMAF (White) : 00%
RootParallelization (Black) vs RootParallelization (White) : 60%
RootParallelization (Black) vs RootParallelization-NeuralNet (White) : 20%
Tournament lasted 1367.523 seconds.

85
rrt/rrt.smaf.black.txt Normal file
View File

@@ -0,0 +1,85 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+8.5
UCT-RAVE (Black) vs Random (White) : B+8.5
Game over. Result: B+31.5
UCT-RAVE (Black) vs Random (White) : B+31.5
Game over. Result: B+16.5
UCT-RAVE (Black) vs Random (White) : B+16.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs Random (White) : B+9.5
Game over. Result: B+16.5
UCT-RAVE (Black) vs Random (White) : B+16.5
Game over. Result: B+48.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+48.5
Game over. Result: W+5.5
UCT-RAVE (Black) vs Alpha-Beta (White) : W+5.5
Game over. Result: B+13.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+13.5
Game over. Result: B+34.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+34.5
Game over. Result: B+1.5
UCT-RAVE (Black) vs Alpha-Beta (White) : B+1.5
Game over. Result: B+2.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+2.5
Game over. Result: B+7.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+7.5
Game over. Result: W+4.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : W+4.5
Game over. Result: B+3.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+3.5
Game over. Result: B+6.5
UCT-RAVE (Black) vs MonteCarloUCT (White) : B+6.5
Game over. Result: B+3.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+3.5
Game over. Result: B+2.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+2.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+9.5
Game over. Result: B+0.5
UCT-RAVE (Black) vs UCT-RAVE (White) : B+0.5
Game over. Result: W+13.5
UCT-RAVE (Black) vs UCT-RAVE (White) : W+13.5
Game over. Result: W+0.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
Game over. Result: B+1.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+1.5
Game over. Result: W+0.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+0.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : B+9.5
Game over. Result: W+20.5
UCT-RAVE (Black) vs MonteCarloSMAF (White) : W+20.5
Game over. Result: B+13.5
UCT-RAVE (Black) vs RootParallelization (White) : B+13.5
Game over. Result: W+16.5
UCT-RAVE (Black) vs RootParallelization (White) : W+16.5
Game over. Result: B+28.5
UCT-RAVE (Black) vs RootParallelization (White) : B+28.5
Game over. Result: B+25.5
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
Game over. Result: B+25.5
UCT-RAVE (Black) vs RootParallelization (White) : B+25.5
Game over. Result: B+48.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+48.5
Game over. Result: B+6.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+6.5
Game over. Result: B+9.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+9.5
Game over. Result: B+55.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+55.5
Game over. Result: B+42.5
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : B+42.5
Tournament Win Rates
====================
UCT-RAVE (Black) vs Random (White) : 100%
UCT-RAVE (Black) vs Alpha-Beta (White) : 80%
UCT-RAVE (Black) vs MonteCarloUCT (White) : 80%
UCT-RAVE (Black) vs UCT-RAVE (White) : 80%
UCT-RAVE (Black) vs MonteCarloSMAF (White) : 40%
UCT-RAVE (Black) vs RootParallelization (White) : 80%
UCT-RAVE (Black) vs RootParallelization-NeuralNet (White) : 100%
Tournament lasted 1458.494 seconds.

85
rrt/rrt.uct.black Normal file
View File

@@ -0,0 +1,85 @@
C:\workspace\msproj\dist>java -Xms256m -Xmx4096m -cp GoGame.jar;antlrworks-1.4.3.jar;kgsGtp.jar;log4j-1.2.16.jar net.woodyfolsom.msproj.RoundRobin
Beginning round-robin tournament.
Initializing policies...
Game over. Result: B+0.5
MonteCarloUCT (Black) vs Random (White) : B+0.5
Game over. Result: B+9.5
MonteCarloUCT (Black) vs Random (White) : B+9.5
Game over. Result: B+24.5
MonteCarloUCT (Black) vs Random (White) : B+24.5
Game over. Result: B+10.5
MonteCarloUCT (Black) vs Random (White) : B+10.5
Game over. Result: B+4.5
MonteCarloUCT (Black) vs Random (White) : B+4.5
Game over. Result: B+15.5
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+15.5
Game over. Result: B+22.5
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+22.5
Game over. Result: B+32.5
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+32.5
Game over. Result: W+12.5
MonteCarloUCT (Black) vs Alpha-Beta (White) : W+12.5
Game over. Result: B+23.5
MonteCarloUCT (Black) vs Alpha-Beta (White) : B+23.5
Game over. Result: B+0.5
MonteCarloUCT (Black) vs MonteCarloUCT (White) : B+0.5
Game over. Result: W+13.5
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+13.5
Game over. Result: W+11.5
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+11.5
Game over. Result: W+8.5
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+8.5
Game over. Result: W+9.5
MonteCarloUCT (Black) vs MonteCarloUCT (White) : W+9.5
Game over. Result: W+9.5
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+9.5
Game over. Result: W+16.5
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+16.5
Game over. Result: W+8.5
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+8.5
Game over. Result: W+11.5
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+11.5
Game over. Result: W+5.5
MonteCarloUCT (Black) vs UCT-RAVE (White) : W+5.5
Game over. Result: W+8.5
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+8.5
Game over. Result: W+9.5
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+9.5
Game over. Result: W+15.5
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+15.5
Game over. Result: W+14.5
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+14.5
Game over. Result: W+13.5
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : W+13.5
Game over. Result: W+15.5
MonteCarloUCT (Black) vs RootParallelization (White) : W+15.5
Game over. Result: W+14.5
MonteCarloUCT (Black) vs RootParallelization (White) : W+14.5
Game over. Result: W+6.5
MonteCarloUCT (Black) vs RootParallelization (White) : W+6.5
Game over. Result: W+6.5
MonteCarloUCT (Black) vs RootParallelization (White) : W+6.5
Game over. Result: W+11.5
MonteCarloUCT (Black) vs RootParallelization (White) : W+11.5
Game over. Result: W+26.5
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+26.5
Game over. Result: W+11.5
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+11.5
Game over. Result: W+47.5
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+47.5
Game over. Result: W+13.5
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : W+13.5
Game over. Result: B+33.5
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : B+33.5
Tournament Win Rates
====================
MonteCarloUCT (Black) vs Random (White) : 100%
MonteCarloUCT (Black) vs Alpha-Beta (White) : 80%
MonteCarloUCT (Black) vs MonteCarloUCT (White) : 20%
MonteCarloUCT (Black) vs UCT-RAVE (White) : 00%
MonteCarloUCT (Black) vs MonteCarloSMAF (White) : 00%
MonteCarloUCT (Black) vs RootParallelization (White) : 00%
MonteCarloUCT (Black) vs RootParallelization-NeuralNet (White) : 20%
Tournament lasted 1355.668 seconds.

View File

@@ -59,6 +59,14 @@ public class GameRecord {
return gameStates.get(0).getGameConfig(); return gameStates.get(0).getGameConfig();
} }
/**
* Gets the game state for the most recent ply.
* @return
*/
public GameState getGameState() {
return gameStates.get(getNumTurns());
}
public GameState getGameState(Integer turn) { public GameState getGameState(Integer turn) {
return gameStates.get(turn); return gameStates.get(turn);
} }

View File

@@ -119,6 +119,14 @@ public class GameState {
return whitePrisoners; return whitePrisoners;
} }
public boolean isPrevPlyPass() {
if (moveHistory.size() == 0) {
return false;
} else {
return moveHistory.get(moveHistory.size()-1).isPass();
}
}
public boolean isSelfFill(Action action, Player player) { public boolean isSelfFill(Action action, Player player) {
return gameBoard.isSelfFill(action, player); return gameBoard.isSelfFill(action, player);
} }

View File

@@ -16,6 +16,7 @@ import net.woodyfolsom.msproj.policy.Minimax;
import net.woodyfolsom.msproj.policy.MonteCarloUCT; import net.woodyfolsom.msproj.policy.MonteCarloUCT;
import net.woodyfolsom.msproj.policy.Policy; import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.policy.RandomMovePolicy; import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParAMAF;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.apache.log4j.xml.DOMConfigurator; import org.apache.log4j.xml.DOMConfigurator;
@@ -80,10 +81,11 @@ public class GoGame implements Runnable {
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
configureLogging(); configureLogging();
if (args.length == 0) { if (args.length == 0) {
Policy defaultMoveGenerator = new MonteCarloUCT(new RandomMovePolicy(), 5000L); Policy policy = new RootParAMAF(4, 10000L);
LOGGER.info("No MoveGenerator specified. Using default: " + defaultMoveGenerator.toString()); policy.setLogging(true);
LOGGER.info("No MoveGenerator specified. Using default: " + policy.getName());
GoGame goGame = new GoGame(defaultMoveGenerator, PROPS_FILE); GoGame goGame = new GoGame(policy, PROPS_FILE);
new Thread(goGame).start(); new Thread(goGame).start();
System.out.println("Creating GtpClient"); System.out.println("Creating GtpClient");
@@ -111,7 +113,9 @@ public class GoGame implements Runnable {
} else if ("alphabeta".equals(policyName)) { } else if ("alphabeta".equals(policyName)) {
return new AlphaBeta(); return new AlphaBeta();
} else if ("montecarlo".equals(policyName)) { } else if ("montecarlo".equals(policyName)) {
return new MonteCarloUCT(new RandomMovePolicy(), 5000L); return new MonteCarloUCT(new RandomMovePolicy(), 10000L);
} else if ("root_par_amaf".equals(policyName)) {
return new RootParAMAF(4, 10000L);
} else { } else {
LOGGER.info("Unable to create Policy for unsupported name: " + policyName); LOGGER.info("Unable to create Policy for unsupported name: " + policyName);
System.exit(INVALID_MOVE_GENERATOR); System.exit(INVALID_MOVE_GENERATOR);

View File

@@ -91,7 +91,6 @@ public class Referee {
while (!gameRecord.isFinished()) { while (!gameRecord.isFinished()) {
GameState gameState = gameRecord.getGameState(gameRecord GameState gameState = gameRecord.getGameState(gameRecord
.getNumTurns()); .getNumTurns());
// System.out.println(gameState);
Player playerToMove = gameRecord.getPlayerToMove(); Player playerToMove = gameRecord.getPlayerToMove();
Policy policy = getPolicy(playerToMove); Policy policy = getPolicy(playerToMove);
@@ -108,6 +107,11 @@ public class Referee {
} else { } else {
System.out.println("Move rejected - try again."); System.out.println("Move rejected - try again.");
} }
if (policy.isLogging()) {
System.out.println(gameState);
}
} }
} catch (Exception ex) { } catch (Exception ex) {
System.out System.out

View File

@@ -0,0 +1,115 @@
package net.woodyfolsom.msproj;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.policy.AlphaBeta;
import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
import net.woodyfolsom.msproj.policy.MonteCarloSMAF;
import net.woodyfolsom.msproj.policy.MonteCarloUCT;
import net.woodyfolsom.msproj.policy.NeuralNetPolicy;
import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParAMAF;
import net.woodyfolsom.msproj.policy.RootParallelization;
public class RoundRobin {
public static final int EXIT_USER_QUIT = 1;
public static final int EXIT_NOMINAL = 0;
public static final int EXIT_IO_EXCEPTION = -1;
public static void main(String[] args) throws IOException {
long startTime = System.currentTimeMillis();
System.out.println("Beginning round-robin tournament.");
System.out.println("Initializing policies...");
List<Policy> policies = new ArrayList<Policy>();
policies.add(new RandomMovePolicy());
//policies.add(new Minimax(1));
policies.add(new AlphaBeta(1));
policies.add(new MonteCarloUCT(new RandomMovePolicy(), 500L));
policies.add(new MonteCarloAMAF(new RandomMovePolicy(), 500L));
policies.add(new MonteCarloSMAF(new RandomMovePolicy(), 500L, 4));
policies.add(new RootParAMAF(4, 500L));
policies.add(new RootParallelization(4, new NeuralNetPolicy(), 500L));
RoundRobin rr = new RoundRobin();
List<List<Double>> tourneyWinRates = new ArrayList<List<Double>>();
int gamesPerMatch = 5;
for (int i = 0; i < policies.size(); i++) {
List<Double> roundWinRates = new ArrayList<Double>();
if (i != 5) {
tourneyWinRates.add(roundWinRates);
continue;
}
for (int j = 0; j < policies.size(); j++) {
Policy policy1 = policies.get(i);
policy1.setLogging(false);
Policy policy2 = policies.get(j);
policy2.setLogging(false);
List<GameResult> gameResults = rr.playGame(policy1, policy2, 9, 6.5, gamesPerMatch, false, false, false);
double wins = 0.0;
double games = 0.0;
for(GameResult gr : gameResults) {
wins += gr.isWinner(Player.BLACK) ? 1.0 : 0.0;
games += 1.0;
}
roundWinRates.add(100.0 * wins / games);
}
tourneyWinRates.add(roundWinRates);
}
System.out.println("");
System.out.println("Tournament Win Rates");
System.out.println("====================");
DecimalFormat df = new DecimalFormat("00.#");
for (int i = 0; i < policies.size(); i++) {
for (int j = 0; j < policies.size(); j++) {
if (i == 5)
System.out.println(policies.get(i).getName() + " (Black) vs " + policies.get(j).getName() + " (White) : " + df.format(tourneyWinRates.get(i).get(j)) + "%");
}
}
long endTime = System.currentTimeMillis();
System.out.println("Tournament lasted " + (endTime-startTime)/1000.0 + " seconds.");
}
public List<GameResult> playGame(Policy player1Policy, Policy player2Policy, int size,
double komi, int rounds, boolean showSpectatorBoard,
boolean blackMoveLogged, boolean whiteMoveLogged) {
GameConfig gameConfig = new GameConfig(size);
gameConfig.setKomi(komi);
Referee referee = new Referee();
referee.setPolicy(Player.BLACK, player1Policy);
referee.setPolicy(Player.WHITE, player2Policy);
List<GameResult> roundResults = new ArrayList<GameResult>();
boolean logGameRecords = false;
int gameNo = 1;
for (int round = 0; round < rounds; round++) {
gameNo++;
GameResult gameResult = referee.play(gameConfig, gameNo,
showSpectatorBoard, logGameRecords);
roundResults.add(gameResult);
System.out.println(player1Policy.getName() + " (Black) vs "
+ player2Policy.getName() + " (White) : " + gameResult);
roundResults.add(gameResult);
}
return roundResults;
}
}

View File

@@ -13,9 +13,11 @@ import net.woodyfolsom.msproj.gui.Goban;
import net.woodyfolsom.msproj.policy.HumanGuiInput; import net.woodyfolsom.msproj.policy.HumanGuiInput;
import net.woodyfolsom.msproj.policy.HumanKeyboardInput; import net.woodyfolsom.msproj.policy.HumanKeyboardInput;
import net.woodyfolsom.msproj.policy.MonteCarloAMAF; import net.woodyfolsom.msproj.policy.MonteCarloAMAF;
import net.woodyfolsom.msproj.policy.MonteCarloSMAF;
import net.woodyfolsom.msproj.policy.MonteCarloUCT; import net.woodyfolsom.msproj.policy.MonteCarloUCT;
import net.woodyfolsom.msproj.policy.Policy; import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.policy.RandomMovePolicy; import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.policy.RootParAMAF;
import net.woodyfolsom.msproj.policy.RootParallelization; import net.woodyfolsom.msproj.policy.RootParallelization;
public class StandAloneGame { public class StandAloneGame {
@@ -26,13 +28,13 @@ public class StandAloneGame {
private int gameNo = 0; private int gameNo = 0;
enum PLAYER_TYPE { enum PLAYER_TYPE {
HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE HUMAN, HUMAN_GUI, ROOT_PAR, UCT, RANDOM, RAVE, SMAF, ROOT_PAR_AMAF
}; };
public static void main(String[] args) { public static void main(String[] args) throws IOException {
try { try {
GameSettings gameSettings = GameSettings GameSettings gameSettings = GameSettings
.createGameSetings("data/gogame.cfg"); .createGameSetings("gogame.cfg");
System.out.println("Game Settings: " + gameSettings); System.out.println("Game Settings: " + gameSettings);
System.out.println("Successfully parsed game settings."); System.out.println("Successfully parsed game settings.");
new StandAloneGame().playGame( new StandAloneGame().playGame(
@@ -41,7 +43,10 @@ public class StandAloneGame {
gameSettings.getBoardSize(), gameSettings.getKomi(), gameSettings.getBoardSize(), gameSettings.getKomi(),
gameSettings.getNumGames(), gameSettings.getTurnTime(), gameSettings.getNumGames(), gameSettings.getTurnTime(),
gameSettings.isSpectatorBoardShown(), gameSettings.isSpectatorBoardShown(),
gameSettings.isBlackMoveLogged(), gameSettings.isWhiteMoveLogged()); gameSettings.isBlackMoveLogged(),
gameSettings.isWhiteMoveLogged());
System.out.println("Press <Enter> or CTRL-C to exit");
System.in.read(new byte[80]);
} catch (IOException ioe) { } catch (IOException ioe) {
ioe.printStackTrace(); ioe.printStackTrace();
System.exit(EXIT_IO_EXCEPTION); System.exit(EXIT_IO_EXCEPTION);
@@ -62,14 +67,19 @@ public class StandAloneGame {
return PLAYER_TYPE.RANDOM; return PLAYER_TYPE.RANDOM;
} else if ("RAVE".equalsIgnoreCase(playerTypeStr)) { } else if ("RAVE".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.RAVE; return PLAYER_TYPE.RAVE;
} else if ("SMAF".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.SMAF;
} else if ("ROOT_PAR_AMAF".equalsIgnoreCase(playerTypeStr)) {
return PLAYER_TYPE.ROOT_PAR_AMAF;
} else { } else {
throw new RuntimeException("Unknown player type: " + playerTypeStr); throw new RuntimeException("Unknown player type: " + playerTypeStr);
} }
} }
public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2, public void playGame(PLAYER_TYPE playerType1, PLAYER_TYPE playerType2,
int size, double komi, int rounds, long turnLength, boolean showSpectatorBoard, int size, double komi, int rounds, long turnLength,
boolean blackMoveLogged, boolean whiteMoveLogged) { boolean showSpectatorBoard, boolean blackMoveLogged,
boolean whiteMoveLogged) {
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
@@ -77,28 +87,38 @@ public class StandAloneGame {
gameConfig.setKomi(komi); gameConfig.setKomi(komi);
Referee referee = new Referee(); Referee referee = new Referee();
referee.setPolicy(Player.BLACK, referee.setPolicy(
getPolicy(playerType1, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); Player.BLACK,
referee.setPolicy(Player.WHITE, getPolicy(playerType1, gameConfig, Player.BLACK, turnLength,
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); blackMoveLogged));
referee.setPolicy(
Player.WHITE,
getPolicy(playerType2, gameConfig, Player.WHITE, turnLength,
whiteMoveLogged));
List<GameResult> round1results = new ArrayList<GameResult>(); List<GameResult> round1results = new ArrayList<GameResult>();
boolean logGameRecords = rounds <= 50; boolean logGameRecords = rounds <= 50;
for (int round = 0; round < rounds; round++) { for (int round = 0; round < rounds; round++) {
gameNo++; gameNo++;
round1results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); round1results.add(referee.play(gameConfig, gameNo,
showSpectatorBoard, logGameRecords));
} }
List<GameResult> round2results = new ArrayList<GameResult>(); List<GameResult> round2results = new ArrayList<GameResult>();
referee.setPolicy(Player.BLACK, referee.setPolicy(
getPolicy(playerType2, gameConfig, Player.BLACK, turnLength, blackMoveLogged)); Player.BLACK,
referee.setPolicy(Player.WHITE, getPolicy(playerType2, gameConfig, Player.BLACK, turnLength,
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength, whiteMoveLogged)); blackMoveLogged));
referee.setPolicy(
Player.WHITE,
getPolicy(playerType1, gameConfig, Player.WHITE, turnLength,
whiteMoveLogged));
for (int round = 0; round < rounds; round++) { for (int round = 0; round < rounds; round++) {
gameNo++; gameNo++;
round2results.add(referee.play(gameConfig, gameNo, showSpectatorBoard, logGameRecords)); round2results.add(referee.play(gameConfig, gameNo,
showSpectatorBoard, logGameRecords));
} }
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
@@ -111,13 +131,14 @@ public class StandAloneGame {
try { try {
if (!logGameRecords) { if (!logGameRecords) {
System.out.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output."); System.out
.println("Each player is set to play more than 50 rounds as each color; omitting individual game .sgf log file output.");
} }
logResults(writer, round1results, playerType1.toString(), logResults(writer, round1results, playerType1.toString(),
playerType2.toString()); playerType2.toString());
logResults(writer, round2results, playerType2.toString(), logResults(writer, round2results, playerType2.toString(),
playerType1.toString()); playerType1.toString());
writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0 writer.write("Elapsed Time: " + (endTime - startTime) / 1000.0
+ " seconds."); + " seconds.");
System.out.println("Game tournament saved as " System.out.println("Game tournament saved as "
@@ -155,25 +176,41 @@ public class StandAloneGame {
private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig, private Policy getPolicy(PLAYER_TYPE playerType, GameConfig gameConfig,
Player player, long turnLength, boolean moveLogged) { Player player, long turnLength, boolean moveLogged) {
Policy policy;
switch (playerType) { switch (playerType) {
case HUMAN: case HUMAN:
return new HumanKeyboardInput(); policy = new HumanKeyboardInput();
break;
case HUMAN_GUI: case HUMAN_GUI:
return new HumanGuiInput(new Goban(gameConfig, player,"")); policy = new HumanGuiInput(new Goban(gameConfig, player, ""));
break;
case ROOT_PAR: case ROOT_PAR:
return new RootParallelization(4, turnLength); policy = new RootParallelization(4, turnLength);
break;
case ROOT_PAR_AMAF:
policy = new RootParAMAF(4, turnLength);
break;
case UCT: case UCT:
return new MonteCarloUCT(new RandomMovePolicy(), turnLength); policy = new MonteCarloUCT(new RandomMovePolicy(), turnLength);
break;
case SMAF:
policy = new MonteCarloSMAF(new RandomMovePolicy(), turnLength, 4);
break;
case RANDOM: case RANDOM:
RandomMovePolicy randomMovePolicy = new RandomMovePolicy(); policy = new RandomMovePolicy();
randomMovePolicy.setLogging(moveLogged); break;
return randomMovePolicy;
case RAVE: case RAVE:
return new MonteCarloAMAF(new RandomMovePolicy(), turnLength); policy = new MonteCarloAMAF(new RandomMovePolicy(), turnLength);
break;
default: default:
throw new IllegalArgumentException("Invalid PLAYER_TYPE: " throw new IllegalArgumentException("Invalid PLAYER_TYPE: "
+ playerType); + playerType);
} }
policy.setLogging(moveLogged);
return policy;
} }
} }

View File

@@ -1,51 +1,107 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.File; import java.io.InputStream;
import java.io.FileInputStream; import java.io.OutputStream;
import java.io.FileOutputStream; import java.util.List;
import java.io.IOException;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.PersistBasicNetwork;
public abstract class AbstractNeuralNetFilter implements NeuralNetFilter { public abstract class AbstractNeuralNetFilter implements NeuralNetFilter {
protected BasicNetwork neuralNetwork; private final FeedforwardNetwork neuralNetwork;
protected int actualTrainingEpochs = 0; private final TrainingMethod trainingMethod;
protected int maxTrainingEpochs = 1000;
private double maxError;
private int actualTrainingEpochs = 0;
private int maxTrainingEpochs;
AbstractNeuralNetFilter(FeedforwardNetwork neuralNetwork, TrainingMethod trainingMethod, int maxTrainingEpochs, double maxError) {
this.neuralNetwork = neuralNetwork;
this.trainingMethod = trainingMethod;
this.maxError = maxError;
this.maxTrainingEpochs = maxTrainingEpochs;
}
@Override
public NNData compute(NNDataPair input) {
return this.neuralNetwork.compute(input);
}
public int getActualTrainingEpochs() { public int getActualTrainingEpochs() {
return actualTrainingEpochs; return actualTrainingEpochs;
} }
@Override
public int getInputSize() {
return 2;
}
public int getMaxTrainingEpochs() { public int getMaxTrainingEpochs() {
return maxTrainingEpochs; return maxTrainingEpochs;
} }
@Override protected FeedforwardNetwork getNeuralNetwork() {
public BasicNetwork getNeuralNetwork() {
return neuralNetwork; return neuralNetwork;
} }
public void load(String filename) throws IOException { @Override
FileInputStream fis = new FileInputStream(new File(filename)); public void learnPatterns(List<NNDataPair> trainingSet) {
neuralNetwork = (BasicNetwork) new PersistBasicNetwork().read(fis); actualTrainingEpochs = 0;
fis.close(); double error;
neuralNetwork.initWeights();
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
if (error <= maxError) {
System.out.println("Initial error: " + error);
return;
}
do {
trainingMethod.iteratePatterns(neuralNetwork,trainingSet);
error = trainingMethod.computePatternError(neuralNetwork,trainingSet);
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ error);
actualTrainingEpochs++;
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
} }
@Override @Override
public void reset() { public void learnSequences(List<List<NNDataPair>> trainingSet) {
neuralNetwork.reset(); actualTrainingEpochs = 0;
double error;
neuralNetwork.initWeights();
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
if (error <= maxError) {
System.out.println("Initial error: " + error);
return;
}
do {
trainingMethod.iterateSequences(neuralNetwork,trainingSet);
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
if (Double.isNaN(error)) {
error = trainingMethod.computeSequenceError(neuralNetwork,trainingSet);
}
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ error);
actualTrainingEpochs++;
System.out.println("MSSE after epoch " + actualTrainingEpochs + ": " + error);
} while (error > maxError && actualTrainingEpochs < maxTrainingEpochs);
} }
@Override @Override
public void reset(int seed) { public boolean load(InputStream input) {
neuralNetwork.reset(seed); return neuralNetwork.load(input);
} }
public void save(String filename) throws IOException { @Override
FileOutputStream fos = new FileOutputStream(new File(filename)); public boolean save(OutputStream output) {
new PersistBasicNetwork().save(fos, getNeuralNetwork()); return neuralNetwork.save(output);
fos.close(); }
public void setMaxError(double maxError) {
this.maxError = maxError;
} }
public void setMaxTrainingEpochs(int max) { public void setMaxTrainingEpochs(int max) {

View File

@@ -0,0 +1,132 @@
package net.woodyfolsom.msproj.ann;
import java.util.List;
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
import net.woodyfolsom.msproj.ann.math.MSSE;
public class BackPropagation extends TrainingMethod {
private final ErrorFunction errorFunction;
private final double learningRate;
private final double momentum;
public BackPropagation(double learningRate, double momentum) {
this.errorFunction = MSSE.function;
this.learningRate = learningRate;
this.momentum = momentum;
}
@Override
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
System.out.println("Learningrate: " + learningRate);
System.out.println("Momentum: " + momentum);
for (NNDataPair trainingPair : trainingSet) {
zeroGradients(neuralNetwork);
System.out.println("Training with: " + trainingPair.getInput());
NNData ideal = trainingPair.getIdeal();
NNData actual = neuralNetwork.compute(trainingPair);
System.out.println("Updating weights. Ideal Output: " + ideal);
System.out.println("Actual Output: " + actual);
//backpropagate the gradients w.r.t. output error
backPropagate(neuralNetwork, ideal);
updateWeights(neuralNetwork);
}
}
@Override
public double computePatternError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
int numDataPairs = trainingSet.size();
int outputSize = neuralNetwork.getOutput().length;
int totalOutputSize = outputSize * numDataPairs;
double[] actuals = new double[totalOutputSize];
double[] ideals = new double[totalOutputSize];
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
NNDataPair nnDataPair = trainingSet.get(dataPair);
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
.getValues());
double[] ideal = nnDataPair.getIdeal().getValues();
int offset = dataPair * outputSize;
System.arraycopy(actual, 0, actuals, offset, outputSize);
System.arraycopy(ideal, 0, ideals, offset, outputSize);
}
double MSSE = errorFunction.compute(ideals, actuals);
return MSSE;
}
@Override
protected
void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
double[] idealValues = ideal.getValues();
for (int i = 0; i < idealValues.length; i++) {
double input = outputNeurons[i].getInput();
double derivative = outputNeurons[i].getActivationFunction()
.derivative(input);
outputNeurons[i].setGradient(outputNeurons[i].getGradient() + derivative * (idealValues[i] - outputNeurons[i].getOutput()));
}
// walking down the list of Neurons in reverse order, propagate the
// error
Neuron[] neurons = neuralNetwork.getNeurons();
for (int n = neurons.length - 1; n >= 0; n--) {
Neuron neuron = neurons[n];
double error = neuron.getGradient();
Connection[] connectionsFromN = neuralNetwork
.getConnectionsFrom(neuron.getId());
if (connectionsFromN.length > 0) {
double derivative = neuron.getActivationFunction().derivative(
neuron.getInput());
for (Connection connection : connectionsFromN) {
error += derivative * connection.getWeight() * neuralNetwork.getNeuron(connection.getDest()).getGradient();
}
}
neuron.setGradient(error);
}
}
private void updateWeights(FeedforwardNetwork neuralNetwork) {
for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = learningRate * srcNeuron.getOutput() * destNeuron.getGradient();
//TODO allow for momentum
//double lastDelta = connection.getLastDelta();
connection.addDelta(delta);
}
}
@Override
public void iterateSequences(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward) {
throw new UnsupportedOperationException();
}
}

View File

@@ -0,0 +1,100 @@
package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlTransient;
public class Connection {
private int src;
private int dest;
private double weight;
//private transient double lastDelta = 0.0;
private transient double trace = 0.0;
public Connection() {
//no-arg constructor for JAXB
}
public Connection(int src, int dest, double weight) {
this.src = src;
this.dest = dest;
this.weight = weight;
}
public void addDelta(double delta) {
this.trace = delta;
this.weight += delta;
//this.lastDelta = delta;
}
@XmlAttribute
public int getDest() {
return dest;
}
@XmlAttribute
public int getSrc() {
return src;
}
public double getTrace() {
return trace;
}
@XmlAttribute
public double getWeight() {
return weight;
}
public void setDest(int dest) {
this.dest = dest;
}
public void setSrc(int src) {
this.src = src;
}
@XmlTransient
public void setTrace(double trace) {
this.trace = trace;
}
public void setWeight(double weight) {
this.weight = weight;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + dest;
result = prime * result + src;
long temp;
temp = Double.doubleToLongBits(weight);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Connection other = (Connection) obj;
if (dest != other.dest)
return false;
if (src != other.src)
return false;
if (Double.doubleToLongBits(weight) != Double
.doubleToLongBits(other.weight))
return false;
return true;
}
@Override
public String toString() {
return "Connection(src: " + src + ",dest: " + dest + ", trace:" + trace +"), weight: " + weight;
}
}

View File

@@ -1,17 +0,0 @@
package net.woodyfolsom.msproj.ann;
import org.encog.ml.data.basic.BasicMLData;
public class DoublePair extends BasicMLData {
// private final double x;
// private final double y;
/**
*
*/
private static final long serialVersionUID = 1L;
public DoublePair(double x, double y) {
super(new double[] { x, y });
}
}

View File

@@ -1,95 +0,0 @@
package net.woodyfolsom.msproj.ann;
import org.encog.mathutil.error.ErrorCalculationMode;
/*
Initial erison of this class was a verbatim copy from Encog framework.
*/
public class ErrorCalculation {
private static ErrorCalculationMode mode = ErrorCalculationMode.MSE;
public static ErrorCalculationMode getMode() {
return ErrorCalculation.mode;
}
public static void setMode(final ErrorCalculationMode theMode) {
ErrorCalculation.mode = theMode;
}
private double globalError;
private int setSize;
public final double calculate() {
if (this.setSize == 0) {
return 0;
}
switch (ErrorCalculation.getMode()) {
case RMS:
return calculateRMS();
case MSE:
return calculateMSE();
case ESS:
return calculateESS();
default:
return calculateMSE();
}
}
public final double calculateMSE() {
if (this.setSize == 0) {
return 0;
}
final double err = this.globalError / this.setSize;
return err;
}
public final double calculateESS() {
if (this.setSize == 0) {
return 0;
}
final double err = this.globalError / 2;
return err;
}
public final double calculateRMS() {
if (this.setSize == 0) {
return 0;
}
final double err = Math.sqrt(this.globalError / this.setSize);
return err;
}
public final void reset() {
this.globalError = 0;
this.setSize = 0;
}
public final void updateError(final double actual, final double ideal) {
double delta = ideal - actual;
this.globalError += delta * delta;
this.setSize++;
}
public final void updateError(final double[] actual, final double[] ideal,
final double significance) {
for (int i = 0; i < actual.length; i++) {
double delta = (ideal[i] - actual[i]) * significance;
this.globalError += delta * delta;
}
this.setSize += ideal.length;
}
}

View File

@@ -0,0 +1,313 @@
package net.woodyfolsom.msproj.ann;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlTransient;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Linear;
import net.woodyfolsom.msproj.ann.math.Sigmoid;
public abstract class FeedforwardNetwork {
private ActivationFunction activationFunction;
private boolean biased;
private List<Connection> connections;
private List<Neuron> neurons;
private String name;
private transient int biasNeuronId;
private transient Map<Integer, List<Connection>> connectionsFrom;
private transient Map<Integer, List<Connection>> connectionsTo;
public FeedforwardNetwork() {
this(false);
}
public FeedforwardNetwork(boolean biased) {
//No-arg constructor for JAXB
this.activationFunction = Sigmoid.function;
this.connections = new ArrayList<Connection>();
this.connectionsFrom = new HashMap<Integer,List<Connection>>();
this.connectionsTo = new HashMap<Integer,List<Connection>>();
this.neurons = new ArrayList<Neuron>();
this.name = "UNDEFINED";
this.biasNeuronId = -1;
setBiased(biased);
}
public void addConnection(Connection connection) {
connections.add(connection);
int src = connection.getSrc();
int dest = connection.getDest();
if (!connectionsFrom.containsKey(src)) {
connectionsFrom.put(src, new ArrayList<Connection>());
}
if (!connectionsTo.containsKey(dest)) {
connectionsTo.put(dest, new ArrayList<Connection>());
}
connectionsFrom.get(src).add(connection);
connectionsTo.get(dest).add(connection);
}
public NNData compute(NNDataPair nnDataPair) {
NNData actual = new NNData(nnDataPair.getIdeal().getFields(),
compute(nnDataPair.getInput().getValues()));
return actual;
}
public double[] compute(double[] input) {
zeroInputs();
setInput(input);
feedforward();
return getOutput();
}
void createBiasConnection(int neuronId, double weight) {
if (!biased) {
throw new UnsupportedOperationException("Not a biased network");
}
addConnection(new Connection(biasNeuronId, neuronId, weight));
}
/**
* Adds a new neuron with a unique id to this FeedforwardNetwork.
* @return
*/
Neuron createNeuron(boolean input, ActivationFunction afunc) {
Neuron neuron;
if (input) {
neuron = new Neuron(Linear.function, neurons.size());
} else {
neuron = new Neuron(afunc, neurons.size());
}
neurons.add(neuron);
return neuron;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
FeedforwardNetwork other = (FeedforwardNetwork) obj;
if (activationFunction == null) {
if (other.activationFunction != null)
return false;
} else if (!activationFunction.equals(other.activationFunction))
return false;
if (connections == null) {
if (other.connections != null)
return false;
} else if (!connections.equals(other.connections))
return false;
if (name == null) {
if (other.name != null)
return false;
} else if (!name.equals(other.name))
return false;
if (neurons == null) {
if (other.neurons != null)
return false;
} else if (!neurons.equals(other.neurons))
return false;
return true;
}
protected void feedforward() {
for (int i = 0; i < neurons.size(); i++) {
Neuron src = neurons.get(i);
for (Connection connection : getConnectionsFrom(src.getId())) {
Neuron dest = getNeuron(connection.getDest());
dest.addInput(src.getOutput() * connection.getWeight());
}
}
}
@XmlElement(type = Sigmoid.class)
public ActivationFunction getActivationFunction() {
return activationFunction;
}
protected abstract double[] getOutput();
protected abstract Neuron[] getOutputNeurons();
@XmlAttribute
public String getName() {
return name;
}
protected Neuron getNeuron(int id) {
return neurons.get(id);
}
public Connection getConnection(int index) {
return connections.get(index);
}
@XmlElement
protected Connection[] getConnections() {
return connections.toArray(new Connection[connections.size()]);
}
protected Connection[] getConnectionsFrom(int neuronId) {
List<Connection> connList = connectionsFrom.get(neuronId);
if (connList == null) {
return new Connection[0];
} else {
return connList.toArray(new Connection[connList.size()]);
}
}
protected Connection[] getConnectionsTo(int neuronId) {
List<Connection> connList = connectionsTo.get(neuronId);
if (connList == null) {
return new Connection[0];
} else {
return connList.toArray(new Connection[connList.size()]);
}
}
public double[] getGradients() {
double[] gradients = new double[neurons.size()];
for (int n = 0; n < gradients.length; n++) {
gradients[n] = neurons.get(n).getGradient();
}
return gradients;
}
public double[] getWeights() {
double[] weights = new double[connections.size()];
for (int i = 0; i < connections.size(); i++) {
weights[i] = connections.get(i).getWeight();
}
return weights;
}
@XmlAttribute
public boolean isBiased() {
return biased;
}
@XmlElement
protected Neuron[] getNeurons() {
return neurons.toArray(new Neuron[neurons.size()]);
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime
* result
+ ((activationFunction == null) ? 0 : activationFunction
.hashCode());
result = prime * result
+ ((connections == null) ? 0 : connections.hashCode());
result = prime * result + ((name == null) ? 0 : name.hashCode());
result = prime * result + ((neurons == null) ? 0 : neurons.hashCode());
return result;
}
public void initWeights() {
for (Connection connection : connections) {
connection.setWeight(1.0-Math.random()*2.0);
}
}
public abstract boolean load(InputStream is);
public abstract boolean save(OutputStream os);
public void setActivationFunction(ActivationFunction activationFunction) {
this.activationFunction = activationFunction;
}
public void setBiased(boolean biased) {
if (this.biased == biased) {
return;
}
this.biased = biased;
if (biased) {
Neuron biasNeuron = createNeuron(true, activationFunction);
biasNeuron.setInput(1.0);
biasNeuronId = biasNeuron.getId();
} else {
//This is an inefficient but concise way to remove all connections involving the bias Neuron
//from the global
//Remove all connections from biasId from this index
List<Connection> connectionsFromBias = connectionsFrom.remove(biasNeuronId);
//Remove all connections to all nodes from biasId from this index
for (Map.Entry<Integer,List<Connection>> mapEntry : connectionsTo.entrySet()) {
mapEntry.getValue().removeAll(connectionsFromBias);
}
//Finally, remove from the (serialized) list of non-indexed connections
connections.remove(connectionsFromBias);
biasNeuronId = -1;
}
}
protected void setConnections(Connection[] connections) {
this.connections.clear();
this.connectionsFrom.clear();
this.connectionsTo.clear();
for (Connection connection : connections) {
addConnection(connection);
}
}
protected abstract void setInput(double[] input);
public void setName(String name) {
this.name = name;
}
protected void setNeurons(Neuron[] neurons) {
this.neurons.clear();
for (Neuron neuron : neurons) {
this.neurons.add(neuron);
}
}
@XmlTransient
public void setWeights(double[] weights) {
if (weights.length != connections.size()) {
throw new IllegalArgumentException("# of weights must == # of connections");
}
for (int i = 0; i < connections.size(); i++) {
connections.get(i).setWeight(weights[i]);
}
}
protected void zeroInputs() {
for (Neuron neuron : neurons) {
if (neuron.getId() != biasNeuronId){
neuron.setInput(0.0);
}
}
}
}

View File

@@ -0,0 +1,292 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.policy.NeuralNetPolicy;
import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
public class FusekiFilterTrainer { // implements epsilon-greedy trainer? online
// version of NeuralNetFilter
private boolean training = true;
public static void main(String[] args) throws IOException {
double alpha = 0.50;
double lambda = 0.90;
int maxGames = 1000;
new FusekiFilterTrainer().trainNetwork(alpha, lambda, maxGames);
}
public void trainNetwork(double alpha, double lambda, int maxGames)
throws IOException {
FeedforwardNetwork neuralNetwork;
GameConfig gameConfig = new GameConfig(9);
if (training) {
neuralNetwork = new MultiLayerPerceptron(true, 81, 18, 1);
neuralNetwork.setName("FusekiFilter" + gameConfig.getSize());
neuralNetwork.initWeights();
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
System.out.println("Playing untrained games.");
for (int i = 0; i < 10; i++) {
GameRecord gameRecord = new GameRecord(gameConfig);
System.out.println("" + (i + 1) + ". "
+ playOptimal(neuralNetwork, gameRecord).getResult());
}
System.out.println("Learning from " + maxGames
+ " games of random self-play");
int gamesPlayed = 0;
List<GameResult> results = new ArrayList<GameResult>();
do {
GameRecord gameRecord = new GameRecord(gameConfig);
playEpsilonGreedy(0.50, neuralNetwork, trainer, gameRecord);
System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++;
results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames);
System.out.println("Results of every 10th training game:");
for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) {
System.out.println("" + (i + 1) + ". " + results.get(i));
}
}
System.out.println("Learned network after " + maxGames
+ " training games.");
} else {
System.out.println("Loading TicTacToe network from file.");
neuralNetwork = new MultiLayerPerceptron();
FileInputStream fis = new FileInputStream(new File("pass.net"));
if (!new MultiLayerPerceptron().load(fis)) {
System.out.println("Error loading pass.net from file.");
return;
}
fis.close();
}
evalTestCases(gameConfig, neuralNetwork);
System.out.println("Playing optimal games.");
List<GameResult> gameResults = new ArrayList<GameResult>();
for (int i = 0; i < 10; i++) {
GameRecord gameRecord = new GameRecord(gameConfig);
gameResults.add(playOptimal(neuralNetwork, gameRecord).getResult());
}
boolean suboptimalPlay = false;
System.out.println("Optimal game summary: ");
for (int i = 0; i < gameResults.size(); i++) {
GameResult result = gameResults.get(i);
System.out.println("" + (i + 1) + ". " + result);
}
File output = new File("pass.net");
FileOutputStream fos = new FileOutputStream(output);
neuralNetwork.save(fos);
System.out.println("Playing optimal vs random games.");
for (int i = 0; i < 10; i++) {
GameRecord gameRecord = new GameRecord(gameConfig);
System.out.println(""
+ (i + 1)
+ ". "
+ playOptimalVsRandom(neuralNetwork, gameRecord)
.getResult());
}
if (suboptimalPlay) {
System.out.println("Suboptimal play detected!");
}
}
private double[] createBoard(GameConfig gameConfig, Action... actions) {
GameRecord gameRec = new GameRecord(gameConfig);
for (Action action : actions) {
gameRec.play(gameRec.getPlayerToMove(), action);
}
return NNDataSetFactory.createDataPair(gameRec.getGameState(), FusekiFilterTrainer.class).getInput().getValues();
}
private void evalTestCases(GameConfig gameConfig, FeedforwardNetwork neuralNetwork) {
double[][] validationSet = new double[1][];
// start state: black has 0, white has 0 + komi, neither has passed
validationSet[0] = createBoard(gameConfig, Action.getInstance("C3"));
String[] inputNames = NNDataSetFactory.getInputFields(FusekiFilterTrainer.class);
String[] outputNames = NNDataSetFactory.getOutputFields(FusekiFilterTrainer.class);
System.out.println("Output from eval set (learned network):");
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
}
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork,
GameRecord gameRecord) {
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
neuralNetPolicy.setMoveFilter(neuralNetwork);
Policy randomPolicy = new RandomMovePolicy();
GameConfig gameConfig = gameRecord.getGameConfig();
GameState gameState = gameRecord.getGameState();
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
int turnNo = 0;
do {
Action action;
GameState nextState;
Player playerToMove = gameState.getPlayerToMove();
action = policies[turnNo % 2].getAction(gameConfig, gameState,
playerToMove);
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Illegal move: " + action);
}
nextState = gameRecord.getGameState();
//System.out.println("Action " + action + " selected by policy "
// + policies[turnNo % 2].getName());
//System.out.println("Next board state: " + nextState);
gameState = nextState;
turnNo++;
} while (!gameState.isTerminal());
return gameRecord;
}
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork,
GameRecord gameRecord) {
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
neuralNetPolicy.setMoveFilter(neuralNetwork);
if (gameRecord.getNumTurns() > 0) {
throw new RuntimeException(
"PlayOptimal requires a new GameRecord with no turns played.");
}
GameState gameState;
do {
Action action;
GameState nextState;
Player playerToMove = gameRecord.getPlayerToMove();
action = neuralNetPolicy.getAction(gameRecord.getGameConfig(),
gameRecord.getGameState(), playerToMove);
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Invalid move played: " + action);
}
nextState = gameRecord.getGameState();
//System.out.println("Action " + action + " selected by policy "
// + neuralNetPolicy.getName());
//System.out.println("Next board state: " + nextState);
gameState = nextState;
} while (!gameState.isTerminal());
return gameRecord;
}
private GameRecord playEpsilonGreedy(double epsilon,
FeedforwardNetwork neuralNetwork, TrainingMethod trainer,
GameRecord gameRecord) {
Policy randomPolicy = new RandomMovePolicy();
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
neuralNetPolicy.setMoveFilter(neuralNetwork);
if (gameRecord.getNumTurns() > 0) {
throw new RuntimeException(
"PlayOptimal requires a new GameRecord with no turns played.");
}
GameState gameState = gameRecord.getGameState();
NNDataPair statePair;
Policy selectedPolicy;
trainer.zeroTraces(neuralNetwork);
do {
Action action;
GameState nextState;
Player playerToMove = gameRecord.getPlayerToMove();
if (Math.random() < epsilon) {
selectedPolicy = randomPolicy;
action = selectedPolicy
.getAction(gameRecord.getGameConfig(),
gameRecord.getGameState(),
gameRecord.getPlayerToMove());
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Illegal move played: " + action);
}
nextState = gameRecord.getGameState();
} else {
selectedPolicy = neuralNetPolicy;
action = selectedPolicy
.getAction(gameRecord.getGameConfig(),
gameRecord.getGameState(),
gameRecord.getPlayerToMove());
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Illegal move played: " + action);
}
nextState = gameRecord.getGameState();
statePair = NNDataSetFactory.createDataPair(gameState, FusekiFilterTrainer.class);
NNDataPair nextStatePair = NNDataSetFactory
.createDataPair(nextState, FusekiFilterTrainer.class);
trainer.iteratePattern(neuralNetwork, statePair,
nextStatePair.getIdeal());
}
gameState = nextState;
} while (!gameState.isTerminal());
// finally, reinforce the actual reward
statePair = NNDataSetFactory.createDataPair(gameState, FusekiFilterTrainer.class);
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
return gameRecord;
}
private void testNetwork(FeedforwardNetwork neuralNetwork,
double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames,
new double[] { 0.0 }));
System.out.println(dp);
System.out.println(" => ");
System.out.println(neuralNetwork.compute(dp));
}
}
}

View File

@@ -1,25 +0,0 @@
package net.woodyfolsom.msproj.ann;
import net.woodyfolsom.msproj.GameState;
import org.encog.ml.data.basic.BasicMLData;
public class GameStateMLData extends BasicMLData {
/**
*
*/
private static final long serialVersionUID = 1L;
private GameState gameState;
public GameStateMLData(double[] d, GameState gameState) {
super(d);
// TODO Auto-generated constructor stub
this.gameState = gameState;
}
public GameState getGameState() {
return gameState;
}
}

View File

@@ -1,121 +0,0 @@
package net.woodyfolsom.msproj.ann;
import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.util.kmeans.Centroid;
public class GameStateMLDataPair implements MLDataPair {
//private final String[] inputs = { "BlackScore", "WhiteScore" };
//private final String[] outputs = { "BlackWins", "WhiteWins" };
private BasicMLDataPair mlDataPairDelegate;
private GameState gameState;
public GameStateMLDataPair(GameState gameState) {
this.gameState = gameState;
mlDataPairDelegate = new BasicMLDataPair(
new GameStateMLData(createInput(), gameState), new BasicMLData(createIdeal()));
}
public GameStateMLDataPair(GameStateMLDataPair that) {
this.gameState = new GameState(that.gameState);
mlDataPairDelegate = new BasicMLDataPair(
that.mlDataPairDelegate.getInput(),
that.mlDataPairDelegate.getIdeal());
}
@Override
public MLDataPair clone() {
return new GameStateMLDataPair(this);
}
@Override
public Centroid<MLDataPair> createCentroid() {
return mlDataPairDelegate.createCentroid();
}
/**
* Creates a vector of normalized scores from GameState.
*
* @return
*/
private double[] createInput() {
GameResult result = gameState.getResult();
double maxScore = gameState.getGameConfig().getSize()
* gameState.getGameConfig().getSize();
double whiteScore = Math.min(1.0, result.getWhiteScore() / maxScore);
double blackScore = Math.min(1.0, result.getBlackScore() / maxScore);
return new double[] { blackScore, whiteScore };
}
/**
* Creates a vector of values indicating strength of black/white win output
* from network.
*
* @return
*/
private double[] createIdeal() {
GameResult result = gameState.getResult();
double blackWinner = result.isWinner(Player.BLACK) ? 1.0 : 0.0;
double whiteWinner = result.isWinner(Player.WHITE) ? 1.0 : 0.0;
return new double[] { blackWinner, whiteWinner };
}
@Override
public MLData getIdeal() {
return mlDataPairDelegate.getIdeal();
}
@Override
public double[] getIdealArray() {
return mlDataPairDelegate.getIdealArray();
}
@Override
public MLData getInput() {
return mlDataPairDelegate.getInput();
}
@Override
public double[] getInputArray() {
return mlDataPairDelegate.getInputArray();
}
@Override
public double getSignificance() {
return mlDataPairDelegate.getSignificance();
}
@Override
public boolean isSupervised() {
return mlDataPairDelegate.isSupervised();
}
@Override
public void setIdealArray(double[] arg0) {
mlDataPairDelegate.setIdealArray(arg0);
}
@Override
public void setInputArray(double[] arg0) {
mlDataPairDelegate.setInputArray(arg0);
}
@Override
public void setSignificance(double arg0) {
mlDataPairDelegate.setSignificance(arg0);
}
}

View File

@@ -1,172 +0,0 @@
package net.woodyfolsom.msproj.ann;
/*
* Class copied verbatim from Encog framework due to dependency on Propagation
* implementation.
*
* Encog(tm) Core v3.2 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2012 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
import java.util.List;
import java.util.Set;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;
public class GradientWorker implements EngineTask {
private final FlatNetwork network;
private final ErrorCalculation errorCalculation = new ErrorCalculation();
private final double[] actual;
private final double[] layerDelta;
private final int[] layerCounts;
private final int[] layerFeedCounts;
private final int[] layerIndex;
private final int[] weightIndex;
private final double[] layerOutput;
private final double[] layerSums;
private final double[] gradients;
private final double[] weights;
private final MLDataPair pair;
private final Set<List<MLDataPair>> training;
private final int low;
private final int high;
private final TemporalDifferenceLearning owner;
private double[] flatSpot;
private final ErrorFunction errorFunction;
public GradientWorker(final FlatNetwork theNetwork,
final TemporalDifferenceLearning theOwner,
final Set<List<MLDataPair>> theTraining, final int theLow,
final int theHigh, final double[] flatSpot,
ErrorFunction ef) {
this.network = theNetwork;
this.training = theTraining;
this.low = theLow;
this.high = theHigh;
this.owner = theOwner;
this.flatSpot = flatSpot;
this.errorFunction = ef;
this.layerDelta = new double[network.getLayerOutput().length];
this.gradients = new double[network.getWeights().length];
this.actual = new double[network.getOutputCount()];
this.weights = network.getWeights();
this.layerIndex = network.getLayerIndex();
this.layerCounts = network.getLayerCounts();
this.weightIndex = network.getWeightIndex();
this.layerOutput = network.getLayerOutput();
this.layerSums = network.getLayerSums();
this.layerFeedCounts = network.getLayerFeedCounts();
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network
.getOutputCount());
}
public FlatNetwork getNetwork() {
return this.network;
}
public double[] getWeights() {
return this.weights;
}
private void process(final double[] input, final double[] ideal, double s) {
this.network.compute(input, this.actual);
this.errorCalculation.updateError(this.actual, ideal, s);
this.errorFunction.calculateError(ideal, actual, this.layerDelta);
for (int i = 0; i < this.actual.length; i++) {
this.layerDelta[i] = ((this.network.getActivationFunctions()[0]
.derivativeFunction(this.layerSums[i],this.layerOutput[i]) + this.flatSpot[0]))
* (this.layerDelta[i] * s);
}
for (int i = this.network.getBeginTraining(); i < this.network
.getEndTraining(); i++) {
processLevel(i);
}
}
private void processLevel(final int currentLevel) {
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
final int toLayerIndex = this.layerIndex[currentLevel];
final int fromLayerSize = this.layerCounts[currentLevel + 1];
final int toLayerSize = this.layerFeedCounts[currentLevel];
final int index = this.weightIndex[currentLevel];
final ActivationFunction activation = this.network
.getActivationFunctions()[currentLevel];
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
// handle weights
int yi = fromLayerIndex;
for (int y = 0; y < fromLayerSize; y++) {
final double output = this.layerOutput[yi];
double sum = 0;
int xi = toLayerIndex;
int wi = index + y;
for (int x = 0; x < toLayerSize; x++) {
this.gradients[wi] += output * this.layerDelta[xi];
sum += this.weights[wi] * this.layerDelta[xi];
wi += fromLayerSize;
xi++;
}
this.layerDelta[yi] = sum
* (activation.derivativeFunction(this.layerSums[yi],this.layerOutput[yi])+currentFlatSpot);
yi++;
}
}
public final void run() {
try {
this.errorCalculation.reset();
//for (int i = this.low; i <= this.high; i++) {
for (List<MLDataPair> trainingSequence : training) {
MLDataPair mldp = trainingSequence.get(trainingSequence.size()-1);
this.pair.setInputArray(mldp.getInputArray());
if (this.pair.getIdealArray() != null) {
this.pair.setIdealArray(mldp.getIdealArray());
}
//this.training.getRecord(i, this.pair);
process(this.pair.getInputArray(), this.pair.getIdealArray(),pair.getSignificance());
}
//}
final double error = this.errorCalculation.calculate();
this.owner.report(this.gradients, error, null);
EngineArray.fill(this.gradients, 0);
} catch (final Throwable ex) {
this.owner.report(null, 0, ex);
}
}
}

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann;
public class JosekiLearner {
}

View File

@@ -0,0 +1,64 @@
package net.woodyfolsom.msproj.ann;
import java.util.Arrays;
import javax.xml.bind.annotation.XmlElement;
public class Layer {
private int[] neuronIds;
public Layer() {
neuronIds = new int[0];
}
public Layer(int numNeurons) {
neuronIds = new int[numNeurons];
}
public int size() {
return neuronIds.length;
}
public int getNeuronId(int index) {
return neuronIds[index];
}
@XmlElement
public int[] getNeuronIds() {
int[] safeCopy = new int[neuronIds.length];
System.arraycopy(neuronIds, 0, safeCopy, 0, neuronIds.length);
return safeCopy;
}
public void setNeuronId(int index, int id) {
neuronIds[index] = id;
}
public void setNeuronIds(int[] neuronIds) {
this.neuronIds = new int[neuronIds.length];
System.arraycopy(neuronIds, 0, this.neuronIds, 0, neuronIds.length);
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(neuronIds);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Layer other = (Layer) obj;
if (!Arrays.equals(neuronIds, other.neuronIds))
return false;
return true;
}
}

View File

@@ -0,0 +1,157 @@
package net.woodyfolsom.msproj.ann;
import java.io.InputStream;
import java.io.OutputStream;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlRootElement;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann.math.Tanh;
@XmlRootElement
public class MultiLayerPerceptron extends FeedforwardNetwork {
private boolean biased;
private Layer[] layers;
public MultiLayerPerceptron() {
this(false, 1, 1);
}
public MultiLayerPerceptron(boolean biased, int... layerSizes) {
super(biased);
int numLayers = layerSizes.length;
if (numLayers < 2) {
throw new IllegalArgumentException("# of layers must be >= 2");
}
this.layers = new Layer[numLayers];
for (int layerIndex = 0; layerIndex < numLayers; layerIndex++) {
int layerSize = layerSizes[layerIndex];
if (layerSize < 1) {
throw new IllegalArgumentException("Layer size must be >= 1");
}
Layer newLayer;
if (layerIndex == numLayers - 1) {
newLayer = createNewLayer(layerIndex, layerSize, Sigmoid.function);
} else {
newLayer = createNewLayer(layerIndex, layerSize, Tanh.function);
}
if (layerIndex > 0) {
Layer prevLayer = layers[layerIndex - 1];
for (int j = 0; j < newLayer.size(); j++) {
if (biased) {
createBiasConnection(newLayer.getNeuronId(j),0.0);
}
for (int i = 0; i < prevLayer.size(); i++) {
addConnection(new Connection(prevLayer.getNeuronId(i),
newLayer.getNeuronId(j), 0.0));
}
}
}
}
}
private Layer createNewLayer(int layerIndex, int layerSize, ActivationFunction afunc) {
Layer layer = new Layer(layerSize);
layers[layerIndex] = layer;
for (int n = 0; n < layerSize; n++) {
Neuron neuron = createNeuron(layerIndex == 0, afunc);
layer.setNeuronId(n, neuron.getId());
}
return layer;
}
@XmlElement
public Layer[] getLayers() {
return layers;
}
@Override
protected double[] getOutput() {
Layer outputLayer = layers[layers.length - 1];
double output[] = new double[outputLayer.size()];
for (int n = 0; n < outputLayer.size(); n++) {
output[n] = getNeuron(outputLayer.getNeuronId(n)).getOutput();
}
return output;
}
@Override
public Neuron[] getOutputNeurons() {
Layer outputLayer = layers[layers.length - 1];
Neuron[] outputNeurons = new Neuron[outputLayer.size()];
for (int i = 0; i < outputLayer.size(); i++) {
outputNeurons[i] = getNeuron(outputLayer.getNeuronId(i));
}
return outputNeurons;
}
@Override
protected void setInput(double[] input) {
Layer inputLayer = layers[0];
for (int n = 0; n < inputLayer.size(); n++) {
try {
getNeuron(inputLayer.getNeuronId(n)).setInput(input[n]);
} catch (NullPointerException npe) {
npe.printStackTrace();
}
}
}
public void setLayers(Layer[] layers) {
this.layers = layers;
}
@Override
public boolean load(InputStream is) {
try {
JAXBContext jc = JAXBContext
.newInstance(MultiLayerPerceptron.class);
Unmarshaller u = jc.createUnmarshaller();
MultiLayerPerceptron mlp = (MultiLayerPerceptron) u.unmarshal(is);
super.setActivationFunction(mlp.getActivationFunction());
super.setConnections(mlp.getConnections());
super.setNeurons(mlp.getNeurons());
this.biased = mlp.biased;
this.layers = mlp.layers;
return true;
} catch (JAXBException je) {
je.printStackTrace();
return false;
}
}
@Override
public boolean save(OutputStream os) {
try {
JAXBContext jc = JAXBContext
.newInstance(MultiLayerPerceptron.class);
Marshaller m = jc.createMarshaller();
m.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true);
m.marshal(this, os);
//m.marshal(this, System.out);
return true;
} catch (JAXBException je) {
je.printStackTrace();
return false;
}
}
}

View File

@@ -0,0 +1,38 @@
package net.woodyfolsom.msproj.ann;
public class NNData {
private final double[] values;
private final String[] fields;
public NNData(String[] fields, double[] values) {
this.fields = fields;
this.values = values;
}
public NNData(NNData that) {
this.fields = that.fields;
this.values = that.values;
}
public String[] getFields() {
return fields;
}
public double[] getValues() {
return values;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < fields.length; i++) {
if (i > 0) {
sb.append(", " );
}
sb.append(fields[i] + "=" + values[i]);
}
sb.append("]");
return sb.toString();
}
}

View File

@@ -0,0 +1,24 @@
package net.woodyfolsom.msproj.ann;
public class NNDataPair {
private final NNData input;
private final NNData ideal;
public NNDataPair(NNData actual, NNData ideal) {
this.input = actual;
this.ideal = ideal;
}
public NNData getInput() {
return input;
}
public NNData getIdeal() {
return ideal;
}
@Override
public String toString() {
return input.toString() + " => " + ideal.toString();
}
}

View File

@@ -1,31 +1,29 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.IOException; import java.io.InputStream;
import java.io.OutputStream;
import java.util.List; import java.util.List;
import java.util.Set;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.BasicNetwork;
public interface NeuralNetFilter { public interface NeuralNetFilter {
BasicNetwork getNeuralNetwork(); int getActualTrainingEpochs();
public int getActualTrainingEpochs(); int getInputSize();
public int getInputSize();
public int getMaxTrainingEpochs();
public int getOutputSize();
public double computeValue(MLData input); int getMaxTrainingEpochs();
public double[] computeVector(MLData input);
public void learn(MLDataSet trainingSet); int getOutputSize();
public void learn(Set<List<MLDataPair>> trainingSet);
public void load(String fileName) throws IOException; boolean load(InputStream input);
public void reset();
public void reset(int seed); boolean save(OutputStream output);
public void save(String fileName) throws IOException;
public void setMaxTrainingEpochs(int max); void setMaxTrainingEpochs(int max);
NNData compute(NNDataPair input);
//Due to Java type erasure, overloading a method
//simply named 'learn' which takes Lists would be problematic
void learnPatterns(List<NNDataPair> trainingSet);
void learnSequences(List<List<NNDataPair>> trainingSet);
} }

View File

@@ -0,0 +1,107 @@
package net.woodyfolsom.msproj.ann;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlElements;
import javax.xml.bind.annotation.XmlTransient;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Linear;
import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann.math.Tanh;
public class Neuron {
private ActivationFunction activationFunction;
private int id;
private transient double input = 0.0;
private transient double gradient = 0.0;
public Neuron() {
//no-arg constructor for JAXB
}
public Neuron(ActivationFunction activationFunction, int id) {
this.activationFunction = activationFunction;
this.id = id;
}
public void addInput(double value) {
input += value;
}
@XmlElements({
@XmlElement(name="LinearActivationFunction",type=Linear.class),
@XmlElement(name="SigmoidActivationFunction",type=Sigmoid.class),
@XmlElement(name="TanhActivationFunction",type=Tanh.class)
})
public ActivationFunction getActivationFunction() {
return activationFunction;
}
@XmlAttribute
public int getId() {
return id;
}
@XmlTransient
public double getGradient() {
return gradient;
}
@XmlTransient
public double getInput() {
return input;
}
public double getOutput() {
return activationFunction.calculate(input);
}
public void setGradient(double value) {
this.gradient = value;
}
public void setInput(double input) {
this.input = input;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime
* result
+ ((activationFunction == null) ? 0 : activationFunction
.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Neuron other = (Neuron) obj;
if (activationFunction == null) {
if (other.activationFunction != null)
return false;
} else if (!activationFunction.equals(other.activationFunction))
return false;
return true;
}
public void setActivationFunction(ActivationFunction activationFunction) {
this.activationFunction = activationFunction;
}
@Override
public String toString() {
return "Neuron #" + id +", input: " + input + ", gradient: " + gradient;
}
}

View File

@@ -1,5 +1,5 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
public class FusekiLearner { public class ObjectiveFunction {
} }

View File

@@ -0,0 +1,298 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.policy.NeuralNetPolicy;
import net.woodyfolsom.msproj.policy.Policy;
import net.woodyfolsom.msproj.policy.RandomMovePolicy;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
public class PassFilterTrainer { // implements epsilon-greedy trainer? online
// version of NeuralNetFilter
private boolean training = true;
public static void main(String[] args) throws IOException {
double alpha = 0.50;
double lambda = 0.1;
int maxGames = 1500;
new PassFilterTrainer().trainNetwork(alpha, lambda, maxGames);
}
public void trainNetwork(double alpha, double lambda, int maxGames)
throws IOException {
FeedforwardNetwork neuralNetwork;
GameConfig gameConfig = new GameConfig(9);
if (training) {
neuralNetwork = new MultiLayerPerceptron(true, 2, 2, 1);
neuralNetwork.setName("PassFilter" + gameConfig.getSize());
neuralNetwork.initWeights();
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
System.out.println("Playing untrained games.");
for (int i = 0; i < 10; i++) {
GameRecord gameRecord = new GameRecord(gameConfig);
System.out.println("" + (i + 1) + ". "
+ playOptimal(neuralNetwork, gameRecord).getResult());
}
System.out.println("Learning from " + maxGames
+ " games of random self-play");
int gamesPlayed = 0;
List<GameResult> results = new ArrayList<GameResult>();
do {
GameRecord gameRecord = new GameRecord(gameConfig);
playEpsilonGreedy(0.5, neuralNetwork, trainer, gameRecord);
System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++;
results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames);
System.out.println("Results of every 10th training game:");
for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) {
System.out.println("" + (i + 1) + ". " + results.get(i));
}
}
System.out.println("Learned network after " + maxGames
+ " training games.");
} else {
System.out.println("Loading TicTacToe network from file.");
neuralNetwork = new MultiLayerPerceptron();
FileInputStream fis = new FileInputStream(new File("pass.net"));
if (!new MultiLayerPerceptron().load(fis)) {
System.out.println("Error loading pass.net from file.");
return;
}
fis.close();
}
evalTestCases(neuralNetwork);
System.out.println("Playing optimal games.");
List<GameResult> gameResults = new ArrayList<GameResult>();
for (int i = 0; i < 10; i++) {
GameRecord gameRecord = new GameRecord(gameConfig);
gameResults.add(playOptimal(neuralNetwork, gameRecord).getResult());
}
boolean suboptimalPlay = false;
System.out.println("Optimal game summary: ");
for (int i = 0; i < gameResults.size(); i++) {
GameResult result = gameResults.get(i);
System.out.println("" + (i + 1) + ". " + result);
}
File output = new File("pass.net");
FileOutputStream fos = new FileOutputStream(output);
neuralNetwork.save(fos);
System.out.println("Playing optimal vs random games.");
for (int i = 0; i < 10; i++) {
GameRecord gameRecord = new GameRecord(gameConfig);
System.out.println(""
+ (i + 1)
+ ". "
+ playOptimalVsRandom(neuralNetwork, gameRecord)
.getResult());
}
if (suboptimalPlay) {
System.out.println("Suboptimal play detected!");
}
}
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
double[][] validationSet = new double[4][];
//losing, opponent did not pass
//don't pass
//(0.0 1.0 0.0) => 0.0
validationSet[0] = new double[] { -1.0, -1.0 };
//winning, opponent did not pass
//maybe pass?
//(1.0 0.0 0.0) => ?
validationSet[1] = new double[] { 1.0, -1.0 };
//winning, opponent passed
//pass!
//(1.0 0.0 1.0) => 1.0
validationSet[2] = new double[] { 1.0, 1.0 };
//losing, opponent passed
//don't pass!
//(0.0 1.0 1.0) => 1.0
validationSet[3] = new double[] { -1.0, 1.0 };
String[] inputNames = NNDataSetFactory.getInputFields(PassFilterTrainer.class);
String[] outputNames = NNDataSetFactory.getOutputFields(PassFilterTrainer.class);
System.out.println("Output from eval set (learned network):");
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
}
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork,
GameRecord gameRecord) {
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
neuralNetPolicy.setPassFilter(neuralNetwork);
Policy randomPolicy = new RandomMovePolicy();
GameConfig gameConfig = gameRecord.getGameConfig();
GameState gameState = gameRecord.getGameState();
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
int turnNo = 0;
do {
Action action;
GameState nextState;
Player playerToMove = gameState.getPlayerToMove();
action = policies[turnNo % 2].getAction(gameConfig, gameState,
playerToMove);
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Illegal move: " + action);
}
nextState = gameRecord.getGameState();
//System.out.println("Action " + action + " selected by policy "
// + policies[turnNo % 2].getName());
//System.out.println("Next board state: " + nextState);
gameState = nextState;
turnNo++;
} while (!gameState.isTerminal());
return gameRecord;
}
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork,
GameRecord gameRecord) {
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
neuralNetPolicy.setPassFilter(neuralNetwork);
if (gameRecord.getNumTurns() > 0) {
throw new RuntimeException(
"PlayOptimal requires a new GameRecord with no turns played.");
}
GameState gameState;
do {
Action action;
GameState nextState;
Player playerToMove = gameRecord.getPlayerToMove();
action = neuralNetPolicy.getAction(gameRecord.getGameConfig(),
gameRecord.getGameState(), playerToMove);
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Invalid move played: " + action);
}
nextState = gameRecord.getGameState();
//System.out.println("Action " + action + " selected by policy "
// + neuralNetPolicy.getName());
//System.out.println("Next board state: " + nextState);
gameState = nextState;
} while (!gameState.isTerminal());
return gameRecord;
}
private GameRecord playEpsilonGreedy(double epsilon,
FeedforwardNetwork neuralNetwork, TrainingMethod trainer,
GameRecord gameRecord) {
Policy randomPolicy = new RandomMovePolicy();
NeuralNetPolicy neuralNetPolicy = new NeuralNetPolicy();
neuralNetPolicy.setPassFilter(neuralNetwork);
if (gameRecord.getNumTurns() > 0) {
throw new RuntimeException(
"PlayOptimal requires a new GameRecord with no turns played.");
}
GameState gameState = gameRecord.getGameState();
NNDataPair statePair;
Policy selectedPolicy;
trainer.zeroTraces(neuralNetwork);
do {
Action action;
GameState nextState;
Player playerToMove = gameRecord.getPlayerToMove();
if (Math.random() < epsilon) {
selectedPolicy = randomPolicy;
action = selectedPolicy
.getAction(gameRecord.getGameConfig(),
gameRecord.getGameState(),
gameRecord.getPlayerToMove());
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Illegal move played: " + action);
}
nextState = gameRecord.getGameState();
} else {
selectedPolicy = neuralNetPolicy;
action = selectedPolicy
.getAction(gameRecord.getGameConfig(),
gameRecord.getGameState(),
gameRecord.getPlayerToMove());
if (!gameRecord.play(playerToMove, action)) {
throw new RuntimeException("Illegal move played: " + action);
}
nextState = gameRecord.getGameState();
statePair = NNDataSetFactory.createDataPair(gameState, PassFilterTrainer.class);
NNDataPair nextStatePair = NNDataSetFactory
.createDataPair(nextState, PassFilterTrainer.class);
trainer.iteratePattern(neuralNetwork, statePair,
nextStatePair.getIdeal());
}
gameState = nextState;
} while (!gameState.isTerminal());
// finally, reinforce the actual reward
statePair = NNDataSetFactory.createDataPair(gameState, PassFilterTrainer.class);
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
return gameRecord;
}
private void testNetwork(FeedforwardNetwork neuralNetwork,
double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames,
new double[] { 0.0 }));
System.out.println(dp + " => " + neuralNetwork.compute(dp));
}
}
}

View File

@@ -1,5 +0,0 @@
package net.woodyfolsom.msproj.ann;
public class ShapeLearner {
}

View File

@@ -0,0 +1,261 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.tictactoe.Action;
import net.woodyfolsom.msproj.tictactoe.GameRecord;
import net.woodyfolsom.msproj.tictactoe.GameRecord.RESULT;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
import net.woodyfolsom.msproj.tictactoe.NeuralNetPolicy;
import net.woodyfolsom.msproj.tictactoe.Policy;
import net.woodyfolsom.msproj.tictactoe.RandomPolicy;
import net.woodyfolsom.msproj.tictactoe.State;
public class TTTFilterTrainer { // implements epsilon-greedy trainer? online
// version of NeuralNetFilter
private boolean training = true;
public static void main(String[] args) throws IOException {
double alpha = 0.025;
double lambda = .10;
int maxGames = 100000;
new TTTFilterTrainer().trainNetwork(alpha, lambda, maxGames);
}
public void trainNetwork(double alpha, double lambda, int maxGames)
throws IOException {
FeedforwardNetwork neuralNetwork;
if (training) {
neuralNetwork = new MultiLayerPerceptron(true, 9, 9, 1);
neuralNetwork.setName("TicTacToe");
neuralNetwork.initWeights();
TrainingMethod trainer = new TemporalDifference(alpha, lambda);
System.out.println("Playing untrained games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i + 1) + ". "
+ playOptimal(neuralNetwork).getResult());
}
System.out.println("Learning from " + maxGames
+ " games of random self-play");
int gamesPlayed = 0;
List<RESULT> results = new ArrayList<RESULT>();
do {
GameRecord gameRecord = playEpsilonGreedy(0.9, neuralNetwork,
trainer);
System.out.println("Winner: " + gameRecord.getResult());
gamesPlayed++;
results.add(gameRecord.getResult());
} while (gamesPlayed < maxGames);
System.out.println("Results of every 10th training game:");
for (int i = 0; i < results.size(); i++) {
if (i % 10 == 0) {
System.out.println("" + (i + 1) + ". " + results.get(i));
}
}
System.out.println("Learned network after " + maxGames
+ " training games.");
} else {
System.out.println("Loading TicTacToe network from file.");
neuralNetwork = new MultiLayerPerceptron();
FileInputStream fis = new FileInputStream(new File("ttt.net"));
if (!new MultiLayerPerceptron().load(fis)) {
System.out.println("Error loading ttt.net from file.");
return;
}
fis.close();
}
evalTestCases(neuralNetwork);
System.out.println("Playing optimal games.");
List<RESULT> gameResults = new ArrayList<RESULT>();
for (int i = 0; i < 10; i++) {
gameResults.add(playOptimal(neuralNetwork).getResult());
}
boolean suboptimalPlay = false;
System.out.println("Optimal game summary: ");
for (int i = 0; i < gameResults.size(); i++) {
RESULT result = gameResults.get(i);
System.out.println("" + (i + 1) + ". " + result);
if (result != RESULT.X_WINS) {
suboptimalPlay = true;
}
}
File output = new File("ttt.net");
FileOutputStream fos = new FileOutputStream(output);
neuralNetwork.save(fos);
System.out.println("Playing optimal vs random games.");
for (int i = 0; i < 10; i++) {
System.out.println("" + (i + 1) + ". "
+ playOptimalVsRandom(neuralNetwork).getResult());
}
if (suboptimalPlay) {
System.out.println("Suboptimal play detected!");
}
}
private void evalTestCases(FeedforwardNetwork neuralNetwork) {
double[][] validationSet = new double[8][];
// empty board
validationSet[0] = new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// center
validationSet[1] = new double[] { 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0 };
// top edge
validationSet[2] = new double[] { 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// left edge
validationSet[3] = new double[] { 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// corner
validationSet[4] = new double[] { 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0 };
// win
validationSet[5] = new double[] { 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0,
-1.0, 0.0 };
// loss
validationSet[6] = new double[] { -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
0.0, -1.0 };
// about to win
validationSet[7] = new double[] { -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0,
-1.0, 0.0 };
String[] inputNames = new String[] { "00", "01", "02", "10", "11",
"12", "20", "21", "22" };
String[] outputNames = new String[] { "values" };
System.out.println("Output from eval set (learned network):");
testNetwork(neuralNetwork, validationSet, inputNames, outputNames);
}
private GameRecord playOptimalVsRandom(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
Policy randomPolicy = new RandomPolicy();
State state = gameRecord.getState();
Policy[] policies = new Policy[] { neuralNetPolicy, randomPolicy };
int turnNo = 0;
do {
Action action;
State nextState;
action = policies[turnNo % 2].getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
System.out.println("Action " + action + " selected by policy "
+ policies[turnNo % 2].getName());
System.out.println("Next board state: " + nextState);
state = nextState;
turnNo++;
} while (!state.isTerminal());
return gameRecord;
}
private GameRecord playOptimal(FeedforwardNetwork neuralNetwork) {
GameRecord gameRecord = new GameRecord();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
State state = gameRecord.getState();
do {
Action action;
State nextState;
action = neuralNetPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
System.out.println("Action " + action + " selected by policy "
+ neuralNetPolicy.getName());
System.out.println("Next board state: " + nextState);
state = nextState;
} while (!state.isTerminal());
return gameRecord;
}
private GameRecord playEpsilonGreedy(double epsilon,
FeedforwardNetwork neuralNetwork, TrainingMethod trainer) {
GameRecord gameRecord = new GameRecord();
Policy randomPolicy = new RandomPolicy();
Policy neuralNetPolicy = new NeuralNetPolicy(neuralNetwork);
// System.out.println("Playing epsilon-greedy game.");
State state = gameRecord.getState();
NNDataPair statePair;
Policy selectedPolicy;
trainer.zeroTraces(neuralNetwork);
do {
Action action;
State nextState;
if (Math.random() < epsilon) {
selectedPolicy = randomPolicy;
action = selectedPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
} else {
selectedPolicy = neuralNetPolicy;
action = selectedPolicy.getAction(gameRecord.getState());
nextState = gameRecord.apply(action);
statePair = NNDataSetFactory.createDataPair(state);
NNDataPair nextStatePair = NNDataSetFactory
.createDataPair(nextState);
trainer.iteratePattern(neuralNetwork, statePair,
nextStatePair.getIdeal());
}
// System.out.println("Action " + action + " selected by policy " +
// selectedPolicy.getName());
// System.out.println("Next board state: " + nextState);
state = nextState;
} while (!state.isTerminal());
// finally, reinforce the actual reward
statePair = NNDataSetFactory.createDataPair(state);
trainer.iteratePattern(neuralNetwork, statePair, statePair.getIdeal());
return gameRecord;
}
private void testNetwork(FeedforwardNetwork neuralNetwork,
double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
NNDataPair dp = new NNDataPair(new NNData(inputNames,
validationSet[valIndex]), new NNData(outputNames,
new double[] { 0.0 }));
System.out.println(dp + " => " + neuralNetwork.compute(dp));
}
}
}

View File

@@ -0,0 +1,123 @@
package net.woodyfolsom.msproj.ann;
import java.util.List;
public class TemporalDifference extends TrainingMethod {
private final double gamma;
private final double lambda;
public TemporalDifference(double alpha, double lambda) {
this.gamma = alpha;
this.lambda = lambda;
}
@Override
public void iteratePatterns(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
public double computePatternError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet) {
int numDataPairs = trainingSet.size();
int outputSize = neuralNetwork.getOutput().length;
int totalOutputSize = outputSize * numDataPairs;
double[] actuals = new double[totalOutputSize];
double[] ideals = new double[totalOutputSize];
for (int dataPair = 0; dataPair < numDataPairs; dataPair++) {
NNDataPair nnDataPair = trainingSet.get(dataPair);
double[] actual = neuralNetwork.compute(nnDataPair.getInput()
.getValues());
double[] ideal = nnDataPair.getIdeal().getValues();
int offset = dataPair * outputSize;
System.arraycopy(actual, 0, actuals, offset, outputSize);
System.arraycopy(ideal, 0, ideals, offset, outputSize);
}
double MSSE = errorFunction.compute(ideals, actuals);
return MSSE;
}
@Override
protected void backPropagate(FeedforwardNetwork neuralNetwork, NNData ideal) {
Neuron[] outputNeurons = neuralNetwork.getOutputNeurons();
double[] idealValues = ideal.getValues();
for (int i = 0; i < idealValues.length; i++) {
double input = outputNeurons[i].getInput();
double derivative = outputNeurons[i].getActivationFunction()
.derivative(input);
outputNeurons[i].setGradient(outputNeurons[i].getGradient()
+ derivative
* (idealValues[i] - outputNeurons[i].getOutput()));
}
// walking down the list of Neurons in reverse order, propagate the
// error
Neuron[] neurons = neuralNetwork.getNeurons();
for (int n = neurons.length - 1; n >= 0; n--) {
Neuron neuron = neurons[n];
double error = neuron.getGradient();
Connection[] connectionsFromN = neuralNetwork
.getConnectionsFrom(neuron.getId());
if (connectionsFromN.length > 0) {
double derivative = neuron.getActivationFunction().derivative(
neuron.getInput());
for (Connection connection : connectionsFromN) {
error += derivative
* connection.getWeight()
* neuralNetwork.getNeuron(connection.getDest())
.getGradient();
}
}
neuron.setGradient(error);
}
}
private void updateWeights(FeedforwardNetwork neuralNetwork,
double predictionError) {
for (Connection connection : neuralNetwork.getConnections()) {
Neuron srcNeuron = neuralNetwork.getNeuron(connection.getSrc());
Neuron destNeuron = neuralNetwork.getNeuron(connection.getDest());
double delta = gamma * srcNeuron.getOutput()
* destNeuron.getGradient() + connection.getTrace() * lambda;
connection.addDelta(delta);
}
}
@Override
public void iterateSequences(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
public double computeSequenceError(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet) {
throw new UnsupportedOperationException();
}
@Override
protected void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward) {
zeroGradients(neuralNetwork);
NNData ideal = nextReward;
NNData actual = neuralNetwork.compute(statePair);
// backpropagate the gradients w.r.t. output error
backPropagate(neuralNetwork, ideal);
double predictionError = statePair.getIdeal().getValues()[0] // reward_t
+ actual.getValues()[0] - nextReward.getValues()[0];
updateWeights(neuralNetwork, predictionError);
}
}

View File

@@ -1,484 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.IntRange;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.ml.train.strategy.end.EndTrainingStrategy;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.Train;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.neural.networks.training.strategy.SmartLearningRate;
import org.encog.neural.networks.training.strategy.SmartMomentum;
import org.encog.util.EncogValidate;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.concurrency.TaskGroup;
import org.encog.util.logging.EncogLogging;
/**
* This class started as a verbatim copy of BackPropagation from the open-source
* Encog framework. It was merged with its super-classes to access protected
* fields without resorting to reflection.
*/
public class TemporalDifferenceLearning implements MLTrain, Momentum,
LearningRate, Train, MultiThreadable {
// New fields for TD(lambda)
private final double lambda;
// end new fields
// BackProp
public static final String LAST_DELTA = "LAST_DELTA";
private double learningRate;
private double momentum;
private double[] lastDelta;
// End BackProp
// Propagation
private FlatNetwork currentFlatNetwork;
private int numThreads;
protected double[] gradients;
private double[] lastGradient;
protected ContainsFlat network;
// private MLDataSet indexable;
private Set<List<MLDataPair>> indexable;
private GradientWorker[] workers;
private double totalError;
protected double lastError;
private Throwable reportedException;
private double[] flatSpot;
private boolean shouldFixFlatSpot;
private ErrorFunction ef = new LinearErrorFunction();
// End Propagation
// BasicTraining
private final List<Strategy> strategies = new ArrayList<Strategy>();
private Set<List<MLDataPair>> training;
private double error;
private int iteration;
private TrainingImplementationType implementationType;
// End BasicTraining
public TemporalDifferenceLearning(final ContainsFlat network,
final Set<List<MLDataPair>> training, double lambda) {
this(network, training, 0, 0, lambda);
addStrategy(new SmartLearningRate());
addStrategy(new SmartMomentum());
}
public TemporalDifferenceLearning(final ContainsFlat network,
Set<List<MLDataPair>> training, final double theLearnRate,
final double theMomentum, double lambda) {
initPropagation(network, training);
// TODO consider how to re-implement validation
// ValidateNetwork.validateMethodToData(network, training);
this.momentum = theMomentum;
this.learningRate = theLearnRate;
this.lastDelta = new double[network.getFlat().getWeights().length];
this.lambda = lambda;
}
private void initPropagation(final ContainsFlat network,
final Set<List<MLDataPair>> training) {
initBasicTraining(TrainingImplementationType.Iterative);
this.network = network;
this.currentFlatNetwork = network.getFlat();
setTraining(training);
this.gradients = new double[this.currentFlatNetwork.getWeights().length];
this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
this.indexable = training;
this.numThreads = 0;
this.reportedException = null;
this.shouldFixFlatSpot = true;
}
private void initBasicTraining(TrainingImplementationType implementationType) {
this.implementationType = implementationType;
}
// Methods from BackPropagation
@Override
public boolean canContinue() {
return false;
}
public double[] getLastDelta() {
return this.lastDelta;
}
@Override
public double getLearningRate() {
return this.learningRate;
}
@Override
public double getMomentum() {
return this.momentum;
}
public boolean isValidResume(final TrainingContinuation state) {
if (!state.getContents().containsKey(Backpropagation.LAST_DELTA)) {
return false;
}
if (!state.getTrainingType().equals(getClass().getSimpleName())) {
return false;
}
final double[] d = (double[]) state.get(Backpropagation.LAST_DELTA);
return d.length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
}
@Override
public TrainingContinuation pause() {
final TrainingContinuation result = new TrainingContinuation();
result.setTrainingType(this.getClass().getSimpleName());
result.set(Backpropagation.LAST_DELTA, this.lastDelta);
return result;
}
@Override
public void resume(final TrainingContinuation state) {
if (!isValidResume(state)) {
throw new TrainingError("Invalid training resume data length");
}
this.lastDelta = ((double[]) state.get(Backpropagation.LAST_DELTA));
}
@Override
public void setLearningRate(final double rate) {
this.learningRate = rate;
}
@Override
public void setMomentum(final double m) {
this.momentum = m;
}
public double updateWeight(final double[] gradients,
final double[] lastGradient, final int index) {
final double delta = (gradients[index] * this.learningRate)
+ (this.lastDelta[index] * this.momentum);
this.lastDelta[index] = delta;
System.out.println("Updating weights for connection: " + index
+ " with lambda: " + lambda);
return delta;
}
public void initOthers() {
}
// End methods from BackPropagation
// Methods from Propagation
public void finishTraining() {
basicFinishTraining();
}
public FlatNetwork getCurrentFlatNetwork() {
return this.currentFlatNetwork;
}
public MLMethod getMethod() {
return this.network;
}
public void iteration() {
iteration(1);
}
public void rollIteration() {
this.iteration++;
}
public void iteration(final int count) {
try {
for (int i = 0; i < count; i++) {
preIteration();
rollIteration();
calculateGradients();
if (this.currentFlatNetwork.isLimited()) {
learnLimited();
} else {
learn();
}
this.lastError = this.getError();
for (final GradientWorker worker : this.workers) {
EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(),
0, worker.getWeights(), 0,
this.currentFlatNetwork.getWeights().length);
}
if (this.currentFlatNetwork.getHasContext()) {
copyContexts();
}
if (this.reportedException != null) {
throw (new EncogError(this.reportedException));
}
postIteration();
EncogLogging.log(EncogLogging.LEVEL_INFO,
"Training iteration done, error: " + getError());
}
} catch (final ArrayIndexOutOfBoundsException ex) {
EncogValidate.validateNetworkForTraining(this.network,
getTraining());
throw new EncogError(ex);
}
}
public void setThreadCount(final int numThreads) {
this.numThreads = numThreads;
}
@Override
public int getThreadCount() {
return this.numThreads;
}
public void fixFlatSpot(boolean b) {
this.shouldFixFlatSpot = b;
}
public void setErrorFunction(ErrorFunction ef) {
this.ef = ef;
}
public void calculateGradients() {
if (this.workers == null) {
init();
}
if (this.currentFlatNetwork.getHasContext()) {
this.workers[0].getNetwork().clearContext();
}
this.totalError = 0;
if (this.workers.length > 1) {
final TaskGroup group = EngineConcurrency.getInstance()
.createTaskGroup();
for (final GradientWorker worker : this.workers) {
EngineConcurrency.getInstance().processTask(worker, group);
}
group.waitForComplete();
} else {
this.workers[0].run();
}
this.setError(this.totalError / this.workers.length);
}
/**
* Copy the contexts to keep them consistent with multithreaded training.
*/
private void copyContexts() {
// copy the contexts(layer outputO from each group to the next group
for (int i = 0; i < (this.workers.length - 1); i++) {
final double[] src = this.workers[i].getNetwork().getLayerOutput();
final double[] dst = this.workers[i + 1].getNetwork()
.getLayerOutput();
EngineArray.arrayCopy(src, dst);
}
// copy the contexts from the final group to the real network
EngineArray.arrayCopy(this.workers[this.workers.length - 1]
.getNetwork().getLayerOutput(), this.currentFlatNetwork
.getLayerOutput());
}
private void init() {
// fix flat spot, if needed
this.flatSpot = new double[this.currentFlatNetwork
.getActivationFunctions().length];
if (this.shouldFixFlatSpot) {
for (int i = 0; i < this.currentFlatNetwork
.getActivationFunctions().length; i++) {
final ActivationFunction af = this.currentFlatNetwork
.getActivationFunctions()[i];
if (af instanceof ActivationSigmoid) {
this.flatSpot[i] = 0.1;
} else {
this.flatSpot[i] = 0.0;
}
}
} else {
EngineArray.fill(this.flatSpot, 0.0);
}
// setup workers
final DetermineWorkload determine = new DetermineWorkload(
this.numThreads, (int) this.indexable.size());
// this.numThreads, (int) this.indexable.getRecordCount());
this.workers = new GradientWorker[determine.getThreadCount()];
int index = 0;
// handle CPU
for (final IntRange r : determine.calculateWorkers()) {
this.workers[index++] = new GradientWorker(
this.currentFlatNetwork.clone(), this, new HashSet(
this.indexable), r.getLow(), r.getHigh(),
this.flatSpot, this.ef);
}
initOthers();
}
public void report(final double[] gradients, final double error,
final Throwable ex) {
synchronized (this) {
if (ex == null) {
for (int i = 0; i < gradients.length; i++) {
this.gradients[i] += gradients[i];
}
this.totalError += error;
} else {
this.reportedException = ex;
}
}
}
protected void learn() {
final double[] weights = this.currentFlatNetwork.getWeights();
for (int i = 0; i < this.gradients.length; i++) {
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
this.gradients[i] = 0;
}
}
protected void learnLimited() {
final double limit = this.currentFlatNetwork.getConnectionLimit();
final double[] weights = this.currentFlatNetwork.getWeights();
for (int i = 0; i < this.gradients.length; i++) {
if (Math.abs(weights[i]) < limit) {
weights[i] = 0;
} else {
weights[i] += updateWeight(this.gradients, this.lastGradient, i);
}
this.gradients[i] = 0;
}
}
public double[] getLastGradient() {
return lastGradient;
}
// End methods from Propagation
// Methods from BasicTraining/
public void addStrategy(final Strategy strategy) {
strategy.init(this);
this.strategies.add(strategy);
}
public void basicFinishTraining() {
}
public double getError() {
return this.error;
}
public int getIteration() {
return this.iteration;
}
public List<Strategy> getStrategies() {
return this.strategies;
}
public MLDataSet getTraining() {
throw new UnsupportedOperationException(
"This learning method operates on Set<List<MLData>>, not MLDataSet");
}
public boolean isTrainingDone() {
for (Strategy strategy : this.strategies) {
if (strategy instanceof EndTrainingStrategy) {
EndTrainingStrategy end = (EndTrainingStrategy) strategy;
if (end.shouldStop()) {
return true;
}
}
}
return false;
}
public void postIteration() {
for (final Strategy strategy : this.strategies) {
strategy.postIteration();
}
}
public void preIteration() {
this.iteration++;
for (final Strategy strategy : this.strategies) {
strategy.preIteration();
}
}
public void setError(final double error) {
this.error = error;
}
public void setIteration(final int iteration) {
this.iteration = iteration;
}
public void setTraining(final Set<List<MLDataPair>> training) {
this.training = training;
}
public TrainingImplementationType getImplementationType() {
return this.implementationType;
}
// End Methods from BasicTraining
}

View File

@@ -0,0 +1,43 @@
package net.woodyfolsom.msproj.ann;
import java.util.List;
import net.woodyfolsom.msproj.ann.math.ErrorFunction;
import net.woodyfolsom.msproj.ann.math.MSSE;
public abstract class TrainingMethod {
protected final ErrorFunction errorFunction;
public TrainingMethod() {
this.errorFunction = MSSE.function;
}
protected abstract void iteratePattern(FeedforwardNetwork neuralNetwork,
NNDataPair statePair, NNData nextReward);
protected abstract void iteratePatterns(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet);
protected abstract double computePatternError(FeedforwardNetwork neuralNetwork,
List<NNDataPair> trainingSet);
protected abstract void iterateSequences(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet);
protected abstract void backPropagate(FeedforwardNetwork neuralNetwork, NNData output);
protected abstract double computeSequenceError(FeedforwardNetwork neuralNetwork,
List<List<NNDataPair>> trainingSet);
protected void zeroGradients(FeedforwardNetwork neuralNetwork) {
for (Neuron neuron : neuralNetwork.getNeurons()) {
neuron.setGradient(0.0);
}
}
protected void zeroTraces(FeedforwardNetwork neuralNetwork) {
for (Connection conn : neuralNetwork.getConnections()) {
conn.setTrace(0.0);
}
}
}

View File

@@ -1,112 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.util.List;
import java.util.Set;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
public class WinFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter {
public WinFilter() {
// create a neural network, without using a factory
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null, false, 2));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 2));
network.getStructure().finalizeStructure();
network.reset();
this.neuralNetwork = network;
}
@Override
public double computeValue(MLData input) {
if (input instanceof GameStateMLData) {
double[] idealVector = computeVector(input);
GameState gameState = ((GameStateMLData) input).getGameState();
Player playerToMove = gameState.getPlayerToMove();
if (playerToMove == Player.BLACK) {
return idealVector[0];
} else if (playerToMove == Player.WHITE) {
return idealVector[1];
} else {
throw new RuntimeException("Invalid GameState.playerToMove: "
+ playerToMove);
}
} else {
throw new UnsupportedOperationException(
"This NeuralNetFilter only accepts GameStates as input.");
}
}
@Override
public double[] computeVector(MLData input) {
if (input instanceof GameStateMLData) {
return neuralNetwork.compute(input).getData();
} else {
throw new UnsupportedOperationException(
"This NeuralNetFilter only accepts GameStates as input.");
}
}
@Override
public void learn(MLDataSet trainingData) {
throw new UnsupportedOperationException("This filter learns a Set<List<MLData>>, not an MLDataSet");
}
@Override
public void learn(Set<List<MLDataPair>> trainingSet) {
// train the neural network
final MLTrain train = new TemporalDifferenceLearning(neuralNetwork,
trainingSet, 0.7, 0.8, 0.25);
actualTrainingEpochs = 0;
do {
train.iteration();
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ train.getError());
actualTrainingEpochs++;
} while (train.getError() > 0.01
&& actualTrainingEpochs <= maxTrainingEpochs);
}
@Override
public void reset() {
neuralNetwork.reset();
}
@Override
public void reset(int seed) {
neuralNetwork.reset(seed);
}
@Override
public BasicNetwork getNeuralNetwork() {
// TODO Auto-generated method stub
return null;
}
@Override
public int getInputSize() {
// TODO Auto-generated method stub
return 0;
}
@Override
public int getOutputSize() {
// TODO Auto-generated method stub
return 0;
}
}

View File

@@ -1,18 +1,5 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.util.List;
import java.util.Set;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
/** /**
* Based on sample code from http://neuroph.sourceforge.net * Based on sample code from http://neuroph.sourceforge.net
* *
@@ -22,62 +9,30 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
public class XORFilter extends AbstractNeuralNetFilter implements public class XORFilter extends AbstractNeuralNetFilter implements
NeuralNetFilter { NeuralNetFilter {
private static final int INPUT_SIZE = 2;
private static final int OUTPUT_SIZE = 1;
public XORFilter() { public XORFilter() {
// create a neural network, without using a factory this(0.8,0.7);
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null, false, 2));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
network.getStructure().finalizeStructure();
network.reset();
this.neuralNetwork = network;
} }
@Override public XORFilter(double learningRate, double momentum) {
public void learn(MLDataSet trainingSet) { super( new MultiLayerPerceptron(true, INPUT_SIZE, 2, OUTPUT_SIZE),
new BackPropagation(learningRate, momentum), 1000, 0.001);
// train the neural network super.getNeuralNetwork().setName("XORFilter");
final MLTrain train = new Backpropagation(neuralNetwork,
trainingSet, 0.7, 0.8);
actualTrainingEpochs = 0;
do {
train.iteration();
System.out.println("Epoch #" + actualTrainingEpochs + " Error:"
+ train.getError());
actualTrainingEpochs++;
} while (train.getError() > 0.01
&& actualTrainingEpochs <= maxTrainingEpochs);
} }
@Override public double compute(double x, double y) {
public double[] computeVector(MLData mlData) { return getNeuralNetwork().compute(new double[]{x,y})[0];
MLDataSet dataset = new BasicMLDataSet(new double[][] { mlData.getData() },
new double[][] { new double[getOutputSize()] });
MLData output = neuralNetwork.compute(dataset.get(0).getInput());
return output.getData();
} }
@Override @Override
public int getInputSize() { public int getInputSize() {
return 2; return INPUT_SIZE;
} }
@Override @Override
public int getOutputSize() { public int getOutputSize() {
// TODO Auto-generated method stub return OUTPUT_SIZE;
return 1;
}
@Override
public double computeValue(MLData input) {
return computeVector(input)[0];
}
@Override
public void learn(Set<List<MLDataPair>> trainingSet) {
throw new UnsupportedOperationException("This Filter learns an MLDataSet, not a Set<List<MLData>>.");
} }
} }

View File

@@ -0,0 +1,53 @@
package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlTransient;
@XmlTransient
public abstract class ActivationFunction {
private String name;
public abstract double calculate(double arg);
public abstract double derivative(double arg);
public ActivationFunction() {
//no-arg constructor for JAXB
}
public ActivationFunction(String name) {
this.name = name;
}
@XmlAttribute
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((name == null) ? 0 : name.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
ActivationFunction other = (ActivationFunction) obj;
if (name == null) {
if (other.name != null)
return false;
} else if (!name.equals(other.name))
return false;
return true;
}
}

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.ann.math;
public interface ErrorFunction {
double compute(double[] ideal, double[] actual);
}

View File

@@ -0,0 +1,19 @@
package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlRootElement;
public class Linear extends ActivationFunction{
public static final Linear function = new Linear();
private Linear() {
super("Linear");
}
public double calculate(double arg) {
return arg;
}
public double derivative(double arg) {
return 0;
}
}

View File

@@ -0,0 +1,22 @@
package net.woodyfolsom.msproj.ann.math;
public class MSSE implements ErrorFunction{
public static final ErrorFunction function = new MSSE();
public double compute(double[] ideal, double[] actual) {
int idealSize = ideal.length;
int actualSize = actual.length;
if (idealSize != actualSize) {
throw new IllegalArgumentException("actualSize != idealSize");
}
double SSE = 0.0;
for (int i = 0; i < idealSize; i++) {
SSE += Math.pow(ideal[i] - actual[i], 2);
}
return SSE / idealSize;
}
}

View File

@@ -0,0 +1,20 @@
package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlRootElement;
public class Sigmoid extends ActivationFunction{
public static final Sigmoid function = new Sigmoid();
private Sigmoid() {
super("Sigmoid");
}
public double calculate(double arg) {
return 1.0 / (1 + Math.pow(Math.E, -1.0 * arg));
}
public double derivative(double arg) {
double eX = Math.exp(arg);
return eX / (Math.pow((1+eX), 2));
}
}

View File

@@ -0,0 +1,23 @@
package net.woodyfolsom.msproj.ann.math;
import javax.xml.bind.annotation.XmlRootElement;
public class Tanh extends ActivationFunction{
public static final Tanh function = new Tanh();
public Tanh() {
super("Tanh");
}
@Override
public double calculate(double arg) {
return Math.tanh(arg);
}
@Override
public double derivative(double arg) {
double tanh = Math.tanh(arg);
return 1 - Math.pow(tanh, 2);
}
}

View File

@@ -16,6 +16,15 @@ public class AlphaBeta implements Policy {
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private int lookAhead; private int lookAhead;
private int numStateEvaluations = 0; private int numStateEvaluations = 0;
@@ -182,4 +191,9 @@ public class AlphaBeta implements Policy {
// TODO Auto-generated method stub // TODO Auto-generated method stub
} }
@Override
public String getName() {
return "Alpha-Beta";
}
} }

View File

@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.gui.Goban; import net.woodyfolsom.msproj.gui.Goban;
public class HumanGuiInput implements Policy { public class HumanGuiInput implements Policy {
private boolean logging;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private Goban goban; private Goban goban;
public HumanGuiInput(Goban goban) { public HumanGuiInput(Goban goban) {
@@ -52,4 +61,9 @@ public class HumanGuiInput implements Policy {
goban.setGameState(gameState); goban.setGameState(gameState);
} }
@Override
public String getName() {
return "HumanGUI";
}
} }

View File

@@ -9,6 +9,15 @@ import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Player;
public class HumanKeyboardInput implements Policy { public class HumanKeyboardInput implements Policy {
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
@Override @Override
public Action getAction(GameConfig gameConfig, GameState gameState, public Action getAction(GameConfig gameConfig, GameState gameState,
@@ -76,4 +85,9 @@ public class HumanKeyboardInput implements Policy {
} }
@Override
public String getName() {
return "HumanKeyboard";
}
} }

View File

@@ -16,6 +16,15 @@ public class Minimax implements Policy {
private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator(); private final ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private int lookAhead; private int lookAhead;
private int numStateEvaluations = 0; private int numStateEvaluations = 0;
@@ -152,7 +161,10 @@ public class Minimax implements Policy {
@Override @Override
public void setState(GameState gameState) { public void setState(GameState gameState) {
// TODO Auto-generated method stub }
@Override
public String getName() {
return "Minimax";
} }
} }

View File

@@ -15,6 +15,15 @@ import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public abstract class MonteCarlo implements Policy { public abstract class MonteCarlo implements Policy {
protected static final int ROLLOUT_DEPTH_LIMIT = 250; protected static final int ROLLOUT_DEPTH_LIMIT = 250;
private boolean logging = false;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
protected int numStateEvaluations = 0; protected int numStateEvaluations = 0;
protected Policy movePolicy; protected Policy movePolicy;

View File

@@ -63,6 +63,45 @@ public class MonteCarloAMAF extends MonteCarloUCT {
rootGameState, new AMAFProperties()); rootGameState, new AMAFProperties());
} }
@Override
public Action getBestAction(GameTreeNode<MonteCarloProperties> node) {
Action bestAction = Action.NONE;
double bestScore = Double.NEGATIVE_INFINITY;
GameTreeNode<MonteCarloProperties> bestChild = null;
for (Action action : node.getActions()) {
GameTreeNode<MonteCarloProperties> childNode = node
.getChild(action);
AMAFProperties childProps = (AMAFProperties)childNode.getProperties();
double childScore = childProps.getAmafWins() / (double)childProps.getAmafVisits();
if (childScore >= bestScore) {
bestScore = childScore;
bestAction = action;
bestChild = childNode;
}
}
if (bestAction == Action.NONE) {
System.out
.println(getName() + " failed - no actions were found for the current game state (not even PASS).");
} else {
if (isLogging()) {
System.out.println("Action " + bestAction + " selected for "
+ node.getGameState().getPlayerToMove()
+ " with simulated win ratio of "
+ (bestScore * 100.0 + "%"));
System.out.println("It was visited "
+ bestChild.getProperties().getVisits() + " times out of "
+ node.getProperties().getVisits() + " rollouts among "
+ node.getNumChildren()
+ " valid actions from the current state.");
}
}
return bestAction;
}
@Override @Override
protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) { protected double getNodeScore(GameTreeNode<MonteCarloProperties> gameTreeNode) {
//double nodeVisits = gameTreeNode.getParent().getProperties().getVisits(); //double nodeVisits = gameTreeNode.getParent().getProperties().getVisits();
@@ -72,16 +111,8 @@ public class MonteCarloAMAF extends MonteCarloUCT {
if (gameTreeNode.getGameState().isTerminal()) { if (gameTreeNode.getGameState().isTerminal()) {
nodeScore = 0.0; nodeScore = 0.0;
} else { } else {
/*
MonteCarloProperties properties = gameTreeNode.getProperties();
nodeScore = (double) (properties.getWins() / properties
.getVisits())
+ (TUNING_CONSTANT * Math.sqrt(Math.log(nodeVisits)
/ gameTreeNode.getProperties().getVisits()));
*
*/
AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties(); AMAFProperties properties = (AMAFProperties) gameTreeNode.getProperties();
nodeScore = (double) (properties.getAmafWins() / properties nodeScore = (properties.getAmafWins() / (double) properties
.getAmafVisits()) .getAmafVisits())
+ (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits) + (TUNING_CONSTANT * Math.sqrt(Math.log(parentAmafVisits)
/ properties.getAmafVisits())); / properties.getAmafVisits()));
@@ -103,4 +134,9 @@ public class MonteCarloAMAF extends MonteCarloUCT {
node.addChild(action, newChild); node.addChild(action, newChild);
return newChildren; return newChildren;
} }
@Override
public String getName() {
return "UCT-RAVE";
}
} }

View File

@@ -0,0 +1,63 @@
package net.woodyfolsom.msproj.policy;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.AMAFProperties;
import net.woodyfolsom.msproj.tree.GameTreeNode;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class MonteCarloSMAF extends MonteCarloAMAF {
private int horizon;
public MonteCarloSMAF(Policy movePolicy, long searchTimeLimit, int horizon) {
super(movePolicy, searchTimeLimit);
this.horizon = horizon;
}
@Override
public void update(GameTreeNode<MonteCarloProperties> node, Rollout rollout) {
GameTreeNode<MonteCarloProperties> currentNode = node;
//List<Action> subTreeActions = new ArrayList<Action>(rollout.getPlayout());
List<Action> playout = rollout.getPlayout();
int reward = rollout.getReward();
while (currentNode != null) {
AMAFProperties nodeProperties = (AMAFProperties)currentNode.getProperties();
//Always update props for the current node
nodeProperties.setWins(nodeProperties.getWins() + reward);
nodeProperties.setVisits(nodeProperties.getVisits() + 1);
nodeProperties.setAmafWins(nodeProperties.getAmafWins() + reward);
nodeProperties.setAmafVisits(nodeProperties.getAmafVisits() + 1);
GameTreeNode<MonteCarloProperties> parentNode = currentNode.getParent();
if (parentNode != null) {
Player playerToMove = parentNode.getGameState().getPlayerToMove();
for (Action actionFromParent : parentNode.getActions()) {
if (playout.subList(0, Math.max(horizon,playout.size())).contains(actionFromParent)) {
GameTreeNode<MonteCarloProperties> subTreeChild = parentNode.getChild(actionFromParent);
//Don't count AMAF properties for the current node twice
if (subTreeChild == currentNode) {
continue;
}
AMAFProperties siblingProperties = (AMAFProperties)subTreeChild.getProperties();
//Only update AMAF properties if the sibling is reached by the same action with the same player to move
if (rollout.hasPlay(playerToMove,actionFromParent)) {
siblingProperties.setAmafWins(siblingProperties.getAmafWins() + reward);
siblingProperties.setAmafVisits(siblingProperties.getAmafVisits() + 1);
}
}
}
}
currentNode = currentNode.getParent();
}
}
public String getName() {
return "MonteCarloSMAF";
}
}

View File

@@ -90,11 +90,8 @@ public class MonteCarloUCT extends MonteCarlo {
GameTreeNode<MonteCarloProperties> childNode = node GameTreeNode<MonteCarloProperties> childNode = node
.getChild(action); .getChild(action);
//MonteCarloProperties properties = childNode.getProperties(); MonteCarloProperties childProps = childNode.getProperties();
//double childScore = (double) properties.getWins() double childScore = childProps.getWins() / (double)childProps.getVisits();
// / properties.getVisits();
double childScore = getNodeScore(childNode);
if (childScore >= bestScore) { if (childScore >= bestScore) {
bestScore = childScore; bestScore = childScore;
@@ -105,8 +102,9 @@ public class MonteCarloUCT extends MonteCarlo {
if (bestAction == Action.NONE) { if (bestAction == Action.NONE) {
System.out System.out
.println("MonteCarloUCT failed - no actions were found for the current game state (not even PASS)."); .println(getName() + " failed - no actions were found for the current game state (not even PASS).");
} else { } else {
if (isLogging()) {
System.out.println("Action " + bestAction + " selected for " System.out.println("Action " + bestAction + " selected for "
+ node.getGameState().getPlayerToMove() + node.getGameState().getPlayerToMove()
+ " with simulated win ratio of " + " with simulated win ratio of "
@@ -116,6 +114,7 @@ public class MonteCarloUCT extends MonteCarlo {
+ node.getProperties().getVisits() + " rollouts among " + node.getProperties().getVisits() + " rollouts among "
+ node.getNumChildren() + node.getNumChildren()
+ " valid actions from the current state."); + " valid actions from the current state.");
}
} }
return bestAction; return bestAction;
} }
@@ -233,4 +232,9 @@ public class MonteCarloUCT extends MonteCarlo {
// TODO Auto-generated method stub // TODO Auto-generated method stub
} }
@Override
public String getName() {
return "MonteCarloUCT";
}
} }

View File

@@ -0,0 +1,144 @@
package net.woodyfolsom.msproj.policy;
import java.util.Collection;
import java.util.List;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.PassFilterTrainer;
import net.woodyfolsom.msproj.tictactoe.NNDataSetFactory;
public class NeuralNetPolicy implements Policy {
private FeedforwardNetwork moveFilter;
private FeedforwardNetwork passFilter;
private ValidMoveGenerator validMoveGenerator = new ValidMoveGenerator();
private Policy randomMovePolicy = new RandomMovePolicy();
public NeuralNetPolicy() {
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Player player) {
//If passFilter != null, check for a strong PASS signal.
if (passFilter != null) {
GameState stateAfterPass = new GameState(gameState);
if (!stateAfterPass.playStone(player, Action.PASS)) {
throw new RuntimeException("Pass should always be valid, but playStone(" + player +", Action.PASS) failed.");
}
NNDataPair passData = NNDataSetFactory.createDataPair(gameState,PassFilterTrainer.class);
double estimatedValue = passFilter.compute(passData).getValues()[0];
//if losing and opponent passed, never pass
//if (passData.getInput().getValues()[0] == -1.0 && passData.getInput().getValues()[1] == 1.0) {
// estimatedValue = 0.0;
//}
if (player == Player.BLACK && 0.6 < estimatedValue) {
//System.out.println("NeuralNetwork estimates value of PASS at > 0.95 (BLACK) for " + passData.getInput());
return Action.PASS;
}
if (player == Player.WHITE && 0.4 > estimatedValue) {
//System.out.println("NeuralNetwork estimates value of PASS at > 0.95 (BLACK) for " + passData.getInput());
return Action.PASS;
}
}
//If moveFilter != null, calculate action estimates and return the best one.
//max # valid moves is 19x19+2 (any coord plus pass, resign).
List<Action> validMoves = validMoveGenerator.getActions(gameConfig, gameState, player, 363);
if (moveFilter != null) {
if (player == Player.BLACK) {
double bestValue = Double.NEGATIVE_INFINITY;
Action bestAction = Action.NONE;
for (Action actionToTry : validMoves) {
GameState stateAfterAction = new GameState(gameState);
if (!stateAfterAction.playStone(player, actionToTry)) {
throw new RuntimeException("Invalid move: " + actionToTry);
}
NNDataPair passData = NNDataSetFactory.createDataPair(stateAfterAction,PassFilterTrainer.class);
double estimatedValue = passFilter.compute(passData).getValues()[0];
if (estimatedValue > bestValue) {
bestAction = actionToTry;
}
}
if (bestValue > 0.95) {
return bestAction;
}
} else if (player == Player.WHITE) {
double bestValue = Double.POSITIVE_INFINITY;
Action bestAction = Action.NONE;
for (Action actionToTry : validMoves) {
GameState stateAfterAction = new GameState(gameState);
if (!stateAfterAction.playStone(player, actionToTry)) {
throw new RuntimeException("Invalid move: " + actionToTry);
}
NNDataPair passData = NNDataSetFactory.createDataPair(stateAfterAction,PassFilterTrainer.class);
double estimatedValue = passFilter.compute(passData).getValues()[0];
if (estimatedValue > bestValue) {
bestAction = actionToTry;
}
}
if (bestValue > 0.95) {
return bestAction;
}
} else {
throw new RuntimeException("Invalid player: " + player);
}
}
//If no moves make the cutoff, just return a random move.
return randomMovePolicy.getAction(gameConfig, gameState, player);
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Collection<Action> prohibitedActions, Player player) {
throw new UnsupportedOperationException();
}
@Override
public int getNumStateEvaluations() {
return randomMovePolicy.getNumStateEvaluations();
}
@Override
public void setState(GameState gameState) {
randomMovePolicy.setState(gameState);
}
@Override
public boolean isLogging() {
return randomMovePolicy.isLogging();
}
@Override
public void setLogging(boolean logging) {
randomMovePolicy.setLogging(logging);
}
public void setMoveFilter(FeedforwardNetwork ffn) {
this.moveFilter = ffn;
}
public void setPassFilter(FeedforwardNetwork ffn) {
this.passFilter = ffn;
}
@Override
public String getName() {
return "NeuralNet" + (passFilter != null ? "-" + passFilter.getName() : "")
+ (passFilter != null ? "-" + passFilter.getName() : "");
}
}

View File

@@ -8,13 +8,19 @@ import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player; import net.woodyfolsom.msproj.Player;
public interface Policy { public interface Policy {
public Action getAction(GameConfig gameConfig, GameState gameState, Action getAction(GameConfig gameConfig, GameState gameState,
Player player); Player player);
public Action getAction(GameConfig gameConfig, GameState gameState, Action getAction(GameConfig gameConfig, GameState gameState,
Collection<Action> prohibitedActions, Player player); Collection<Action> prohibitedActions, Player player);
public int getNumStateEvaluations(); String getName();
public void setState(GameState gameState); int getNumStateEvaluations();
void setState(GameState gameState);
boolean isLogging();
void setLogging(boolean logging);
} }

View File

@@ -0,0 +1,14 @@
package net.woodyfolsom.msproj.policy;
public class PolicyFactory {
public static Policy createNew(Policy policyPrototype) {
if (policyPrototype instanceof RandomMovePolicy) {
return new RandomMovePolicy();
} else if (policyPrototype instanceof NeuralNetPolicy) {
return new NeuralNetPolicy();
} else {
throw new UnsupportedOperationException("Can only create new NeuralNetPolicy, not " + policyPrototype.getName());
}
}
}

View File

@@ -110,6 +110,7 @@ public class RandomMovePolicy implements Policy, ActionGenerator {
return randomAction; return randomAction;
} }
@Override
public boolean isLogging() { public boolean isLogging() {
return logging; return logging;
} }
@@ -122,4 +123,9 @@ public class RandomMovePolicy implements Policy, ActionGenerator {
public void setState(GameState gameState) { public void setState(GameState gameState) {
// TODO Auto-generated method stub // TODO Auto-generated method stub
} }
@Override
public String getName() {
return "Random";
}
} }

View File

@@ -0,0 +1,186 @@
package net.woodyfolsom.msproj.policy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.AMAFProperties;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class RootParAMAF implements Policy {
private boolean logging = false;
private int numTrees = 1;
private Policy rolloutPolicy;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private long timeLimit = 1000L;
public RootParAMAF(int numTrees, long timeLimit) {
this.numTrees = numTrees;
this.timeLimit = timeLimit;
this.rolloutPolicy = new RandomMovePolicy();
}
public RootParAMAF(int numTrees, Policy policyPrototype, long timeLimit) {
this.numTrees = numTrees;
this.timeLimit = timeLimit;
this.rolloutPolicy = policyPrototype;
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Player player) {
Action bestAction = Action.NONE;
List<PolicyRunner> policyRunners = new ArrayList<PolicyRunner>();
List<Thread> simulationThreads = new ArrayList<Thread>();
for (int i = 0; i < numTrees; i++) {
MonteCarlo policy = new MonteCarloAMAF(
PolicyFactory.createNew(rolloutPolicy), timeLimit);
//policy.setLogging(true);
PolicyRunner policyRunner = new PolicyRunner(policy, gameConfig, gameState,
player);
policyRunners.add(policyRunner);
Thread simThread = new Thread(policyRunner);
simulationThreads.add(simThread);
}
for (Thread simThread : simulationThreads) {
simThread.start();
}
for (Thread simThread : simulationThreads) {
try {
simThread.join();
} catch (InterruptedException ie) {
System.out
.println("Interrupted while waiting for Monte Carlo simulations to finish.");
}
}
Map<Action,Integer> totalReward = new HashMap<Action,Integer>();
Map<Action,Integer> numSims = new HashMap<Action,Integer>();
for (PolicyRunner policyRunner : policyRunners) {
Map<Action, MonteCarloProperties> qValues = policyRunner.getQvalues();
for (Action action : qValues.keySet()) {
if (totalReward.containsKey(action)) {
totalReward.put(action, totalReward.get(action) + ((AMAFProperties)qValues.get(action)).getAmafWins());
} else {
totalReward.put(action, ((AMAFProperties)qValues.get(action)).getAmafWins());
}
if (numSims.containsKey(action)) {
numSims.put(action, numSims.get(action) + ((AMAFProperties)qValues.get(action)).getAmafVisits());
} else {
numSims.put(action, ((AMAFProperties)qValues.get(action)).getAmafVisits());
}
}
}
double bestValue = 0.0;
int totalRollouts = 0;
int bestWins = 0;
int bestSims = 0;
for (Action action : totalReward.keySet())
{
int totalWins = totalReward.get(action);
int totalSims = numSims.get(action);
totalRollouts += totalSims;
double value = ((double)totalWins) / ((double)totalSims);
if (bestAction.isNone() || bestValue < value) {
bestAction = action;
bestValue = value;
bestWins = totalWins;
bestSims = totalSims;
}
}
if(isLogging()) {
System.out.println("Action " + bestAction + " selected for "
+ player
+ " with simulated win ratio of "
+ (bestValue * 100.0 + "% among " + numTrees + " parallel simulations."));
System.out.println("It won "
+ bestWins + " out of " + bestSims
+ " rollouts among " + totalRollouts
+ " total rollouts (" + totalReward.size()
+ " possible moves evaluated) from the current state.");
}
return bestAction;
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Collection<Action> prohibitedActions, Player player) {
throw new UnsupportedOperationException(
"Prohibited actions not supported by this class.");
}
@Override
public int getNumStateEvaluations() {
// TODO Auto-generated method stub
return 0;
}
class PolicyRunner implements Runnable {
Map<Action,MonteCarloProperties> qValues;
private GameConfig gameConfig;
private GameState gameState;
private Player player;
private MonteCarlo policy;
public PolicyRunner(MonteCarlo policy, GameConfig gameConfig,
GameState gameState, Player player) {
this.policy = policy;
this.gameConfig = gameConfig;
this.gameState = gameState;
this.player = player;
}
public Map<Action,MonteCarloProperties> getQvalues() {
return qValues;
}
@Override
public void run() {
qValues = policy.getQvalues(gameConfig, gameState, player);
}
}
@Override
public void setState(GameState gameState) {
}
@Override
public String getName() {
if (rolloutPolicy.getName() == "Random") {
return "RootParallelization";
} else {
return "RootParallelization-" + rolloutPolicy.getName();
}
}
}

View File

@@ -0,0 +1,186 @@
package net.woodyfolsom.msproj.policy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.woodyfolsom.msproj.Action;
import net.woodyfolsom.msproj.GameConfig;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.AMAFProperties;
import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class RootParSMAF implements Policy {
private boolean logging = false;
private int numTrees = 1;
private Policy rolloutPolicy;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private long timeLimit = 1000L;
public RootParSMAF(int numTrees, long timeLimit) {
this.numTrees = numTrees;
this.timeLimit = timeLimit;
this.rolloutPolicy = new RandomMovePolicy();
}
public RootParSMAF(int numTrees, Policy policyPrototype, long timeLimit) {
this.numTrees = numTrees;
this.timeLimit = timeLimit;
this.rolloutPolicy = policyPrototype;
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Player player) {
Action bestAction = Action.NONE;
List<PolicyRunner> policyRunners = new ArrayList<PolicyRunner>();
List<Thread> simulationThreads = new ArrayList<Thread>();
for (int i = 0; i < numTrees; i++) {
MonteCarlo policy = new MonteCarloSMAF(
PolicyFactory.createNew(rolloutPolicy), timeLimit, 4);
//policy.setLogging(true);
PolicyRunner policyRunner = new PolicyRunner(policy, gameConfig, gameState,
player);
policyRunners.add(policyRunner);
Thread simThread = new Thread(policyRunner);
simulationThreads.add(simThread);
}
for (Thread simThread : simulationThreads) {
simThread.start();
}
for (Thread simThread : simulationThreads) {
try {
simThread.join();
} catch (InterruptedException ie) {
System.out
.println("Interrupted while waiting for Monte Carlo simulations to finish.");
}
}
Map<Action,Integer> totalReward = new HashMap<Action,Integer>();
Map<Action,Integer> numSims = new HashMap<Action,Integer>();
for (PolicyRunner policyRunner : policyRunners) {
Map<Action, MonteCarloProperties> qValues = policyRunner.getQvalues();
for (Action action : qValues.keySet()) {
if (totalReward.containsKey(action)) {
totalReward.put(action, totalReward.get(action) + ((AMAFProperties)qValues.get(action)).getAmafWins());
} else {
totalReward.put(action, ((AMAFProperties)qValues.get(action)).getAmafWins());
}
if (numSims.containsKey(action)) {
numSims.put(action, numSims.get(action) + ((AMAFProperties)qValues.get(action)).getAmafVisits());
} else {
numSims.put(action, ((AMAFProperties)qValues.get(action)).getAmafVisits());
}
}
}
double bestValue = 0.0;
int totalRollouts = 0;
int bestWins = 0;
int bestSims = 0;
for (Action action : totalReward.keySet())
{
int totalWins = totalReward.get(action);
int totalSims = numSims.get(action);
totalRollouts += totalSims;
double value = ((double)totalWins) / ((double)totalSims);
if (bestAction.isNone() || bestValue < value) {
bestAction = action;
bestValue = value;
bestWins = totalWins;
bestSims = totalSims;
}
}
if(isLogging()) {
System.out.println("Action " + bestAction + " selected for "
+ player
+ " with simulated win ratio of "
+ (bestValue * 100.0 + "% among " + numTrees + " parallel simulations."));
System.out.println("It won "
+ bestWins + " out of " + bestSims
+ " rollouts among " + totalRollouts
+ " total rollouts (" + totalReward.size()
+ " possible moves evaluated) from the current state.");
}
return bestAction;
}
@Override
public Action getAction(GameConfig gameConfig, GameState gameState,
Collection<Action> prohibitedActions, Player player) {
throw new UnsupportedOperationException(
"Prohibited actions not supported by this class.");
}
@Override
public int getNumStateEvaluations() {
// TODO Auto-generated method stub
return 0;
}
class PolicyRunner implements Runnable {
Map<Action,MonteCarloProperties> qValues;
private GameConfig gameConfig;
private GameState gameState;
private Player player;
private MonteCarlo policy;
public PolicyRunner(MonteCarlo policy, GameConfig gameConfig,
GameState gameState, Player player) {
this.policy = policy;
this.gameConfig = gameConfig;
this.gameState = gameState;
this.player = player;
}
public Map<Action,MonteCarloProperties> getQvalues() {
return qValues;
}
@Override
public void run() {
qValues = policy.getQvalues(gameConfig, gameState, player);
}
}
@Override
public void setState(GameState gameState) {
}
@Override
public String getName() {
if (rolloutPolicy.getName() == "Random") {
return "RootParallelization";
} else {
return "RootParallelization-" + rolloutPolicy.getName();
}
}
}

View File

@@ -13,12 +13,30 @@ import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.tree.MonteCarloProperties; import net.woodyfolsom.msproj.tree.MonteCarloProperties;
public class RootParallelization implements Policy { public class RootParallelization implements Policy {
private boolean logging = false;
private int numTrees = 1; private int numTrees = 1;
private Policy rolloutPolicy;
public boolean isLogging() {
return logging;
}
public void setLogging(boolean logging) {
this.logging = logging;
}
private long timeLimit = 1000L; private long timeLimit = 1000L;
public RootParallelization(int numTrees, long timeLimit) { public RootParallelization(int numTrees, long timeLimit) {
this.numTrees = numTrees; this.numTrees = numTrees;
this.timeLimit = timeLimit; this.timeLimit = timeLimit;
this.rolloutPolicy = new RandomMovePolicy();
}
public RootParallelization(int numTrees, Policy policyPrototype, long timeLimit) {
this.numTrees = numTrees;
this.timeLimit = timeLimit;
this.rolloutPolicy = policyPrototype;
} }
@Override @Override
@@ -31,7 +49,7 @@ public class RootParallelization implements Policy {
for (int i = 0; i < numTrees; i++) { for (int i = 0; i < numTrees; i++) {
PolicyRunner policyRunner = new PolicyRunner(new MonteCarloUCT( PolicyRunner policyRunner = new PolicyRunner(new MonteCarloUCT(
new RandomMovePolicy(), timeLimit), gameConfig, gameState, PolicyFactory.createNew(rolloutPolicy), timeLimit), gameConfig, gameState,
player); player);
policyRunners.add(policyRunner); policyRunners.add(policyRunner);
@@ -94,6 +112,7 @@ public class RootParallelization implements Policy {
} }
if(isLogging()) {
System.out.println("Action " + bestAction + " selected for " System.out.println("Action " + bestAction + " selected for "
+ player + player
+ " with simulated win ratio of " + " with simulated win ratio of "
@@ -103,7 +122,7 @@ public class RootParallelization implements Policy {
+ " rollouts among " + totalRollouts + " rollouts among " + totalRollouts
+ " total rollouts (" + totalReward.size() + " total rollouts (" + totalReward.size()
+ " possible moves evaluated) from the current state."); + " possible moves evaluated) from the current state.");
}
return bestAction; return bestAction;
} }
@@ -148,7 +167,14 @@ public class RootParallelization implements Policy {
@Override @Override
public void setState(GameState gameState) { public void setState(GameState gameState) {
// TODO Auto-generated method stub }
@Override
public String getName() {
if (rolloutPolicy.getName() == "Random") {
return "RootParallelization";
} else {
return "RootParallelization-" + rolloutPolicy.getName();
}
} }
} }

View File

@@ -0,0 +1,54 @@
package net.woodyfolsom.msproj.tictactoe;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class Action {
public static final Action NONE = new Action(PLAYER.NONE, -1, -1);
private Game.PLAYER player;
private int row;
private int column;
public static Action getInstance(PLAYER player, int row, int column) {
return new Action(player,row,column);
}
private Action(PLAYER player, int row, int column) {
this.player = player;
this.row = row;
this.column = column;
}
public Game.PLAYER getPlayer() {
return player;
}
public int getColumn() {
return column;
}
public int getRow() {
return row;
}
public boolean isNone() {
return this == Action.NONE;
}
public void setPlayer(Game.PLAYER player) {
this.player = player;
}
public void setRow(int row) {
this.row = row;
}
public void setColumn(int column) {
this.column = column;
}
@Override
public String toString() {
return player + "(" + row + ", " + column + ")";
}
}

View File

@@ -0,0 +1,5 @@
package net.woodyfolsom.msproj.tictactoe;
public class Game {
public enum PLAYER {X,O,NONE}
}

View File

@@ -0,0 +1,63 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class GameRecord {
public enum RESULT {X_WINS, O_WINS, TIE_GAME, IN_PROGRESS}
private List<Action> actions = new ArrayList<Action>();
private List<State> states = new ArrayList<State>();
private RESULT result = RESULT.IN_PROGRESS;
public GameRecord() {
actions.add(Action.NONE);
states.add(new State());
}
public void addState(State state) {
states.add(state);
}
public State apply(Action action) {
State nextState = getState().apply(action);
if (nextState.isValid()) {
states.add(nextState);
actions.add(action);
}
if (nextState.isTerminal()) {
if (nextState.isWinner(PLAYER.X)) {
result = RESULT.X_WINS;
} else if (nextState.isWinner(PLAYER.O)) {
result = RESULT.O_WINS;
} else {
result = RESULT.TIE_GAME;
}
}
return nextState;
}
public int getNumStates() {
return states.size();
}
public RESULT getResult() {
return result;
}
public void setResult(RESULT result) {
this.result = result;
}
public State getState() {
return states.get(states.size()-1);
}
public State getState(int index) {
return states.get(index);
}
}

View File

@@ -0,0 +1,20 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class MoveGenerator {
public List<Action> getValidActions(State state) {
PLAYER playerToMove = state.getPlayerToMove();
List<Action> validActions = new ArrayList<Action>();
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
if (state.isEmpty(i,j))
validActions.add(Action.getInstance(playerToMove, i, j));
}
}
return validActions;
}
}

View File

@@ -0,0 +1,211 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.GameBoard;
import net.woodyfolsom.msproj.GameResult;
import net.woodyfolsom.msproj.GameState;
import net.woodyfolsom.msproj.Player;
import net.woodyfolsom.msproj.ann.FusekiFilterTrainer;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.PassFilterTrainer;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class NNDataSetFactory {
public static final String[] TTT_INPUT_FIELDS = {"00","01","02","10","11","12","20","21","22"};
public static final String[] TTT_OUTPUT_FIELDS = {"VALUE"};
public static final String[] PASS_INPUT_FIELDS = {"WINNING","PREV_PLY_PASS"};
public static final String[] PASS_OUTPUT_FIELDS = {"SHOULD_PASS"};
public static final String[] FUSEKI_INPUT_FIELDS = {
"00","11","22","33","44","55","66","77","88",
"10","11","22","33","44","55","66","77","88",
"20","11","22","33","44","55","66","77","88",
"30","11","22","33","44","55","66","77","88",
"40","11","22","33","44","55","66","77","88",
"50","11","22","33","44","55","66","77","88",
"60","11","22","33","44","55","66","77","88",
"70","11","22","33","44","55","66","77","88",
"70","11","22","33","44","55","66","77","88"};
public static final String[] FUSEKI_OUTPUT_FIELDS = {"VALUE"};
public static List<List<NNDataPair>> createDataSet(List<GameRecord> tttGames) {
List<List<NNDataPair>> nnDataSet = new ArrayList<List<NNDataPair>>();
for (GameRecord tttGame : tttGames) {
List<NNDataPair> gameData = createDataPairList(tttGame);
nnDataSet.add(gameData);
}
return nnDataSet;
}
public static String[] getInputFields(Object clazz) {
if (clazz == PassFilterTrainer.class) {
return PASS_INPUT_FIELDS;
} else if (clazz == FusekiFilterTrainer.class) {
return FUSEKI_INPUT_FIELDS;
} else {
throw new RuntimeException("Don't know how to return inputFields for NeuralNetwork Trainer of type: " + clazz.getClass().getName());
}
}
public static String[] getOutputFields(Object clazz) {
if (clazz == PassFilterTrainer.class) {
return PASS_OUTPUT_FIELDS;
} else if (clazz == FusekiFilterTrainer.class) {
return FUSEKI_OUTPUT_FIELDS;
} else {
throw new RuntimeException("Don't know how to return inputFields for NeuralNetwork Trainer of type: " + clazz.getClass().getName());
}
}
public static List<NNDataPair> createDataPairList(GameRecord gameRecord) {
List<NNDataPair> gameData = new ArrayList<NNDataPair>();
for (int i = 0; i < gameRecord.getNumStates(); i++) {
gameData.add(createDataPair(gameRecord.getState(i)));
}
return gameData;
}
public static NNDataPair createDataPair(GameState goState, Object clazz) {
if (clazz == PassFilterTrainer.class) {
return createPassFilterDataPair(goState);
} else if (clazz == FusekiFilterTrainer.class) {
return createFusekiFilterDataPair(goState);
} else {
throw new RuntimeException("Don't know how to create DataPair for NeuralNetwork Trainer of type: " + clazz.getClass().getName());
}
}
private static NNDataPair createFusekiFilterDataPair(GameState goState) {
double value;
if (goState.isTerminal()) {
if (goState.getResult().isWinner(Player.BLACK)) {
value = 1.0; // win for black
} else if (goState.getResult().isWinner(Player.WHITE)) {
value = 0.0; // loss for black
//value = -1.0;
} else {// tie
value = 0.5;
//value = 0.0; //tie
}
} else {
value = 0.0;
}
int size = goState.getGameConfig().getSize();
double[] inputValues = new double[size * size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
//col then row
char symbol = goState.getGameBoard().getSymbolAt(j, i);
switch (symbol) {
case GameBoard.EMPTY_INTERSECTION : inputValues[i*size+j] = 0.0;
break;
case GameBoard.BLACK_STONE : inputValues[i*size+j] = 1.0;
break;
case GameBoard.WHITE_STONE : inputValues[i*size+j] = -1.0;
break;
}
}
}
return new NNDataPair(new NNData(FUSEKI_INPUT_FIELDS,inputValues),new NNData(FUSEKI_OUTPUT_FIELDS,new double[]{value}));
}
private static NNDataPair createPassFilterDataPair(GameState goState) {
double value;
GameResult result = goState.getResult();
if (goState.isTerminal()) {
if (result.isWinner(Player.BLACK)) {
value = 1.0; // win for black
} else if (result.isWinner(Player.WHITE)) {
value = 0.0; // loss for black
} else {// tie
value = 0.5;
//value = 0.0; //tie
}
} else {
value = 0.0;
}
double[] inputValues = new double[4];
inputValues[0] = result.isWinner(goState.getPlayerToMove()) ? 1.0 : -1.0;
//inputValues[1] = result.isWinner(goState.getPlayerToMove()) ? -1.0 : 1.0;
inputValues[1] = goState.isPrevPlyPass() ? 1.0 : 0.0;
return new NNDataPair(new NNData(PASS_INPUT_FIELDS,inputValues),new NNData(PASS_OUTPUT_FIELDS,new double[]{value}));
}
/*
private static double getNormalizedScore(GameState goState, Player player) {
GameResult gameResult = goState.getResult();
GameConfig gameConfig = goState.getGameConfig();
double maxPoints = Math.pow(gameConfig.getSize(),2);
double komi = gameConfig.getKomi();
if (player == Player.BLACK) {
return gameResult.getBlackScore() / maxPoints;
} else if (player == Player.WHITE) {
return gameResult.getWhiteScore() / (maxPoints + komi);
} else {
throw new RuntimeException("Invalid player");
}
}*/
public static NNDataPair createDataPair(State tttState) {
double value;
if (tttState.isTerminal()) {
if (tttState.isWinner(PLAYER.X)) {
value = 1.0; // win for black
} else if (tttState.isWinner(PLAYER.O)) {
value = 0.0; // loss for black
//value = -1.0;
} else {
value = 0.5;
//value = 0.0; //tie
}
} else {
value = 0.0;
}
double[] inputValues = new double[9];
char[] boardCopy = tttState.getBoard();
inputValues[0] = getTicTacToeInput(boardCopy, 0, 0);
inputValues[1] = getTicTacToeInput(boardCopy, 0, 1);
inputValues[2] = getTicTacToeInput(boardCopy, 0, 2);
inputValues[3] = getTicTacToeInput(boardCopy, 1, 0);
inputValues[4] = getTicTacToeInput(boardCopy, 1, 1);
inputValues[5] = getTicTacToeInput(boardCopy, 1, 2);
inputValues[6] = getTicTacToeInput(boardCopy, 2, 0);
inputValues[7] = getTicTacToeInput(boardCopy, 2, 1);
inputValues[8] = getTicTacToeInput(boardCopy, 2, 2);
return new NNDataPair(new NNData(TTT_INPUT_FIELDS,inputValues),new NNData(TTT_OUTPUT_FIELDS,new double[]{value}));
}
private static double getTicTacToeInput(char[] board, int row, int column) {
switch (board[row*3+column]) {
case 'X' :
return 1.0;
case 'O' :
return -1.0;
case '.' :
return 0.0;
default:
throw new RuntimeException("Invalid board symbol at " + row +", " + column);
}
}
}

View File

@@ -0,0 +1,68 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class NeuralNetPolicy extends Policy {
private FeedforwardNetwork neuralNet;
private MoveGenerator moveGenerator = new MoveGenerator();
public NeuralNetPolicy(FeedforwardNetwork neuralNet) {
super("NeuralNet-" + neuralNet.getName());
this.neuralNet = neuralNet;
}
@Override
public Action getAction(State state) {
List<Action> validMoves = moveGenerator.getValidActions(state);
Map<Action, Double> scores = new HashMap<Action, Double>();
for (Action action : validMoves) {
State nextState = state.apply(action);
//NNDataPair dataPair = NNDataSetFactory.createDataPair(state);
NNDataPair dataPair = NNDataSetFactory.createDataPair(nextState);
//estimated reward for X
scores.put(action, neuralNet.compute(dataPair).getValues()[0]);
}
PLAYER playerToMove = state.getPlayerToMove();
if (playerToMove == PLAYER.X) {
return returnMaxAction(scores);
} else if (playerToMove == PLAYER.O) {
return returnMinAction(scores);
} else {
throw new IllegalArgumentException("Invalid playerToMove: " + playerToMove);
}
//return validMoves.get((int)(Math.random() * validMoves.size()));
}
private Action returnMaxAction(Map<Action,Double> scores) {
Action bestAction = null;
Double bestScore = Double.NEGATIVE_INFINITY;
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
if (entry.getValue() > bestScore) {
bestScore = entry.getValue();
bestAction = entry.getKey();
}
}
return bestAction;
}
private Action returnMinAction(Map<Action,Double> scores) {
Action bestAction = null;
Double bestScore = Double.POSITIVE_INFINITY;
for (Map.Entry<Action,Double> entry : scores.entrySet()) {
if (entry.getValue() < bestScore) {
bestScore = entry.getValue();
bestAction = entry.getKey();
}
}
return bestAction;
}
}

View File

@@ -0,0 +1,15 @@
package net.woodyfolsom.msproj.tictactoe;
public abstract class Policy {
private String name;
protected Policy(String name) {
this.name = name;
}
public abstract Action getAction(State state);
public String getName() {
return name;
}
}

View File

@@ -0,0 +1,18 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.List;
public class RandomPolicy extends Policy {
private MoveGenerator moveGenerator = new MoveGenerator();
public RandomPolicy() {
super("Random");
}
@Override
public Action getAction(State state) {
List<Action> validMoves = moveGenerator.getValidActions(state);
return validMoves.get((int)(Math.random() * validMoves.size()));
}
}

View File

@@ -0,0 +1,43 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.ArrayList;
import java.util.List;
public class Referee {
public static void main(String[] args) {
new Referee().play(50);
}
public List<GameRecord> play(int nGames) {
Policy policy = new RandomPolicy();
List<GameRecord> tournament = new ArrayList<GameRecord>();
for (int i = 0; i < nGames; i++) {
GameRecord gameRecord = new GameRecord();
System.out.println("Playing game #" +(i+1));
State state;
do {
Action action = policy.getAction(gameRecord.getState());
System.out.println("Action " + action + " selected by policy " + policy.getName());
state = gameRecord.apply(action);
System.out.println("Next board state:");
System.out.println(gameRecord.getState());
} while (!state.isTerminal());
System.out.println("Game #" + (i+1) + " is finished. Result: " + gameRecord.getResult());
tournament.add(gameRecord);
}
System.out.println("Played " + tournament.size() + " random games.");
System.out.println("Results:");
for (int i = 0; i < tournament.size(); i++) {
GameRecord gameRecord = tournament.get(i);
System.out.println((i+1) + ". " + gameRecord.getResult());
}
return tournament;
}
}

View File

@@ -0,0 +1,116 @@
package net.woodyfolsom.msproj.tictactoe;
import java.util.Arrays;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class State {
public static final State INVALID = new State();
public static char EMPTY_SQUARE = '.';
private char[] board;
private PLAYER playerToMove;
public State() {
playerToMove = Game.PLAYER.X;
board = new char[9];
Arrays.fill(board,'.');
}
private State(State that) {
this.board = Arrays.copyOf(that.board, that.board.length);
this.playerToMove = that.playerToMove;
}
public State apply(Action action) {
if (action.getPlayer() != playerToMove) {
System.out.println("It is not " + action.getPlayer() +"'s turn.");
return State.INVALID;
}
State nextState = new State(this);
int row = action.getRow();
int column = action.getColumn();
int dest = row * 3 + column;
if (board[dest] != EMPTY_SQUARE) {
System.out.println("Invalid move " + action + ", coordinate not empty.");
return State.INVALID;
}
switch (playerToMove) {
case X : nextState.board[dest] = 'X';
break;
case O : nextState.board[dest] = 'O';
break;
default:
throw new RuntimeException("Invalid playerToMove");
}
if (playerToMove == PLAYER.X) {
nextState.playerToMove = PLAYER.O;
} else {
nextState.playerToMove = PLAYER.X;
}
return nextState;
}
public char[] getBoard() {
return Arrays.copyOf(board, board.length);
}
public PLAYER getPlayerToMove() {
return playerToMove;
}
public boolean isEmpty(int row, int column) {
return board[row*3+column] == EMPTY_SQUARE;
}
public boolean isFull(char mark1, char mark2, char mark3) {
return mark1 != '.' && mark2 != '.' && mark3 != '.';
}
public boolean isWinner(PLAYER player) {
return isWin(player,board[0],board[1],board[2]) ||
isWin(player,board[3],board[4],board[5]) ||
isWin(player,board[6],board[7],board[8]) ||
isWin(player,board[0],board[3],board[6]) ||
isWin(player,board[1],board[4],board[7]) ||
isWin(player,board[2],board[5],board[8]) ||
isWin(player,board[0],board[4],board[8]) ||
isWin(player,board[2],board[4],board[6]);
}
public boolean isWin(PLAYER player, char mark1, char mark2, char mark3) {
if (isFull(mark1,mark2,mark3)) {
switch (player) {
case X : return mark1 == 'X' && mark2 == 'X' && mark3 == 'X';
case O : return mark1 == 'O' && mark2 == 'O' && mark3 == 'O';
default :
return false;
}
} else {
return false;
}
}
public boolean isTerminal() {
return isWinner(PLAYER.X) || isWinner(PLAYER.O) ||
(isFull(board[0],board[1], board[2]) &&
isFull(board[3],board[4], board[5]) &&
isFull(board[6],board[7], board[8]));
}
public boolean isValid() {
return this != INVALID;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("TicTacToe state ("+playerToMove + " to move):\n");
sb.append(""+board[0] + board[1] + board[2] + "\n");
sb.append(""+board[3] + board[4] + board[5] + "\n");
sb.append(""+board[6] + board[7] + board[8] + "\n");
return sb.toString();
}
}

View File

@@ -0,0 +1,95 @@
package net.woodyfolsom.msproj.ann;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.xml.bind.JAXBException;
import net.woodyfolsom.msproj.ann.Connection;
import net.woodyfolsom.msproj.ann.FeedforwardNetwork;
import net.woodyfolsom.msproj.ann.MultiLayerPerceptron;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class MultiLayerPerceptronTest {
static final File TEST_FILE = new File("data/test/mlp.net");
static final double EPS = 0.001;
@BeforeClass
public static void setUp() {
if (TEST_FILE.exists()) {
TEST_FILE.delete();
}
}
@AfterClass
public static void tearDown() {
if (TEST_FILE.exists()) {
TEST_FILE.delete();
}
}
@Test
public void testConstructor() {
new MultiLayerPerceptron(true, 2, 4, 1);
new MultiLayerPerceptron(false, 2, 1);
}
@Test(expected = IllegalArgumentException.class)
public void testConstructorTooFewLayers() {
new MultiLayerPerceptron(true, 2);
}
@Test(expected = IllegalArgumentException.class)
public void testConstructorTooFewNeurons() {
new MultiLayerPerceptron(true, 2, 4, 0, 1);
}
@Test
public void testPersistence() throws JAXBException, IOException {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 4, 1);
FileOutputStream fos = new FileOutputStream(TEST_FILE);
assertTrue(mlp.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(TEST_FILE);
FeedforwardNetwork mlp2 = new MultiLayerPerceptron();
assertTrue(mlp2.load(fis));
assertEquals(mlp, mlp2);
fis.close();
}
@Test
public void testCompute() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.5}));
NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
}
@Test
public void testXORnetwork() {
FeedforwardNetwork mlp = new MultiLayerPerceptron(true, 2, 2, 1);
mlp.setWeights(new double[] {
0.341232, 0.129952, -0.923123, //hidden neuron 1 from input0, input1, bias
-0.115223, 0.570345, -0.328932, //hidden neuron 2 from input0, input1, bias
-0.993423, 0.164732, 0.752621}); //output
for (Connection connection : mlp.getConnections()) {
System.out.println(connection);
}
NNDataPair expected = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.263932}));
NNDataPair actual = new NNDataPair(new NNData(new String[]{"x","y"}, new double[]{0.0,0.0}),new NNData(new String[]{"xor"}, new double[]{0.0}));
NNData actualOutput = mlp.compute(actual);
assertEquals(expected.getIdeal().getValues()[0], actualOutput.getValues()[0], EPS);
}
}

View File

@@ -1,66 +0,0 @@
package net.woodyfolsom.msproj.ann;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.woodyfolsom.msproj.GameRecord;
import net.woodyfolsom.msproj.Referee;
import org.antlr.runtime.RecognitionException;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.junit.Test;
public class WinFilterTest {
@Test
public void testLearnSaveLoad() throws IOException, RecognitionException {
File[] sgfFiles = new File("data/games/random_vs_random")
.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
return pathname.getName().endsWith(".sgf");
}
});
Set<List<MLDataPair>> trainingData = new HashSet<List<MLDataPair>>();
for (File file : sgfFiles) {
FileInputStream fis = new FileInputStream(file);
GameRecord gameRecord = Referee.replay(fis);
List<MLDataPair> gameData = new ArrayList<MLDataPair>();
for (int i = 0; i <= gameRecord.getNumTurns(); i++) {
gameData.add(new GameStateMLDataPair(gameRecord.getGameState(i)));
}
trainingData.add(gameData);
fis.close();
}
WinFilter winFilter = new WinFilter();
winFilter.learn(trainingData);
for (List<MLDataPair> trainingSequence : trainingData) {
//for (MLDataPair mlDataPair : trainingSequence) {
for (int stateIndex = 0; stateIndex < trainingSequence.size(); stateIndex++) {
if (stateIndex > 0 && stateIndex < trainingSequence.size()-1) {
continue;
}
MLData input = trainingSequence.get(stateIndex).getInput();
System.out.println("Turn " + stateIndex + ": " + input + " => "
+ winFilter.computeValue(input));
}
//}
}
}
}

View File

@@ -1,10 +1,19 @@
package net.woodyfolsom.msproj.ann; package net.woodyfolsom.msproj.ann;
import java.io.File; import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import net.woodyfolsom.msproj.ann.NNData;
import net.woodyfolsom.msproj.ann.NNDataPair;
import net.woodyfolsom.msproj.ann.NeuralNetFilter;
import net.woodyfolsom.msproj.ann.XORFilter;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@@ -29,10 +38,51 @@ public class XORFilterTest {
} }
@Test @Test
public void testLearnSaveLoad() throws IOException { public void testLearn() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(); NeuralNetFilter nnLearner = new XORFilter(0.5,0.0);
// create training set (logical XOR function)
int size = 1;
double[][] trainingInput = new double[4 * size][];
double[][] trainingOutput = new double[4 * size][];
for (int i = 0; i < size; i++) {
trainingInput[i * 4 + 0] = new double[] { 0, 0 };
trainingInput[i * 4 + 1] = new double[] { 0, 1 };
trainingInput[i * 4 + 2] = new double[] { 1, 0 };
trainingInput[i * 4 + 3] = new double[] { 1, 1 };
trainingOutput[i * 4 + 0] = new double[] { 0 };
trainingOutput[i * 4 + 1] = new double[] { 1 };
trainingOutput[i * 4 + 2] = new double[] { 1 };
trainingOutput[i * 4 + 3] = new double[] { 0 };
}
// create training data
List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.setMaxTrainingEpochs(20000);
nnLearner.learnPatterns(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs."); System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2];
validationSet[0] = new double[] { 0, 0 };
validationSet[1] = new double[] { 0, 1 };
validationSet[2] = new double[] { 1, 0 };
validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network):");
testNetwork(nnLearner, validationSet, inputNames, outputNames);
}
@Test
public void testLearnSaveLoad() throws IOException {
NeuralNetFilter nnLearner = new XORFilter(0.05,0.0);
// create training set (logical XOR function) // create training set (logical XOR function)
int size = 1; int size = 1;
double[][] trainingInput = new double[4 * size][]; double[][] trainingInput = new double[4 * size][];
@@ -49,9 +99,16 @@ public class XORFilterTest {
} }
// create training data // create training data
MLDataSet trainingSet = new BasicMLDataSet(trainingInput, trainingOutput); List<NNDataPair> trainingSet = new ArrayList<NNDataPair>();
String[] inputNames = new String[] {"x","y"};
String[] outputNames = new String[] {"XOR"};
for (int i = 0; i < 4*size; i++) {
trainingSet.add(new NNDataPair(new NNData(inputNames,trainingInput[i]),new NNData(outputNames,trainingOutput[i])));
}
nnLearner.learn(trainingSet); nnLearner.setMaxTrainingEpochs(10000);
nnLearner.learnPatterns(trainingSet);
System.out.println("Learned network after " + nnLearner.getActualTrainingEpochs() + " training epochs.");
double[][] validationSet = new double[4][2]; double[][] validationSet = new double[4][2];
@@ -61,19 +118,24 @@ public class XORFilterTest {
validationSet[3] = new double[] { 1, 1 }; validationSet[3] = new double[] { 1, 1 };
System.out.println("Output from eval set (learned network, pre-serialization):"); System.out.println("Output from eval set (learned network, pre-serialization):");
testNetwork(nnLearner, validationSet); testNetwork(nnLearner, validationSet, inputNames, outputNames);
nnLearner.save(FILENAME); FileOutputStream fos = new FileOutputStream(FILENAME);
nnLearner.load(FILENAME); assertTrue(nnLearner.save(fos));
fos.close();
FileInputStream fis = new FileInputStream(FILENAME);
assertTrue(nnLearner.load(fis));
fis.close();
System.out.println("Output from eval set (learned network, post-serialization):"); System.out.println("Output from eval set (learned network, post-serialization):");
testNetwork(nnLearner, validationSet); testNetwork(nnLearner, validationSet, inputNames, outputNames);
} }
private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet) { private void testNetwork(NeuralNetFilter nnLearner, double[][] validationSet, String[] inputNames, String[] outputNames) {
for (int valIndex = 0; valIndex < validationSet.length; valIndex++) { for (int valIndex = 0; valIndex < validationSet.length; valIndex++) {
DoublePair dp = new DoublePair(validationSet[valIndex][0],validationSet[valIndex][1]); NNDataPair dp = new NNDataPair(new NNData(inputNames,validationSet[valIndex]), new NNData(outputNames,validationSet[valIndex]));
System.out.println(dp + " => " + nnLearner.computeValue(dp)); System.out.println(dp + " => " + nnLearner.compute(dp));
} }
} }
} }

View File

@@ -0,0 +1,29 @@
package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Sigmoid;
import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test;
public class SigmoidTest {
static double EPS = 0.001;
@Test
public void testCalculate() {
ActivationFunction sigmoid = Sigmoid.function;
assertEquals(0.5,sigmoid.calculate(0.0),EPS);
assertTrue(sigmoid.calculate(100.0) > 1.0 - EPS);
assertTrue(sigmoid.calculate(-9000.0) < EPS);
}
@Test
public void testDerivative() {
ActivationFunction sigmoid = new Tanh();
assertEquals(1.0,sigmoid.derivative(0.0), EPS);
}
}

View File

@@ -0,0 +1,28 @@
package net.woodyfolsom.msproj.ann.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import net.woodyfolsom.msproj.ann.math.ActivationFunction;
import net.woodyfolsom.msproj.ann.math.Tanh;
import org.junit.Test;
public class TanhTest {
static double EPS = 0.001;
@Test
public void testCalculate() {
ActivationFunction tanh = new Tanh();
assertEquals(0.0,tanh.calculate(0.0),EPS);
assertTrue(tanh.calculate(100.0) > 0.5 - EPS);
assertTrue(tanh.calculate(-9000.0) < -0.5 + EPS);
}
@Test
public void testDerivative() {
ActivationFunction tanh = new Tanh();
assertEquals(1.0,tanh.derivative(0.0), EPS);
}
}

View File

@@ -0,0 +1,73 @@
package net.woodyfolsom.msproj.tictactoe;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import net.woodyfolsom.msproj.tictactoe.Game.PLAYER;
public class GameRecordTest {
@Test
public void testGetResultXwins() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertTrue(finalState.isWinner(PLAYER.X));
assertEquals(GameRecord.RESULT.X_WINS,gameRecord.getResult());
}
@Test
public void testGetResultOwins() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 0));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertTrue(finalState.isWinner(PLAYER.O));
assertEquals(GameRecord.RESULT.O_WINS,gameRecord.getResult());
}
@Test
public void testGetResultTieGame() {
GameRecord gameRecord = new GameRecord();
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 0, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 0, 1));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 0));
gameRecord.apply(Action.getInstance(PLAYER.X, 1, 2));
gameRecord.apply(Action.getInstance(PLAYER.O, 1, 1));
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 0));
gameRecord.apply(Action.getInstance(PLAYER.O, 2, 2));
gameRecord.apply(Action.getInstance(PLAYER.X, 2, 1));
State finalState = gameRecord.getState();
System.out.println("Final state:");
System.out.println(finalState);
assertTrue(finalState.isValid());
assertTrue(finalState.isTerminal());
assertFalse(finalState.isWinner(PLAYER.X));
assertFalse(finalState.isWinner(PLAYER.O));
assertEquals(GameRecord.RESULT.TIE_GAME,gameRecord.getResult());
}
}

View File

@@ -0,0 +1,12 @@
package net.woodyfolsom.msproj.tictactoe;
import org.junit.Test;
public class RefereeTest {
@Test
public void testPlay100Games() {
new Referee().play(100);
}
}

189
ttt.net Normal file
View File

@@ -0,0 +1,189 @@
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<multiLayerPerceptron biased="true" name="TicTacToe">
<activationFunction name="Sigmoid"/>
<connections dest="10" src="0" weight="-1.139430876029846"/>
<connections dest="10" src="1" weight="3.091584814276022"/>
<connections dest="10" src="2" weight="-0.2551933137016801"/>
<connections dest="10" src="3" weight="-0.7615637398659946"/>
<connections dest="10" src="4" weight="-0.6548680752915276"/>
<connections dest="10" src="5" weight="0.08510244492139961"/>
<connections dest="10" src="6" weight="-0.8138255062915528"/>
<connections dest="10" src="7" weight="-0.2048642154445006"/>
<connections dest="10" src="8" weight="0.4118548734860931"/>
<connections dest="10" src="9" weight="-0.3418643333437593"/>
<connections dest="11" src="0" weight="0.985258631016843"/>
<connections dest="11" src="1" weight="0.5585206829273895"/>
<connections dest="11" src="2" weight="-0.00710478214319128"/>
<connections dest="11" src="3" weight="0.4458768855799938"/>
<connections dest="11" src="4" weight="0.699630908274699"/>
<connections dest="11" src="5" weight="-0.291692394014361"/>
<connections dest="11" src="6" weight="-0.3968126140382831"/>
<connections dest="11" src="7" weight="1.5110166318959362"/>
<connections dest="11" src="8" weight="-0.18225892993024753"/>
<connections dest="11" src="9" weight="-0.7602259764999595"/>
<connections dest="12" src="0" weight="1.7429897430035988"/>
<connections dest="12" src="1" weight="-0.28322509888402325"/>
<connections dest="12" src="2" weight="0.5040019819001578"/>
<connections dest="12" src="3" weight="0.9359376456777513"/>
<connections dest="12" src="4" weight="-0.15284920922664844"/>
<connections dest="12" src="5" weight="-3.4788220747438667"/>
<connections dest="12" src="6" weight="0.7547569837163356"/>
<connections dest="12" src="7" weight="-0.32085504506413287"/>
<connections dest="12" src="8" weight="0.518047606917643"/>
<connections dest="12" src="9" weight="0.8207811849267582"/>
<connections dest="13" src="0" weight="0.0768646322526947"/>
<connections dest="13" src="1" weight="0.675759240542896"/>
<connections dest="13" src="2" weight="-1.04758516390251"/>
<connections dest="13" src="3" weight="-1.1207097351854434"/>
<connections dest="13" src="4" weight="-1.0663558249243994"/>
<connections dest="13" src="5" weight="0.40669746747595964"/>
<connections dest="13" src="6" weight="-0.8040553688830026"/>
<connections dest="13" src="7" weight="-0.810063503984392"/>
<connections dest="13" src="8" weight="-0.63726821466013"/>
<connections dest="13" src="9" weight="-0.062253116036353605"/>
<connections dest="14" src="0" weight="-0.2569035720861068"/>
<connections dest="14" src="1" weight="-0.23868649547740917"/>
<connections dest="14" src="2" weight="0.3319329593778122"/>
<connections dest="14" src="3" weight="0.22285129465763973"/>
<connections dest="14" src="4" weight="-1.1932177045246797"/>
<connections dest="14" src="5" weight="-0.8246033698516325"/>
<connections dest="14" src="6" weight="-1.1522063213004192"/>
<connections dest="14" src="7" weight="-0.08295162498206299"/>
<connections dest="14" src="8" weight="0.45121422208738693"/>
<connections dest="14" src="9" weight="0.1344210997671879"/>
<connections dest="15" src="0" weight="-0.19080274015172097"/>
<connections dest="15" src="1" weight="-0.08751712180997395"/>
<connections dest="15" src="2" weight="0.6338301857587448"/>
<connections dest="15" src="3" weight="-0.9971509770232028"/>
<connections dest="15" src="4" weight="0.37406630555233944"/>
<connections dest="15" src="5" weight="1.7040252761510988"/>
<connections dest="15" src="6" weight="0.43507827352032896"/>
<connections dest="15" src="7" weight="-1.030255483779959"/>
<connections dest="15" src="8" weight="0.6425158958005772"/>
<connections dest="15" src="9" weight="-0.2768699175127161"/>
<connections dest="16" src="0" weight="0.383162191474126"/>
<connections dest="16" src="1" weight="-0.316758353560207"/>
<connections dest="16" src="2" weight="-0.40398044046890863"/>
<connections dest="16" src="3" weight="-0.4103150897933657"/>
<connections dest="16" src="4" weight="-0.2110512886314012"/>
<connections dest="16" src="5" weight="-0.7537325411227123"/>
<connections dest="16" src="6" weight="-0.277432410233177"/>
<connections dest="16" src="7" weight="0.6523906983042057"/>
<connections dest="16" src="8" weight="0.8237246362854799"/>
<connections dest="16" src="9" weight="0.6450796646675565"/>
<connections dest="17" src="0" weight="-1.3222355951033131"/>
<connections dest="17" src="1" weight="-0.6775300244272042"/>
<connections dest="17" src="2" weight="-0.9101223420136262"/>
<connections dest="17" src="3" weight="-0.8913218057082705"/>
<connections dest="17" src="4" weight="-0.3228919507773142"/>
<connections dest="17" src="5" weight="0.6156397776974011"/>
<connections dest="17" src="6" weight="-0.6008468597628974"/>
<connections dest="17" src="7" weight="-0.3094929421425772"/>
<connections dest="17" src="8" weight="1.4800051973199828"/>
<connections dest="17" src="9" weight="-0.26820420703433634"/>
<connections dest="18" src="0" weight="-0.4724752146627139"/>
<connections dest="18" src="1" weight="-0.17278878268217254"/>
<connections dest="18" src="2" weight="-0.3213530770778259"/>
<connections dest="18" src="3" weight="-0.4343270409319928"/>
<connections dest="18" src="4" weight="0.5864291732809569"/>
<connections dest="18" src="5" weight="0.4944358358169582"/>
<connections dest="18" src="6" weight="0.8432289820265341"/>
<connections dest="18" src="7" weight="0.7294985221790254"/>
<connections dest="18" src="8" weight="0.19741936496860893"/>
<connections dest="18" src="9" weight="1.0649680979002503"/>
<connections dest="19" src="0" weight="2.6635011409263543"/>
<connections dest="19" src="10" weight="2.185963006026185"/>
<connections dest="19" src="11" weight="-1.401987872790659"/>
<connections dest="19" src="12" weight="-2.572264670917092"/>
<connections dest="19" src="13" weight="-2.719351802228293"/>
<connections dest="19" src="14" weight="2.14428000554082"/>
<connections dest="19" src="15" weight="-2.5948425406968325"/>
<connections dest="19" src="16" weight="-2.593589676600079"/>
<connections dest="19" src="17" weight="1.2492257986319857"/>
<connections dest="19" src="18" weight="2.7912530986331845"/>
<neurons id="0">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="1">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="2">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="3">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="4">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="5">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="6">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="7">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="8">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="9">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Linear"/>
</neurons>
<neurons id="10">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="11">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="12">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="13">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="14">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="15">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="16">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="17">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="18">
<activationFunction xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="activationFunction" name="Tanh"/>
</neurons>
<neurons id="19">
<activationFunction name="Sigmoid"/>
</neurons>
<layers>
<neuronIds>1</neuronIds>
<neuronIds>2</neuronIds>
<neuronIds>3</neuronIds>
<neuronIds>4</neuronIds>
<neuronIds>5</neuronIds>
<neuronIds>6</neuronIds>
<neuronIds>7</neuronIds>
<neuronIds>8</neuronIds>
<neuronIds>9</neuronIds>
</layers>
<layers>
<neuronIds>10</neuronIds>
<neuronIds>11</neuronIds>
<neuronIds>12</neuronIds>
<neuronIds>13</neuronIds>
<neuronIds>14</neuronIds>
<neuronIds>15</neuronIds>
<neuronIds>16</neuronIds>
<neuronIds>17</neuronIds>
<neuronIds>18</neuronIds>
</layers>
<layers>
<neuronIds>19</neuronIds>
</layers>
</multiLayerPerceptron>