Skip to content

Commit

Permalink
add interface for the viz tool by OMEinsumContractionOrders v0.9 (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao authored Aug 6, 2024
1 parent 68c2736 commit 90c85b8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CUDA = "4, 5"
ChainRulesCore = "1"
Combinatorics = "1.0"
MacroTools = "0.5"
OMEinsumContractionOrders = "0.8, 0.9"
OMEinsumContractionOrders = "0.9"
TupleTools = "1.2, 1.3"
julia = "1"

Expand All @@ -39,6 +39,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -48,4 +49,4 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "CUDA", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"]
test = ["Test", "CUDA", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials", "LuxorGraphPlot"]
4 changes: 3 additions & 1 deletion src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ export CodeOptimizer, CodeSimplifier,
peak_memory, timespace_complexity, timespacereadwrite_complexity, flop, contraction_complexity,
# file io
writejson, readjson,
label_elimination_order
label_elimination_order,
# visualization
viz_eins, viz_contraction

include("Core.jl")
include("loop_einsum.jl")
Expand Down
3 changes: 3 additions & 0 deletions src/contractionorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ OMEinsumContractionOrders.contraction_complexity(code::AbstractEinsum, size_dict
OMEinsumContractionOrders.uniformsize(code::AbstractEinsum, size) = Dict([l=>size for l in uniquelabels(code)])
OMEinsumContractionOrders.label_elimination_order(code::AbstractEinsum) = label_elimination_order(rawcode(code))

OMEinsumContractionOrders.viz_eins(code::AbstractEinsum, args...; kwargs...) = viz_eins(rawcode(code), args...; kwargs...)
OMEinsumContractionOrders.viz_contraction(code::AbstractEinsum, args...; kwargs...) = viz_contraction(rawcode(code), args...; kwargs...)

# save load
function writejson(filename::AbstractString, ne::Union{NestedEinsum, SlicedEinsum})
OMEinsumContractionOrders.writejson(filename, rawcode(ne))
Expand Down
19 changes: 19 additions & 0 deletions test/contractionorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,23 @@ end
@test optcode == code2
end
end
end

using LuxorGraphPlot

@testset "visualization tool" begin
eincode = ein"ab,acd,bcef,e,df->"
nested_ein = optein"ab,acd,bcef,e,df->"

graph_1 = viz_eins(eincode)
@test graph_1 isa LuxorGraphPlot.Luxor.Drawing

graph_2 = viz_eins(nested_ein)
@test graph_2 isa LuxorGraphPlot.Luxor.Drawing

gif = viz_contraction(nested_ein, filename = tempname() * ".gif")
@test gif isa String

video = viz_contraction(nested_ein)
@test video isa String
end

0 comments on commit 90c85b8

Please sign in to comment.