Skip to content

Commit

Permalink
[Lux] Use the node ID as key instead of the Ast itself
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jul 5, 2019
1 parent a88a858 commit 20026fb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 51 deletions.
38 changes: 12 additions & 26 deletions laser/lux/ast/ast_codegen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ proc codegen*(
ast: LuxNode,
arch: SimdArch,
params: seq[NimNode],
visited: var Table[LuxNode, NimNode],
visited: var Table[Id, NimNode],
stmts: var NimNode): NimNode =
## Recursively walk the AST
## Append the corresponding Nim AST for generic instructions
Expand All @@ -34,13 +34,13 @@ proc codegen*(
return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal))
of Output, LVal:
let sym = newIdentNode(ast.symLVal)
if ast in visited:
if ast.id in visited:
return sym
elif ast.prev_version.isNil:
visited[ast] = sym
visited[ast.id] = sym
return sym
else:
visited[ast] = sym
visited[ast.id] = sym
var blck = newStmtList()
let expression = codegen(ast.prev_version, arch, params, visited, blck)
stmts.add blck
Expand All @@ -51,22 +51,15 @@ proc codegen*(
)
return newIdentNode(ast.symLVal)
of Assign:
if ast in visited:
return visited[ast]

# Workaround compileTime table not finding keys
# https://github.com/mratsim/compute-graph-optim/issues/1
for key in visited.keys():
if hash(key) == hash(ast):
{.warning: "Triggered compile-time table 'Key not found' workaround".}
return visited[key]
if ast.id in visited:
return visited[ast.id]

var varAssign = false

if ast.lhs notin visited and
if ast.lhs.id notin visited and
ast.lhs.kind == LVal and
ast.lhs.prev_version.isNil and
ast.rhs notin visited:
ast.rhs.id notin visited:
varAssign = true

var rhsStmt = newStmtList()
Expand All @@ -86,15 +79,8 @@ proc codegen*(
return lhs

of Add, Mul:
if ast in visited:
return visited[ast]

# Workaround compileTime table not finding keys
# https://github.com/mratsim/compute-graph-optim/issues/1
for key in visited.keys():
if hash(key) == hash(ast):
{.warning: "Triggered compile-time table 'Key not found' workaround".}
return visited[key]
if ast.id in visited:
return visited[ast.id]

var callStmt = nnkCall.newTree()
var lhsStmt = newStmtList()
Expand Down Expand Up @@ -122,7 +108,7 @@ proc codegen*(

let memloc = genSym(nskLet, "memloc_")
stmts.add newLetStmt(memloc, callStmt)
visited[ast] = memloc
visited[ast.id] = memloc
return memloc

proc bodyGen*(
Expand All @@ -133,7 +119,7 @@ proc bodyGen*(
): NimNode =
# Does topological ordering and dead-code elimination
result = newStmtList()
var visitedNodes = initTable[LuxNode, NimNode]()
var visitedNodes = initTable[Id, NimNode]()

for i, inOutVar in io:
if inOutVar.kind != Input:
Expand Down
45 changes: 20 additions & 25 deletions laser/lux/ast/ast_definition.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import
# Standard library
hashes, random, tables
random, tables

# ###########################################
#
Expand All @@ -29,7 +29,10 @@ type
LVal # Temporary allocated node
Assign # Assignment statement

Id* = int

LuxNode* = ref object
id*: Id
case kind*: LuxNodeKind
of Input:
symId*: int
Expand All @@ -44,8 +47,6 @@ type
of Assign, Add, Mul:
lhs*, rhs*: LuxNode

ctHash*: Hash # Compile-Time only Hash (TODO)

# ###########################################
#
# Routine definitions
Expand All @@ -56,37 +57,31 @@ var luxNodeRng {.compileTime.} = initRand(0x42)
## Workaround for having no UUID for LuxNodes
## at compile-time - https://github.com/nim-lang/RFCs/issues/131

proc genHash(): Hash =
Hash luxNodeRng.rand(high(int))

proc hash*(x: LuxNode): Hash {.inline.} =
when nimvm:
x.cthash
else: # Take its address
cast[Hash](x)
proc genId(): int =
luxNodeRng.rand(high(int))

proc input*(id: int): LuxNode =
when nimvm:
LuxNode(ctHash: genHash(), kind: Input, symId: id)
LuxNode(id: genId(), kind: Input, symId: id)
else:
LuxNode(kind: Input, symId: id)

proc `+`*(a, b: LuxNode): LuxNode =
when nimvm:
LuxNode(ctHash: genHash(), kind: Add, lhs: a, rhs: b)
LuxNode(id: genId(), kind: Add, lhs: a, rhs: b)
else:
LuxNode(kind: Add, lhs: a, rhs: b)

proc `*`*(a, b: LuxNode): LuxNode =
when nimvm:
LuxNode(ctHash: genHash(), kind: Mul, lhs: a, rhs: b)
LuxNode(id: genId(), kind: Mul, lhs: a, rhs: b)
else:
LuxNode(ctHash: genHash(), kind: Mul, lhs: a, rhs: b)
LuxNode(id: genId(), kind: Mul, lhs: a, rhs: b)

proc `*`*(a: LuxNode, b: SomeInteger): LuxNode =
when nimvm:
LuxNode(
ctHash: genHash(),
id: genId(),
kind: Mul,
lhs: a,
rhs: LuxNode(kind: IntImm, intVal: b)
Expand All @@ -102,17 +97,17 @@ proc `+=`*(a: var LuxNode, b: LuxNode) =
assert a.kind notin {Input, IntImm, FloatImm}
if a.kind notin {Output, LVal}:
a = LuxNode(
ctHash: genHash(),
id: genId(),
kind: LVal,
symLVal: "localvar__" & $a.ctHash, # Generate unique symbol
symLVal: "localvar__" & $a.id, # Generate unique symbol
version: 1,
prev_version: LuxNode(
cthash: a.ctHash,
id: a.id,
kind: Assign,
lhs: LuxNode(
ctHash: a.ctHash, # Keep the hash
id: a.id, # Keep the hash
kind: LVal,
symLVal: "localvar__" & $a.ctHash, # Generate unique symbol
symLVal: "localvar__" & $a.id, # Generate unique symbol
version: 0,
prev_version: nil,
),
Expand All @@ -121,25 +116,25 @@ proc `+=`*(a: var LuxNode, b: LuxNode) =
)
if a.kind == Output:
a = LuxNode(
ctHash: genHash(),
id: genId(),
kind: Output,
symLVal: a.symLVal, # Keep original unique symbol
version: a.version + 1,
prev_version: LuxNode(
ctHash: a.ctHash,
id: a.id,
kind: Assign,
lhs: a,
rhs: a + b
)
)
else:
a = LuxNode(
ctHash: genHash(),
id: genId(),
kind: LVal,
symLVal: a.symLVal, # Keep original unique symbol
version: a.version + 1,
prev_version: LuxNode(
ctHash: a.ctHash,
id: a.id,
kind: Assign,
lhs: a,
rhs: a + b
Expand Down

0 comments on commit 20026fb

Please sign in to comment.