Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Nov 28, 2021
2 parents a2be767 + d9065f1 commit 3647afe
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GraphTensorNetworks"
uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
authors = ["GiggleLiu <[email protected]> and contributors"]
version = "0.1.1"
version = "0.1.2"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
4 changes: 3 additions & 1 deletion src/arithematics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,6 @@ function onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
CountingTropical{TV,BS}(one(TV), onehotv(BS, x, v))
end
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
Base.transpose(c::ConfigEnumerator) = c
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))
27 changes: 16 additions & 11 deletions src/bounding.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using OMEinsum: DynamicEinCode

struct AllConfigs{K} end
largest_k(::AllConfigs{K}) where K = K
struct SingleConfig end

"""
backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
Expand All @@ -17,11 +21,11 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ
nixs = OMEinsum._insertat(ixs, i, iy)
nxs = OMEinsum._insertat( xs, i, y)
niy = ixs[i]
if mode == :all
if mode isa AllConfigs
mask = zeros(Bool, size(xs[i]))
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .== xs[i]
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical(largest_k(mode)-1)
push!(masks, mask)
elseif mode == :single # wrong, need `B` matching `A`.
elseif mode isa SingleConfig
A = zeros(eltype(xs[i]), size(xs[i]))
A = einsum(EinCode(nixs, niy), nxs, size_dict)
push!(masks, onehotmask(A, xs[i]))
Expand Down Expand Up @@ -65,12 +69,12 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
end

# computed mask tree by back propagation
function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all)
function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
if OMEinsum.isleaf(code)
return CacheTree(mask, CacheTree{Bool}[])
else
submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
return CacheTree(mask, generate_masktree.(Ref(mode), code.args, cache.siblings, submasks, Ref(size_dict)))
end
end

Expand All @@ -89,27 +93,28 @@ function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
end

"""
bounding_contract(code, xsa, ymask, xsb; size_info=nothing)
bounding_contract(mode, code, xsa, ymask, xsb; size_info=nothing)
Contraction method with bounding.
* `mode` is a `AllConfigs{K}` instance, where `MIS-K+1` is the largest IS size that you care about.
* `xsa` are input tensors for bounding, e.g. tropical tensors,
* `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
* `ymask` is the initial gradient mask for the output tensor.
"""
function bounding_contract(code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
function bounding_contract(mode::AllConfigs, code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
LT = OMEinsum.labeltype(code)
bounding_contract(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
bounding_contract(mode, NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
end
function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
function bounding_contract(mode::AllConfigs, code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
OMEinsum.get_size_dict!(code, xsa, size_dict)
# compute intermediate tensors
@debug "caching einsum..."
c = cached_einsum(code, xsa, size_dict)
# compute masks from cached tensors
@debug "generating masked tree..."
mt = generate_masktree(code, c, ymask, size_dict, :all)
mt = generate_masktree(mode, code, c, ymask, size_dict)
# compute results with masks
masked_einsum(code, xsb, mt, size_dict)
end
Expand All @@ -129,7 +134,7 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no
n = asscalar(c.content)
# compute masks from cached tensors
@debug "generating masked tree..."
mt = generate_masktree(code, c, ymask, size_dict, :single)
mt = generate_masktree(SingleConfig(), code, c, ymask, size_dict)
n, read_config!(code, mt, Dict())
end

Expand Down
13 changes: 12 additions & 1 deletion src/configurations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export best_solutions, best2_solutions, solutions, all_solutions
export bestk_solutions

"""
best_solutions(problem; all=false, usecuda=false)
Expand All @@ -23,7 +24,7 @@ function best_solutions(gp::GraphProblem; all=false, usecuda=false)
end
if all
xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp)
return bounding_contract(gp.code, xst, ymask, xs)
return bounding_contract(AllConfigs{1}(), gp.code, xst, ymask, xs)
else
@assert ndims(ymask) == 0
t, res = solution_ad(gp.code, xst, ymask)
Expand Down Expand Up @@ -58,6 +59,16 @@ Finding optimal and suboptimal solutions.
"""
best2_solutions(gp::GraphProblem; all=true, usecuda=false) = solutions(gp, Max2Poly{Float64,Float64}; all=all, usecuda=usecuda)

function bestk_solutions(gp::GraphProblem, k::Int)
syms = symbols(gp)
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
xst = generate_tensors(l->TropicalF64(1.0), gp)
ymask = trues(fill(2, length(_getiy(gp.code)))...)
T = set_type(TruncatedPoly{k,Float64,Float64}, length(syms), bondsize(gp))
xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp)
return bounding_contract(AllConfigs{k}(), gp.code, xst, ymask, xs)
end

"""
all_solutions(problem)
Expand Down
4 changes: 4 additions & 0 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ function solve(gp::GraphProblem, task; usecuda=false, kwargs...)
return best_solutions(gp; all=false, usecuda=usecuda)
elseif task == "configs max (bounded)"
return best_solutions(gp; all=true, usecuda=usecuda)
elseif task == "configs max2 (bounded)"
return bestk_solutions(gp, 2)
elseif task == "configs max3 (bounded)"
return bestk_solutions(gp, 3)
else
error("unknown task $task.")
end
Expand Down
4 changes: 4 additions & 0 deletions test/configurations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using GraphTensorNetworks, Test, Graphs
using OMEinsum
using TropicalNumbers: CountingTropicalF64
using OMEinsumContractionOrders: uniformsize

@testset "Config types" begin
T = sampler_type(CountingTropical{Float32}, 5, 2)
Expand Down Expand Up @@ -45,10 +46,13 @@ end
@test res5.n == res0
@test res5.c.data res2.c.data
res6 = best2_solutions(code; all=true)[]
res6_ = bestk_solutions(code, 2)[]
res7 = all_solutions(code)[]
idp = graph_polynomial(code, Val(:finitefield))[]
@test all(x->x res7.coeffs[end-1].data, res6.coeffs[1].data)
@test all(x->x res7.coeffs[end].data, res6.coeffs[2].data)
@test all(x->x res7.coeffs[end-1].data, res6_.coeffs[1].data)
@test all(x->x res7.coeffs[end].data, res6_.coeffs[2].data)
for (i, (s, c)) in enumerate(zip(res7.coeffs, idp.coeffs))
@test length(s) == c
@test all(x->count_ones(x)==(i-1), s.data)
Expand Down
5 changes: 5 additions & 0 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using Graphs, Test
res13 = solve(gp, "configs max (bounded)")[]
res14 = solve(gp, "counting max3")[]
res15 = solve(gp, "configs max3")[]
res16 = solve(gp, "configs max2 (bounded)")[]
res17 = solve(gp, "configs max3 (bounded)")[]
@test res1.n == 4
@test res2 == 76
@test res3.n == 4 && res3.c == 5
Expand All @@ -35,6 +37,9 @@ using Graphs, Test
@test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5
@test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) &&
length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5
@test all(x->sum(x) == 3, res16.coeffs[1].data) && all(x->sum(x) == 4, res16.coeffs[2].data) && length(res16.coeffs[1].data) == 30 && length(res16.coeffs[2].data) == 5
@test all(x->sum(x) == 2, res17.coeffs[1].data) && all(x->sum(x) == 3, res17.coeffs[2].data) && all(x->sum(x) == 4, res17.coeffs[3].data) &&
length(res17.coeffs[1].data) == 30 && length(res17.coeffs[2].data) == 30 && length(res17.coeffs[3].data) == 5
end

@testset "save load" begin
Expand Down

0 comments on commit 3647afe

Please sign in to comment.