Skip to content

Commit

Permalink
Add rrules for common functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 9, 2024
1 parent f10ba33 commit fb6766c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
40 changes: 38 additions & 2 deletions src/basis.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
module Basis

using ..Boltz: _unsqueeze1
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using Markdown: @doc_str

const CRC = ChainRulesCore

# The rrules in this file are hardcoded to be used exclusively with GeneralBasisFunction
@concrete struct GeneralBasisFunction{name}
f
n::Int
Expand Down Expand Up @@ -65,12 +69,32 @@ $F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F
"""
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n)

@inline function __fourier(i, x)
# Don't use @fastmasth here, fast mode needs float but Zygote uses Duals for broadcast
@inline @fastmath function __fourier(i, x::AbstractFloat)
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@inline function __fourier(i, x) # No FastMath for non abstract floats
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@fastmath function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__fourier), i, x)
ix_by_2 = @. i * x / 2
s = @. sin(ix_by_2)
c = @. cos(ix_by_2)
y = @. ifelse(iseven(i), c, s)

∇fourier = let s = s, c = c, i = i
Δ -> begin
return (NoTangent(), NoTangent(), NoTangent(),
dropdims(sum((i / 2) .* ifelse.(iseven.(i), -s, c) .* Δ; dims=1); dims=1))
end
end

return y, ∇fourier
end

@doc doc"""
Legendre(n)
Expand Down Expand Up @@ -112,4 +136,16 @@ Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n)

@inline __polynomial(i, x) = x^(i - 1)

function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__polynomial), i, x)
y_m1 = x .^ (i .- 2)
y = y_m1 .* x
∇polynomial = let y_m1 = y_m1, i = i
Δ -> begin
return (NoTangent(), NoTangent(), NoTangent(),
dropdims(sum((i .- 1) .* y_m1 .* Δ; dims=1); dims=1))
end
end
return y, ∇polynomial
end

end
2 changes: 1 addition & 1 deletion test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end
y, st = tensor_project(x, ps, st)
@test size(y) == (2, 4, 5)

@jet tensor_project(x, ps, st)
@jet tensor_project(x, ps, st) opt_broken=!(ongpu) # Due to recursion

__f = (x, ps) -> sum(abs2, first(tensor_project(x, ps, st)))
@eval @test_gradients $(__f) $x $ps gpu_testing=$(ongpu) atol=1e-3 rtol=1e-3 skip_tracker=true
Expand Down

0 comments on commit fb6766c

Please sign in to comment.