From 867d6f55a13e3c423ff55af4ebe9e0ed6b1851b5 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Oct 2023 10:34:38 +0200 Subject: [PATCH] Split codegen support in storage and arithmetic. --- src/bfloat16.jl | 41 +++++++++++++++++++++++++++++++---------- test/runtests.jl | 2 ++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 03ac981..1dff1da 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -15,11 +15,32 @@ import Base: isfinite, isnan, precision, iszero, eps, asinh, acosh, atanh, acsch, asech, acoth, bitstring, isinteger -# Julia 1.11 provides codegen support for BFloat16 -const codegen_support = if isdefined(Core, :BFloat16) && - Sys.ARCH in [:x86_64, :i686] +# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating +# code that uses the `bfloat` IR type, together with the necessary runtime functions. +# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s +# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall +# back to defining a primitive type that will be represented as an `i16`. If, in addition, +# the target supports BFloat16 arithmetic, we can use LLVM intrinsics. +# - x86: storage and arithmetic support in LLVM 15 +# - aarch64: storage support in LLVM 17 +const llvm_storage = if isdefined(Core, :BFloat16) + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17" + true + else + false + end +else + false +end +const llvm_arithmetic = if llvm_storage using Core: BFloat16 - true + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + else + false + end else primitive type BFloat16 <: AbstractFloat 16 end false @@ -76,7 +97,7 @@ precision(::Type{BFloat16}) = 8 eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00) ## Rounding ## -if codegen_support +if llvm_arithmetic round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x) round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x) round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x) @@ -118,7 +139,7 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16, (reinterpret(UInt32, x) >> 16) % UInt16 ) -if codegen_support +if llvm_arithmetic BFloat16(x::Float32) = Base.fptrunc(BFloat16, x) BFloat16(x::Float64) = Base.fptrunc(BFloat16, x) @@ -147,7 +168,7 @@ else end # Conversion from Integer -if codegen_support +if llvm_arithmetic for st in (Int8, Int16, Int32, Int64) @eval begin BFloat16(x::($st)) = Base.sitofp(BFloat16, x) @@ -170,7 +191,7 @@ function Base.Float16(x::BFloat16) Float16(Float32(x)) end -if codegen_support +if llvm_arithmetic Base.Float32(x::BFloat16) = Base.fpext(Float32, x) Base.Float64(x::BFloat16) = Base.fpext(Float64, x) else @@ -186,7 +207,7 @@ else end # Basic arithmetic -if codegen_support +if llvm_arithmetic +(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y) -(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y) *(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y) @@ -238,7 +259,7 @@ end Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y) # Truncation to integer types -if codegen_support +if llvm_arithmetic for Ti in (Int8, Int16, Int32, Int64) @eval begin Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x) diff --git a/test/runtests.jl b/test/runtests.jl index 44f4bcd..ab7a504 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using Test, BFloat16s, Printf, Random +@info "Testing BFloat16s" BFloat16s.llvm_storage BFloat16s.llvm_arithmetic + @testset "comparisons" begin @test BFloat16(1) < BFloat16(2) @test BFloat16(1f0) < BFloat16(2f0)