Skip to content

Commit

Permalink
GenericCPU now also uses the SIMDTable (+ workaround nim-lang/Nim#11668)
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jul 5, 2019
1 parent 20026fb commit 0b4ac7d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 33 deletions.
24 changes: 6 additions & 18 deletions laser/lux/ast/ast_codegen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,9 @@ proc codegen*(
of Input:
return params[ast.symId]
of IntImm:
if arch == ArchGeneric:
return newLit(ast.intVal)
else:
return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal))
return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal))
of FloatImm:
if arch == ArchGeneric:
return newLit(ast.floatVal)
else:
return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal))
return newCall(SimdTable[arch][simdBroadcast], newLit(ast.intVal))
of Output, LVal:
let sym = newIdentNode(ast.symLVal)
if ast.id in visited:
Expand Down Expand Up @@ -92,16 +86,10 @@ proc codegen*(
stmts.add lhsStmt
stmts.add rhsStmt

if arch == ArchGeneric:
case ast.kind
of Add: callStmt.add newidentNode"+"
of Mul: callStmt.add newidentNode"*"
else: raise newException(ValueError, "Unreachable code")
else:
case ast.kind
of Add: callStmt.add SimdTable[arch][simdAdd]
of Mul: callStmt.add SimdTable[arch][simdMul]
else: raise newException(ValueError, "Unreachable code")
case ast.kind
of Add: callStmt.add SimdTable[arch][simdAdd]
of Mul: callStmt.add SimdTable[arch][simdMul]
else: raise newException(ValueError, "Unreachable code")

callStmt.add lhs
callStmt.add rhs
Expand Down
9 changes: 8 additions & 1 deletion laser/lux/platforms.nim
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0).
# This file may not be copied, modified, or distributed except according to those terms.

import ./platforms/platform_common

when defined(i386) or defined(x86_64):
import ./platforms/platform_x86
export platform_x86

export SimdPrimitives
export SimdArch, SimdAlignment, SimdTable

func elemsPerVector*(arch: SimdArch, T: typedesc): int =
SimdWidth[arch] div sizeof(T)
43 changes: 43 additions & 0 deletions laser/lux/platforms/platform_common.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Laser
# Copyright (c) 2018 Mamy André-Ratsimbazafy
# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0).
# This file may not be copied, modified, or distributed except according to those terms.

import
# Standard library
macros

type
SimdPrimitives* = enum
simdSetZero
simdBroadcast
simdLoadA
simdLoadU
simdStoreA
simdStoreU
simdAdd
simdMul
simdFma
simdType

template noop*(scalar: untyped): untyped =
scalar

template unreachable*(): untyped =
{.error: "Unreachable".}

proc genericPrimitives*: array[SimdPrimitives, NimNode] =
# Use a proc instead of a const
# to workaround https://github.com/nim-lang/Nim/issues/11668
[
simdSetZero: bindSym"unreachable",
simdBroadcast: bindSym"noop",
simdLoadA: bindSym"unreachable",
simdLoadU: bindSym"unreachable",
simdStoreA: bindSym"unreachable",
simdStoreU: bindSym"unreachable",
simdAdd: bindSym"+",
simdMul: bindSym"*",
simdFma: bindSym"unreachable",
simdType: bindSym"unreachable"
]
17 changes: 3 additions & 14 deletions laser/lux/platforms/platform_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,10 @@ import
# Standard library
macros,
# Internal
./platform_common,
../../simd

type
SimdPrimitives* = enum
simdSetZero
simdBroadcast
simdLoadA
simdLoadU
simdStoreA
simdStoreU
simdAdd
simdMul
simdFma
simdType

SimdArch* = enum
ArchGeneric,
x86_SSE,
Expand Down Expand Up @@ -58,7 +47,7 @@ template avx_fma_fallback(a, b, c: m128): m128 =

proc genSimdTableX86(): array[SimdArch, array[SimdPrimitives, NimNode]] =

let sse: array[SimdPrimitives, NimNode] = [
let sse = [
simdSetZero: bindSym"mm_setzero_ps",
simdBroadcast: bindSym"mm_set1_ps",
simdLoadA: bindSym"mm_load_ps",
Expand Down Expand Up @@ -88,7 +77,7 @@ proc genSimdTableX86(): array[SimdArch, array[SimdPrimitives, NimNode]] =
avx_fma[simdFma] = bindSym"mm256_fmadd_ps"

result = [
ArchGeneric: default(array[SimdPrimitives, NimNode]),
ArchGeneric: genericPrimitives(),
x86_SSE: sse,
x86_AVX: avx,
x86_AVX_FMA: avx_fma
Expand Down

0 comments on commit 0b4ac7d

Please sign in to comment.