diff --git a/laser/lux/ast/ast_codegen.nim b/laser/lux/ast/ast_codegen.nim index 0170166..8638e46 100644 --- a/laser/lux/ast/ast_codegen.nim +++ b/laser/lux/ast/ast_codegen.nim @@ -23,15 +23,9 @@ proc codegen*( of Input: return params[ast.symId] of IntImm: - if arch == ArchGeneric: - return newLit(ast.intVal) - else: - return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal)) + return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal)) of FloatImm: - if arch == ArchGeneric: - return newLit(ast.floatVal) - else: - return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal)) + return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal)) of Output, LVal: let sym = newIdentNode(ast.symLVal) if ast.id in visited: @@ -92,16 +86,10 @@ proc codegen*( stmts.add lhsStmt stmts.add rhsStmt - if arch == ArchGeneric: - case ast.kind - of Add: callStmt.add newidentNode"+" - of Mul: callStmt.add newidentNode"*" - else: raise newException(ValueError, "Unreachable code") - else: - case ast.kind - of Add: callStmt.add SimdTable[arch][simdAdd] - of Mul: callStmt.add SimdTable[arch][simdMul] - else: raise newException(ValueError, "Unreachable code") + case ast.kind + of Add: callStmt.add SimdTable[arch][simdAdd] + of Mul: callStmt.add SimdTable[arch][simdMul] + else: raise newException(ValueError, "Unreachable code") callStmt.add lhs callStmt.add rhs diff --git a/laser/lux/platforms.nim b/laser/lux/platforms.nim index 45b7c99..2e4f67f 100644 --- a/laser/lux/platforms.nim +++ b/laser/lux/platforms.nim @@ -3,6 +3,13 @@ # Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). # This file may not be copied, modified, or distributed except according to those terms. +import ./platforms/platform_common + when defined(i386) or defined(x86_64): import ./platforms/platform_x86 - export platform_x86 + +export SimdPrimitives +export SimdArch, SimdAlignment, SimdTable + +func elemsPerVector*(arch: SimdArch, T: typedesc): int = + SimdWidth[arch] div sizeof(T) diff --git a/laser/lux/platforms/platform_common.nim b/laser/lux/platforms/platform_common.nim new file mode 100644 index 0000000..64b203e --- /dev/null +++ b/laser/lux/platforms/platform_common.nim @@ -0,0 +1,43 @@ +# Laser +# Copyright (c) 2018 Mamy André-Ratsimbazafy +# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). +# This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + macros + +type + SimdPrimitives* = enum + simdSetZero + simdBroadcast + simdLoadA + simdLoadU + simdStoreA + simdStoreU + simdAdd + simdMul + simdFma + simdType + +template noop*(scalar: untyped): untyped = + scalar + +template unreachable*(): untyped = + {.error: "Unreachable".} + +proc genericPrimitives*: array[SimdPrimitives, NimNode] = + # Use a proc instead of a const + # to workaround https://github.com/nim-lang/Nim/issues/11668 + [ + simdSetZero: bindSym"unreachable", + simdBroadcast: bindSym"noop", + simdLoadA: bindSym"unreachable", + simdLoadU: bindSym"unreachable", + simdStoreA: bindSym"unreachable", + simdStoreU: bindSym"unreachable", + simdAdd: bindSym"+", + simdMul: bindSym"*", + simdFma: bindSym"unreachable", + simdType: bindSym"unreachable" + ] diff --git a/laser/lux/platforms/platform_x86.nim b/laser/lux/platforms/platform_x86.nim index d1caa3c..9cb361d 100644 --- a/laser/lux/platforms/platform_x86.nim +++ b/laser/lux/platforms/platform_x86.nim @@ -9,21 +9,10 @@ import # Standard library macros, # Internal + ./platform_common, ../../simd type - SimdPrimitives* = enum - simdSetZero - simdBroadcast - simdLoadA - simdLoadU - simdStoreA - simdStoreU - simdAdd - simdMul - simdFma - simdType - SimdArch* = enum ArchGeneric, x86_SSE, @@ -58,7 +47,7 @@ template avx_fma_fallback(a, b, c: m128): m128 = proc genSimdTableX86(): array[SimdArch, array[SimdPrimitives, NimNode]] = - let sse: array[SimdPrimitives, NimNode] = [ + let sse = [ simdSetZero: bindSym"mm_setzero_ps", simdBroadcast: bindSym"mm_set1_ps", simdLoadA: bindSym"mm_load_ps", @@ -88,7 +77,7 @@ proc genSimdTableX86(): array[SimdArch, array[SimdPrimitives, NimNode]] = avx_fma[simdFma] = bindSym"mm256_fmadd_ps" result = [ - ArchGeneric: default(array[SimdPrimitives, NimNode]), + ArchGeneric: genericPrimitives(), x86_SSE: sse, x86_AVX: avx, x86_AVX_FMA: avx_fma