Skip to content

Commit

Permalink
Merge branch 'main' into raultorres/remove_kokkos_plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc authored Jul 30, 2024
2 parents a0a9eec + 0055f9e commit 260ce32
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 40 deletions.
31 changes: 25 additions & 6 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@

```

* Exponential extrapolation is now a supported method of extrapolation when using `mitigate_with_zne`.
[(#953)](https://github.com/PennyLaneAI/catalyst/pull/953)

This new functionality fits the data from noise-scaled circuits with an exponential function,
and returns the zero-noise value. This functionality is available through the pennylane module
as follows
```py
from pennylane.transforms import exponential_extrapolate

catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=exponential_extrapolate
)
```

<h3>Improvements</h3>

* Catalyst is now compatible with Enzyme `v0.0.130`
Expand Down Expand Up @@ -182,6 +196,12 @@
* Support for TOML files in Schema 1 has been disabled.
[(#960)](https://github.com/PennyLaneAI/catalyst/pull/960)

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
function is valid. Keyword arguments can be passed to this function using the
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

<h3>Bug fixes</h3>

* Static arguments can now be passed through a QNode when specified
Expand Down Expand Up @@ -248,6 +268,10 @@

<h3>Internal changes</h3>

* llvm O2 and Enzyme passes are only run when needed (gradients presents). Async execution of QNodes triggers now triggers a
Coroutine lowering pass.
[(#968)](https://github.com/PennyLaneAI/catalyst/pull/968)

* The function `inactive_callback` was renamed `__catalyst_inactive_callback`.
[(#899)](https://github.com/PennyLaneAI/catalyst/pull/899)

Expand Down Expand Up @@ -284,6 +308,7 @@ Mehrdad Malekmohammadi,
Romain Moyard,
Erick Ochoa,
Mudit Pandey,
nate stemen,
Raul Torres,
Tzung-Han Juang,
Paul Haochen Wang,
Expand Down Expand Up @@ -799,12 +824,6 @@ Paul Haochen Wang,

<h3>Breaking changes</h3>

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
function is valid. Keyword arguments can be passed to this function using the
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
As a result, Catalyst will only be compatible on systems with `glibc` versions `2.28` and above
(e.g., Ubuntu 20.04 and above).
Expand Down
1 change: 1 addition & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
str(workspace),
module_name,
keep_intermediate=self.options.keep_intermediate,
async_qnodes=self.options.async_qnodes,
verbose=self.options.verbose,
pipelines=self.options.get_pipelines(),
lower_to_llvm=lower_to_llvm,
Expand Down
74 changes: 59 additions & 15 deletions frontend/test/pytest/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,26 @@
import numpy as np
import pennylane as qml
import pytest
from pennylane.transforms import exponential_extrapolate

import catalyst
from catalyst.api_extensions.error_mitigation import polynomial_extrapolation

quadratic_extrapolation = polynomial_extrapolation(2)


def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func):
"""skip test if exponential extrapolation will be unstable"""
if circuit_param < 0.3 and extrapolation_func == exponential_extrapolate:
pytest.skip("Exponential extrapolation unstable in this region.")


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_single_measurement(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_single_measurement(params, extrapolation):
"""Test that without noise the same results are returned for single measurements."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -42,15 +52,18 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_multiple_measurements(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_multiple_measurements(params, extrapolation):
"""Test that without noise the same results are returned for multiple measurements"""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -65,7 +78,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))
Expand Down Expand Up @@ -121,7 +134,8 @@ def mitigated_function(args):
mitigated_function(0.1)


def test_dtype_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_dtype_error(extrapolation):
"""Test that an error is raised when multiple results do not have the same dtype."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -137,7 +151,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand All @@ -146,7 +160,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)


def test_dtype_not_float_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_dtype_not_float_error(extrapolation):
"""Test that an error is raised when results are not float."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -162,7 +177,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand All @@ -171,7 +186,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)


def test_shape_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_shape_error(extrapolation):
"""Test that an error is raised when results have shape."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -187,7 +203,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand Down Expand Up @@ -229,8 +245,11 @@ def mitigated_qnode():


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_zne_usage_patterns(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_zne_usage_patterns(params, extrapolation):
"""Test usage patterns of catalyst.zne."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -245,13 +264,13 @@ def fn(x):
@catalyst.qjit
def mitigated_qnode_fn_as_argument(args):
return catalyst.mitigate_with_zne(
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

@catalyst.qjit
def mitigated_qnode_partial(args):
return catalyst.mitigate_with_zne(
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(fn)(args)

assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params))
Expand All @@ -271,13 +290,13 @@ def circuit():
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

def jax_extrap(scale_factors, results):
def jax_extrapolation(scale_factors, results):
return jax.numpy.polyfit(scale_factors, results, 2)[-1]

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrap
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrapolation
)()

assert np.allclose(mitigated_qnode(), circuit())
Expand Down Expand Up @@ -308,5 +327,30 @@ def mitigated_qnode():
assert np.allclose(mitigated_qnode(), circuit())


def test_exponential_extrapolation_with_kwargs():
"""test mitigate_with_zne with keyword arguments for exponential extrapolation function"""
dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
def circuit():
qml.Hadamard(wires=0)
qml.RZ(0.1, wires=0)
qml.RZ(0.2, wires=0)
qml.CNOT(wires=[1, 0])
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit,
scale_factors=jax.numpy.array([1, 2, 3]),
extrapolate=qml.transforms.exponential_extrapolate,
extrapolate_kwargs={"asymptote": 3},
)()

assert np.allclose(mitigated_qnode(), circuit())


if __name__ == "__main__":
pytest.main(["-x", __file__])
2 changes: 2 additions & 0 deletions mlir/include/Driver/CompilerDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ struct CompilerOptions {
llvm::raw_ostream &diagnosticStream;
/// If true, the driver will output the module at intermediate points.
bool keepIntermediate;
/// If true, the llvm.coroutine will be lowered.
bool asyncQnodes;
/// Sets the verbosity level to use when printing messages.
Verbosity verbosity;
/// Ordered list of named pipelines to execute, each pipeline is described by a list of MLIR
Expand Down
Loading

0 comments on commit 260ce32

Please sign in to comment.