diff --git a/doc/changelog.md b/doc/changelog.md
index 77244a7170..1c64ef6cd4 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -114,6 +114,20 @@
```
+* Exponential extrapolation is now a supported method of extrapolation when using `mitigate_with_zne`.
+ [(#953)](https://github.com/PennyLaneAI/catalyst/pull/953)
+
+ This new functionality fits the data from noise-scaled circuits with an exponential function,
+ and returns the zero-noise value. This functionality is available through the pennylane module
+ as follows
+ ```py
+ from pennylane.transforms import exponential_extrapolate
+
+ catalyst.mitigate_with_zne(
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=exponential_extrapolate
+ )
+ ```
+
Improvements
* Catalyst is now compatible with Enzyme `v0.0.130`
@@ -182,6 +196,12 @@
* Support for TOML files in Schema 1 has been disabled.
[(#960)](https://github.com/PennyLaneAI/catalyst/pull/960)
+* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
+ and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
+ function is valid. Keyword arguments can be passed to this function using the
+ `extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
+ [(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)
+
Bug fixes
* Static arguments can now be passed through a QNode when specified
@@ -248,6 +268,10 @@
Internal changes
+* llvm O2 and Enzyme passes are only run when needed (gradients presents). Async execution of QNodes triggers now triggers a
+ Coroutine lowering pass.
+ [(#968)](https://github.com/PennyLaneAI/catalyst/pull/968)
+
* The function `inactive_callback` was renamed `__catalyst_inactive_callback`.
[(#899)](https://github.com/PennyLaneAI/catalyst/pull/899)
@@ -284,6 +308,7 @@ Mehrdad Malekmohammadi,
Romain Moyard,
Erick Ochoa,
Mudit Pandey,
+nate stemen,
Raul Torres,
Tzung-Han Juang,
Paul Haochen Wang,
@@ -799,12 +824,6 @@ Paul Haochen Wang,
Breaking changes
-* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
- and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
- function is valid. Keyword arguments can be passed to this function using the
- `extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
- [(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)
-
* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
As a result, Catalyst will only be compatible on systems with `glibc` versions `2.28` and above
(e.g., Ubuntu 20.04 and above).
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index ce6361aee6..8dc24f34dc 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -531,6 +531,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
str(workspace),
module_name,
keep_intermediate=self.options.keep_intermediate,
+ async_qnodes=self.options.async_qnodes,
verbose=self.options.verbose,
pipelines=self.options.get_pipelines(),
lower_to_llvm=lower_to_llvm,
diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py
index 97ec785525..ec0dfda7a3 100644
--- a/frontend/test/pytest/test_mitigation.py
+++ b/frontend/test/pytest/test_mitigation.py
@@ -18,6 +18,7 @@
import numpy as np
import pennylane as qml
import pytest
+from pennylane.transforms import exponential_extrapolate
import catalyst
from catalyst.api_extensions.error_mitigation import polynomial_extrapolation
@@ -25,9 +26,18 @@
quadratic_extrapolation = polynomial_extrapolation(2)
+def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func):
+ """skip test if exponential extrapolation will be unstable"""
+ if circuit_param < 0.3 and extrapolation_func == exponential_extrapolate:
+ pytest.skip("Exponential extrapolation unstable in this region.")
+
+
@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
-def test_single_measurement(params):
+@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
+def test_single_measurement(params, extrapolation):
"""Test that without noise the same results are returned for single measurements."""
+ skip_if_exponential_extrapolation_unstable(params, extrapolation)
+
dev = qml.device("lightning.qubit", wires=2)
@qml.qnode(device=dev)
@@ -42,15 +52,18 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
- circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)
assert np.allclose(mitigated_qnode(params), circuit(params))
@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
-def test_multiple_measurements(params):
+@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
+def test_multiple_measurements(params, extrapolation):
"""Test that without noise the same results are returned for multiple measurements"""
+ skip_if_exponential_extrapolation_unstable(params, extrapolation)
+
dev = qml.device("lightning.qubit", wires=2)
@qml.qnode(device=dev)
@@ -65,7 +78,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
- circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)
assert np.allclose(mitigated_qnode(params), circuit(params))
@@ -121,7 +134,8 @@ def mitigated_function(args):
mitigated_function(0.1)
-def test_dtype_error():
+@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
+def test_dtype_error(extrapolation):
"""Test that an error is raised when multiple results do not have the same dtype."""
dev = qml.device("lightning.qubit", wires=2)
@@ -137,7 +151,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
- circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)
with pytest.raises(
@@ -146,7 +160,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)
-def test_dtype_not_float_error():
+@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
+def test_dtype_not_float_error(extrapolation):
"""Test that an error is raised when results are not float."""
dev = qml.device("lightning.qubit", wires=2)
@@ -162,7 +177,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
- circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)
with pytest.raises(
@@ -171,7 +186,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)
-def test_shape_error():
+@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
+def test_shape_error(extrapolation):
"""Test that an error is raised when results have shape."""
dev = qml.device("lightning.qubit", wires=2)
@@ -187,7 +203,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
- circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)
with pytest.raises(
@@ -229,8 +245,11 @@ def mitigated_qnode():
@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
-def test_zne_usage_patterns(params):
+@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
+def test_zne_usage_patterns(params, extrapolation):
"""Test usage patterns of catalyst.zne."""
+ skip_if_exponential_extrapolation_unstable(params, extrapolation)
+
dev = qml.device("lightning.qubit", wires=2)
@qml.qnode(device=dev)
@@ -245,13 +264,13 @@ def fn(x):
@catalyst.qjit
def mitigated_qnode_fn_as_argument(args):
return catalyst.mitigate_with_zne(
- fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)
@catalyst.qjit
def mitigated_qnode_partial(args):
return catalyst.mitigate_with_zne(
- scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
+ scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(fn)(args)
assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params))
@@ -271,13 +290,13 @@ def circuit():
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))
- def jax_extrap(scale_factors, results):
+ def jax_extrapolation(scale_factors, results):
return jax.numpy.polyfit(scale_factors, results, 2)[-1]
@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
- circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrap
+ circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrapolation
)()
assert np.allclose(mitigated_qnode(), circuit())
@@ -308,5 +327,30 @@ def mitigated_qnode():
assert np.allclose(mitigated_qnode(), circuit())
+def test_exponential_extrapolation_with_kwargs():
+ """test mitigate_with_zne with keyword arguments for exponential extrapolation function"""
+ dev = qml.device("lightning.qubit", wires=2)
+
+ @qml.qnode(device=dev)
+ def circuit():
+ qml.Hadamard(wires=0)
+ qml.RZ(0.1, wires=0)
+ qml.RZ(0.2, wires=0)
+ qml.CNOT(wires=[1, 0])
+ qml.Hadamard(wires=1)
+ return qml.expval(qml.PauliY(wires=0))
+
+ @catalyst.qjit
+ def mitigated_qnode():
+ return catalyst.mitigate_with_zne(
+ circuit,
+ scale_factors=jax.numpy.array([1, 2, 3]),
+ extrapolate=qml.transforms.exponential_extrapolate,
+ extrapolate_kwargs={"asymptote": 3},
+ )()
+
+ assert np.allclose(mitigated_qnode(), circuit())
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/mlir/include/Driver/CompilerDriver.h b/mlir/include/Driver/CompilerDriver.h
index cdf3602ceb..4724131dfe 100644
--- a/mlir/include/Driver/CompilerDriver.h
+++ b/mlir/include/Driver/CompilerDriver.h
@@ -71,6 +71,8 @@ struct CompilerOptions {
llvm::raw_ostream &diagnosticStream;
/// If true, the driver will output the module at intermediate points.
bool keepIntermediate;
+ /// If true, the llvm.coroutine will be lowered.
+ bool asyncQnodes;
/// Sets the verbosity level to use when printing messages.
Verbosity verbosity;
/// Ordered list of named pipelines to execute, each pipeline is described by a list of MLIR
diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp
index 02c193a622..886a5c6ecb 100644
--- a/mlir/lib/Driver/CompilerDriver.cpp
+++ b/mlir/lib/Driver/CompilerDriver.cpp
@@ -46,6 +46,11 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Host.h"
+#include "llvm/Transforms/Coroutines/CoroCleanup.h"
+#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
+#include "llvm/Transforms/Coroutines/CoroEarly.h"
+#include "llvm/Transforms/Coroutines/CoroSplit.h"
+#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "Catalyst/IR/CatalystDialect.h"
#include "Catalyst/Transforms/Passes.h"
@@ -53,6 +58,7 @@
#include "Driver/CompilerDriver.h"
#include "Driver/Support.h"
#include "Gradient/IR/GradientDialect.h"
+#include "Gradient/IR/GradientInterfaces.h"
#include "Gradient/Transforms/Passes.h"
#include "Mitigation/IR/MitigationDialect.h"
#include "Mitigation/Transforms/Passes.h"
@@ -258,6 +264,17 @@ OwningOpRef parseMLIRSource(MLIRContext *ctx, const llvm::SourceMgr &s
return parseSourceFile(sourceMgr, parserConfig);
}
+/// From the MLIR module it checks if gradients operations are in the program.
+bool containsGradients(mlir::ModuleOp moduleOp)
+{
+ bool contain = false;
+ moduleOp.walk([&](catalyst::gradient::GradientOpInterface op) {
+ contain = true;
+ return WalkResult::interrupt();
+ });
+ return contain;
+}
+
/// Parse an LLVM module given in textual representation. Any parse errors will be output to
/// the provided SMDiagnostic.
std::shared_ptr parseLLVMSource(llvm::LLVMContext &context, StringRef source,
@@ -360,8 +377,49 @@ LogicalResult inferMLIRReturnTypes(MLIRContext *ctx, llvm::Type *returnType,
return failure();
}
-LogicalResult runLLVMPasses(const CompilerOptions &options,
- std::shared_ptr llvmModule, CompilerOutput &output)
+LogicalResult runCoroLLVMPasses(const CompilerOptions &options,
+ std::shared_ptr llvmModule, CompilerOutput &output)
+{
+
+ auto &outputs = output.pipelineOutputs;
+
+ // Create a pass to lower LLVM coroutines (similar to what happens in O0)
+ llvm::ModulePassManager CoroPM;
+ CoroPM.addPass(llvm::CoroEarlyPass());
+ llvm::CGSCCPassManager CGPM;
+ CGPM.addPass(llvm::CoroSplitPass());
+ CoroPM.addPass(llvm::createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+ CoroPM.addPass(llvm::CoroCleanupPass());
+ CoroPM.addPass(llvm::GlobalDCEPass());
+
+ // Create the analysis managers.
+ llvm::LoopAnalysisManager LAM;
+ llvm::FunctionAnalysisManager FAM;
+ llvm::CGSCCAnalysisManager CGAM;
+ llvm::ModuleAnalysisManager MAM;
+
+ llvm::PassBuilder PB;
+ PB.registerModuleAnalyses(MAM);
+ PB.registerCGSCCAnalyses(CGAM);
+ PB.registerFunctionAnalyses(FAM);
+ PB.registerLoopAnalyses(LAM);
+ PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+
+ // Optimize the IR!
+ CoroPM.run(*llvmModule.get(), MAM);
+
+ if (options.keepIntermediate) {
+ llvm::raw_string_ostream rawStringOstream{outputs["CoroOpt"]};
+ llvmModule->print(rawStringOstream, nullptr);
+ auto outFile = output.nextPipelineDumpFilename("CoroOpt", ".ll");
+ dumpToFile(options, outFile, outputs["CoroOpt"]);
+ }
+
+ return success();
+}
+
+LogicalResult runO2LLVMPasses(const CompilerOptions &options,
+ std::shared_ptr llvmModule, CompilerOutput &output)
{
// opt -O2
// As seen here:
@@ -393,10 +451,10 @@ LogicalResult runLLVMPasses(const CompilerOptions &options,
MPM.run(*llvmModule.get(), MAM);
if (options.keepIntermediate) {
- llvm::raw_string_ostream rawStringOstream{outputs["PreEnzymeOpt"]};
+ llvm::raw_string_ostream rawStringOstream{outputs["O2Opt"]};
llvmModule->print(rawStringOstream, nullptr);
- auto outFile = output.nextPipelineDumpFilename("PreEnzymeOpt", ".ll");
- dumpToFile(options, outFile, outputs["PreEnzymeOpt"]);
+ auto outFile = output.nextPipelineDumpFilename("O2Opt", ".ll");
+ dumpToFile(options, outFile, outputs["O2Opt"]);
}
return success();
@@ -572,8 +630,9 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
OwningOpRef op =
timer::timer(parseMLIRSource, "parseMLIRSource", /* add_endl */ false, &ctx, *sourceMgr);
catalyst::utils::LinesCount::ModuleOp(*op);
-
+ bool enzymeRun = false;
if (op) {
+ enzymeRun = containsGradients(*op);
if (failed(runLowering(options, &ctx, *op, output))) {
CO_MSG(options, Verbosity::Urgent, "Failed to lower MLIR module\n");
return failure();
@@ -634,19 +693,28 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
llvmModule->setDataLayout(targetMachine->createDataLayout());
llvmModule->setTargetTriple(targetTriple);
- if (failed(timer::timer(runLLVMPasses, "runLLVMPasses", /* add_endl */ false, options,
- llvmModule, output))) {
- return failure();
- }
-
catalyst::utils::LinesCount::Module(*llvmModule.get());
- if (failed(timer::timer(runEnzymePasses, "runEnzymePasses", /* add_endl */ false, options,
- llvmModule, output))) {
- return failure();
+ if (options.asyncQnodes) {
+ if (failed(timer::timer(runCoroLLVMPasses, "runCoroLLVMPasses", /* add_endl */ false,
+ options, llvmModule, output))) {
+ return failure();
+ }
+ catalyst::utils::LinesCount::Module(*llvmModule.get());
}
+ if (enzymeRun) {
+ if (failed(timer::timer(runO2LLVMPasses, "runO2LLVMPasses", /* add_endl */ false,
+ options, llvmModule, output))) {
+ return failure();
+ }
+ catalyst::utils::LinesCount::Module(*llvmModule.get());
- catalyst::utils::LinesCount::Module(*llvmModule.get());
+ if (failed(timer::timer(runEnzymePasses, "runEnzymePasses", /* add_endl */ false,
+ options, llvmModule, output))) {
+ return failure();
+ }
+ catalyst::utils::LinesCount::Module(*llvmModule.get());
+ }
output.outIR.clear();
outIRStream << *llvmModule;
@@ -693,4 +761,4 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
output.objectFilename = outfile;
}
return success();
-}
+}
\ No newline at end of file
diff --git a/mlir/python/PyCompilerDriver.cpp b/mlir/python/PyCompilerDriver.cpp
index 5934a9c754..707b24205c 100644
--- a/mlir/python/PyCompilerDriver.cpp
+++ b/mlir/python/PyCompilerDriver.cpp
@@ -77,7 +77,7 @@ PYBIND11_MODULE(compiler_driver, m)
m.def(
"run_compiler_driver",
[](const char *source, const char *workspace, const char *moduleName, bool keepIntermediate,
- bool verbose, py::list pipelines,
+ bool asyncQnodes, bool verbose, py::list pipelines,
bool lower_to_llvm) -> std::unique_ptr {
// Install signal handler to catch user interrupts (e.g. CTRL-C).
signal(SIGINT,
@@ -93,6 +93,7 @@ PYBIND11_MODULE(compiler_driver, m)
.moduleName = moduleName,
.diagnosticStream = errStream,
.keepIntermediate = keepIntermediate,
+ .asyncQnodes = asyncQnodes,
.verbosity = verbose ? Verbosity::All : Verbosity::Urgent,
.pipelinesCfg = parseCompilerSpec(pipelines),
.lowerToLLVM = lower_to_llvm};
@@ -105,6 +106,7 @@ PYBIND11_MODULE(compiler_driver, m)
return output;
},
py::arg("source"), py::arg("workspace"), py::arg("module_name") = "jit source",
- py::arg("keep_intermediate") = false, py::arg("verbose") = false,
- py::arg("pipelines") = py::list(), py::arg("lower_to_llvm") = true);
+ py::arg("keep_intermediate") = false, py::arg("async_qnodes") = false,
+ py::arg("verbose") = false, py::arg("pipelines") = py::list(),
+ py::arg("lower_to_llvm") = true);
}