Skip to content

Commit

Permalink
Resolve most of the issues in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Oct 5, 2023
1 parent a115a63 commit b363596
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@ import ChainRulesCore: rrule
using LuxCore
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer{T} <: AbstractExplicitLayer # where {in_dim,out_dim,T}
W::AbstractMatrix{T}
struct ConstLinearLayer <: AbstractExplicitLayer # where {in_dim,out_dim,T}
W #::AbstractMatrix{T}
position::Union{Vector{Int64}, UnitRange{Int64}}
in_dim::Integer
out_dim::Integer
end

ConstLinearLayer(W::AbstractMatrix{T}) where T = ConstLinearLayer(W,1:size(W,2),size(W,2),size(W,1))
ConstLinearLayer(W::AbstractMatrix{T}, pos::Union{Vector{Int64}, UnitRange{Int64}}) where T = ConstLinearLayer(W,pos,size(W,2),size(W,1))
ConstLinearLayer(W) where T = ConstLinearLayer(W,1:size(W,2))
# ConstLinearLayer(W, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(W,pos)

(l::ConstLinearLayer)(x::AbstractVector) = l.in_dim == length(x[l.position]) ? l.W * x[l.position] : error("x (or the position index) has a wrong length!")
(l::ConstLinearLayer)(x::AbstractVector) = l.W * x[l.position]

(l::ConstLinearLayer)(x::AbstractMatrix) = begin
Tmp = l(x[1,:])
Expand Down
1 change: 1 addition & 0 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

#TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that...
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i],pos[i]) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
Expand Down

0 comments on commit b363596

Please sign in to comment.