Skip to content

Commit

Permalink
New customization points for Identity specialization (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran authored Jul 16, 2022
1 parent 0c81e38 commit 1f022f9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ManifoldsBase"
uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.13.12"
version = "0.13.13"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
14 changes: 13 additions & 1 deletion src/ManifoldsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ function allocate(::AbstractManifold, a, T::Type, dim1::Integer, dims::Integer..
end
allocate(::AbstractManifold, a, T::Type, dims::Tuple) = allocate(a, T, dims)

"""
_pick_basic_allocation_argument(::AbstractManifold, f, x...)
Pick which one of elements of `x` should be used as a basis for allocation in the
`allocate_result(M::AbstractManifold, f, x...)` method. This can be specialized to, for
example, skip `Identity` arguments in Manifolds.jl group-related functions.
"""
function _pick_basic_allocation_argument(::AbstractManifold, f, x...)
return x[1]
end

"""
allocate_result(M::AbstractManifold, f, x...)
Expand All @@ -88,7 +99,8 @@ isomorphisms.
"""
@inline function allocate_result(M::AbstractManifold, f, x...)
T = allocate_result_type(M, f, x)
return allocate(M, x[1], T)
picked = _pick_basic_allocation_argument(M, f, x...)
return allocate(M, picked, T)
end
@inline function allocate_result(M::AbstractManifold, f)
T = allocate_result_type(M, f, ())
Expand Down
25 changes: 15 additions & 10 deletions src/PowerManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,11 @@ const PowerManifoldNestedReplacing = AbstractPowerManifold{
NestedReplacingPowerRepresentation,
} where {𝔽}

_access_nested(x, i::Int) = x[i]
_access_nested(x, i::Tuple) = x[i...]
# _access_nested(::AbstractManifold, x, i::Tuple) can be overloaded to achieve
# manifold-specific nested element access (for example to `Identity` on power manifolds).
@inline _access_nested(M::AbstractManifold, x, i::Int) = _access_nested(M, x, (i,))
@inline _access_nested(::AbstractManifold, x, i::Tuple) = _access_nested(x, i)
@inline _access_nested(x, i::Tuple) = x[i...]

function Base.:^(
M::PowerManifold{
Expand All @@ -150,7 +153,7 @@ function allocate_result(M::PowerManifoldNested, f, x...)
return allocate(M, x[1])
else
return [
allocate_result(M.manifold, f, map(y -> _access_nested(y, i), x)...) for
allocate_result(M.manifold, f, map(y -> _access_nested(M, y, i), x)...) for
i in get_iterator(M)
]
end
Expand Down Expand Up @@ -199,7 +202,9 @@ function allocate_result(
)
end
function allocate_result(M::PowerManifoldNested, f::typeof(get_vector), p, X)
return [allocate_result(M.manifold, f, _access_nested(p, i)) for i in get_iterator(M)]
return [
allocate_result(M.manifold, f, _access_nested(M, p, i)) for i in get_iterator(M)
]
end
function allocate_result(::PowerManifoldNestedReplacing, ::typeof(get_vector), p, X)
return copy(p)
Expand Down Expand Up @@ -452,7 +457,7 @@ function get_coordinates(
M.manifold,
_read(M, rep_size, p, i),
_read(M, rep_size, X, i),
_access_nested(B.data.bases, i),
_access_nested(M, B.data.bases, i),
) for i in get_iterator(M)
]
return reduce(vcat, reshape(vs, length(vs)))
Expand Down Expand Up @@ -492,7 +497,7 @@ function get_coordinates!(
view(c, v_iter:(v_iter + dim - 1)),
_read(M, rep_size, p, i),
_read(M, rep_size, X, i),
_access_nested(B.data.bases, i),
_access_nested(M, B.data.bases, i),
)
v_iter += dim
end
Expand Down Expand Up @@ -532,7 +537,7 @@ function get_vector!(
_write(M, rep_size, Y, i),
_read(M, rep_size, p, i),
c[v_iter:(v_iter + dim - 1)],
_access_nested(B.data.bases, i),
_access_nested(M, B.data.bases, i),
)
v_iter += dim
end
Expand All @@ -553,7 +558,7 @@ function get_vector!(
M.manifold,
_read(M, rep_size, p, i),
c[v_iter:(v_iter + dim - 1)],
_access_nested(B.data.bases, i),
_access_nested(M, B.data.bases, i),
)
v_iter += dim
end
Expand Down Expand Up @@ -604,7 +609,7 @@ function _get_vectors(
rep_size = representation_size(M.manifold)
vs = typeof(zero_tv)[]
for i in get_iterator(M)
b_i = _access_nested(B.data.bases, i)
b_i = _access_nested(M, B.data.bases, i)
p_i = _read(M, rep_size, p, i)
for v in b_i.data
new_v = copy(M, p, zero_tv)
Expand All @@ -622,7 +627,7 @@ function _get_vectors(
zero_tv = zero_vector(M, p)
vs = typeof(zero_tv)[]
for i in get_iterator(M)
b_i = _access_nested(B.data.bases, i)
b_i = _access_nested(M, B.data.bases, i)
for v in b_i.data
new_v = copy(M, p, zero_tv)
new_v[i...] = v
Expand Down

2 comments on commit 1f022f9

@mateuszbaran
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64374

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.13 -m "<description of version>" 1f022f92facc89a5ebd0a79159376c327465fafd
git push origin v0.13.13

Please sign in to comment.