From fb6766c7a2b8e5f996f28c8ffc0a8d3db17d2e82 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 9 Jun 2024 00:18:55 -0700 Subject: [PATCH] Add rrules for common functions --- src/basis.jl | 40 ++++++++++++++++++++++++++++++++++++++-- test/layer_tests.jl | 2 +- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/basis.jl b/src/basis.jl index 973da6b..51c81c6 100644 --- a/src/basis.jl +++ b/src/basis.jl @@ -1,9 +1,13 @@ module Basis using ..Boltz: _unsqueeze1 +using ChainRulesCore: ChainRulesCore, NoTangent using ConcreteStructs: @concrete using Markdown: @doc_str +const CRC = ChainRulesCore + +# The rrules in this file are hardcoded to be used exclusively with GeneralBasisFunction @concrete struct GeneralBasisFunction{name} f n::Int @@ -65,12 +69,32 @@ $F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F """ Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n) -@inline function __fourier(i, x) - # Don't use @fastmasth here, fast mode needs float but Zygote uses Duals for broadcast +@inline @fastmath function __fourier(i, x::AbstractFloat) + s, c = sincos(i * x / 2) + return ifelse(iseven(i), c, s) +end + +@inline function __fourier(i, x) # No FastMath for non abstract floats s, c = sincos(i * x / 2) return ifelse(iseven(i), c, s) end +@fastmath function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__fourier), i, x) + ix_by_2 = @. i * x / 2 + s = @. sin(ix_by_2) + c = @. cos(ix_by_2) + y = @. ifelse(iseven(i), c, s) + + ∇fourier = let s = s, c = c, i = i + Δ -> begin + return (NoTangent(), NoTangent(), NoTangent(), + dropdims(sum((i / 2) .* ifelse.(iseven.(i), -s, c) .* Δ; dims=1); dims=1)) + end + end + + return y, ∇fourier +end + @doc doc""" Legendre(n) @@ -112,4 +136,16 @@ Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n) @inline __polynomial(i, x) = x^(i - 1) +function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__polynomial), i, x) + y_m1 = x .^ (i .- 2) + y = y_m1 .* x + ∇polynomial = let y_m1 = y_m1, i = i + Δ -> begin + return (NoTangent(), NoTangent(), NoTangent(), + dropdims(sum((i .- 1) .* y_m1 .* Δ; dims=1); dims=1)) + end + end + return y, ∇polynomial +end + end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 32d987d..e3ca832 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -100,7 +100,7 @@ end y, st = tensor_project(x, ps, st) @test size(y) == (2, 4, 5) - @jet tensor_project(x, ps, st) + @jet tensor_project(x, ps, st) opt_broken=!(ongpu) # Due to recursion __f = (x, ps) -> sum(abs2, first(tensor_project(x, ps, st))) @eval @test_gradients $(__f) $x $ps gpu_testing=$(ongpu) atol=1e-3 rtol=1e-3 skip_tracker=true