Skip to content

Commit

Permalink
[WebAssembly] Implement prototype f16x8.splat instruction. (llvm#93228)
Browse files Browse the repository at this point in the history
Adds a builtin and intrinsic for the f16x8.splat instruction.

Specified at:

https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md

Note: the current spec has f16x8.splat as opcode 0x123, but this is
incorrect and will be changed to 0x120 soon.
  • Loading branch information
brendandahl authored May 24, 2024
1 parent d1d9545 commit 09c5525
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 4 deletions.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/BuiltinsWebAssembly.def
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ TARGET_BUILTIN(__builtin_wasm_relaxed_dot_bf16x8_add_f32_f32x4, "V4fV8UsV8UsV4f"
// Half-Precision (fp16)
TARGET_BUILTIN(__builtin_wasm_loadf16_f32, "fh*", "nU", "half-precision")
TARGET_BUILTIN(__builtin_wasm_storef16_f32, "vfh*", "n", "half-precision")
TARGET_BUILTIN(__builtin_wasm_splat_f16x8, "V8hf", "nc", "half-precision")

// Reference Types builtins
// Some builtins are custom type-checked - see 't' as part of the third argument,
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Basic/Targets/WebAssembly.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {

StringRef getABI() const override;
bool setABI(const std::string &Name) override;
bool useFP16ConversionIntrinsics() const override {
return !HasHalfPrecision;
}

protected:
void getTargetDefines(const LangOptions &Opts,
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21230,6 +21230,11 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_storef16_f32);
return Builder.CreateCall(Callee, {Val, Addr});
}
case WebAssembly::BI__builtin_wasm_splat_f16x8: {
Value *Val = EmitScalarExpr(E->getArg(0));
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_splat_f16x8);
return Builder.CreateCall(Callee, {Val});
}
case WebAssembly::BI__builtin_wasm_table_get: {
assert(E->getArg(0)->getType()->isArrayType());
Value *Table = EmitArrayToPointerDecay(E->getArg(0)).emitRawPointer(*this);
Expand Down
6 changes: 6 additions & 0 deletions clang/test/CodeGen/builtins-wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ typedef unsigned char u8x16 __attribute((vector_size(16)));
typedef unsigned short u16x8 __attribute((vector_size(16)));
typedef unsigned int u32x4 __attribute((vector_size(16)));
typedef unsigned long long u64x2 __attribute((vector_size(16)));
typedef __fp16 f16x8 __attribute((vector_size(16)));
typedef float f32x4 __attribute((vector_size(16)));
typedef double f64x2 __attribute((vector_size(16)));

Expand Down Expand Up @@ -813,6 +814,11 @@ void store_f16_f32(float val, __fp16 *addr) {
// WEBASSEMBLY-NEXT: ret
}

f16x8 splat_f16x8(float a) {
// WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.splat.f16x8(float %a)
// WEBASSEMBLY-NEXT: ret <8 x half> %0
return __builtin_wasm_splat_f16x8(a);
}
__externref_t externref_null() {
return __builtin_wasm_ref_null_extern();
// WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def int_wasm_storef16_f32:
[llvm_float_ty, llvm_ptr_ty],
[IntrWriteMem, IntrArgMemOnly],
"", [SDNPMemOperand]>;
def int_wasm_splat_f16x8:
DefaultAttrsIntrinsic<[llvm_v8f16_ty],
[llvm_float_ty],
[IntrNoMem, IntrSpeculatable]>;


//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ wasm::ValType WebAssembly::toValType(MVT Type) {
case MVT::v8i16:
case MVT::v4i32:
case MVT::v2i64:
case MVT::v8f16:
case MVT::v4f32:
case MVT::v2f64:
return wasm::ValType::V128;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
addRegisterClass(MVT::v2i64, &WebAssembly::V128RegClass);
addRegisterClass(MVT::v2f64, &WebAssembly::V128RegClass);
}
if (Subtarget->hasHalfPrecision()) {
addRegisterClass(MVT::v8f16, &WebAssembly::V128RegClass);
}
if (Subtarget->hasReferenceTypes()) {
addRegisterClass(MVT::externref, &WebAssembly::EXTERNREFRegClass);
addRegisterClass(MVT::funcref, &WebAssembly::FUNCREFRegClass);
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
asmstr_s, simdop, HasRelaxedSIMD>;
}

multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
asmstr_s, simdop, HasHalfPrecision>;
}


defm "" : ARGUMENT<V128, v16i8>;
defm "" : ARGUMENT<V128, v8i16>;
Expand Down Expand Up @@ -591,6 +598,14 @@ defm "" : Splat<I64x2, 18>;
defm "" : Splat<F32x4, 19>;
defm "" : Splat<F64x2, 20>;

// Half values are not fully supported so an intrinsic is used instead of a
// regular Splat pattern as above.
defm SPLAT_F16x8 :
HALF_PRECISION_I<(outs V128:$dst), (ins F32:$x),
(outs), (ins),
[(set (v8f16 V128:$dst), (int_wasm_splat_f16x8 F32:$x))],
"f16x8.splat\t$dst, $x", "f16x8.splat", 0x120>;

// scalar_to_vector leaves high lanes undefined, so can be a splat
foreach vec = AllVecs in
def : Pat<(vec.vt (scalar_to_vector (vec.lane_vt vec.lane_rc:$x))),
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def I32 : WebAssemblyRegClass<[i32], 32, (add FP32, SP32, I32_0)>;
def I64 : WebAssemblyRegClass<[i64], 64, (add FP64, SP64, I64_0)>;
def F32 : WebAssemblyRegClass<[f32], 32, (add F32_0)>;
def F64 : WebAssemblyRegClass<[f64], 64, (add F64_0)>;
def V128 : WebAssemblyRegClass<[v4f32, v2f64, v2i64, v4i32, v16i8, v8i16], 128,
(add V128_0)>;
def V128 : WebAssemblyRegClass<[v8f16, v4f32, v2f64, v2i64, v4i32, v16i8,
v8i16],
128, (add V128_0)>;
def FUNCREF : WebAssemblyRegClass<[funcref], 0, (add FUNCREF_0)>;
def EXTERNREF : WebAssemblyRegClass<[externref], 0, (add EXTERNREF_0)>;
12 changes: 10 additions & 2 deletions llvm/test/CodeGen/WebAssembly/half-precision.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s
; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s

declare float @llvm.wasm.loadf32.f16(ptr)
declare void @llvm.wasm.storef16.f32(float, ptr)
Expand All @@ -19,3 +19,11 @@ define void @stf16_32(float %v, ptr %p) {
tail call void @llvm.wasm.storef16.f32(float %v, ptr %p)
ret void
}

; CHECK-LABEL: splat_v8f16:
; CHECK: f16x8.splat $push0=, $0
; CHECK-NEXT: return $pop0
define <8 x half> @splat_v8f16(float %x) {
%v = call <8 x half> @llvm.wasm.splat.f16x8(float %x)
ret <8 x half> %v
}
3 changes: 3 additions & 0 deletions llvm/test/MC/WebAssembly/simd-encodings.s
Original file line number Diff line number Diff line change
Expand Up @@ -845,4 +845,7 @@ main:
# CHECK: f32.store_f16 32 # encoding: [0xfc,0x31,0x01,0x20]
f32.store_f16 32

# CHECK: f16x8.splat # encoding: [0xfd,0xa0,0x02]
f16x8.splat

end_function

0 comments on commit 09c5525

Please sign in to comment.