Skip to content

Commit

Permalink
Add fastmath flag to PTXCompilerTarget (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zentrik authored Aug 14, 2023
1 parent 15f0077 commit b649ef0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
blocks_per_sm::Union{Nothing,Int} = nothing
maxregs::Union{Nothing,Int} = nothing

fastmath::Bool = Base.JLOptions().fast_math == 1

# deprecated; remove with next major version
exitable::Union{Nothing,Bool} = nothing
unreachable::Union{Nothing,Bool} = nothing
Expand All @@ -33,6 +35,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt)
h = hash(target.maxthreads, h)
h = hash(target.blocks_per_sm, h)
h = hash(target.maxregs, h)
h = hash(target.fastmath, h)

h
end
Expand Down Expand Up @@ -82,6 +85,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget}))
job.config.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.config.target.maxthreads)")
job.config.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.config.target.blocks_per_sm)")
job.config.target.maxregs !== nothing && print(io, ", maxregs=$(job.config.target.maxregs)")
job.config.target.fastmath && print(io, ", fast math enabled")
end

const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free")
Expand Down Expand Up @@ -424,7 +428,7 @@ function nvvm_reflect!(fun::LLVM.Function)
# handle possible cases
# XXX: put some of these property in the compiler job?
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
fast_math = Base.JLOptions().fast_math == 1
fast_math = current_job.config.target.fastmath
# NOTE: we follow nvcc's --use_fast_math
reflect_val = if reflect_arg == "__CUDA_FTZ"
# single-precision denormals support
Expand All @@ -433,7 +437,7 @@ function nvvm_reflect!(fun::LLVM.Function)
# single-precision floating-point division and reciprocals.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_PREC_SQRT"
# single-precision denormals support
# single-precision floating point square roots.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_FMAD"
# contraction of floating-point multiplies and adds/subtracts into
Expand Down

0 comments on commit b649ef0

Please sign in to comment.