Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base autograph transformer and if/else support #6406

Open
wants to merge 47 commits into
base: master
Choose a base branch
from

Conversation

lillian542
Copy link
Contributor

@lillian542 lillian542 commented Oct 16, 2024

First of 3 PRs to add support for python control flow with capture enabled by mapping the pythonic control-flow to JAX-compatible implementations in PennyLane, using autograph. The follow up PRs will add support for while loops and for loops.

This is a modification for PennyLane of the autograph module implemented in catalyst.

Context:
Without autograph, calling this circuit:

import pennylane as qml

qml.capture.enable()

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(x):
    if x > 1.967:
        qml.Hadamard(2)
    else:
        qml.Y(1)
    return qml.state()

raises a TracerBoolConversionError error because the if/else logic isn't JAX-compatible.

Description of the Change:

  • The initial infrastructure for using autograph to convert pythonic control flow is added
  • The primitive for converting if/else to @qml.cond for JAX compatibility is added
  • The malt package is added as a dependency

Benefits:
For the circuit above, we can run

>>> circuit = run_autograph(circuit)
>>> circuit(2.3)

and now the circuit can execute.

[sc-71821]

@lillian542 lillian542 marked this pull request as ready for review October 16, 2024 20:12
Copy link

codecov bot commented Oct 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.39%. Comparing base (661dd52) to head (d5eaa8d).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff            @@
##           master    #6406    +/-   ##
========================================
  Coverage   99.38%   99.39%            
========================================
  Files         452      455     +3     
  Lines       42789    42914   +125     
========================================
+ Hits        42527    42653   +126     
+ Misses        262      261     -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +157 to +158
def passthrough_wrapper(*args, **kwargs):
return converted_call(wrapped_fn, args, kwargs, caller_fn_scope, options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def passthrough_wrapper(*args, **kwargs):
return converted_call(wrapped_fn, args, kwargs, caller_fn_scope, options)
def passthrough_wrapper(*inner_args, **inner_kwargs):
return converted_call(wrapped_fn, inner_args, inner_kwargs, caller_fn_scope, options)

Or some other sort of varation on args and kwargs, since those name are already taken in the scope.

assert args and callable(args[0])
wrapped_fn = args[0]

def passthrough_wrapper(*args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def passthrough_wrapper(*args, **kwargs):
@functools.wraps(wrapped_fn)
def passthrough_wrapper(*args, **kwargs):

Should we use a wrapper here too?

from malt.impl.api import PyToPy

import pennylane as qml
from pennylane.capture.autograph.ag_primitives import AutoGraphError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relative import not work?

pass
elif callable(obj):
# pylint: disable=unnecessary-lambda,unnecessary-lambda-assignment
fn = lambda *args, **kwargs: obj(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this assignment?

@@ -0,0 +1,178 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Copyright 2024 Xanadu Quantum Technologies Inc.

@@ -0,0 +1,227 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Copyright 2024 Xanadu Quantum Technologies Inc.

@@ -0,0 +1,308 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Copyright 2024 Xanadu Quantum Technologies Inc.

@@ -0,0 +1,478 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Copyright 2024 Xanadu Quantum Technologies Inc.

@@ -147,6 +147,7 @@ class MyCustomOp(qml.operation.Operator):
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
from pennylane.capture import autograph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this done so that we can import qml.capture.autograph? I feel like that shouldn't be necessary right?

Comment on lines +157 to +190
**Example**

.. code-block:: python

def decide(x):
if x < 5:
y = 15
else:
y = 1
return y

@qjit(autograph=True)
def func(x: int):
y = decide(x)
return y ** 2

>>> print(autograph_source(decide))
def decide_1(x):
with ag__.FunctionScope('decide', 'fscope', ag__.STD) as fscope:
def get_state():
return (y,)
def set_state(vars_):
nonlocal y
(y,) = vars_
def if_body():
nonlocal y
y = 15
def else_body():
nonlocal y
y = 1
y = ag__.Undefined('y')
ag__.if_stmt(x < 5, if_body, else_body, get_state, set_state, ('y',), 1)
return y
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might make sense to replace this with an example that doesn't use qjit.

Comment on lines +89 to +91
self._cache.has(fn, TOPLEVEL_OPTIONS)
or self._cache.has(fn, NESTED_OPTIONS)
or self._cache.has(fn, STANDARD_OPTIONS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are TOP_LEVEL_OPTIONS, etc, and where do they come from? It doesn't look like they are imported up top 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 210 below

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants