diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index 52ee758..f73529c 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -18,7 +18,7 @@ end (l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) - # NOTE: the following rrule is kept because there is a issue with SparseArray + # NOTE: the following rrule is kept because there is an issue with SparseArray function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) val = l(x) function pb(A) diff --git a/src/radial.jl b/src/radial.jl index 68ca587..e20df7f 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -31,8 +31,7 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr end end - - _norm(x) = try norm(x); catch; norm(x.rr); end + _norm(x) = norm(x.rr) return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(_norm.(x))), evaluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec) end \ No newline at end of file