Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack is not differentiable #1423

Open
bicycle1885 opened this issue May 23, 2023 · 3 comments
Open

Stack is not differentiable #1423

bicycle1885 opened this issue May 23, 2023 · 3 comments
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@bicycle1885
Copy link

I want to use the stack function introduced in Julia 1.9 in my model but Flux.jl (or its backend) cannot auto-differentiate it.

using Flux

nn = Dense(3 => 2)
x = randn(Float32, 3, 5)

slicestack(x) = stack((x for x in eachslice(x, dims = 1)), dims = 1)
slicecat(x) = reduce(vcat, (x' for x in eachslice(x, dims = 1)))
@assert slicestack(nn(x)) == slicecat(nn(x))
Flux.withgradient(nn -> sum(slicecat(nn(x))), nn)  # this works
Flux.withgradient(nn -> sum(slicestack(nn(x))), nn)  # but this doesn't

error (truncated):

kenta@KS-MBP ~/tmp> julia stack.jl
ERROR: LoadError: Mutating arrays is not supported -- called copyto!(SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:88
  [3] (::Zygote.var"#555#556"{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:103
  [4] (::Zygote.var"#2653#back#557"{Zygote.var"#555#556"{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
...

Environemnt:

julia> versioninfo()
Julia Version 1.9.0
Commit 8e630552924 (2023-05-07 11:25 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, apple-m1)
  Threads: 1 on 6 virtual cores
Environment:
  JULIA_PROJECT = @.

(tmp) pkg> status
Status `~/tmp/Project.toml`
  [587475ba] Flux v0.13.16
@ToucheSir
Copy link
Member

stack is only differentiable when applied on arrays. Given Zygote does pretty poorly with general (lazy) iterators in general, you'll want to use that path anyhow.

@ToucheSir ToucheSir transferred this issue from FluxML/Flux.jl May 23, 2023
@ToucheSir ToucheSir added the ChainRules adjoint -> rrule, and further integration label May 23, 2023
@bicycle1885
Copy link
Author

Thank you for the tip! That's actually very helpful. Having a mention somewhere in the docs might be nice if this feature is hard to support.

@mcabbott
Copy link
Member

Full support for stack was part of the motivation for JuliaDiff/ChainRules.jl#671 but not done yet...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

3 participants