Skip to content

Commit

Permalink
Upgrade to SparseMatrixColorings v0.4 (#405)
Browse files Browse the repository at this point in the history
* Upgrade to SparseMatrixColorings v0.4 (still in dev)

* Semicolons

* Correct branch

* Remove up

* Install SMC in test env for DIT

* Fix imports

* Type params

* Typo

* Use main

* Add ColoringProblem

* Typo

* Use coloring

* Result type

* Make ColoringProblem type-stable

* SMC v0.4 is registered

* Re-add SMC to test deps
  • Loading branch information
gdalle authored Aug 15, 2024
1 parent 2ef423e commit 9182912
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 158 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ PolyesterForwardDiff = "0.1.1"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
SparseMatrixColorings = "0.3.5"
SparseMatrixColorings = "0.4.0"
Symbolics = "5.27.1, 6"
Tapir = "0.2.4"
Tracker = "0.2.33"
Expand Down
3 changes: 1 addition & 2 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
- [`TracerSparsityDetector`](@extref SparseConnectivityTracer.TracerSparsityDetector) from [SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl)
- [`SymbolicsSparsityDetector`](@extref Symbolics.SymbolicsSparsityDetector) from [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl)
- [`DenseSparsityDetector`](@ref) from DifferentiationInterface.jl (beware that this detector only gives a locally valid pattern)
3. A coloring algorithm like:
- [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl)
3. A coloring algorithm: [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) is the only one we support.

These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports.
Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around.
Expand Down
21 changes: 10 additions & 11 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module DifferentiationInterface
using ADTypes: ADTypes, AbstractADType
using ADTypes: mode, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode
using ADTypes: AutoSparse, dense_ad
using ADTypes: coloring_algorithm, column_coloring, row_coloring
using ADTypes: coloring_algorithm
using ADTypes: sparsity_detector, jacobian_sparsity, hessian_sparsity
using ADTypes:
AutoChainRules,
Expand All @@ -35,16 +35,16 @@ using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose
using PackageExtensionCompat: @require_extensions
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
color_groups,
decompress_columns,
decompress_columns!,
decompress_rows,
decompress_rows!,
decompress_symmetric,
decompress_symmetric!,
symmetric_coloring_detailed,
StarSet
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!

abstract type Extras end

Expand Down Expand Up @@ -74,7 +74,6 @@ include("second_order/hessian.jl")
include("fallbacks/no_extras.jl")

include("sparse/fallbacks.jl")
include("sparse/matrices.jl")
include("sparse/jacobian.jl")
include("sparse/hessian.jl")

Expand Down
80 changes: 35 additions & 45 deletions DifferentiationInterface/src/sparse/hessian.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,31 @@
struct SparseHessianExtras{
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,R,E2<:HVPExtras,E1<:GradientExtras
B,
C<:AbstractColoringResult{:symmetric,:column},
M<:AbstractMatrix{<:Real},
D,
R,
E2<:HVPExtras,
E1<:GradientExtras,
} <: HessianExtras
sparsity::S
colors::Vector{Int}
star_set::StarSet
groups::Vector{Vector{Int}}
compressed::C
coloring_result::C
compressed_matrix::M
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
hvp_batched_extras::E2
gradient_extras::E1
end

function SparseHessianExtras{B}(;
sparsity::S,
colors,
star_set,
groups,
compressed::C,
coloring_result::C,
compressed_matrix::M,
batched_seeds::Vector{Batch{B,D}},
batched_results::Vector{Batch{B,R}},
hvp_batched_extras::E2,
gradient_extras::E1,
) where {B,S,C,D,R,E2,E1}
@assert size(sparsity, 1) == size(sparsity, 2) == size(compressed, 1) == length(colors)
return SparseHessianExtras{B,S,C,D,R,E2,E1}(
sparsity,
colors,
star_set,
groups,
compressed,
) where {B,C,M,D,R,E2,E1}
return SparseHessianExtras{B,C,M,D,R,E2,E1}(
coloring_result,
compressed_matrix,
batched_seeds,
batched_results,
hvp_batched_extras,
Expand All @@ -41,14 +37,16 @@ end

function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
dense_backend = dense_ad(backend)
initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
sparsity = col_major(initial_sparsity)
colors, star_set = symmetric_coloring_detailed(sparsity, coloring_algorithm(backend))
groups = color_groups(colors)
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
problem = ColoringProblem{:symmetric,:column}()
coloring_result = coloring(
sparsity, problem, coloring_algorithm(backend); decompression_eltype=eltype(x)
)
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(maybe_outer(dense_backend), Ng)
seeds = map(group -> make_seed(x, group), groups)
compressed = stack(_ -> vec(similar(x)), groups; dims=2)
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
batched_seeds =
Batch.([
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for
Expand All @@ -58,11 +56,8 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
hvp_batched_extras = prepare_hvp_batched(f, dense_backend, x, batched_seeds[1])
gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x)
return SparseHessianExtras{B}(;
sparsity,
colors,
star_set,
groups,
compressed,
coloring_result,
compressed_matrix,
batched_seeds,
batched_results,
hvp_batched_extras,
Expand All @@ -71,11 +66,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
end

function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) where {F,B}
@compat (;
sparsity, compressed, colors, star_set, groups, batched_seeds, hvp_batched_extras
) = extras
@compat (; coloring_result, batched_seeds, hvp_batched_extras) = extras
dense_backend = dense_ad(backend)
Ng = length(groups)
Ng = length(column_groups(coloring_result))

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, dense_backend, x, batched_seeds[1], hvp_batched_extras
Expand All @@ -86,28 +79,25 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) w
stack(vec, dg_batch.elements; dims=2)
end

compressed = reduce(hcat, compressed_blocks)
if Ng < size(compressed, 2)
compressed = compressed[:, 1:Ng]
compressed_matrix = reduce(hcat, compressed_blocks)
if Ng < size(compressed_matrix, 2)
compressed_matrix = compressed_matrix[:, 1:Ng]
end
return decompress_symmetric(sparsity, compressed, colors, star_set)
return decompress(compressed_matrix, coloring_result)
end

function hessian!(
f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (;
sparsity,
compressed,
colors,
star_set,
groups,
coloring_result,
compressed_matrix,
batched_seeds,
batched_results,
hvp_batched_extras,
) = extras
dense_backend = dense_ad(backend)
Ng = length(groups)
Ng = length(column_groups(coloring_result))

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, dense_backend, x, batched_seeds[1], hvp_batched_extras
Expand All @@ -125,13 +115,13 @@ function hessian!(

for b in eachindex(batched_results[a].elements)
copyto!(
view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
view(compressed_matrix, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
vec(batched_results[a].elements[b]),
)
end
end

decompress_symmetric!(hess, sparsity, compressed, colors, star_set)
decompress!(hess, compressed_matrix, coloring_result)
return hess
end

Expand Down
Loading

0 comments on commit 9182912

Please sign in to comment.