Skip to content

Commit

Permalink
Get higher derivatives working
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Apr 15, 2024
1 parent 837c16b commit ac3a650
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 199 deletions.
110 changes: 66 additions & 44 deletions src/elementary_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ function ITensors.op(::OpName"Dup", ::SiteType"Digit", s::Index)
return ITensor(o, s, s')
end

function plus_shift_ttn(
s::IndsNetwork, bit_map; dimension=default_dimension(), boundary=default_boundary(), n::Int64 = 0
function plus_shift_opsum(
s::IndsNetwork,
bit_map;
dimension=default_dimension(),
boundary=default_boundary(),
n::Int64=0,
)
@assert is_tree(s)
@assert base(bit_map) == 2
Expand All @@ -56,32 +60,42 @@ function plus_shift_ttn(

string_site = [("D+", vertex(bit_map, dimension, L - n))]
add!(ttn_op, 1.0, "D+", vertex(bit_map, dimension, L - n))
for i in (L-n):-1:2
for i in (L - n):-1:2
pop!(string_site)
push!(string_site, ("D-", vertex(bit_map, dimension, i)))
push!(string_site, ("D+", vertex(bit_map, dimension, i - 1)))
add!(ttn_op, 1.0, (string_site...)...)
end

if boundary == "Neumann"
#TODO: Not convinced this is right....
for i in 0:n
string_site = [j <= (L -i) ? ("Dup", vertex(bit_map, dimension, j)) : ("Ddn", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
string_site = [
if j <= (L - n)
("Dup", vertex(bit_map, dimension, j))
else
("I", vertex(bit_map, dimension, j))
end for j in 1:L
]
add!(ttn_op, 1.0, (string_site...)...)
elseif boundary == "Periodic"
#TODO: Not convinced this is right....
for i in 0:n
string_site = [j <= (L -i) ? ("D-", vertex(bit_map, dimension, j)) : ("D+", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
string_site = [
if j <= (L - n)
("D-", vertex(bit_map, dimension, j))
else
("I", vertex(bit_map, dimension, j))
end for j in 1:L
]
add!(ttn_op, 1.0, (string_site...)...)
end

return ttn(ttn_op, s; algorithm="svd")
return ttn_op
end

function minus_shift_ttn(
s::IndsNetwork, bit_map; dimension=default_dimension(), boundary=default_boundary(), n::Int64 = 0
function minus_shift_opsum(
s::IndsNetwork,
bit_map;
dimension=default_dimension(),
boundary=default_boundary(),
n::Int64=0,
)
@assert is_tree(s)
@assert base(bit_map) == 2
Expand All @@ -91,33 +105,41 @@ function minus_shift_ttn(

string_site = [("D-", vertex(bit_map, dimension, L - n))]
add!(ttn_op, 1.0, "D-", vertex(bit_map, dimension, L - n))
for i in (L-n):-1:2
for i in (L - n):-1:2
pop!(string_site)
push!(string_site, ("D+", vertex(bit_map, dimension, i)))
push!(string_site, ("D-", vertex(bit_map, dimension, i - 1)))
add!(ttn_op, 1.0, (string_site...)...)
end

if boundary == "Neumann"
for i in 0:n
string_site = [j <= (L -i) ? ("Ddn", vertex(bit_map, dimension, j)) : ("Dup", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
string_site = [
if j <= (L - n)
("Ddn", vertex(bit_map, dimension, j))
else
("I", vertex(bit_map, dimension, j))
end for j in 1:L
]
add!(ttn_op, 1.0, (string_site...)...)
elseif boundary == "Periodic"
for i in 0:n
string_site = [j <= (L -i) ? ("D+", vertex(bit_map, dimension, j)) : ("D-", vertex(bit_map, dimension, j)) for j in 1:L]
add!(ttn_op, 1.0, (string_site...)...)
end
string_site = [
if j <= (L - n)
("D+", vertex(bit_map, dimension, j))
else
("I", vertex(bit_map, dimension, j))
end for j in 1:L
]
add!(ttn_op, 1.0, (string_site...)...)
end

return ttn(ttn_op, s; algorithm="svd")
return ttn_op
end

function no_shift_ttn(s::IndsNetwork)
function no_shift_opsum(s::IndsNetwork)
ttn_op = OpSum()
string_site_full = [("I", v) for v in vertices(s)]
add!(ttn_op, 1.0, (string_site_full...)...)
return ttn(ttn_op, s; algorithm="svd")
return ttn_op
end

function stencil(
Expand All @@ -129,25 +151,28 @@ function stencil(
left_boundary=default_boundary(),
right_boundary=default_boundary(),
scale=true,
truncate_kwargs...,
truncate_op=true,
kwargs...,
)
@assert length(shifts) == 5
stencil_op = shifts[3] * no_shift_ttn(s)
for i in [1,2]
n = i == 1 ? 1 : 0
if !iszero(i)
stencil_op += shifts[i] * plus_shift_ttn(s, bit_map; dimension, boundary=right_boundary, n)
stencil_opsum = shifts[3] * no_shift_opsum(s)
for i in [1, 2]
n = i == 1 ? 1 : 0
if !iszero(shifts[i])
stencil_opsum +=
shifts[i] * plus_shift_opsum(s, bit_map; dimension, boundary=right_boundary, n)
end
end

for i in [4,5]
n = i == 5 ? 1 : 0
if !iszero(i)
stencil_op += shifts[i] * minus_shift_ttn(s, bit_map; dimension, boundary=left_boundary, n)
for i in [4, 5]
n = i == 5 ? 1 : 0
if !iszero(shifts[i])
stencil_opsum +=
shifts[i] * minus_shift_opsum(s, bit_map; dimension, boundary=left_boundary, n)
end
end

stencil_op = truncate(stencil_op; truncate_kwargs...)
stencil_op = ttn(stencil_opsum, s; algorithm="svd", kwargs...)

if scale
for v in vertices(bit_map, dimension)
Expand All @@ -167,7 +192,7 @@ function second_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
end

function third_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
return stencil(s, bit_map, [-0.5, 1.0, 0.0, -1.0, 0.5], 3; kwargs...)
return stencil(s, bit_map, [0.5, -1.0, 0.0, 1.0, -0.5], 3; kwargs...)
end

function fourth_derivative_operator(s::IndsNetwork, bit_map; kwargs...)
Expand Down Expand Up @@ -227,12 +252,9 @@ end

Base.:*(fs::ITensorNetworkFunction...) = multiply(fs...)

function operate(
operator::TreeTensorNetwork, ψ::ITensorNetworkFunction; truncate_kwargs=(;), kwargs...
)
function operate(operator::TreeTensorNetwork, ψ::ITensorNetworkFunction; kwargs...)
ψ_tn = ttn(itensornetwork(ψ))
ψO_tn = noprime(contract(operator, ψ_tn; init=prime(copy(ψ_tn)), kwargs...))
ψO_tn = truncate(ψO_tn; truncate_kwargs...)

return ITensorNetworkFunction(ITensorNetwork(ψO_tn), bit_map(ψ))
end
Expand Down
Loading

0 comments on commit ac3a650

Please sign in to comment.