Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Executing a qjit'd QNode with abstracted_axes twice with the same arguments fails #1231

Open
isaacdevlugt opened this issue Oct 25, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@isaacdevlugt
Copy link
Contributor

isaacdevlugt commented Oct 25, 2024

Issue description

Executing a qjit'd QNode with abstracted_axes twice with the same arguments fails on the second execution.

  • Expected behavior: It works both times

  • Actual behavior: It works the first time, fails the second time (even with newly shaped arguments).

  • Reproduces how often: 100%

  • System information: (post the output of import pennylane as qml; qml.about())

Name: PennyLane
Version: 0.38.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.7-arm64-arm-64bit
Python version:          3.11.9
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.39.0.dev39)
- default.gaussian (PennyLane-0.39.0.dev39)
- default.mixed (PennyLane-0.39.0.dev39)
- default.qubit (PennyLane-0.39.0.dev39)
- default.qutrit (PennyLane-0.39.0.dev39)
- default.qutrit.mixed (PennyLane-0.39.0.dev39)
- default.tensor (PennyLane-0.39.0.dev39)
- null.qubit (PennyLane-0.39.0.dev39)
- reference.qubit (PennyLane-0.39.0.dev39)
- nvidia.custatevec (PennyLane-Catalyst-0.8.1)
- nvidia.cutensornet (PennyLane-Catalyst-0.8.1)
- oqc.cloud (PennyLane-Catalyst-0.8.1)
- softwareq.qpp (PennyLane-Catalyst-0.8.1)
- lightning.qubit (PennyLane_Lightning-0.38.0)

Source code and tracebacks

@qml.qjit(abstracted_axes=(('n',), ()))
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x1, x2):

    @qml.for_loop(0, jnp.shape(x1)[0], 1)
    def loop_block(i):
        qml.Hadamard(0)
        qml.RX(x1[i], 0)
        qml.CNOT(wires=[0, 1])
        qml.Hadamard(1)
    
    loop_block()
    qml.RY(x2, 1)
    return qml.expval(qml.Z(1))

x1 = jnp.array([0.1, 0.2, 0.3])
x2 = 0.1967

print(circuit(x1, x2)) # works
print(circuit(x1, x2)) # fails
0.9611678060162306
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 30
     27 x2 = 0.1967
     29 print(circuit(x1, x2)) # works
---> 30 print(circuit(x1, x2)) # fails

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:457, in QJIT.__call__(self, *args, **kwargs)
    453         kwargs = {"static_argnums": self.compile_options.static_argnums, **kwargs}
    455     return self.user_function(*args, **kwargs)
--> 457 requires_promotion = self.jit_compile(args, **kwargs)
    459 # If we receive tracers as input, dispatch to the JAX integration.
    460 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:508, in QJIT.jit_compile(self, args, **kwargs)
    496 @debug_logger
    497 def jit_compile(self, args, **kwargs):
    498     """Compile Python function on invocation using the provided arguments.
    499 
    500     Args:
   (...)
    505               function
    506     """
--> 508     cached_fn, requires_promotion = self.fn_cache.lookup(args)
    510     if cached_fn is None:
    511         if self.user_sig and not self.compile_options.static_argnums:

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/compiled_functions.py:451, in CompilationCache.lookup(self, args)
    440 def lookup(self, args):
    441     """Get a function (if present) that matches the provided argument signature. Also computes
    442     whether promotion is necessary.
    443 
   (...)
    449         bool: whether the matched entry requires argument promotion
    450     """
--> 451     action, key = self.get_function_status_and_key(args)
    453     if action == TypeCompatibility.NEEDS_COMPILATION:
    454         return None, None

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/compiled_functions.py:437, in CompilationCache.get_function_status_and_key(self, args)
    435 entry = self.cache[key]
    436 runtime_signature = tree_unflatten(treedef, flat_runtime_sig)
--> 437 action = typecheck_signatures(entry.signature, runtime_signature, self.abstracted_axes)
    438 return action, key

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/tracing/type_signatures.py:238, in typecheck_signatures(compiled_signature, runtime_signature, abstracted_axes)
    230 # We first check signature equality considering dynamic axes, allowing the shape of an array
    231 # to be different if it was compiled with a dynamical shape.
    232 # TODO: unify this with the promotion checks, allowing the dtype to change for a dynamic axis
    233 with Patcher(
    234     # pylint: disable=protected-access
    235     (jax._src.interpreters.partial_eval, "get_aval", get_aval2),
    236 ):
    237     # TODO: do away with private jax functions
--> 238     axes_specs_compile = _flat_axes_specs(abstracted_axes, *compiled_signature, {})
    239     axes_specs_runtime = _flat_axes_specs(abstracted_axes, *runtime_signature, {})
    240     in_type_compiled = infer_lambda_input_type(axes_specs_compile, flat_compiled_sig)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/pjit.py:672, in _flat_axes_specs(abstracted_axes, *args, **kwargs)
    669 def ax_leaf(l):
    670   return (isinstance(l, dict) and all_leaves(l.values()) or
    671           isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
--> 672 return broadcast_prefix(abstracted_axes, args, ax_leaf)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/tree_util.py:582, in broadcast_prefix(prefix_tree, full_tree, is_leaf)
    580 num_leaves = lambda t: tree_structure(t).num_leaves
    581 add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
--> 582 tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
    583 return result

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/tree_util.py:319, in tree_map(f, tree, is_leaf, *rest)
    282 """Maps a multi-input function over pytree args to produce a new pytree.
    283 
    284 Args:
   (...)
    316   - :func:`jax.tree.reduce`
    317 """
    318 leaves, treedef = tree_flatten(tree, is_leaf)
--> 319 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    320 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/tree_util.py:319, in <listcomp>(.0)
    282 """Maps a multi-input function over pytree args to produce a new pytree.
    283 
    284 Args:
   (...)
    316   - :func:`jax.tree.reduce`
    317 """
    318 leaves, treedef = tree_flatten(tree, is_leaf)
--> 319 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    320 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Tuple arity mismatch: 3 != 2; tuple: (ShapedArray(float64[3]), ShapedArray(float64[], weak_type=True), {}).

Additional information

Any additional information, configuration or data that might be necessary
to reproduce the issue.

@isaacdevlugt isaacdevlugt added the bug Something isn't working label Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant