Skip to content

Commit

Permalink
Merge pull request #408 from JuliaHealth/optimize-arbitrary-view
Browse files Browse the repository at this point in the history
Optimize ArbitraryMotion (continuation)
  • Loading branch information
cncastillo authored Jun 27, 2024
2 parents 56fcedc + 5dcc8b8 commit 3128e9f
Show file tree
Hide file tree
Showing 32 changed files with 1,306 additions and 364 deletions.
4 changes: 2 additions & 2 deletions KomaMRIBase/src/KomaMRIBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include("datatypes/Phantom.jl")
include("datatypes/simulation/DiscreteSequence.jl")
include("timing/TimeStepCalculation.jl")
include("timing/TrapezoidalIntegration.jl")
include("timing/UnitTime.jl")

# Main
export γ # gyro-magnetic ratio [Hz/T]
Expand All @@ -51,8 +52,7 @@ export NoMotion, SimpleMotion, ArbitraryMotion
export SimpleMotionType
export Translation, Rotation, HeartBeat
export PeriodicTranslation, PeriodicRotation, PeriodicHeartBeat
export get_spin_coords, sort_motions!
export LinearInterpolator
export get_spin_coords
# Secondary
export get_kspace, rotx, roty, rotz
# Additionals
Expand Down
21 changes: 10 additions & 11 deletions KomaMRIBase/src/datatypes/Phantom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ Base.getindex(x::Phantom, i::Integer) = x[i:i]
"""Compare two phantoms"""
Base.:(==)(obj1::Phantom, obj2::Phantom) = reduce(
&,
[
getfield(obj1, field) == getfield(obj2, field) for
field in Iterators.filter(x -> !(x == :name), fieldnames(Phantom))
],
[getfield(obj1, field) == getfield(obj2, field) for field in Iterators.filter(x -> !(x == :name), fieldnames(Phantom))],
)
Base.:()(obj1::Phantom, obj2::Phantom) = reduce(&, [getfield(obj1, field) getfield(obj2, field) for field in Iterators.filter(x -> !(x == :name), fieldnames(Phantom))])
Base.:(==)(m1::MotionModel, m2::MotionModel) = false
Expand All @@ -88,7 +85,13 @@ Base.getindex(obj::Phantom, p::Union{AbstractRange,AbstractVector,Colon}) = begi
end

"""Separate object spins in a sub-group (lightweigth)."""
Base.view(obj::Phantom, p::Union{AbstractRange,AbstractVector,Colon}) = @views obj[p]
Base.view(obj::Phantom, p::Union{AbstractRange,AbstractVector,Colon}) = begin
fields = []
for field in Iterators.filter(x -> !(x == :name), fieldnames(Phantom))
push!(fields, (field, @view(getfield(obj, field)[p])))
end
return Phantom(; name=obj.name, fields...)
end

"""Addition of phantoms"""
+(obj1::Phantom, obj2::Phantom) = begin
Expand Down Expand Up @@ -117,10 +120,6 @@ function get_dims(obj::Phantom)
return dims
end

function sort_motions!(motion::MotionModel)
return nothing
end

"""
obj = heart_phantom(...)
Expand Down Expand Up @@ -178,7 +177,7 @@ function heart_phantom(
Dλ1=Dλ1[ρ .!= 0],
Dλ2=Dλ2[ρ .!= 0],
=Dθ[ρ .!= 0],
motion=SimpleMotion([
motion=SimpleMotion(
PeriodicHeartBeat(;
period=period,
asymmetry=asymmetry,
Expand All @@ -189,7 +188,7 @@ function heart_phantom(
PeriodicRotation(;
period=period, asymmetry=asymmetry, yaw=rotation_angle, pitch=0.0, roll=0.0
),
]),
),
)
return phantom
end
Expand Down
153 changes: 70 additions & 83 deletions KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@
# Interpolator{T,Degree,ETPType},
# Degree = Linear,Cubic....
# ETPType = Periodic, Flat...
const LinearInterpolator = Interpolations.Extrapolation{
T,
1,
Interpolations.GriddedInterpolation{T,1,V,Gridded{Linear{Throw{OnGrid}}},Tuple{V}},
Gridded{Linear{Throw{OnGrid}}},
Periodic{Nothing},
} where {T<:Real,V<:AbstractVector{T}}

const Interpolator1D = Interpolations.GriddedInterpolation{
T,1,V,Itp,K
} where {
T<:Real,
V<:AbstractArray{T},
Itp<:Interpolations.Gridded{Linear{Throw{OnGrid}}},
K<:Tuple{AbstractVector{T}},
}

const Interpolator2D = Interpolations.GriddedInterpolation{
T,2,V,Itp,K
} where {
T<:Real,
V<:AbstractArray{T},
Itp<:Interpolations.Gridded{Linear{Throw{OnGrid}}},
K<:Tuple{AbstractVector{T}, AbstractVector{T}},
}

"""
motion = ArbitraryMotion(period_durations, dx, dy, dz)
Expand Down Expand Up @@ -43,106 +54,82 @@ julia> motion = ArbitraryMotion(
)
```
"""
struct ArbitraryMotion{T<:Real,V<:AbstractVector{T}} <: MotionModel{T}
period_durations::Vector{T}
dx::Array{T,2}
dy::Array{T,2}
dz::Array{T,2}
ux::Vector{LinearInterpolator{T,V}}
uy::Vector{LinearInterpolator{T,V}}
uz::Vector{LinearInterpolator{T,V}}
end

function ArbitraryMotion(
period_durations::AbstractVector{T},
dx::AbstractArray{T,2},
dy::AbstractArray{T,2},
dz::AbstractArray{T,2},
) where {T<:Real}
@warn "Note that ArbitraryMotion is under development so it is not optimized so far" maxlog = 1
Ns = size(dx)[1]
num_pieces = size(dx)[2] + 1
limits = times(period_durations, num_pieces)

#! format: off
Δ = zeros(Ns,length(limits),4)
Δ[:,:,1] = hcat(repeat(hcat(zeros(Ns,1),dx),1,length(period_durations)),zeros(Ns,1))
Δ[:,:,2] = hcat(repeat(hcat(zeros(Ns,1),dy),1,length(period_durations)),zeros(Ns,1))
Δ[:,:,3] = hcat(repeat(hcat(zeros(Ns,1),dz),1,length(period_durations)),zeros(Ns,1))

etpx = [extrapolate(interpolate((limits,), Δ[i,:,1], Gridded(Linear())), Periodic()) for i in 1:Ns]
etpy = [extrapolate(interpolate((limits,), Δ[i,:,2], Gridded(Linear())), Periodic()) for i in 1:Ns]
etpz = [extrapolate(interpolate((limits,), Δ[i,:,3], Gridded(Linear())), Periodic()) for i in 1:Ns]
#! format: on

return ArbitraryMotion(period_durations, dx, dy, dz, etpx, etpy, etpz)
struct ArbitraryMotion{T} <: MotionModel{T}
t_start::T
t_end::T
dx::AbstractArray{T}
dy::AbstractArray{T}
dz::AbstractArray{T}
end

function Base.getindex(
motion::ArbitraryMotion, p::Union{AbstractRange,AbstractVector,Colon}
)
fields = []
for field in fieldnames(ArbitraryMotion)
if field in (:dx, :dy, :dz)
push!(fields, getfield(motion, field)[p, :])
elseif field in (:ux, :uy, :uz)
push!(fields, getfield(motion, field)[p])
else
push!(fields, getfield(motion, field))
end
end
return ArbitraryMotion(fields...)
return ArbitraryMotion(motion.t_start, motion.t_end, motion.dx[p,:], motion.dy[p,:], motion.dz[p,:])
end
function Base.view(
motion::ArbitraryMotion, p::Union{AbstractRange,AbstractVector,Colon}
)
return ArbitraryMotion(motion.t_start, motion.t_end, @view(motion.dx[p,:]), @view(motion.dy[p,:]), @view(motion.dz[p,:]))
end

Base.:(==)(m1::ArbitraryMotion, m2::ArbitraryMotion) = reduce(&, [getfield(m1, field) == getfield(m2, field) for field in fieldnames(ArbitraryMotion)])
Base.:()(m1::ArbitraryMotion, m2::ArbitraryMotion) = reduce(&, [getfield(m1, field) getfield(m2, field) for field in fieldnames(ArbitraryMotion)])

function Base.vcat(m1::ArbitraryMotion, m2::ArbitraryMotion)
fields = []
@assert m1.period_durations == m2.period_durations "period_durations of both ArbitraryMotions must be the same"
for field in
Iterators.filter(x -> !(x == :period_durations), fieldnames(ArbitraryMotion))
push!(fields, [getfield(m1, field); getfield(m2, field)])
end
return ArbitraryMotion(m1.period_durations, fields...)
@assert (m1.t_start == m2.t_start) && (m1.t_end == m2.t_end) "Starting and ending times must be the same"
return ArbitraryMotion(m1.t_start, m1.t_end, [m1.dx; m2.dx], [m1.dy; m2.dy], [m1.dz; m2.dz])
end

"""
limits = times(obj.motion)
"""
function times(motion::ArbitraryMotion)
period_durations = motion.period_durations
num_pieces = size(motion.dx)[2] + 1
return times(period_durations, num_pieces)
return range(motion.t_start, motion.t_end, length=size(motion.dx, 2))
end

function GriddedInterpolation(nodes, A, ITP)
return Interpolations.GriddedInterpolation{eltype(A), length(nodes), typeof(A), typeof(ITP), typeof(nodes)}(nodes, A, ITP)
end

function interpolate(motion::ArbitraryMotion{T}, Ns::Val{1}) where {T<:Real}
_, Nt = size(motion.dx)
t = similar(motion.dx, Nt); copyto!(t, collect(range(zero(T), oneunit(T), Nt)))
itpx = GriddedInterpolation((t, ), motion.dx[:], Gridded(Linear()))
itpy = GriddedInterpolation((t, ), motion.dy[:], Gridded(Linear()))
itpz = GriddedInterpolation((t, ), motion.dz[:], Gridded(Linear()))
return itpx, itpy, itpz
end

function interpolate(motion::ArbitraryMotion{T}, Ns::Val) where {T<:Real}
Ns, Nt = size(motion.dx)
id = similar(motion.dx, Ns); copyto!(id, collect(range(oneunit(T), T(Ns), Ns)))
t = similar(motion.dx, Nt); copyto!(t, collect(range(zero(T), oneunit(T), Nt)))
itpx = GriddedInterpolation((id, t), motion.dx, Gridded(Linear()))
itpy = GriddedInterpolation((id, t), motion.dy, Gridded(Linear()))
itpz = GriddedInterpolation((id, t), motion.dz, Gridded(Linear()))
return itpx, itpy, itpz
end

function resample(itpx::Interpolator1D{T}, itpy::Interpolator1D{T}, itpz::Interpolator1D{T}, t::AbstractArray{T}) where {T<:Real}
return itpx.(t), itpy.(t), itpz.(t)
end

function times(period_durations::AbstractVector, num_pieces::Int)
# Pre-allocating memory
limits = zeros(eltype(period_durations), num_pieces * length(period_durations) + 1)

idx = 1
for i in 1:length(period_durations)
segment_increment = period_durations[i] / num_pieces
cumulative_sum = limits[idx] # Start from the last computed value in limits
for j in 1:num_pieces
cumulative_sum += segment_increment
limits[idx + 1] = cumulative_sum
idx += 1
end
end
return limits
function resample(itpx::Interpolator2D{T}, itpy::Interpolator2D{T}, itpz::Interpolator2D{T}, t::AbstractArray{T}) where {T<:Real}
Ns = size(itpx.coefs, 1)
id = similar(itpx.coefs, Ns)
copyto!(id, collect(range(oneunit(T), T(Ns), Ns)))
return itpx.(id, t), itpy.(id, t), itpz.(id, t)
end

# TODO: Calculate interpolation functions "on the fly"
function get_spin_coords(
motion::ArbitraryMotion{T},
x::AbstractVector{T},
y::AbstractVector{T},
z::AbstractVector{T},
t::AbstractArray{T},
t::AbstractArray{T}
) where {T<:Real}
xt = x .+ reduce(vcat, [etp.(t) for etp in motion.ux])
yt = y .+ reduce(vcat, [etp.(t) for etp in motion.uy])
zt = z .+ reduce(vcat, [etp.(t) for etp in motion.uz])
return xt, yt, zt
end
motion_functions = interpolate(motion, Val(size(x,1)))
ux, uy, uz = resample(motion_functions..., unit_time(t, motion.t_start, motion.t_end))
return x .+ ux, y .+ uy, z .+ uz
end
3 changes: 2 additions & 1 deletion KomaMRIBase/src/datatypes/phantom/motion/NoMotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ x = x
struct NoMotion{T<:Real} <: MotionModel{T} end

Base.getindex(motion::NoMotion, p::Union{AbstractRange,AbstractVector,Colon}) = motion
Base.view(motion::NoMotion, p::Union{AbstractRange,AbstractVector,Colon}) = motion

Base.:(==)(m1::NoMotion, m2::NoMotion) = true
Base.:()(m1::NoMotion, m2::NoMotion) = true
Expand All @@ -18,7 +19,7 @@ function get_spin_coords(
x::AbstractVector{T},
y::AbstractVector{T},
z::AbstractVector{T},
t::AbstractArray{T},
t::AbstractArray{T}
) where {T<:Real}
return x, y, z
end
Expand Down
Loading

0 comments on commit 3128e9f

Please sign in to comment.