diff --git a/Project.toml b/Project.toml index 51675f82..d3a74bee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GraphTensorNetworks" uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2" authors = ["GiggleLiu and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/arithematics.jl b/src/arithematics.jl index a6866ed4..4635b43e 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -219,4 +219,6 @@ function onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS} CountingTropical{TV,BS}(one(TV), onehotv(BS, x, v)) end onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)]) -onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v)) \ No newline at end of file +onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v)) +Base.transpose(c::ConfigEnumerator) = c +Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data)) \ No newline at end of file diff --git a/src/bounding.jl b/src/bounding.jl index 47a43dfc..9213e24c 100644 --- a/src/bounding.jl +++ b/src/bounding.jl @@ -1,5 +1,9 @@ using OMEinsum: DynamicEinCode +struct AllConfigs{K} end +largest_k(::AllConfigs{K}) where K = K +struct SingleConfig end + """ backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict) @@ -17,11 +21,11 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ nixs = OMEinsum._insertat(ixs, i, iy) nxs = OMEinsum._insertat( xs, i, y) niy = ixs[i] - if mode == :all + if mode isa AllConfigs mask = zeros(Bool, size(xs[i])) - mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .== xs[i] + mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical(largest_k(mode)-1) push!(masks, mask) - elseif mode == :single # wrong, need `B` matching `A`. + elseif mode isa SingleConfig A = zeros(eltype(xs[i]), size(xs[i])) A = einsum(EinCode(nixs, niy), nxs, size_dict) push!(masks, onehotmask(A, xs[i])) @@ -65,12 +69,12 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict) end # computed mask tree by back propagation -function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all) +function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict) if OMEinsum.isleaf(code) return CacheTree(mask, CacheTree{Bool}[]) else submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict) - return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode)) + return CacheTree(mask, generate_masktree.(Ref(mode), code.args, cache.siblings, submasks, Ref(size_dict))) end end @@ -89,19 +93,20 @@ function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict) end """ - bounding_contract(code, xsa, ymask, xsb; size_info=nothing) + bounding_contract(mode, code, xsa, ymask, xsb; size_info=nothing) Contraction method with bounding. + * `mode` is a `AllConfigs{K}` instance, where `MIS-K+1` is the largest IS size that you care about. * `xsa` are input tensors for bounding, e.g. tropical tensors, * `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra, * `ymask` is the initial gradient mask for the output tensor. """ -function bounding_contract(code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing) +function bounding_contract(mode::AllConfigs, code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing) LT = OMEinsum.labeltype(code) - bounding_contract(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info) + bounding_contract(mode, NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info) end -function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing) +function bounding_contract(mode::AllConfigs, code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing) size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info) OMEinsum.get_size_dict!(code, xsa, size_dict) # compute intermediate tensors @@ -109,7 +114,7 @@ function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospe c = cached_einsum(code, xsa, size_dict) # compute masks from cached tensors @debug "generating masked tree..." - mt = generate_masktree(code, c, ymask, size_dict, :all) + mt = generate_masktree(mode, code, c, ymask, size_dict) # compute results with masks masked_einsum(code, xsb, mt, size_dict) end @@ -129,7 +134,7 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no n = asscalar(c.content) # compute masks from cached tensors @debug "generating masked tree..." - mt = generate_masktree(code, c, ymask, size_dict, :single) + mt = generate_masktree(SingleConfig(), code, c, ymask, size_dict) n, read_config!(code, mt, Dict()) end diff --git a/src/configurations.jl b/src/configurations.jl index 80c3715a..b2501db0 100644 --- a/src/configurations.jl +++ b/src/configurations.jl @@ -1,4 +1,5 @@ export best_solutions, best2_solutions, solutions, all_solutions +export bestk_solutions """ best_solutions(problem; all=false, usecuda=false) @@ -23,7 +24,7 @@ function best_solutions(gp::GraphProblem; all=false, usecuda=false) end if all xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp) - return bounding_contract(gp.code, xst, ymask, xs) + return bounding_contract(AllConfigs{1}(), gp.code, xst, ymask, xs) else @assert ndims(ymask) == 0 t, res = solution_ad(gp.code, xst, ymask) @@ -58,6 +59,16 @@ Finding optimal and suboptimal solutions. """ best2_solutions(gp::GraphProblem; all=true, usecuda=false) = solutions(gp, Max2Poly{Float64,Float64}; all=all, usecuda=usecuda) +function bestk_solutions(gp::GraphProblem, k::Int) + syms = symbols(gp) + vertex_index = Dict([s=>i for (i, s) in enumerate(syms)]) + xst = generate_tensors(l->TropicalF64(1.0), gp) + ymask = trues(fill(2, length(_getiy(gp.code)))...) + T = set_type(TruncatedPoly{k,Float64,Float64}, length(syms), bondsize(gp)) + xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp) + return bounding_contract(AllConfigs{k}(), gp.code, xst, ymask, xs) +end + """ all_solutions(problem) diff --git a/src/interfaces.jl b/src/interfaces.jl index f6c671dc..f5657917 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -51,6 +51,10 @@ function solve(gp::GraphProblem, task; usecuda=false, kwargs...) return best_solutions(gp; all=false, usecuda=usecuda) elseif task == "configs max (bounded)" return best_solutions(gp; all=true, usecuda=usecuda) + elseif task == "configs max2 (bounded)" + return bestk_solutions(gp, 2) + elseif task == "configs max3 (bounded)" + return bestk_solutions(gp, 3) else error("unknown task $task.") end diff --git a/test/configurations.jl b/test/configurations.jl index 53b30ec9..d0d3ef94 100644 --- a/test/configurations.jl +++ b/test/configurations.jl @@ -1,6 +1,7 @@ using GraphTensorNetworks, Test, Graphs using OMEinsum using TropicalNumbers: CountingTropicalF64 +using OMEinsumContractionOrders: uniformsize @testset "Config types" begin T = sampler_type(CountingTropical{Float32}, 5, 2) @@ -45,10 +46,13 @@ end @test res5.n == res0 @test res5.c.data ∈ res2.c.data res6 = best2_solutions(code; all=true)[] + res6_ = bestk_solutions(code, 2)[] res7 = all_solutions(code)[] idp = graph_polynomial(code, Val(:finitefield))[] @test all(x->x ∈ res7.coeffs[end-1].data, res6.coeffs[1].data) @test all(x->x ∈ res7.coeffs[end].data, res6.coeffs[2].data) + @test all(x->x ∈ res7.coeffs[end-1].data, res6_.coeffs[1].data) + @test all(x->x ∈ res7.coeffs[end].data, res6_.coeffs[2].data) for (i, (s, c)) in enumerate(zip(res7.coeffs, idp.coeffs)) @test length(s) == c @test all(x->count_ones(x)==(i-1), s.data) diff --git a/test/interfaces.jl b/test/interfaces.jl index 76106b3a..5130c596 100644 --- a/test/interfaces.jl +++ b/test/interfaces.jl @@ -19,6 +19,8 @@ using Graphs, Test res13 = solve(gp, "configs max (bounded)")[] res14 = solve(gp, "counting max3")[] res15 = solve(gp, "configs max3")[] + res16 = solve(gp, "configs max2 (bounded)")[] + res17 = solve(gp, "configs max3 (bounded)")[] @test res1.n == 4 @test res2 == 76 @test res3.n == 4 && res3.c == 5 @@ -35,6 +37,9 @@ using Graphs, Test @test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5 @test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) && length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5 + @test all(x->sum(x) == 3, res16.coeffs[1].data) && all(x->sum(x) == 4, res16.coeffs[2].data) && length(res16.coeffs[1].data) == 30 && length(res16.coeffs[2].data) == 5 + @test all(x->sum(x) == 2, res17.coeffs[1].data) && all(x->sum(x) == 3, res17.coeffs[2].data) && all(x->sum(x) == 4, res17.coeffs[3].data) && + length(res17.coeffs[1].data) == 30 && length(res17.coeffs[2].data) == 30 && length(res17.coeffs[3].data) == 5 end @testset "save load" begin