Skip to content

Commit

Permalink
Attempt to generalize to higher order derivatives. Work in progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Apr 12, 2024
1 parent 098c8f6 commit 837c16b
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 150 deletions.
15 changes: 13 additions & 2 deletions examples/2d_laplace_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ println(
"Starting DMRG to find eigensolution of 2D Laplace operator. Initial energy is $init_energy",
)

dmrg_kwargs = (nsweeps=10, normalize=true, maxdim=15, cutoff=1e-10, outputlevel=1, nsites=2)
dmrg_kwargs = (nsweeps=15, normalize=true, maxdim=30, cutoff=1e-12, outputlevel=1, nsites=2)
ϕ_fxy = dmrg(∇, ttn(itensornetwork(ψ_fxy)); dmrg_kwargs...)
ϕ_fxy = ITensorNetworkFunction(ITensorNetwork(ϕ_fxy), bit_map)

ϕ_fxy = truncate(ϕ_fxy; cutoff=1e-8)
ϕ_fxy = truncate(ϕ_fxy; cutoff=1e-10)

final_energy = inner(ttn(itensornetwork(ϕ_fxy))', ∇, ttn(itensornetwork(ϕ_fxy)))
println(
Expand All @@ -55,3 +55,14 @@ end

println("Here is the heatmap of the 2D function")
show(heatmap(vals; xfact=0.01, yfact=0.01, xoffset=0, yoffset=0, colormap=:inferno))

n_grid = 100
x_vals = grid_points(bit_map, n_grid, 1)
y = 0.5
vals = zeros(length(x_vals))
for (i, x) in enumerate(x_vals)
vals[i] = real(calculate_fxyz(ϕ_fxy, [x, y]))
end

println("Here is a cut of the function at y = $y")
show(lineplot(x_vals, vals))
5 changes: 4 additions & 1 deletion src/ITensorNumericalAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ export const_itensornetwork,
polynomial_itensornetwork,
random_itensornetworkfunction,
laplacian_operator,
derivative_operator,
first_derivative_operator,
second_derivative_operator,
third_derivative_operator,
fourth_derivative_operator,
identity_operator
export const_itn,
poly_itn, cosh_itn, sinh_itn, tanh_itn, exp_itn, sin_itn, cos_itn, rand_itn
Expand Down
90 changes: 60 additions & 30 deletions src/elementary_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,58 +46,68 @@ function ITensors.op(::OpName"Dup", ::SiteType"Digit", s::Index)
end

function plus_shift_ttn(
s::IndsNetwork, bit_map; dimension=default_dimension(), boundary=default_boundary()
s::IndsNetwork, bit_map; dimension=default_dimension(), boundary=default_boundary(), n::Int64 = 0
)
@assert is_tree(s)
@assert base(bit_map) == 2
ttn_op = OpSum()
dim_vertices = vertices(bit_map, dimension)
L = length(dim_vertices)

string_site = [("D+", vertex(bit_map, dimension, L))]
add!(ttn_op, 1.0, "D+", vertex(bit_map, dimension, L))
for i in L:-1:2
string_site = [("D+", vertex(bit_map, dimension, L - n))]
add!(ttn_op, 1.0, "D+", vertex(bit_map, dimension, L - n))
for i in (L-n):-1:2
pop!(string_site)
push!(string_site, ("D-", vertex(bit_map, dimension, i)))
push!(string_site, ("D+", vertex(bit_map, dimension, i - 1)))
add!(ttn_op, 1.0, (string_site...)...)
end

if boundary == "Neumann"
string_site = [("Dup", vertex(bit_map, dimension, i)) for i in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
#TODO: Not convinced this is right....
for i in 0:n
string_site = [j <= (L -i) ? ("Dup", vertex(bit_map, dimension, j)) : ("Ddn", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
elseif boundary == "Periodic"
string_site = [("D-", vertex(bit_map, dimension, i)) for i in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
#TODO: Not convinced this is right....
for i in 0:n
string_site = [j <= (L -i) ? ("D-", vertex(bit_map, dimension, j)) : ("D+", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
end

return ttn(ttn_op, s; algorithm="svd")
end

function minus_shift_ttn(
s::IndsNetwork, bit_map; dimension=default_dimension(), boundary=default_boundary()
s::IndsNetwork, bit_map; dimension=default_dimension(), boundary=default_boundary(), n::Int64 = 0
)
@assert is_tree(s)
@assert base(bit_map) == 2
ttn_op = OpSum()
dim_vertices = vertices(bit_map, dimension)
L = length(dim_vertices)

string_site = [("D-", vertex(bit_map, dimension, L))]
add!(ttn_op, 1.0, "D-", vertex(bit_map, dimension, L))
for i in L:-1:2
string_site = [("D-", vertex(bit_map, dimension, L - n))]
add!(ttn_op, 1.0, "D-", vertex(bit_map, dimension, L - n))
for i in (L-n):-1:2
pop!(string_site)
push!(string_site, ("D+", vertex(bit_map, dimension, i)))
push!(string_site, ("D-", vertex(bit_map, dimension, i - 1)))
add!(ttn_op, 1.0, (string_site...)...)
end

if boundary == "Neumann"
string_site = [("Ddn", vertex(bit_map, dimension, i)) for i in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
for i in 0:n
string_site = [j <= (L -i) ? ("Ddn", vertex(bit_map, dimension, j)) : ("Dup", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
elseif boundary == "Periodic"
string_site = [("D+", vertex(bit_map, dimension, i)) for i in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
for i in 0:n
string_site = [j <= (L -i) ? ("D+", vertex(bit_map, dimension, j)) : ("D-", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
end

return ttn(ttn_op, s; algorithm="svd")
Expand All @@ -121,14 +131,22 @@ function stencil(
scale=true,
truncate_kwargs...,
)
@assert length(shifts) == 3
plus_shift =
first(shifts) * plus_shift_ttn(s, bit_map; dimension, boundary=right_boundary)
minus_shift =
last(shifts) * minus_shift_ttn(s, bit_map; dimension, boundary=left_boundary)
no_shift = shifts[2] * no_shift_ttn(s)

stencil_op = plus_shift + minus_shift + no_shift
@assert length(shifts) == 5
stencil_op = shifts[3] * no_shift_ttn(s)
for i in [1,2]
n = i == 1 ? 1 : 0
if !iszero(i)
stencil_op += shifts[i] * plus_shift_ttn(s, bit_map; dimension, boundary=right_boundary, n)
end
end

for i in [4,5]
n = i == 5 ? 1 : 0
if !iszero(i)
stencil_op += shifts[i] * minus_shift_ttn(s, bit_map; dimension, boundary=left_boundary, n)
end
end

stencil_op = truncate(stencil_op; truncate_kwargs...)

if scale
Expand All @@ -140,22 +158,34 @@ function stencil(
return stencil_op
end

function first_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
return stencil(s, bit_map, [0.0, 0.5, 0.0, -0.5, 0.0], 1; kwargs...)
end

function second_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
return stencil(s, bit_map, [0.0, 1.0, -2.0, 1.0, 0.0], 2; kwargs...)
end

function third_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
return stencil(s, bit_map, [-0.5, 1.0, 0.0, -1.0, 0.5], 3; kwargs...)
end

function fourth_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
return stencil(s, bit_map, [1.0, -4.0, 6.0, -4.0, 1.0], 4; kwargs...)
end

function laplacian_operator(
s::IndsNetwork, bit_map; dimensions=[i for i in 1:dimension(bit_map)], kwargs...
)
remaining_dims = copy(dimensions)
= stencil(s, bit_map, [1.0, -2.0, 1.0], 2; dimension=first(remaining_dims), kwargs...)
= second_derivative_operator(s, bit_map; dimension=first(remaining_dims), kwargs...)
popfirst!(remaining_dims)
for rd in remaining_dims
+= stencil(s, bit_map, [1.0, -2.0, 1.0], 2; dimension=rd, kwargs...)
+= second_derivative_operator(s, bit_map; dimension=rd, kwargs...)
end
return
end

function derivative_operator(s::IndsNetwork, bit_map; kwargs...)
return 0.5 * stencil(s, bit_map, [1.0, 0.0, -1.0], 1; kwargs...)
end

function identity_operator(s::IndsNetwork, bit_map; kwargs...)
return stencil(s, bit_map, [0.0, 1.0, 0.0], 0; kwargs...)
end
Expand Down
Loading

0 comments on commit 837c16b

Please sign in to comment.