Skip to content

Commit

Permalink
feat: Preserve Indices When Copying Tracked Arrays (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi authored Jun 28, 2024
1 parent 554a9c0 commit 5aaf060
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/similar_convert_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ end
function Base.convert(::Type{ComponentArray{T1,N,A1,Ax1}}, x::ComponentArray{T2,N,A2,Ax2}) where {T1,T2,N,A1,A2,Ax1,Ax2}
return T1.(x)
end
function Base.convert(::Type{ComponentArray{T,N,A1,Ax}}, x::ComponentArray{T,N,A2,Ax}) where {T,N,A1,A2,Ax}
return x
end
function Base.convert(::Type{ComponentArray{T,N,A,Ax}}, x::ComponentArray{T,N,A,Ax}) where {T,N,A,Ax}
return x
end
Base.convert(T::Type{<:Array}, x::ComponentArray) = convert(T, getdata(x))

Base.convert(::Type{Cholesky{T1,Matrix{T1}}}, x::Cholesky{T2,<:ComponentArray}) where {T1,T2} = Cholesky(Matrix{T1}(x.factors), x.uplo, x.info)
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ComponentArrays
using BenchmarkTools
using ForwardDiff
using Tracker
using InvertedIndices
using LabelledArrays
using LinearAlgebra
Expand Down Expand Up @@ -400,6 +401,10 @@ end

@test convert(Array, ca) == getdata(ca)
@test convert(Matrix{Float32}, cmat) isa Matrix{Float32}

tr = Tracker.param(ca)
ca_ = convert(typeof(ca), tr)
@test ca_.a == ca.a
end

@testset "Broadcasting" begin
Expand Down

0 comments on commit 5aaf060

Please sign in to comment.