Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding QJIT tests for all supported PL templates #1161

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Conversation

willjmax
Copy link

This PR adds QJIT compatibility tests for all supported PL templates. Some tests currently fail, and rely on fixes from the following PennyLane PRs: #6305, #6306, #6307. This PR addresses [sc-72625].

Copy link
Contributor

github-actions bot commented Oct 1, 2024

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@willjmax willjmax requested review from soranjh and paul0403 and removed request for soranjh October 18, 2024 17:33
@@ -34,7 +57,11 @@ def amplitude_embedding(f: jax.core.ShapedArray([4], float)):
params = jax.numpy.array([1 / 2] * 4)
interpreted_fn = qml.QNode(amplitude_embedding, device)
jitted_fn = qjit(interpreted_fn)
assert np.allclose(interpreted_fn(params), jitted_fn(params))

interpreted_result = interpreted_fn(params)
Copy link
Contributor

@paul0403 paul0403 Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you are putting interpreted_fn(params) and jitted_fn(params) into their own variables and then comparing them? Is the current version np.allclose(interpreted_fn(params), jitted_fn(params)) failing?

@@ -49,7 +76,11 @@ def angle_embedding(f: jax.core.ShapedArray([3], int)):
params = jnp.array([1, 2, 3])
interpreted_fn = qml.QNode(angle_embedding, device)
jitted_fn = qjit(interpreted_fn)
assert np.allclose(interpreted_fn(params), jitted_fn(params))

interpreted_result = interpreted_fn(params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -63,7 +94,53 @@ def basis_embedding(f: jax.core.ShapedArray([3], int)):
params = jax.numpy.array([1, 1, 1])
interpreted_fn = qml.QNode(basis_embedding, device)
jitted_fn = qjit(interpreted_fn)
assert np.allclose(interpreted_fn(params), jitted_fn(params))

interpreted_result = interpreted_fn(params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and here (and all other places)

assert np.allclose(interpreted_result, jitted_result)


@pytest.mark.xfail(reason="Displacement operator not supported on lightning.")
Copy link
Contributor

@paul0403 paul0403 Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think quantum optics is no longer even in pennylane ecosystem's scope anymore, so no need to add tests for these.

For context, there was a software package called StrawberryFields whose purpose was quantum optics simulations, so PennyLane has them as a historical relic; however quantum optics was never officially in pennylane's scope, which means native pennylane software packages (like catalyst and lightning) are not, nor have plans to become, aware of these quantum optical operations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything that inherits from CVOperation can be ignored --- the only device these are supported on is default.gaussian and it is likely not even tested anymore, and definitely not with JIT compilation.

assert np.allclose(interpreted_result, jitted_result)


@pytest.mark.xfail(reason="Beamsplitter is not supported by lightning.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto (quantum optics operators include displacements, squeezes, and beamsplitters)

charge = 1

H, qubits = qml.qchem.molecular_hamiltonian(symbols, geometry, charge=charge)
def uccsd(weights):
Copy link
Contributor

@paul0403 paul0403 Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a chemistry expert so maybe someone else can chime in or approve, but just want to say that it seems like this test's functionality is being changed here.



# Hilbert Schmidt templates take a quantum tape as a parameter.
# Therefore unsuitable for JIT compilation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put xfail reasons in the pytest.xfail("message") itself, instead of free-floating comments ; )

assert np.allclose(interpreted_result, jitted_result)


@pytest.mark.xfail(reason="Squeezing operator not supported on lightning.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto (no need to add tests for quantum optics)

Copy link
Contributor

@paul0403 paul0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the tests 💯 ! Some comments from me.

Glad to see plenty of templates work with qjit out of the box

assert np.allclose(interpreted_result, jitted_result)


def test_cosine_window(backend):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember seeing you also had a cosine window PR. Since this test is in here, does that PR still need to be merged/reviewed? Or did that PR have some additional purpose as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cosine window PR contains a fix necessary to make this test pass.

from scipy.stats import norm

from catalyst import for_loop, qjit


def test_adder(backend):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ask pylint to disable too-many-lines, see for example

# pylint: disable=too-many-lines

@paul0403
Copy link
Contributor

Seems like some of these tests fail in CI and should be marked xfail as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants