Skip to content

Commit

Permalink
[Random] Add more comments and a helper function in Xoshiro code (#56144
Browse files Browse the repository at this point in the history
)

Follow up to #55994 and #55997. This should basically be a
non-functional change and I see no performance difference, but the
comments and the definition of a helper function should make the code
easier to follow (I initially struggled in #55997) and extend to other
types.
  • Loading branch information
giordano authored Oct 15, 2024
1 parent 9f92989 commit d09abe5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
21 changes: 13 additions & 8 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,16 @@ rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52Raw{UInt64}}) = ran
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float16}}) =
Float16(rand(r, UInt16) >>> 5) * Float16(0x1.0p-11)

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float32}}) =
Float32(rand(r, UInt32) >>> 8) * Float32(0x1.0p-24)

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01_64}) =
Float64(rand(r, UInt64) >>> 11) * 0x1.0p-53
for FT in (Float16, Float32, Float64)
UT = Base.uinttype(FT)
# Helper function: scale an unsigned integer to a floating point number of the same size
# in the interval [0, 1). This is equivalent to, but more easily extensible than
# Float16(i >>> 5) * Float16(0x1.0p-11)
# Float32(i >>> 8) * Float32(0x1.0p-24)
# Float32(i >>> 11) * Float64(0x1.0p-53)
@eval @inline _uint2float(i::$(UT), ::Type{$(FT)}) =
$(FT)(i >>> $(8 * sizeof(FT) - precision(FT))) * $(FT(2) ^ -precision(FT))

@eval rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{$(FT)}}) =
_uint2float(rand(r, $(UT)), $(FT))
end
21 changes: 13 additions & 8 deletions stdlib/Random/src/XoshiroSimd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module XoshiroSimd
# Getting the xoroshiro RNG to reliably vectorize is somewhat of a hassle without Simd.jl.
import ..Random: rand!
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!, _uint2float
using Base: BitInteger_types
using Base.Libc: memcpy
using Core.Intrinsics: llvmcall
Expand All @@ -30,7 +30,12 @@ simdThreshold(::Type{Bool}) = 640
Tuple{UInt64, Int64},
x, y)

@inline _bits2float(x::UInt64, ::Type{Float64}) = reinterpret(UInt64, Float64(x >>> 11) * 0x1.0p-53)
# `_bits2float(x::UInt64, T)` takes `x::UInt64` as input, it splits it in `N` parts where
# `N = sizeof(UInt64) / sizeof(T)` (`N = 1` for `Float64`, `N = 2` for `Float32, etc...), it
# truncates each part to the unsigned type of the same size as `T`, scales all of these
# numbers to a value of type `T` in the range [0,1) with `_uint2float`, and then
# recomposes another `UInt64` using all these parts.
@inline _bits2float(x::UInt64, ::Type{Float64}) = reinterpret(UInt64, _uint2float(x, Float64))
@inline function _bits2float(x::UInt64, ::Type{Float32})
#=
# this implementation uses more high bits, but is harder to vectorize
Expand All @@ -40,19 +45,19 @@ simdThreshold(::Type{Bool}) = 640
=#
ui = (x>>>32) % UInt32
li = x % UInt32
u = Float32(ui >>> 8) * Float32(0x1.0p-24)
l = Float32(li >>> 8) * Float32(0x1.0p-24)
u = _uint2float(ui, Float32)
l = _uint2float(ui, Float32)
(UInt64(reinterpret(UInt32, u)) << 32) | UInt64(reinterpret(UInt32, l))
end
@inline function _bits2float(x::UInt64, ::Type{Float16})
i1 = (x>>>48) % UInt16
i2 = (x>>>32) % UInt16
i3 = (x>>>16) % UInt16
i4 = x % UInt16
f1 = Float16(i1 >>> 5) * Float16(0x1.0p-11)
f2 = Float16(i2 >>> 5) * Float16(0x1.0p-11)
f3 = Float16(i3 >>> 5) * Float16(0x1.0p-11)
f4 = Float16(i4 >>> 5) * Float16(0x1.0p-11)
f1 = _uint2float(i1, Float16)
f2 = _uint2float(i2, Float16)
f3 = _uint2float(i3, Float16)
f4 = _uint2float(i4, Float16)
return (UInt64(reinterpret(UInt16, f1)) << 48) | (UInt64(reinterpret(UInt16, f2)) << 32) | (UInt64(reinterpret(UInt16, f3)) << 16) | UInt64(reinterpret(UInt16, f4))
end

Expand Down

0 comments on commit d09abe5

Please sign in to comment.