From ea75484c1289f7ce197b175d1745633898322aec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 21:39:40 -0700 Subject: [PATCH] fix: explicit imports --- ext/LuxLibForwardDiffExt.jl | 4 ++-- src/api/conv.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/LuxLibForwardDiffExt.jl b/ext/LuxLibForwardDiffExt.jl index 24622cdc..20ca3054 100644 --- a/ext/LuxLibForwardDiffExt.jl +++ b/ext/LuxLibForwardDiffExt.jl @@ -2,7 +2,7 @@ module LuxLibForwardDiffExt using ForwardDiff: ForwardDiff using LuxLib: LuxLib -using LuxDeviceUtils: AbstractLuxDevice, AbstractLuxGPUDevice +using LuxDeviceUtils: AbstractLuxGPUDevice using NNlib: NNlib LuxLib.__has_dual(::ForwardDiff.Dual) = true @@ -80,6 +80,6 @@ end LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -LuxLib.__value(::Type{<:ForwardDiff.Dual{T}}) where {T} = LuxLib.__value(T) +LuxLib.__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) end diff --git a/src/api/conv.jl b/src/api/conv.jl index 27223945..f29d3618 100644 --- a/src/api/conv.jl +++ b/src/api/conv.jl @@ -38,7 +38,8 @@ for (check, fop) in ( (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) @eval function fused_conv_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, - x::AbstractArray{<:Number, N}, b::Nothing, cdims::ConvDims) where {F, N} + x::AbstractArray{<:Number, N}, + b::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} return $(fop)(σ, weight, x, b, cdims) end end