Skip to content

Commit

Permalink
fix: eltype fix for wrapper types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent a60f5ee commit f2e563a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.30"
version = "0.3.31-DEV"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
1 change: 1 addition & 0 deletions ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end

function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale,
bias, x, momentum, epsilon, t::Val{training}) where {training}
# TODO: Transition this to an error in the future
!training && @warn "`training=Val(false)` but gradient was called." maxlog=1
y, xmean, xivar = LuxLib.batchnorm_cudnn(
running_mean, running_var, scale, bias, x, momentum, epsilon, t)
Expand Down
24 changes: 12 additions & 12 deletions src/impl/fused_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,28 @@ function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N},
__materialize_subarray(_ofeltype_array(yT, weight)), cdims)
end

function __conv(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims)
x, weight = __get_conv_input_weight(
get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_)
function __conv(
x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT}
x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_)
return conv(x, weight, cdims)
end

function __∇conv_data(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims)
x, weight = __get_conv_input_weight(
get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_)
function __∇conv_data(
x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT}
x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_)
return ∇conv_data(x, weight, cdims)
end

function __∇conv_filter(x_::AbstractArray, y_::AbstractArray, cdims::ConvDims)
x, y = __get_conv_input_weight(
get_device_type((x_, y_)), eltype(x_), eltype(y_), x_, y_)
function __∇conv_filter(
x_::AbstractArray{xT}, y_::AbstractArray{yT}, cdims::ConvDims) where {xT, yT}
x, y = __get_conv_input_weight(get_device_type((x_, y_)), xT, yT, x_, y_)
return ∇conv_filter(x, y, cdims)
end

function __conv_bias_act(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims,
bias_::Optional{<:AbstractArray}, act::F) where {F}
function __conv_bias_act(x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims,
bias_::Optional{<:AbstractArray}, act::F) where {xT, wT, F}
dev = get_device_type((x_, weight_, bias_))
x, weight = __get_conv_input_weight(dev, eltype(x_), eltype(weight_), x_, weight_)
x, weight = __get_conv_input_weight(dev, xT, wT, x_, weight_)
bias = _ofeltype_array(eltype(x), bias_)
return __conv_bias_act_impl(dev, x, weight, cdims, bias, act)
end
Expand Down
7 changes: 4 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
return ntuple(i -> i == N - 1 ? ly : 1, N)
elseif N > 2 && ly == sx[N - 1] * sx[N - 2]
return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N)
else
throw(ArgumentError("Invalid Dimensions!"))
end
throw(ArgumentError("Invalid Dimensions!"))
end

CRC.@non_differentiable _get_reshape_dims(::Any...)
Expand Down Expand Up @@ -194,6 +193,8 @@ __value(::Type{T}) where {T <: Number} = T

__value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)
__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T)
__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T)

__value(::Nothing) = nothing

__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl
5 changes: 4 additions & 1 deletion test/others/qa_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
@testitem "Aqua: Quality Assurance" tags=[:others] begin
using Aqua

Aqua.test_all(LuxLib)
Aqua.test_all(LuxLib; ambiguities=false, piracies=false)
Aqua.test_ambiguities(
LuxLib; recursive=false, exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv])
Aqua.test_piracies(LuxLib; treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv])
end

@testitem "Explicit Imports" tags=[:others] begin
Expand Down

0 comments on commit f2e563a

Please sign in to comment.