Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

SVD on jax backend and thus split_node cannot be jitted when max_truncation_err is set #953

Open
refraction-ray opened this issue Nov 18, 2021 · 0 comments

Comments

@refraction-ray
Copy link
Contributor

SVD and split_node are ok on tensorflow backend with tensorflow jit:

import tensorflow as tf
tn.set_default_backend("tensorflow")
@tf.function
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(tf.ones([2,2,2,2]))

But it fails on jax backend as:

import jax
from jax import numpy as jnp
tn.set_default_backend("jax")
@jax.jit
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(jnp.ones([2,2,2,2]))

The error is raised from svd operation in backends/numpy/decompositions.py: num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) as ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:.

This error is actually as expected even before I tried this, since jax jitted function only accepts and returns tensors with fixed shape, which supports only a subset of functionalities of tf.function. Since split_node with max_truncation_err returns nodes of varying shape (final shape depends on the singular value), it seems to be incompatible with jax jit mechanism.

Any thoughts or workaround on this? As I believe it is very common to apply split_node with max_singular_values in tensornetwork related algorithms and it would be great such algorithms can be jitted.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant