diff --git a/ext/FluxAMDGPUExt/conv.jl b/ext/FluxAMDGPUExt/conv.jl index 29a6042c81..681a38db6d 100644 --- a/ext/FluxAMDGPUExt/conv.jl +++ b/ext/FluxAMDGPUExt/conv.jl @@ -8,7 +8,7 @@ function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray # Calculate size of "input", from ∇conv_data()'s perspective... combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+ - (size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad + (size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad .+ c.outpad C_in = size(c.weight)[end - 1] * c.groups batch_size = size(x)[end] diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index f025dd65ed..8dfb6c0d5d 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -81,6 +81,9 @@ function _amd(id::Union{Nothing, Int}, x) fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(id), x), x; exclude=_exclude) end +_other_args(m::Conv) = [m.stride, m.pad, m.dilation, m.groups] +_other_args(m::ConvTranspose) = [m.stride, m.pad, m.outpad, m.dilation, m.groups] + # CPU -> GPU function Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::CPU_CONV) @@ -89,7 +92,7 @@ function Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::CPU_CONV) Adapt.adapt(to, m.σ), Adapt.adapt(to, flipped_weight), Adapt.adapt(to, m.bias), - m.stride, m.pad, m.dilation, m.groups) + _other_args(m)...) end # Don't adapt again. @@ -102,7 +105,7 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV) dims = ntuple(i -> i, ndims(m.weight) - 2) _conv_basetype(m)( Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims), - Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) + Adapt.adapt(to, m.bias), _other_args(m)...) end function Flux.get_device(::Val{:AMDGPU}, id::Int) # id should start from 0 diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index b7bbb286e5..86b1cccf37 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -46,6 +46,18 @@ end end end +@testset "ConvTranspose output padding" begin + x = randn(Float32, 10, 11, 3, 2) + m = ConvTranspose((3, 5), 3=>6, stride=3, outpad=(1, 0)) + md, xd = Flux.gpu.((m, x)) + @test size(m(x)) == size(md(xd)) + + x = randn(Float32, 10, 11, 12, 3, 2) + m = ConvTranspose((3, 5, 3), 3=>6, stride=3, outpad=(1, 0, 1)) + md, xd = Flux.gpu.((m, x)) + @test size(m(x)) == size(md(xd)) +end + @testset "Chain(Conv)" begin m = Chain(Conv((3, 3), 3 => 3)) |> f32 x = rand(Float32, 10, 10, 3, 2)