From a88a8581f0e2a72332efbfa8b6d60d9a33f21ea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Thu, 4 Jul 2019 14:15:12 +0200 Subject: [PATCH] Initial working version of Lux --- laser/lux/ast.nim | 62 ++++ laser/lux/ast/ast_codegen.nim | 4 +- laser/lux/ast/ast_codegen_transfo.nim | 11 +- laser/lux/ast/ast_compiler.nim | 267 ++++++++++++++++++ laser/lux/ast/ast_definition.nim | 8 +- laser/lux/ast/ast_sigmatch.nim | 51 ++++ .../platform_common.nim => platforms.nim} | 2 +- laser/lux/platforms/platform_x86.nim | 54 ++-- 8 files changed, 428 insertions(+), 31 deletions(-) create mode 100644 laser/lux/ast/ast_compiler.nim create mode 100644 laser/lux/ast/ast_sigmatch.nim rename laser/lux/{platforms/platform_common.nim => platforms.nim} (90%) diff --git a/laser/lux/ast.nim b/laser/lux/ast.nim index e69de29..21cc148 100644 --- a/laser/lux/ast.nim +++ b/laser/lux/ast.nim @@ -0,0 +1,62 @@ +# Laser +# Copyright (c) 2018 Mamy André-Ratsimbazafy +# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). +# This file may not be copied, modified, or distributed except according to those terms. + +import + ./ast/ast_compiler, + ./ast/ast_definition + +# ########################### +# +# Tests +# +# ########################### + +when isMainModule: + + proc foobar(a: LuxNode, b, c: LuxNode): tuple[bar: LuxNode, baz, buzz: LuxNode] = + + let foo = a + b + c + + # Don't use in-place updates + # https://github.com/nim-lang/Nim/issues/11637 + let bar = foo * 2 + + var baz = foo * 3 + var buzz = baz + + buzz += a * 1000 + baz += b + buzz += b + + result.bar = bar + result.baz = baz + result.buzz = buzz + + proc foobar(a: int, b, c: int): tuple[bar, baz, buzz: int] = + echo "Overloaded proc to test bindings" + discard + + generate foobar: + proc foobar(a: seq[float32], b, c: seq[float32]): tuple[bar: seq[float32], baz, buzz: seq[float32]] + + # Note to use aligned store, SSE requires 16-byte alignment and AVX 32-byte alignment + # Unfortunately there is no way with normal seq to specify that (pending destructors) + # As a hack, we use the unaligned load and store simd, and a required alignment of 4, + # in practice we define our own tensor type + # with aligned allocator + + import sequtils + + let + len = 10 + u = newSeqWith(len, 1'f32) + v = newSeqWith(len, 2'f32) + w = newSeqWith(len, 3'f32) + + let (pim, pam, poum) = foobar(u, v, w) + + echo pim # 12 + echo pam # 20 + echo poum # 10020 diff --git a/laser/lux/ast/ast_codegen.nim b/laser/lux/ast/ast_codegen.nim index 263c051..f8d9ed6 100644 --- a/laser/lux/ast/ast_codegen.nim +++ b/laser/lux/ast/ast_codegen.nim @@ -8,7 +8,7 @@ import macros, tables, # Internal ./ast_definition, - ../platforms/[platform_common] + ../platforms proc codegen*( ast: LuxNode, @@ -125,7 +125,7 @@ proc codegen*( visited[ast] = memloc return memloc -proc bodyGen( +proc bodyGen*( genSimd: bool, arch: SimdArch, io: varargs[LuxNode], ids: seq[NimNode], diff --git a/laser/lux/ast/ast_codegen_transfo.nim b/laser/lux/ast/ast_codegen_transfo.nim index 9470782..afeb9a5 100644 --- a/laser/lux/ast/ast_codegen_transfo.nim +++ b/laser/lux/ast/ast_codegen_transfo.nim @@ -3,7 +3,14 @@ # Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). # This file may not be copied, modified, or distributed except according to those terms. -proc vectorize( +import + # Standard library + macros, + # Internals + ../platforms, + ../../private/align_unroller + +proc vectorize*( funcName: NimNode, ptrs, simds: tuple[inParams, outParams: seq[NimNode]], len: NimNode, @@ -177,7 +184,7 @@ proc vectorize( block: # Aligned part let idx = newIdentNode("idx_") result.add quote do: - let `unroll_stop` = round_down_power_of_2( + let `unroll_stop` = round_step_down( `len` - `idxPeeling`, `unroll_factor`) let (fcall, dst, dst_init, dst_assign) = elems(idx, simd = true) diff --git a/laser/lux/ast/ast_compiler.nim b/laser/lux/ast/ast_compiler.nim new file mode 100644 index 0000000..5fca26f --- /dev/null +++ b/laser/lux/ast/ast_compiler.nim @@ -0,0 +1,267 @@ +# Laser +# Copyright (c) 2018 Mamy André-Ratsimbazafy +# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). +# This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + macros, + # Internal + ../platforms, + ./ast_definition, + ./ast_sigmatch, + ./ast_codegen, + ./ast_codegen_transfo, + ./macro_utils + +# TODO: Do we need both compile and generate? + +proc initParams( + procDef, + resultType: NimNode + ): tuple[ + ids: seq[NimNode], + ptrs, simds: tuple[inParams, outParams: seq[NimNode]], + length: NimNode, + initStmt: NimNode + ] = + # Get the idents from proc definition. We order the same as proc def + # Start with non-result + # We work at simd vector level + result.initStmt = newStmtList() + let type0 = newCall( + newIdentNode"type", + nnkBracketExpr.newTree( + procDef[0][3][1][0], + newLit 0 + ) + ) + + for i in 1 ..< procDef[0][3].len: # Proc formal params + let iddefs = procDef[0][3][i] + for j in 0 ..< iddefs.len - 2: + let ident = iddefs[j] + result.ids.add ident + let raw_ptr = newIdentNode($ident & "_raw_ptr") + result.ptrs.inParams.add raw_ptr + + if j == 0: + result.length = quote do: `ident`.len + else: + let len0 = result.length + result.initStmt.add quote do: + assert `len0` == `ident`.len + result.initStmt.add quote do: + let `raw_ptr` = cast[ptr UncheckedArray[`type0`]](`ident`[0].unsafeAddr) + result.simds.inParams.add newIdentNode($ident & "_simd") + + # Now add the result idents + # We work at simd vector level + let len0 = result.length + + if resultType.kind == nnkEmpty: + discard + elif resultType.kind == nnkTupleTy: + for i in 0 ..< resultType.len: + let iddefs = resultType[i] + for j in 0 ..< iddefs.len - 2: + let ident = iddefs[j] + result.ids.add ident + let raw_ptr = newIdentNode($ident & "_raw_ptr") + result.ptrs.outParams.add raw_ptr + + let res = nnkDotExpr.newTree( + newIdentNode"result", + iddefs[j] + ) + result.initStmt.add quote do: + `res` = newSeq[`type0`](`len0`) + let `raw_ptr` = cast[ptr UncheckedArray[`type0`]](`res`[0].unsafeAddr) + + result.simds.outParams.add newIdentNode($ident & "_simd") + +macro compile(arch: static SimdArch, io: static varargs[LuxNode], procDef: untyped): untyped = + # Note: io must be an array - https://github.com/nim-lang/Nim/issues/10691 + + # compile([a, b, c, bar, baz, buzz]): + # proc foobar[T](a, b, c: T): tuple[bar, baz, buzz: T] + # + # StmtList + # ProcDef + # Ident "foobar" + # Empty + # GenericParams + # IdentDefs + # Ident "T" + # Empty + # Empty + # FormalParams + # TupleTy + # IdentDefs + # Ident "bar" + # Ident "baz" + # Ident "buzz" + # Ident "T" + # Empty + # IdentDefs + # Ident "a" + # Ident "b" + # Ident "c" + # Ident "T" + # Empty + # Empty + # Empty + # Empty + + # echo procDef.treerepr + + ## Sanity checks + procDef.expectkind(nnkStmtList) + assert procDef.len == 1, "Only 1 statement is allowed, the function definition" + procDef[0].expectkind({nnkProcDef, nnkFuncDef}) + # TODO: check that the function inputs are in a symbol table? + procDef[0][6].expectKind(nnkEmpty) + + let resultTy = procDef[0][3][0] + let (ids, ptrs, simds, length, initParams) = initParams(procDef, resultTy) + + # echo initParams.toStrLit() + + let seqT = nnkBracketExpr.newTree( + newIdentNode"seq", newIdentNode"float32" + ) + + # We create the inner SIMD proc, specialized to a SIMD architecture + # In the inner proc we shadow the original idents ids. + let simdBody = bodyGen( + genSimd = true, + arch = arch, + io = io, + ids = ids, + resultType = resultTy + ) + + var simdProc = procDef[0].replaceType(seqT, SimdTable[arch][simdType]) + + simdProc[6] = simdBody # Assign to proc body + echo simdProc.toStrLit + + # We create the inner generic proc + let genericBody = bodyGen( + genSimd = false, + arch = ArchGeneric, + io = io, + ids = ids, + resultType = resultTy + ) + + var genericProc = procDef[0].replaceType(seqT, newIdentNode"float32") + genericProc[6] = genericBody # Assign to proc body + echo genericProc.toStrLit + + # We vectorize the inner proc to apply to an contiguous array + var vecBody: NimNode + if arch == x86_SSE: + vecBody = vectorize( + procDef[0][0], + ptrs, simds, + length, + arch, 4, 4 # We require 4 alignment as a hack to keep seq[T] and use unaligned load/store in code + ) + else: + vecBody = vectorize( + procDef[0][0], + ptrs, simds, + length, + arch, 4, 8 # We require 4 alignment as a hack to keep seq[T] and use unaligned load/store in code + ) + + result = procDef.copyNimTree() + let resBody = newStmtList() + resBody.add initParams + resBody.add genericProc + resBody.add simdProc + resBody.add vecBody + result[0][6] = resBody + + # echo result.toStrLit + +macro generate*(ast_routine: typed, signature: untyped): untyped = + let formalParams = signature[0][3] + let ast = ast_routine.resolveASToverload(formalParams) + + # Get the routine signature + let sig = ast.getImpl[3] + sig.expectKind(nnkFormalParams) + + # Get all inputs + var inputs: seq[NimNode] + for idx_identdef in 1 ..< sig.len: + let identdef = sig[idx_identdef] + doAssert identdef[^2].eqIdent"LuxNode" + identdef[^1].expectKind(nnkEmpty) + for idx_ident in 0 .. identdef.len-3: + inputs.add genSym(nskLet, $identdef[idx_ident] & "_") + + # Allocate inputs + result = newStmtList() + proc ct(ident: NimNode): NimNode = + nnkPragmaExpr.newTree( + ident, + nnkPragma.newTree( + ident"compileTime" + ) + ) + + for i, in_ident in inputs: + result.add newLetStmt( + ct(in_ident), + newCall("input", newLit i) + ) + + # Call the AST routine + let call = newCall(ast, inputs) + var callAssign: NimNode + case sig[0].kind + of nnkEmpty: # Case 1: no result + result.add call + # Compile-time tuple destructuring is bugged - https://github.com/nim-lang/Nim/issues/11634 + # of nnkTupleTy: # Case 2: tuple result + # callAssign = nnkVarTuple.newTree() + # for identdef in sig[0]: + # doAssert identdef[^2].eqIdent"LuxNode" + # identdef[^1].expectKind(nnkEmpty) + # for idx_ident in 0 .. identdef.len-3: + # callAssign.add ct(identdef[idx_ident]) + # callAssign.add newEmptyNode() + # callAssign.add call + # result.add nnkLetSection.newTree( + # callAssign + # ) + else: # Case 3: single return value + callAssign = ct(genSym(nskLet, "callResult_")) + result.add newLetStmt( + callAssign, call + ) + + # Collect all the input/output idents + var io = inputs + case sig[0].kind + of nnkEmpty: + discard + of nnkTupleTy: + var idx = 0 + for identdef in sig[0]: + for idx_ident in 0 .. identdef.len-3: + io.add nnkBracketExpr.newTree( + callAssign[0], + newLit idx + ) + inc idx + else: + io.add callAssign + + result.add quote do: + compile(x86_SSE, `io`, `signature`) + + echo result.toStrlit diff --git a/laser/lux/ast/ast_definition.nim b/laser/lux/ast/ast_definition.nim index df4bdc2..f004439 100644 --- a/laser/lux/ast/ast_definition.nim +++ b/laser/lux/ast/ast_definition.nim @@ -52,12 +52,12 @@ type # # ########################################### -var astNodeRng {.compileTime.} = initRand(0x42) +var luxNodeRng {.compileTime.} = initRand(0x42) ## Workaround for having no UUID for LuxNodes ## at compile-time - https://github.com/nim-lang/RFCs/issues/131 proc genHash(): Hash = - Hash astNodeRng.rand(high(int)) + Hash luxNodeRng.rand(high(int)) proc hash*(x: LuxNode): Hash {.inline.} = when nimvm: @@ -89,13 +89,13 @@ proc `*`*(a: LuxNode, b: SomeInteger): LuxNode = ctHash: genHash(), kind: Mul, lhs: a, - rhs: LuxNode(kind: IntScalar, intVal: b) + rhs: LuxNode(kind: IntImm, intVal: b) ) else: LuxNode( kind: Mul, lhs: a, - rhs: LuxNode(kind: IntScalar, intVal: b) + rhs: LuxNode(kind: IntImm, intVal: b) ) proc `+=`*(a: var LuxNode, b: LuxNode) = diff --git a/laser/lux/ast/ast_sigmatch.nim b/laser/lux/ast/ast_sigmatch.nim new file mode 100644 index 0000000..63282bb --- /dev/null +++ b/laser/lux/ast/ast_sigmatch.nim @@ -0,0 +1,51 @@ +# Laser +# Copyright (c) 2018 Mamy André-Ratsimbazafy +# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). +# This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + macros + +proc matchAST(overload, signature: NimNode): bool = + proc inspect(overload, signature: NimNode, match: var bool) = + # echo "overload: ", overload.kind, " - match status: ", match + if overload.kind in {nnkIdent, nnkSym} and overload.eqident("LuxNode"): + # LuxNode match with any type + # It should especially match with seq[T] which is of kind nnkBracketExpr + return + + # Return early when not matching + if overload.kind != signature.kind: + match = false + if overload.len != signature.len: + match = false + if match == false: + return + + case overload.kind: + of {nnkIdent, nnkSym}: + match = eqIdent(overload, signature) + of nnkEmpty: + discard + else: + for i in 0 ..< overload.len: + inspect(overload[i], signature[i], match) + + result = true + inspect(overload, signature, result) + +proc resolveASToverload*(overloads, formalParams: NimNode): NimNode = + if overloads.kind == nnkSym: + result = overloads.getImpl() + result[3].expectKind nnkFormalParams + return + else: + overloads.expectKind(nnkClosedSymChoice) + for o in overloads: + let implSig = o.getImpl[3] + implSig.expectKind nnkFormalParams + let match = implSig.matchAST(formalParams) + if match: + return o + raise newException(ValueError, "no matching overload found") diff --git a/laser/lux/platforms/platform_common.nim b/laser/lux/platforms.nim similarity index 90% rename from laser/lux/platforms/platform_common.nim rename to laser/lux/platforms.nim index ff2356a..45b7c99 100644 --- a/laser/lux/platforms/platform_common.nim +++ b/laser/lux/platforms.nim @@ -4,5 +4,5 @@ # This file may not be copied, modified, or distributed except according to those terms. when defined(i386) or defined(x86_64): - import ./platform_x86 + import ./platforms/platform_x86 export platform_x86 diff --git a/laser/lux/platforms/platform_x86.nim b/laser/lux/platforms/platform_x86.nim index 3786a65..d1caa3c 100644 --- a/laser/lux/platforms/platform_x86.nim +++ b/laser/lux/platforms/platform_x86.nim @@ -6,7 +6,10 @@ # TODO: merge with laser/primitives/matrix_multiplication/gemm_tiling import - macros + # Standard library + macros, + # Internal + ../../simd type SimdPrimitives* = enum @@ -47,38 +50,45 @@ const SimdAlignment* = [ x86_AVX_FMA: 32, ] +template sse_fma_fallback(a, b, c: m128): m128 = + mm_add_ps(mm_mul_ps(a, b), c) + +template avx_fma_fallback(a, b, c: m128): m128 = + mm256_add_ps(mm256_mul_ps(a, b), c) + proc genSimdTableX86(): array[SimdArch, array[SimdPrimitives, NimNode]] = let sse: array[SimdPrimitives, NimNode] = [ - simdSetZero: ident"mm_setzero_ps", - simdBroadcast: ident"mm_set1_ps", - simdLoadA: ident"mm_load_ps", - simdLoadU: ident"mm_loadu_ps", - simdStoreA: ident"mm_store_ps", - simdStoreU: ident"mm_storeu_ps", - simdAdd: ident"mm_add_ps", - simdMul: ident"mm_mul_ps", - simdFma: ident"sse_fma_fallback", - simdType: ident"m128" + simdSetZero: bindSym"mm_setzero_ps", + simdBroadcast: bindSym"mm_set1_ps", + simdLoadA: bindSym"mm_load_ps", + simdLoadU: bindSym"mm_loadu_ps", + simdStoreA: bindSym"mm_store_ps", + simdStoreU: bindSym"mm_storeu_ps", + simdAdd: bindSym"mm_add_ps", + simdMul: bindSym"mm_mul_ps", + simdFma: bindSym"sse_fma_fallback", + simdType: bindSym"m128" ] let avx: array[SimdPrimitives, NimNode] = [ - simdSetZero: ident"mm256_setzero_ps", - simdBroadcast: ident"mm256_set1_ps", - simdLoadA: ident"mm256_load_ps", - simdLoadU: ident"mm256_loadu_ps", - simdStoreA: ident"mm256_store_ps", - simdStoreU: ident"mm256_storeu_ps", - simdAdd: ident"mm256_add_ps", - simdMul: ident"mm256_mul_ps", - simdFma: ident"avx_fma_fallback", - simdType: ident"m256" + simdSetZero: bindSym"mm256_setzero_ps", + simdBroadcast: bindSym"mm256_set1_ps", + simdLoadA: bindSym"mm256_load_ps", + simdLoadU: bindSym"mm256_loadu_ps", + simdStoreA: bindSym"mm256_store_ps", + simdStoreU: bindSym"mm256_storeu_ps", + simdAdd: bindSym"mm256_add_ps", + simdMul: bindSym"mm256_mul_ps", + simdFma: bindSym"avx_fma_fallback", + simdType: bindSym"m256" ] var avx_fma = avx - avx_fma[simdFma] = ident"mm256_fmadd_ps" + avx_fma[simdFma] = bindSym"mm256_fmadd_ps" result = [ + ArchGeneric: default(array[SimdPrimitives, NimNode]), x86_SSE: sse, x86_AVX: avx, x86_AVX_FMA: avx_fma