Skip to content

Commit

Permalink
Initial working alphazero agent.
Browse files Browse the repository at this point in the history
  • Loading branch information
nelhage committed Nov 6, 2022
2 parents d1ad0a0 + 9f5a49f commit 3d97b37
Show file tree
Hide file tree
Showing 145 changed files with 8,665 additions and 4,117 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ jobs:
- name: pytest
run: |
cd python && pytest
env:
TEST_WANDB: true
16 changes: 11 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PREFIX := github.com/nelhage/taktician

PROTOS := $(wildcard proto/*.proto)
PROTONAMES := $(basename $(notdir $(PROTOS)))
PROTOS := $(wildcard proto/tak/proto/*.proto)
PROTONAMES := $(foreach proto,$(PROTOS), $(basename $(notdir $(proto))))
GOPROTOSRC := $(foreach proto,$(PROTONAMES),pb/$(proto).pb.go)
PYPROTOSRC := $(foreach proto,$(PROTONAMES),python/tak/proto/$(proto)_pb2.py)
GENFILES := ai/feature_string.go $(GOPROTOSRC) $(PYPROTOSRC)
Expand All @@ -13,9 +13,15 @@ ai/feature_string.go: ai/evaluate.go
protoc: $(GOPROTOSRC) $(PYPROTOSRC)

$(GOPROTOSRC) $(PYPROTOSRC): $(PROTOS)
protoc -I proto/ \
--python_out=python/tak/proto/ --go_out=pb \
proto/*.proto
python -m grpc_tools.protoc\
-I proto/ \
--python_out=python/ \
--grpc_python_out=python/ \
--go_out=. \
--go_opt="module=$(PREFIX)" \
--go-grpc_out=. \
--go-grpc_opt="module=$(PREFIX)" \
proto/tak/proto/*.proto

build: $(GENFILES)
go build $(PREFIX)/...
Expand Down
6 changes: 3 additions & 3 deletions ai/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func scoreGroups(c *bitboard.Constants, gs []uint64, ws *Weights, other uint64)
return sc
}

func countThreats(c *bitboard.Constants, p *tak.Position) (wp, wt, bp, bt int) {
func CountThreats(c *bitboard.Constants, p *tak.Position) (wp, wt, bp, bt int) {
analysis := p.Analysis()
empty := c.Mask &^ (p.White | p.Black)

Expand Down Expand Up @@ -416,7 +416,7 @@ func scoreThreats(c *bitboard.Constants, ws *Weights, p *tak.Position) int64 {
return 0
}

wp, wt, bp, bt := countThreats(c, p)
wp, wt, bp, bt := CountThreats(c, p)

if wp+wt > 0 && p.ToMove() == tak.White {
return ForcedWin
Expand Down Expand Up @@ -543,7 +543,7 @@ func ExplainScore(m *MinimaxAI, out io.Writer, p *tak.Position) {

fmt.Fprintf(tw, "liberties\t%d\t%d\n", wl, bl)

wp, wt, bp, bt := countThreats(&m.c, p)
wp, wt, bp, bt := CountThreats(&m.c, p)
fmt.Fprintf(tw, "potential\t%d\t%d\n", wp, bp)
fmt.Fprintf(tw, "threat\t%d\t%d\n", wt, bt)

Expand Down
43 changes: 43 additions & 0 deletions ai/feature_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions ai/mcts/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ func (mc *MonteCarloAI) dumpTreeNode(f io.Writer, t *tree) {
parent := 1
if t.parent != nil {
parent = t.parent.simulations
if t.parent.proven != 0 && t.proven == 0 {
return
}

}
label := fmt.Sprintf("n=%d p=%d v=%.0f+%.0f",
t.simulations,
Expand All @@ -36,17 +40,20 @@ func (mc *MonteCarloAI) dumpTreeNode(f io.Writer, t *tree) {

fmt.Fprintf(f, ` n%p [label="%s"]`, t, label)
fmt.Fprintln(f)
if t.children == nil || t.simulations < mc.cfg.InitialVisits {
if t.children == nil {
return
}

for _, c := range t.children {
if c.simulations < mc.cfg.InitialVisits {
if t.proven > 0 && c.proven >= 0 {
continue
}
fmt.Fprintf(f, ` n%p -> n%p [label="%s"]`,
t, c, ptn.FormatMove(c.move))
fmt.Fprintln(f)
mc.dumpTreeNode(f, c)
if c.proven < 0 {
break
}
}
}
52 changes: 25 additions & 27 deletions ai/mcts/mcts.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ type MCTSConfig struct {
C float64
Seed int64

InitialVisits int
MMDepth int
MaxRollout int
EvalThreshold int64
Expand Down Expand Up @@ -154,21 +153,6 @@ func (ai *MonteCarloAI) GetMove(ctx context.Context, p *tak.Position) tak.Move {
ai.dumpTree(tree)
}

if tree.proven != 0 {
if len(tree.children) == 0 {
return ai.mm.GetMove(ctx, p)
}
best := tree.children[0]
for _, c := range tree.children {
if c.proven < best.proven {
best = c
}
}
if ai.cfg.Debug > 1 {
log.Printf("proven m=%s v=%d", ptn.FormatMove(best.move), -best.proven)
}
return best.move
}
best := tree.children[0]
i := 0
sort.Sort(bySims(tree.children))
Expand All @@ -194,6 +178,21 @@ func (ai *MonteCarloAI) GetMove(ctx context.Context, p *tak.Position) tak.Move {
}
}
}
if tree.proven != 0 {
if len(tree.children) == 0 {
return ai.mm.GetMove(ctx, p)
}
best := tree.children[0]
for _, c := range tree.children {
if c.proven < best.proven {
best = c
}
}
if ai.cfg.Debug > 1 {
log.Printf("proven m=%s v=%d", ptn.FormatMove(best.move), -best.proven)
}
return best.move
}
if ai.cfg.Debug > 0 {
log.Printf("[mcts] evaluated simulations=%d value=%d proven=%d", tree.simulations, tree.value, tree.proven)
}
Expand All @@ -214,14 +213,16 @@ func (mc *MonteCarloAI) printdbg(t *tree) {
}

func (mc *MonteCarloAI) populate(ctx context.Context, t *tree) {
_, v, _ := mc.mm.Analyze(ctx, t.position)
if v > ai.WinThreshold {
t.proven = 1
return
} else if v < -ai.WinThreshold {
t.proven = -1
return
}
/*
_, v, _ := mc.mm.Analyze(ctx, t.position)
if v > ai.WinThreshold {
t.proven = 1
return
} else if v < -ai.WinThreshold {
t.proven = -1
return
}
*/

moves := t.position.AllMoves(nil)
t.children = make([]*tree, 0, len(moves))
Expand Down Expand Up @@ -354,9 +355,6 @@ func NewMonteCarlo(cfg MCTSConfig) *MonteCarloAI {
if mc.cfg.Seed == 0 {
mc.cfg.Seed = time.Now().Unix()
}
if mc.cfg.InitialVisits == 0 {
mc.cfg.InitialVisits = 3
}
if mc.cfg.MMDepth == 0 {
mc.cfg.MMDepth = 3
}
Expand Down
3 changes: 2 additions & 1 deletion ai/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ type RandomAI struct {
}

func (r *RandomAI) GetMove(ctx context.Context, p *tak.Position) tak.Move {
moves := p.AllMoves(nil)
var buffer [100]tak.Move
moves := p.AllMoves(buffer[:0])
i := r.r.Int31n(int32(len(moves)))
return moves[i]
}
Expand Down
24 changes: 24 additions & 0 deletions bin/sync
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
set -eu

root=$(dirname "$0")/../
cd "$root"
target=$1
path=${2-'~/code/taktician'}


exec rsync -Pax \
--exclude='__pycache__' \
--exclude='/python/notebooks/' \
--exclude='/python/build/' \
--exclude='/python/dist/' \
--exclude='/python/.pytest_cache/' \
--exclude='/python/*.so' \
--exclude='/.git' \
--exclude='/.direnv' \
--exclude='*.test' \
--exclude='/data' \
--exclude='wandb' \
--exclude='/.envrc' \
--exclude='python/tak.egg-info' \
. "$target:$path"
64 changes: 61 additions & 3 deletions cmd/internal/analyze/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (a *pnAnalysis) Analyze(ctx context.Context, p *tak.Position) {
cli.RenderBoard(nil, os.Stdout, p)
}

out := prover.Prove(ctx, p)
out, stats := prover.Prove(ctx, p)
var result string
switch out.Result {
case prove.EvalTrue:
Expand All @@ -131,11 +131,11 @@ func (a *pnAnalysis) Analyze(ctx context.Context, p *tak.Position) {
result,
move,
out.Duration,
out.Stats.Nodes,
stats.Nodes,
out.Proof,
out.Disproof,
out.Depth,
out.Stats.MaxDepth,
stats.MaxDepth,
)

if a.cmd.dumpTree != "" {
Expand All @@ -151,3 +151,61 @@ func (a *pnAnalysis) Analyze(ctx context.Context, p *tak.Position) {
out.Close()
}
}

type dfpnAnalysis struct {
cmd *Command
}

func (a *dfpnAnalysis) Analyze(ctx context.Context, p *tak.Position) {
var attacker tak.Color
if a.cmd.attacker == "white" {
attacker = tak.White
} else if a.cmd.attacker == "black" {
attacker = tak.Black
} else if a.cmd.attacker == "" {
attacker = tak.NoColor
} else {
log.Fatalf("Cannot parse attacker: %q", a.cmd.attacker)
}
prover := prove.NewDFPN(&prove.DFPNConfig{
Debug: a.cmd.mmopt.Debug,
TableMem: a.cmd.mmopt.TableMem,
Attacker: attacker,
})

if !a.cmd.quiet {
cli.RenderBoard(nil, os.Stdout, p)
}

out, stats := prover.Prove(p)
var result string
switch out.Result {
case prove.EvalTrue:
result = "WIN"
case prove.EvalFalse:
result = "DRAW|LOSE"
case prove.EvalUnknown:
result = "UNKNOWN"
}
fmt.Printf("PN search analysis:\n")
var move string
if out.Move.Type != 0 {
move = ptn.FormatMove(out.Move)
} else {
move = "(none)"
}
fmt.Printf(" value=%s move=%s duration=%s\n",
result,
move,
out.Duration,
)
fmt.Printf(" work=%d terminal=%d solved=%d repetition=%d hit=%d/%d (%0.2f%%)\n",
stats.Work,
stats.Terminal,
stats.Solved,
stats.Repetition,
stats.Hits,
stats.Hits+stats.Miss,
100*float64(stats.Hits)/float64(stats.Hits+stats.Miss),
)
}
Loading

0 comments on commit 3d97b37

Please sign in to comment.