Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Oct 5, 2023
1 parent 149447d commit 42ad3a9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
25 changes: 7 additions & 18 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@ import ChainRulesCore: rrule
using LuxCore
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer <: AbstractExplicitLayer # where {in_dim,out_dim,T}
op #::AbstractMatrix{T}
position::Union{Vector{Int64}, UnitRange{Int64}}
struct ConstLinearLayer <: AbstractExplicitLayer
op
end

ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2))
# ConstLinearLayer(op, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(op,pos)

(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[l.position]
(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)]

(l::ConstLinearLayer)(x::AbstractMatrix) = begin
Tmp = l(x[1,:])
Expand All @@ -20,6 +16,9 @@ ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2))
return Tmp'
end

(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st)

# NOTE: the following rrule is kept because there is a issue with SparseArray
function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector)
val = l(x)
function pb(A)
Expand All @@ -28,20 +27,10 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector)
return val, pb
end

(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st)

function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st)
val = l(x,ps,st)
function pb(A)
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end

# function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractMatrix, ps, st)
# val = l(x, ps, st)
# function pb(A)
# return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
# end
# return val, pb
# end
end
3 changes: 1 addition & 2 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

#TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that...
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
Expand All @@ -278,7 +277,7 @@ function new_sparse_matrix(C,pos)
for i = 1:size(C,1)
C_new[i,pos] = C[i,:]
end
return C_new
return sparse(C_new)
end

# more constructors equivariant_model
Expand Down

0 comments on commit 42ad3a9

Please sign in to comment.