diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index 5dcbb0d..d03ccd6 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -2,15 +2,11 @@ import ChainRulesCore: rrule using LuxCore using LuxCore: AbstractExplicitLayer -struct ConstLinearLayer <: AbstractExplicitLayer # where {in_dim,out_dim,T} - op #::AbstractMatrix{T} - position::Union{Vector{Int64}, UnitRange{Int64}} +struct ConstLinearLayer <: AbstractExplicitLayer + op end -ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2)) -# ConstLinearLayer(op, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(op,pos) - -(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[l.position] +(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)] (l::ConstLinearLayer)(x::AbstractMatrix) = begin Tmp = l(x[1,:]) @@ -20,6 +16,9 @@ ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2)) return Tmp' end + (l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) + + # NOTE: the following rrule is kept because there is a issue with SparseArray function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) val = l(x) function pb(A) @@ -28,20 +27,10 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) return val, pb end -(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) - function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st) val = l(x,ps,st) function pb(A) return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() end return val, pb -end - -# function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractMatrix, ps, st) -# val = l(x, ps, st) -# function pb(A) -# return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() -# end -# return val, pb -# end \ No newline at end of file +end \ No newline at end of file diff --git a/src/builder.jl b/src/builder.jl index 2bbab0c..c215411 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -261,7 +261,6 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - #TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that... l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) @@ -278,7 +277,7 @@ function new_sparse_matrix(C,pos) for i = 1:size(C,1) C_new[i,pos] = C[i,:] end - return C_new + return sparse(C_new) end # more constructors equivariant_model