-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding QJIT tests for all supported PL templates #1161
base: main
Are you sure you want to change the base?
Conversation
Hello. You may have forgotten to update the changelog!
|
@@ -34,7 +57,11 @@ def amplitude_embedding(f: jax.core.ShapedArray([4], float)): | |||
params = jax.numpy.array([1 / 2] * 4) | |||
interpreted_fn = qml.QNode(amplitude_embedding, device) | |||
jitted_fn = qjit(interpreted_fn) | |||
assert np.allclose(interpreted_fn(params), jitted_fn(params)) | |||
|
|||
interpreted_result = interpreted_fn(params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you are putting interpreted_fn(params)
and jitted_fn(params)
into their own variables and then comparing them? Is the current version np.allclose(interpreted_fn(params), jitted_fn(params))
failing?
@@ -49,7 +76,11 @@ def angle_embedding(f: jax.core.ShapedArray([3], int)): | |||
params = jnp.array([1, 2, 3]) | |||
interpreted_fn = qml.QNode(angle_embedding, device) | |||
jitted_fn = qjit(interpreted_fn) | |||
assert np.allclose(interpreted_fn(params), jitted_fn(params)) | |||
|
|||
interpreted_result = interpreted_fn(params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@@ -63,7 +94,53 @@ def basis_embedding(f: jax.core.ShapedArray([3], int)): | |||
params = jax.numpy.array([1, 1, 1]) | |||
interpreted_fn = qml.QNode(basis_embedding, device) | |||
jitted_fn = qjit(interpreted_fn) | |||
assert np.allclose(interpreted_fn(params), jitted_fn(params)) | |||
|
|||
interpreted_result = interpreted_fn(params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and here (and all other places)
assert np.allclose(interpreted_result, jitted_result) | ||
|
||
|
||
@pytest.mark.xfail(reason="Displacement operator not supported on lightning.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think quantum optics is no longer even in pennylane ecosystem's scope anymore, so no need to add tests for these.
For context, there was a software package called StrawberryFields whose purpose was quantum optics simulations, so PennyLane has them as a historical relic; however quantum optics was never officially in pennylane's scope, which means native pennylane software packages (like catalyst and lightning) are not, nor have plans to become, aware of these quantum optical operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anything that inherits from CVOperation can be ignored --- the only device these are supported on is default.gaussian
and it is likely not even tested anymore, and definitely not with JIT compilation.
assert np.allclose(interpreted_result, jitted_result) | ||
|
||
|
||
@pytest.mark.xfail(reason="Beamsplitter is not supported by lightning.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto (quantum optics operators include displacements, squeezes, and beamsplitters)
charge = 1 | ||
|
||
H, qubits = qml.qchem.molecular_hamiltonian(symbols, geometry, charge=charge) | ||
def uccsd(weights): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a chemistry expert so maybe someone else can chime in or approve, but just want to say that it seems like this test's functionality is being changed here.
|
||
|
||
# Hilbert Schmidt templates take a quantum tape as a parameter. | ||
# Therefore unsuitable for JIT compilation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put xfail reasons in the pytest.xfail("message")
itself, instead of free-floating comments ; )
assert np.allclose(interpreted_result, jitted_result) | ||
|
||
|
||
@pytest.mark.xfail(reason="Squeezing operator not supported on lightning.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto (no need to add tests for quantum optics)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the tests 💯 ! Some comments from me.
Glad to see plenty of templates work with qjit out of the box
assert np.allclose(interpreted_result, jitted_result) | ||
|
||
|
||
def test_cosine_window(backend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember seeing you also had a cosine window PR. Since this test is in here, does that PR still need to be merged/reviewed? Or did that PR have some additional purpose as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cosine window PR contains a fix necessary to make this test pass.
from scipy.stats import norm | ||
|
||
from catalyst import for_loop, qjit | ||
|
||
|
||
def test_adder(backend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can ask pylint to disable too-many-lines, see for example
catalyst/frontend/catalyst/jax_tracer.py
Line 102 in 3fac9c2
# pylint: disable=too-many-lines |
Seems like some of these tests fail in CI and should be marked xfail as well |
This PR adds QJIT compatibility tests for all supported PL templates. Some tests currently fail, and rely on fixes from the following PennyLane PRs: #6305, #6306, #6307. This PR addresses [sc-72625].