diff --git a/doc/changelog.md b/doc/changelog.md index 77244a7170..1c64ef6cd4 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -114,6 +114,20 @@ ``` +* Exponential extrapolation is now a supported method of extrapolation when using `mitigate_with_zne`. + [(#953)](https://github.com/PennyLaneAI/catalyst/pull/953) + + This new functionality fits the data from noise-scaled circuits with an exponential function, + and returns the zero-noise value. This functionality is available through the pennylane module + as follows + ```py + from pennylane.transforms import exponential_extrapolate + + catalyst.mitigate_with_zne( + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=exponential_extrapolate + ) + ``` +

Improvements

* Catalyst is now compatible with Enzyme `v0.0.130` @@ -182,6 +196,12 @@ * Support for TOML files in Schema 1 has been disabled. [(#960)](https://github.com/PennyLaneAI/catalyst/pull/960) +* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting + and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation + function is valid. Keyword arguments can be passed to this function using the + `extrapolate_kwargs` keyword argument in `mitigate_with_zne`. + [(#806)](https://github.com/PennyLaneAI/catalyst/pull/806) +

Bug fixes

* Static arguments can now be passed through a QNode when specified @@ -248,6 +268,10 @@

Internal changes

+* llvm O2 and Enzyme passes are only run when needed (gradients presents). Async execution of QNodes triggers now triggers a + Coroutine lowering pass. + [(#968)](https://github.com/PennyLaneAI/catalyst/pull/968) + * The function `inactive_callback` was renamed `__catalyst_inactive_callback`. [(#899)](https://github.com/PennyLaneAI/catalyst/pull/899) @@ -284,6 +308,7 @@ Mehrdad Malekmohammadi, Romain Moyard, Erick Ochoa, Mudit Pandey, +nate stemen, Raul Torres, Tzung-Han Juang, Paul Haochen Wang, @@ -799,12 +824,6 @@ Paul Haochen Wang,

Breaking changes

-* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting - and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation - function is valid. Keyword arguments can be passed to this function using the - `extrapolate_kwargs` keyword argument in `mitigate_with_zne`. - [(#806)](https://github.com/PennyLaneAI/catalyst/pull/806) - * Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`. As a result, Catalyst will only be compatible on systems with `glibc` versions `2.28` and above (e.g., Ubuntu 20.04 and above). diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index ce6361aee6..8dc24f34dc 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -531,6 +531,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): str(workspace), module_name, keep_intermediate=self.options.keep_intermediate, + async_qnodes=self.options.async_qnodes, verbose=self.options.verbose, pipelines=self.options.get_pipelines(), lower_to_llvm=lower_to_llvm, diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index 97ec785525..ec0dfda7a3 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -18,6 +18,7 @@ import numpy as np import pennylane as qml import pytest +from pennylane.transforms import exponential_extrapolate import catalyst from catalyst.api_extensions.error_mitigation import polynomial_extrapolation @@ -25,9 +26,18 @@ quadratic_extrapolation = polynomial_extrapolation(2) +def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func): + """skip test if exponential extrapolation will be unstable""" + if circuit_param < 0.3 and extrapolation_func == exponential_extrapolate: + pytest.skip("Exponential extrapolation unstable in this region.") + + @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) -def test_single_measurement(params): +@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) +def test_single_measurement(params, extrapolation): """Test that without noise the same results are returned for single measurements.""" + skip_if_exponential_extrapolation_unstable(params, extrapolation) + dev = qml.device("lightning.qubit", wires=2) @qml.qnode(device=dev) @@ -42,15 +52,18 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) -def test_multiple_measurements(params): +@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) +def test_multiple_measurements(params, extrapolation): """Test that without noise the same results are returned for multiple measurements""" + skip_if_exponential_extrapolation_unstable(params, extrapolation) + dev = qml.device("lightning.qubit", wires=2) @qml.qnode(device=dev) @@ -65,7 +78,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @@ -121,7 +134,8 @@ def mitigated_function(args): mitigated_function(0.1) -def test_dtype_error(): +@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) +def test_dtype_error(extrapolation): """Test that an error is raised when multiple results do not have the same dtype.""" dev = qml.device("lightning.qubit", wires=2) @@ -137,7 +151,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(args) with pytest.raises( @@ -146,7 +160,8 @@ def mitigated_qnode(args): mitigated_qnode(0.1) -def test_dtype_not_float_error(): +@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) +def test_dtype_not_float_error(extrapolation): """Test that an error is raised when results are not float.""" dev = qml.device("lightning.qubit", wires=2) @@ -162,7 +177,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(args) with pytest.raises( @@ -171,7 +186,8 @@ def mitigated_qnode(args): mitigated_qnode(0.1) -def test_shape_error(): +@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) +def test_shape_error(extrapolation): """Test that an error is raised when results have shape.""" dev = qml.device("lightning.qubit", wires=2) @@ -187,7 +203,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(args) with pytest.raises( @@ -229,8 +245,11 @@ def mitigated_qnode(): @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) -def test_zne_usage_patterns(params): +@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) +def test_zne_usage_patterns(params, extrapolation): """Test usage patterns of catalyst.zne.""" + skip_if_exponential_extrapolation_unstable(params, extrapolation) + dev = qml.device("lightning.qubit", wires=2) @qml.qnode(device=dev) @@ -245,13 +264,13 @@ def fn(x): @catalyst.qjit def mitigated_qnode_fn_as_argument(args): return catalyst.mitigate_with_zne( - fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(args) @catalyst.qjit def mitigated_qnode_partial(args): return catalyst.mitigate_with_zne( - scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation + scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation )(fn)(args) assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params)) @@ -271,13 +290,13 @@ def circuit(): qml.Hadamard(wires=1) return qml.expval(qml.PauliY(wires=0)) - def jax_extrap(scale_factors, results): + def jax_extrapolation(scale_factors, results): return jax.numpy.polyfit(scale_factors, results, 2)[-1] @catalyst.qjit def mitigated_qnode(): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrap + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrapolation )() assert np.allclose(mitigated_qnode(), circuit()) @@ -308,5 +327,30 @@ def mitigated_qnode(): assert np.allclose(mitigated_qnode(), circuit()) +def test_exponential_extrapolation_with_kwargs(): + """test mitigate_with_zne with keyword arguments for exponential extrapolation function""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(device=dev) + def circuit(): + qml.Hadamard(wires=0) + qml.RZ(0.1, wires=0) + qml.RZ(0.2, wires=0) + qml.CNOT(wires=[1, 0]) + qml.Hadamard(wires=1) + return qml.expval(qml.PauliY(wires=0)) + + @catalyst.qjit + def mitigated_qnode(): + return catalyst.mitigate_with_zne( + circuit, + scale_factors=jax.numpy.array([1, 2, 3]), + extrapolate=qml.transforms.exponential_extrapolate, + extrapolate_kwargs={"asymptote": 3}, + )() + + assert np.allclose(mitigated_qnode(), circuit()) + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/mlir/include/Driver/CompilerDriver.h b/mlir/include/Driver/CompilerDriver.h index cdf3602ceb..4724131dfe 100644 --- a/mlir/include/Driver/CompilerDriver.h +++ b/mlir/include/Driver/CompilerDriver.h @@ -71,6 +71,8 @@ struct CompilerOptions { llvm::raw_ostream &diagnosticStream; /// If true, the driver will output the module at intermediate points. bool keepIntermediate; + /// If true, the llvm.coroutine will be lowered. + bool asyncQnodes; /// Sets the verbosity level to use when printing messages. Verbosity verbosity; /// Ordered list of named pipelines to execute, each pipeline is described by a list of MLIR diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index 02c193a622..886a5c6ecb 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -46,6 +46,11 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Host.h" +#include "llvm/Transforms/Coroutines/CoroCleanup.h" +#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h" +#include "llvm/Transforms/Coroutines/CoroEarly.h" +#include "llvm/Transforms/Coroutines/CoroSplit.h" +#include "llvm/Transforms/IPO/GlobalDCE.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" @@ -53,6 +58,7 @@ #include "Driver/CompilerDriver.h" #include "Driver/Support.h" #include "Gradient/IR/GradientDialect.h" +#include "Gradient/IR/GradientInterfaces.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/Transforms/Passes.h" @@ -258,6 +264,17 @@ OwningOpRef parseMLIRSource(MLIRContext *ctx, const llvm::SourceMgr &s return parseSourceFile(sourceMgr, parserConfig); } +/// From the MLIR module it checks if gradients operations are in the program. +bool containsGradients(mlir::ModuleOp moduleOp) +{ + bool contain = false; + moduleOp.walk([&](catalyst::gradient::GradientOpInterface op) { + contain = true; + return WalkResult::interrupt(); + }); + return contain; +} + /// Parse an LLVM module given in textual representation. Any parse errors will be output to /// the provided SMDiagnostic. std::shared_ptr parseLLVMSource(llvm::LLVMContext &context, StringRef source, @@ -360,8 +377,49 @@ LogicalResult inferMLIRReturnTypes(MLIRContext *ctx, llvm::Type *returnType, return failure(); } -LogicalResult runLLVMPasses(const CompilerOptions &options, - std::shared_ptr llvmModule, CompilerOutput &output) +LogicalResult runCoroLLVMPasses(const CompilerOptions &options, + std::shared_ptr llvmModule, CompilerOutput &output) +{ + + auto &outputs = output.pipelineOutputs; + + // Create a pass to lower LLVM coroutines (similar to what happens in O0) + llvm::ModulePassManager CoroPM; + CoroPM.addPass(llvm::CoroEarlyPass()); + llvm::CGSCCPassManager CGPM; + CGPM.addPass(llvm::CoroSplitPass()); + CoroPM.addPass(llvm::createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + CoroPM.addPass(llvm::CoroCleanupPass()); + CoroPM.addPass(llvm::GlobalDCEPass()); + + // Create the analysis managers. + llvm::LoopAnalysisManager LAM; + llvm::FunctionAnalysisManager FAM; + llvm::CGSCCAnalysisManager CGAM; + llvm::ModuleAnalysisManager MAM; + + llvm::PassBuilder PB; + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + // Optimize the IR! + CoroPM.run(*llvmModule.get(), MAM); + + if (options.keepIntermediate) { + llvm::raw_string_ostream rawStringOstream{outputs["CoroOpt"]}; + llvmModule->print(rawStringOstream, nullptr); + auto outFile = output.nextPipelineDumpFilename("CoroOpt", ".ll"); + dumpToFile(options, outFile, outputs["CoroOpt"]); + } + + return success(); +} + +LogicalResult runO2LLVMPasses(const CompilerOptions &options, + std::shared_ptr llvmModule, CompilerOutput &output) { // opt -O2 // As seen here: @@ -393,10 +451,10 @@ LogicalResult runLLVMPasses(const CompilerOptions &options, MPM.run(*llvmModule.get(), MAM); if (options.keepIntermediate) { - llvm::raw_string_ostream rawStringOstream{outputs["PreEnzymeOpt"]}; + llvm::raw_string_ostream rawStringOstream{outputs["O2Opt"]}; llvmModule->print(rawStringOstream, nullptr); - auto outFile = output.nextPipelineDumpFilename("PreEnzymeOpt", ".ll"); - dumpToFile(options, outFile, outputs["PreEnzymeOpt"]); + auto outFile = output.nextPipelineDumpFilename("O2Opt", ".ll"); + dumpToFile(options, outFile, outputs["O2Opt"]); } return success(); @@ -572,8 +630,9 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & OwningOpRef op = timer::timer(parseMLIRSource, "parseMLIRSource", /* add_endl */ false, &ctx, *sourceMgr); catalyst::utils::LinesCount::ModuleOp(*op); - + bool enzymeRun = false; if (op) { + enzymeRun = containsGradients(*op); if (failed(runLowering(options, &ctx, *op, output))) { CO_MSG(options, Verbosity::Urgent, "Failed to lower MLIR module\n"); return failure(); @@ -634,19 +693,28 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & llvmModule->setDataLayout(targetMachine->createDataLayout()); llvmModule->setTargetTriple(targetTriple); - if (failed(timer::timer(runLLVMPasses, "runLLVMPasses", /* add_endl */ false, options, - llvmModule, output))) { - return failure(); - } - catalyst::utils::LinesCount::Module(*llvmModule.get()); - if (failed(timer::timer(runEnzymePasses, "runEnzymePasses", /* add_endl */ false, options, - llvmModule, output))) { - return failure(); + if (options.asyncQnodes) { + if (failed(timer::timer(runCoroLLVMPasses, "runCoroLLVMPasses", /* add_endl */ false, + options, llvmModule, output))) { + return failure(); + } + catalyst::utils::LinesCount::Module(*llvmModule.get()); } + if (enzymeRun) { + if (failed(timer::timer(runO2LLVMPasses, "runO2LLVMPasses", /* add_endl */ false, + options, llvmModule, output))) { + return failure(); + } + catalyst::utils::LinesCount::Module(*llvmModule.get()); - catalyst::utils::LinesCount::Module(*llvmModule.get()); + if (failed(timer::timer(runEnzymePasses, "runEnzymePasses", /* add_endl */ false, + options, llvmModule, output))) { + return failure(); + } + catalyst::utils::LinesCount::Module(*llvmModule.get()); + } output.outIR.clear(); outIRStream << *llvmModule; @@ -693,4 +761,4 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & output.objectFilename = outfile; } return success(); -} +} \ No newline at end of file diff --git a/mlir/python/PyCompilerDriver.cpp b/mlir/python/PyCompilerDriver.cpp index 5934a9c754..707b24205c 100644 --- a/mlir/python/PyCompilerDriver.cpp +++ b/mlir/python/PyCompilerDriver.cpp @@ -77,7 +77,7 @@ PYBIND11_MODULE(compiler_driver, m) m.def( "run_compiler_driver", [](const char *source, const char *workspace, const char *moduleName, bool keepIntermediate, - bool verbose, py::list pipelines, + bool asyncQnodes, bool verbose, py::list pipelines, bool lower_to_llvm) -> std::unique_ptr { // Install signal handler to catch user interrupts (e.g. CTRL-C). signal(SIGINT, @@ -93,6 +93,7 @@ PYBIND11_MODULE(compiler_driver, m) .moduleName = moduleName, .diagnosticStream = errStream, .keepIntermediate = keepIntermediate, + .asyncQnodes = asyncQnodes, .verbosity = verbose ? Verbosity::All : Verbosity::Urgent, .pipelinesCfg = parseCompilerSpec(pipelines), .lowerToLLVM = lower_to_llvm}; @@ -105,6 +106,7 @@ PYBIND11_MODULE(compiler_driver, m) return output; }, py::arg("source"), py::arg("workspace"), py::arg("module_name") = "jit source", - py::arg("keep_intermediate") = false, py::arg("verbose") = false, - py::arg("pipelines") = py::list(), py::arg("lower_to_llvm") = true); + py::arg("keep_intermediate") = false, py::arg("async_qnodes") = false, + py::arg("verbose") = false, py::arg("pipelines") = py::list(), + py::arg("lower_to_llvm") = true); }