-
Notifications
You must be signed in to change notification settings - Fork 598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add base autograph transformer and if/else support #6406
base: master
Are you sure you want to change the base?
Conversation
…e into autograph_ctrl_flow
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6406 +/- ##
========================================
Coverage 99.38% 99.39%
========================================
Files 452 455 +3
Lines 42789 42914 +125
========================================
+ Hits 42527 42653 +126
+ Misses 262 261 -1 ☔ View full report in Codecov by Sentry. |
Co-authored-by: Christina Lee <[email protected]>
def passthrough_wrapper(*args, **kwargs): | ||
return converted_call(wrapped_fn, args, kwargs, caller_fn_scope, options) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def passthrough_wrapper(*args, **kwargs): | |
return converted_call(wrapped_fn, args, kwargs, caller_fn_scope, options) | |
def passthrough_wrapper(*inner_args, **inner_kwargs): | |
return converted_call(wrapped_fn, inner_args, inner_kwargs, caller_fn_scope, options) |
Or some other sort of varation on args
and kwargs
, since those name are already taken in the scope.
assert args and callable(args[0]) | ||
wrapped_fn = args[0] | ||
|
||
def passthrough_wrapper(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def passthrough_wrapper(*args, **kwargs): | |
@functools.wraps(wrapped_fn) | |
def passthrough_wrapper(*args, **kwargs): |
Should we use a wrapper here too?
from malt.impl.api import PyToPy | ||
|
||
import pennylane as qml | ||
from pennylane.capture.autograph.ag_primitives import AutoGraphError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relative import not work?
pass | ||
elif callable(obj): | ||
# pylint: disable=unnecessary-lambda,unnecessary-lambda-assignment | ||
fn = lambda *args, **kwargs: obj(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this assignment?
@@ -0,0 +1,178 @@ | |||
# Copyright 2023 Xanadu Quantum Technologies Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023 Xanadu Quantum Technologies Inc. | |
# Copyright 2024 Xanadu Quantum Technologies Inc. |
@@ -0,0 +1,227 @@ | |||
# Copyright 2023 Xanadu Quantum Technologies Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023 Xanadu Quantum Technologies Inc. | |
# Copyright 2024 Xanadu Quantum Technologies Inc. |
@@ -0,0 +1,308 @@ | |||
# Copyright 2023 Xanadu Quantum Technologies Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023 Xanadu Quantum Technologies Inc. | |
# Copyright 2024 Xanadu Quantum Technologies Inc. |
@@ -0,0 +1,478 @@ | |||
# Copyright 2023 Xanadu Quantum Technologies Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023 Xanadu Quantum Technologies Inc. | |
# Copyright 2024 Xanadu Quantum Technologies Inc. |
@@ -147,6 +147,7 @@ class MyCustomOp(qml.operation.Operator): | |||
def _(*args, **kwargs): | |||
return type.__call__(MyCustomOp, *args, **kwargs) | |||
""" | |||
from pennylane.capture import autograph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this done so that we can import qml.capture.autograph
? I feel like that shouldn't be necessary right?
**Example** | ||
|
||
.. code-block:: python | ||
|
||
def decide(x): | ||
if x < 5: | ||
y = 15 | ||
else: | ||
y = 1 | ||
return y | ||
|
||
@qjit(autograph=True) | ||
def func(x: int): | ||
y = decide(x) | ||
return y ** 2 | ||
|
||
>>> print(autograph_source(decide)) | ||
def decide_1(x): | ||
with ag__.FunctionScope('decide', 'fscope', ag__.STD) as fscope: | ||
def get_state(): | ||
return (y,) | ||
def set_state(vars_): | ||
nonlocal y | ||
(y,) = vars_ | ||
def if_body(): | ||
nonlocal y | ||
y = 15 | ||
def else_body(): | ||
nonlocal y | ||
y = 1 | ||
y = ag__.Undefined('y') | ||
ag__.if_stmt(x < 5, if_body, else_body, get_state, set_state, ('y',), 1) | ||
return y | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might make sense to replace this with an example that doesn't use qjit
.
self._cache.has(fn, TOPLEVEL_OPTIONS) | ||
or self._cache.has(fn, NESTED_OPTIONS) | ||
or self._cache.has(fn, STANDARD_OPTIONS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are TOP_LEVEL_OPTIONS
, etc, and where do they come from? It doesn't look like they are imported up top 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 210 below
First of 3 PRs to add support for python control flow with capture enabled by mapping the pythonic control-flow to JAX-compatible implementations in PennyLane, using
autograph
. The follow up PRs will add support forwhile
loops andfor
loops.This is a modification for PennyLane of the
autograph
module implemented in catalyst.Context:
Without autograph, calling this circuit:
raises a
TracerBoolConversionError
error because theif
/else
logic isn't JAX-compatible.Description of the Change:
autograph
to convert pythonic control flow is addedif
/else
to@qml.cond
for JAX compatibility is addedmalt
package is added as a dependencyBenefits:
For the circuit above, we can run
and now the circuit can execute.
[sc-71821]