Skip to content

Commit

Permalink
Merge pull request #71 from milankl/mk/rand
Browse files Browse the repository at this point in the history
rand, randn, zeros, ones constructors
  • Loading branch information
milankl authored Jul 14, 2023
2 parents 448cd74 + 3495b1a commit 5e053ec
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/StochasticRounding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,4 @@ module StochasticRounding
include("general.jl")
include("promotion.jl")
include("conversions.jl")

end
9 changes: 9 additions & 0 deletions src/bfloat16sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Base.isnan(x::BFloat16sr) = isnan(Float32(x))
Base.precision(::Type{BFloat16sr}) = 8
Base.one(::Type{BFloat16sr}) = reinterpret(BFloat16sr,0x3f80)
Base.zero(::Type{BFloat16sr}) = reinterpret(BFloat16sr,0x0000)
Base.one(::BFloat16sr) = one(BFloat16sr)
Base.zero(::BFloat16sr) = zero(BFloat16sr)
Base.rand(::Type{BFloat16sr}) = reinterpret(BFloat16sr,rand(BFloat16))
Base.randn(::Type{BFloat16sr}) = reinterpret(BFloat16sr,randn(BFloat16))

const InfB16sr = reinterpret(BFloat16sr, 0x7f80)
const NaNB16sr = reinterpret(BFloat16sr, 0x7fc0)
Expand Down Expand Up @@ -131,6 +135,11 @@ for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt
@eval Base.promote_rule(::Type{BFloat16sr}, ::Type{$t}) = BFloat16sr
end

Base.rand(::Type{BFloat16sr},dims::Integer...) = reinterpret.(BFloat16sr,rand(BFloat16,dims...))
Base.randn(::Type{BFloat16sr},dims::Integer...) = reinterpret.(BFloat16sr,randn(BFloat16,dims...))
Base.zeros(::Type{BFloat16sr},dims::Integer...) = reinterpret.(BFloat16sr,zeros(BFloat16,dims...))
Base.ones(::Type{BFloat16sr},dims::Integer...) = reinterpret.(BFloat16sr,ones(BFloat16,dims...))

Base.show(io::IO, x::BFloat16sr) = show(io,BFloat16(x))
Base.bitstring(x::BFloat16sr) = bitstring(reinterpret(UInt16,x))

Expand Down
8 changes: 8 additions & 0 deletions src/float16sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Base.one(::Type{Float16sr}) = reinterpret(Float16sr,0x3c00)
Base.zero(::Type{Float16sr}) = reinterpret(Float16sr,0x0000)
Base.one(::Float16sr) = one(Float16sr)
Base.zero(::Float16sr) = zero(Float16sr)
Base.rand(::Type{Float16sr}) = reinterpret(Float16sr,rand(Float16))
Base.randn(::Type{Float16sr}) = reinterpret(Float16sr,randn(Float16))

Base.typemin(::Type{Float16sr}) = Float16sr(typemin(Float16))
Base.typemax(::Type{Float16sr}) = Float16sr(typemax(Float16))
Expand Down Expand Up @@ -161,6 +163,12 @@ function Base.sincos(x::Float16sr)
return (Float16_stochastic_round(s),Float16_stochastic_round(c))
end

# array generators
Base.rand(::Type{Float16sr},dims::Integer...) = reinterpret.(Float16sr,rand(Float16,dims...))
Base.randn(::Type{Float16sr},dims::Integer...) = reinterpret.(Float16sr,randn(Float16,dims...))
Base.zeros(::Type{Float16sr},dims::Integer...) = reinterpret.(Float16sr,zeros(Float16,dims...))
Base.ones(::Type{Float16sr},dims::Integer...) = reinterpret.(Float16sr,ones(Float16,dims...))

Base.show(io::IO, x::Float16sr) = show(io,Float16(x))
Base.bitstring(x::Float16sr) = bitstring(reinterpret(UInt16,x))

Expand Down
8 changes: 8 additions & 0 deletions src/float32sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Base.one(::Type{Float32sr}) = reinterpret(Float32sr,one(Float32))
Base.zero(::Type{Float32sr}) = reinterpret(Float32sr,0x0000_0000)
Base.one(::Float32sr) = one(Float32sr)
Base.zero(::Float32sr) = zero(Float32sr)
Base.rand(::Type{Float32sr}) = reinterpret(Float32sr,rand(Float32))
Base.randn(::Type{Float32sr}) = reinterpret(Float32sr,randn(Float32))

Base.typemin(::Type{Float32sr}) = Float32sr(typemin(Float32))
Base.typemax(::Type{Float32sr}) = Float32sr(typemax(Float32))
Expand Down Expand Up @@ -160,6 +162,12 @@ function Base.sincos(x::Float32sr)
return (Float32_stochastic_round(s),Float32_stochastic_round(c))
end

# array generators
Base.rand(::Type{Float32sr},dims::Integer...) = reinterpret.(Float32sr,rand(Float32,dims...))
Base.randn(::Type{Float32sr},dims::Integer...) = reinterpret.(Float32sr,randn(Float32,dims...))
Base.zeros(::Type{Float32sr},dims::Integer...) = reinterpret.(Float32sr,zeros(Float32,dims...))
Base.ones(::Type{Float32sr},dims::Integer...) = reinterpret.(Float32sr,ones(Float32,dims...))

Base.show(io::IO, x::Float32sr) = show(io,Float32(x))
Base.bitstring(x::Float32sr) = bitstring(reinterpret(UInt32,x))

Expand Down
7 changes: 7 additions & 0 deletions src/float64sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Base.one(::Type{Float64sr}) = reinterpret(Float64sr,one(Float64))
Base.zero(::Type{Float64sr}) = reinterpret(Float64sr,zero(Float64))
Base.one(::Float64sr) = one(Float64sr)
Base.zero(::Float64sr) = zero(Float64sr)
Base.rand(::Type{Float64sr}) = reinterpret(Float64sr,rand(Float64))
Base.randn(::Type{Float64sr}) = reinterpret(Float64sr,randn(Float64))

Base.typemin(::Type{Float64sr}) = Float64sr(typemin(Float64))
Base.typemax(::Type{Float64sr}) = Float64sr(typemax(Float64))
Expand Down Expand Up @@ -114,6 +116,11 @@ for func in (:atan,:hypot)
end
end

# array generators
Base.rand(::Type{Float64sr},dims::Integer...) = reinterpret.(Float64sr,rand(Float64,dims...))
Base.randn(::Type{Float64sr},dims::Integer...) = reinterpret.(Float64sr,randn(Float64,dims...))
Base.zeros(::Type{Float64sr},dims::Integer...) = reinterpret.(Float64sr,zeros(Float64,dims...))
Base.ones(::Type{Float64sr},dims::Integer...) = reinterpret.(Float64sr,ones(Float64,dims...))

# Showing
Base.show(io::IO, x::Float64sr) = show(io,Float64(x))
Expand Down

0 comments on commit 5e053ec

Please sign in to comment.