From eb7c349795804c1e5df72980a81e4b1e1c4be4f0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 26 Jul 2024 15:02:15 -0400 Subject: [PATCH] Add dynamic_one_shot tensorflow support + expanded testing (#5973) **Context:** As name says. Gradient workflows no longer raise errors after the merge of #5791 , but their correctness is yet to be verified. **Description of the Change:** * Updated casting rules in `dynamic_one_shot`'s processing function for tensorflow. * For the changes to be fully integrated, the way that the interface is passed around when calling a QNode needed to be changed, so the following changes were made: * `QNode` has updated behaviour for how `mcm_config` is used during execution. In `QNode._execution_component`, a copy of `self.execute_kwargs["mcm_config"]` is the source of truth, and in `qml.execute`, `config.mcm_config` is the source of truth. * Added a private `pad-invalid-samples` `postselect_mode`. The `postselect_mode` is switched to this automatically in `qml.execute` if executing with jax and shots and `postselect_mode == "hw-like"`. This way we standardize how the MCM transforms determine if jax is being used. * Updates to `capture` module to accommodate the above changes. **Benefits:** * `dynamic_one_shot` doesn't cast to interfaces inside the ML boundary * `dynamic_one_shot` works with tensorflow * Expanded tests **Possible Drawbacks:** **Related GitHub Issues:** Fixes #5736, #5710 Duplicate of #5861 which was closed due to release branch merge stuff. --------- Co-authored-by: Jay Soni Co-authored-by: Astral Cai Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Yushao Chen (Jerry) Co-authored-by: Christina Lee Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com> Co-authored-by: Pietropaolo Frisoni Co-authored-by: Ahmed Darwish Co-authored-by: Utkarsh Co-authored-by: David Wierichs Co-authored-by: Christina Lee Co-authored-by: Mikhail Andrenkov Co-authored-by: Diego <67476785+DSGuala@users.noreply.github.com> Co-authored-by: Josh Izaac Co-authored-by: Diego Co-authored-by: Vincent Michaud-Rioux Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> Co-authored-by: Jack Brown Co-authored-by: Paul Finlay <50180049+doctorperceptron@users.noreply.github.com> Co-authored-by: David Ittah Co-authored-by: Cristian Emiliano Godinez Ramirez <57567043+EmilianoG-byte@users.noreply.github.com> Co-authored-by: Vincent Michaud-Rioux --- doc/releases/changelog-0.37.0.md | 2 +- doc/releases/changelog-dev.md | 12 +- pennylane/devices/default_qubit.py | 5 +- pennylane/devices/execution_config.py | 2 +- pennylane/devices/preprocess.py | 5 +- pennylane/devices/qubit/apply_operation.py | 2 +- pennylane/math/single_dispatch.py | 1 + pennylane/measurements/mid_measure.py | 2 +- pennylane/transforms/dynamic_one_shot.py | 96 +++-- pennylane/workflow/execution.py | 43 +- pennylane/workflow/qnode.py | 27 +- .../test_default_qubit_native_mcm.py | 272 ++++++------- tests/test_qnode.py | 4 +- tests/transforms/test_dynamic_one_shot.py | 375 +++++++++++------- 14 files changed, 493 insertions(+), 355 deletions(-) diff --git a/doc/releases/changelog-0.37.0.md b/doc/releases/changelog-0.37.0.md index 969b5673c85..936034c98e9 100644 --- a/doc/releases/changelog-0.37.0.md +++ b/doc/releases/changelog-0.37.0.md @@ -929,4 +929,4 @@ Kenya Sakka, Jay Soni, Kazuki Tsuoka, Haochen Paul Wang, -David Wierichs. +David Wierichs. \ No newline at end of file diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 089d8db61e8..3b9c85ace17 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,12 +4,15 @@

New features since last release

+* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP` + classes, allowing for more efficient handling of quantum density matrices, particularly with batch + processing support. This method simplifies the calculation of probabilities from quantum states + represented as density matrices. + [(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830) + * Resolved the bug in `qml.ThermalRelaxationError` where there was a typo from `tq` to `tg`. [(#5988)](https://github.com/PennyLaneAI/pennylane/issues/5988) -* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP` classes, allowing for more efficient handling of quantum density matrices, particularly with batch processing support. This method simplifies the calculation of probabilities from quantum states represented as density matrices. - [(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830) - * The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear combination of unitaries. [(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756) @@ -57,6 +60,9 @@ * `QuantumScript.hash` is now cached, leading to performance improvements. [(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919) +* `qml.dynamic_one_shot` now supports circuits using the `"tensorflow"` interface. + [(#5973)](https://github.com/PennyLaneAI/pennylane/pull/5973) + * The representation for `Wires` has now changed to be more copy-paste friendly. [(#5958)](https://github.com/PennyLaneAI/pennylane/pull/5958) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index bcd97c003b8..d7805f79d28 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -532,10 +532,7 @@ def preprocess( transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( - mid_circuit_measurements, - device=self, - mcm_config=config.mcm_config, - interface=config.interface, + mid_circuit_measurements, device=self, mcm_config=config.mcm_config ) transform_program.add_transform( decompose, diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 5ed57238ed5..5b7af096d81 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -50,7 +50,7 @@ def __post_init__(self): None, ): raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.") - if self.postselect_mode not in ("hw-like", "fill-shots", None): + if self.postselect_mode not in ("hw-like", "fill-shots", "pad-invalid-samples", None): raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.") diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index ec7d337bc9e..819e375df54 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -151,14 +151,13 @@ def mid_circuit_measurements( tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig(), - interface=None, + **kwargs, # pylint: disable=unused-argument ) -> tuple[QuantumTapeBatch, PostprocessingFn]: """Provide the transform to handle mid-circuit measurements. If the tape or device uses finite-shot, use the native implementation (i.e. no transform), and use the ``qml.defer_measurements`` transform otherwise. """ - if isinstance(mcm_config, dict): mcm_config = MCMConfig(**mcm_config) mcm_method = mcm_config.mcm_method @@ -166,7 +165,7 @@ def mid_circuit_measurements( mcm_method = "one-shot" if tape.shots else "deferred" if mcm_method == "one-shot": - return qml.dynamic_one_shot(tape, interface=interface) + return qml.dynamic_one_shot(tape, postselect_mode=mcm_config.postselect_mode) if mcm_method == "tree-traversal": return (tape,), null_postprocessing return qml.defer_measurements(tape, device=device) diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index e51cad39d9e..8e082ee9820 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -328,7 +328,7 @@ def apply_mid_measure( axis = wire.toarray()[0] slices = [slice(None)] * qml.math.ndim(state) slices[axis] = 0 - prob0 = qml.math.norm(state[tuple(slices)]) ** 2 + prob0 = qml.math.real(qml.math.norm(state[tuple(slices)])) ** 2 if prng_key is not None: # pylint: disable=import-outside-toplevel diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index 892e5387830..aef2fa96bb4 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -237,6 +237,7 @@ def _take_autograd(tensor, indices, axis=None): ar.autoray._SUBMODULE_ALIASES["tensorflow", "arctan"] = "tensorflow.math" ar.autoray._SUBMODULE_ALIASES["tensorflow", "arctan2"] = "tensorflow.math" ar.autoray._SUBMODULE_ALIASES["tensorflow", "mod"] = "tensorflow.math" +ar.autoray._SUBMODULE_ALIASES["tensorflow", "logical_and"] = "tensorflow.math" ar.autoray._SUBMODULE_ALIASES["tensorflow", "kron"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "moveaxis"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "sinc"] = "tensorflow.experimental.numpy" diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index a1ef30e61af..9b626ffd378 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -443,7 +443,7 @@ def __mul__(self, other): return self._transform_bin_op(lambda a, b: a * b, other) def __rmul__(self, other): - return self._apply(lambda v: other * v) + return self._apply(lambda v: other * qml.math.cast_like(v, other)) def __truediv__(self, other): return self._transform_bin_op(lambda a, b: a / b, other) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 2eb326d29d1..5ff62e1815e 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -125,7 +125,7 @@ def func(x, y): aux_tapes = [init_auxiliary_tape(t) for t in tapes] - interface = kwargs.get("interface", None) + postselect_mode = kwargs.get("postselect_mode", None) def reshape_data(array): return qml.math.squeeze(qml.math.vstack(array)) @@ -161,7 +161,9 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): results = [ reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0]) ] - return parse_native_mid_circuit_measurements(tape, aux_tapes, results, interface=interface) + return parse_native_mid_circuit_measurements( + tape, aux_tapes, results, postselect_mode=postselect_mode + ) return aux_tapes, processing_fn @@ -227,7 +229,7 @@ def parse_native_mid_circuit_measurements( circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike, - interface=None, + postselect_mode=None, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. @@ -247,20 +249,27 @@ def measurement_with_no_shots(measurement): else np.nan ) - interface = interface or qml.math.get_deep_interface(circuit.data) + interface = qml.math.get_deep_interface(results) interface = "numpy" if interface == "builtins" else interface + interface = "tensorflow" if interface == "tf" else interface active_qjit = qml.compiler.active() all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) - mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:])) + mcm_samples = qml.math.hstack( + tuple(qml.math.reshape(res, (-1, 1)) for res in results[-n_mcms:]) + ) mcm_samples = qml.math.array(mcm_samples, like=interface) # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 has_postselect = qml.math.array( - [[int(op.postselect is not None) for op in all_mcms]], like=interface + [[op.postselect is not None for op in all_mcms]], + like=interface, + dtype=mcm_samples.dtype, ) postselect = qml.math.array( - [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface + [[0 if op.postselect is None else op.postselect for op in all_mcms]], + like=interface, + dtype=mcm_samples.dtype, ) is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1) has_valid = qml.math.any(is_valid) @@ -277,9 +286,11 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv and active_qjit: - meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover + meas = gather_mcm_qjit( + m, mcm_samples, is_valid, postselect_mode=postselect_mode + ) # pragma: no cover elif m.mv: - meas = gather_mcm(m, mcm_samples, is_valid) + meas = gather_mcm(m, mcm_samples, is_valid, postselect_mode=postselect_mode) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 @@ -296,12 +307,15 @@ def measurement_with_no_shots(measurement): # We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples if isinstance(m, CountsMP): normalized_meas.append( - (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0)) + ( + result[0][0], + qml.math.sum(result[1] * qml.math.reshape(is_valid, (-1, 1)), axis=0), + ) ) m_count += 1 continue result = qml.math.squeeze(result) - meas = gather_non_mcm(m, result, is_valid) + meas = gather_non_mcm(m, result, is_valid, postselect_mode=postselect_mode) m_count += 1 if isinstance(m, SampleMP): meas = qml.math.squeeze(meas) @@ -310,7 +324,7 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover +def gather_mcm_qjit(measurement, samples, is_valid, postselect_mode=None): # pragma: no cover """Process MCM measurements when the Catalyst compiler is active. Args: @@ -331,7 +345,7 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover raise LookupError("MCM not found") meas = qml.math.squeeze(meas) if isinstance(measurement, (CountsMP, ProbabilityMP)): - interface = qml.math.get_deep_interface(is_valid) + interface = qml.math.get_interface(is_valid) sum_valid = qml.math.sum(is_valid) count_1 = qml.math.sum(meas * is_valid) if isinstance(measurement, CountsMP): @@ -341,10 +355,10 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover if isinstance(measurement, ProbabilityMP): counts = qml.math.array([sum_valid - count_1, count_1], like=interface) return counts / sum_valid - return gather_non_mcm(measurement, meas, is_valid) + return gather_non_mcm(measurement, meas, is_valid, postselect_mode=postselect_mode) -def gather_non_mcm(measurement, samples, is_valid): +def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None): """Combines, gathers and normalizes several measurements with trivial measurement values. Args: @@ -365,25 +379,39 @@ def gather_non_mcm(measurement, samples, is_valid): if not measurement.all_outcomes: tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) - if isinstance(measurement, ExpectationMP): - return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) - if isinstance(measurement, ProbabilityMP): - return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid) + if isinstance(measurement, SampleMP): - is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - if is_interface_jax and samples.ndim == 2: - is_valid = is_valid.reshape((-1, 1)) + if postselect_mode == "pad-invalid-samples" and samples.ndim == 2: + is_valid = qml.math.reshape(is_valid, (-1, 1)) return ( qml.math.where(is_valid, samples, fill_in_value) - if is_interface_jax + if postselect_mode == "pad-invalid-samples" else samples[is_valid] ) + + if (interface := qml.math.get_interface(is_valid)) == "tensorflow": + # Tensorflow requires arrays that are used for arithmetic with each other to have the + # same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to + # index other tf.Tensors (is_valid is used to index valid samples). + is_valid = qml.math.cast_like(is_valid, samples) + + if isinstance(measurement, ExpectationMP): + return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) + if isinstance(measurement, ProbabilityMP): + return qml.math.sum(samples * qml.math.reshape(is_valid, (-1, 1)), axis=0) / qml.math.sum( + is_valid + ) + # VarianceMP expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) + if interface == "tensorflow": + # Casting needed for tensorflow + samples = qml.math.cast_like(samples, expval) + is_valid = qml.math.cast_like(is_valid, expval) return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) -def gather_mcm(measurement, samples, is_valid): +def gather_mcm(measurement, samples, is_valid, postselect_mode=None): """Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: @@ -404,20 +432,30 @@ def gather_mcm(measurement, samples, is_valid): if isinstance(measurement, ProbabilityMP): values = [list(m.branches.values()) for m in mv] values = list(itertools.product(*values)) - values = [qml.math.array([v], like=interface) for v in values] + values = [qml.math.array([v], like=interface, dtype=mcm_samples.dtype) for v in values] + # Need to use boolean functions explicitly as Tensorflow does not allow integer math + # on boolean arrays counts = [ - qml.math.sum(qml.math.all(mcm_samples == v, axis=1) * is_valid) for v in values + qml.math.count_nonzero( + qml.math.logical_and(qml.math.all(mcm_samples == v, axis=1), is_valid) + ) + for v in values ] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid) + return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode) mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): - counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] + # Need to use boolean functions explicitly as Tensorflow does not allow integer math + # on boolean arrays + counts = [ + qml.math.count_nonzero(qml.math.logical_and((mcm_samples == v), is_valid)) + for v in list(mv.branches.values()) + ] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{float(s): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid) + return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index cac6aafd2fe..1e9560b6571 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -467,6 +467,28 @@ def _deprecated_arguments_warnings( return tapes, override_shots, expand_fn, max_expansion, device_batch_transform +def _update_mcm_config(mcm_config: "qml.devices.MCMConfig", interface: str, finite_shots: bool): + """Helper function to update the mid-circuit measurements configuration based on + execution parameters""" + if interface == "jax-jit" and mcm_config.mcm_method == "deferred": + # This is a current limitation of defer_measurements. "hw-like" behaviour is + # not yet accessible. + if mcm_config.postselect_mode == "hw-like": + raise ValueError( + "Using postselect_mode='hw-like' is not supported with jax-jit when using " + "mcm_method='deferred'." + ) + mcm_config.postselect_mode = "fill-shots" + + if ( + finite_shots + and "jax" in interface + and mcm_config.mcm_method in (None, "one-shot") + and mcm_config.postselect_mode in (None, "hw-like") + ): + mcm_config.postselect_mode = "pad-invalid-samples" + + def execute( tapes: QuantumTapeBatch, device: device_type, @@ -697,16 +719,17 @@ def cost_fn(params, x): ) # Mid-circuit measurement configuration validation - mcm_interface = _get_interface_name(tapes, "auto") if interface is None else interface - if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": - # This is a current limitation of defer_measurements. "hw-like" behaviour is - # not yet accessible. - if config.mcm_config.postselect_mode == "hw-like": - raise ValueError( - "Using postselect_mode='hw-like' is not supported with jax-jit when using " - "mcm_method='deferred'." - ) - config.mcm_config.postselect_mode = "fill-shots" + mcm_interface = interface or _get_interface_name(tapes, "auto") + finite_shots = ( + ( + qml.measurements.Shots(device.shots) + if isinstance(device, qml.devices.LegacyDevice) + else device.shots + ) + if override_shots is False + else override_shots + ) + _update_mcm_config(config.mcm_config, mcm_interface, finite_shots) is_gradient_transform = isinstance(gradient_fn, qml.transforms.core.TransformDispatcher) transform_program, inner_transform = _make_transform_programs( diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 71eb11a9f80..640f6c795e6 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -64,7 +64,7 @@ def _get_device_shots(device) -> Shots: def _make_execution_config( - circuit: Optional["QNode"], diff_method=None + circuit: Optional["QNode"], diff_method=None, mcm_config=None ) -> "qml.devices.ExecutionConfig": if diff_method is None or isinstance(diff_method, str): _gradient_method = diff_method @@ -76,14 +76,13 @@ def _make_execution_config( grad_on_execution = False elif grad_on_execution == "best": grad_on_execution = None - mcm_config = execute_kwargs.get("mcm_config", {}) return qml.devices.ExecutionConfig( interface=getattr(circuit, "interface", None), gradient_method=_gradient_method, grad_on_execution=grad_on_execution, use_device_jacobian_product=execute_kwargs.get("device_vjp", False), - mcm_config=mcm_config, + mcm_config=mcm_config or qml.devices.MCMConfig(), ) @@ -555,9 +554,8 @@ def __init__( self.diff_method = diff_method self.expansion_strategy = expansion_strategy self.max_expansion = max_expansion - cache = (max_diff > 1) if cache == "auto" else cache - mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) + cache = (max_diff > 1) if cache == "auto" else cache # execution keyword arguments self.execute_kwargs = { @@ -1080,8 +1078,9 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml ) self._tape_cached = using_custom_cache and self.tape.hash in cache - mcm_config = copy.copy(self.execute_kwargs["mcm_config"]) - finite_shots = _get_device_shots if override_shots is False else override_shots + execute_kwargs = copy.copy(self.execute_kwargs) + mcm_config = copy.copy(execute_kwargs["mcm_config"]) + finite_shots = _get_device_shots(self.device) if override_shots is False else override_shots if not finite_shots: mcm_config.postselect_mode = None if mcm_config.mcm_method in ("one-shot", "tree-traversal"): @@ -1097,7 +1096,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml if isinstance(self.device, qml.devices.Device): - config = _make_execution_config(self, self.gradient_fn) + config = _make_execution_config(self, self.gradient_fn, mcm_config) device_transform_program, config = self.device.preprocess(execution_config=config) if config.use_device_gradient: @@ -1115,13 +1114,9 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml qml.devices.preprocess.mid_circuit_measurements, device=self.device, mcm_config=mcm_config, - interface=self.interface, ) elif hasattr(self.device, "capabilities"): - inner_transform_program.add_transform( - qml.defer_measurements, - device=self.device, - ) + inner_transform_program.add_transform(qml.defer_measurements, device=self.device) # Add the gradient expand to the program if necessary if getattr(self.gradient_fn, "expand_transform", False): @@ -1134,8 +1129,10 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml full_transform_program.set_classical_component(self, args, kwargs) _prune_dynamic_transform(full_transform_program, inner_transform_program) + execute_kwargs["mcm_config"] = mcm_config + with warnings.catch_warnings(): - # TODO: remove this once the cycle for the arguements have finished, i.e. 0.39. + # TODO: remove this once the cycle for the arguments have finished, i.e. 0.39. warnings.filterwarnings( action="ignore", message=r".*argument is deprecated and will be removed in version 0.39.*", @@ -1152,7 +1149,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml config=config, gradient_kwargs=self.gradient_kwargs, override_shots=override_shots, - **self.execute_kwargs, + **execute_kwargs, ) res = res[0] diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index be0f8574d97..42b34d3246f 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for default qubit preprocessing.""" from collections.abc import Sequence +from unittest.mock import patch import mcm_utils import numpy as np @@ -424,156 +425,141 @@ def circuit(x): _ = circuit([0.1, 0.2]) -# pylint: disable=not-an-iterable -@pytest.mark.jax +@pytest.mark.all_interfaces +@pytest.mark.parametrize("interface", ["torch", "tensorflow", "jax", "autograd"]) @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) -@pytest.mark.parametrize("shots", [100, [100, 101], [100, 100, 101]]) -@pytest.mark.parametrize("postselect", [None, 0, 1]) -def test_sample_with_prng_key(mcm_method, shots, postselect): - """Test that setting a PRNGKey gives the expected behaviour. With separate calls - to DefaultQubit.execute, the same results are expected when using a PRNGKey""" - # pylint: disable=import-outside-toplevel - from jax.random import PRNGKey - - dev = get_device(shots=shots, seed=PRNGKey(678)) - params = [np.pi / 4, np.pi / 3] - obs = qml.PauliZ(0) @ qml.PauliZ(1) +def test_finite_diff_in_transform_program(interface, mcm_method): + """Test that finite diff is in the transform program of a qnode containing + mid-circuit measurements""" - def func(x, y): - obs_tape(x, y, None, postselect=postselect) - return qml.sample(op=obs) - - func0 = qml.QNode(func, dev, mcm_method=mcm_method) - results0 = func0(*params) - results1 = qml.QNode(func, dev, mcm_method="deferred")(*params) + dev = get_device(shots=10) - mcm_utils.validate_measurements(qml.sample, shots, results1, results0, batch_size=None) + @qml.qnode(dev, mcm_method=mcm_method, diff_method="finite-diff") + def circuit(x): + qml.RX(x, 0) + qml.measure(0) + return qml.expval(qml.Z(0)) - evals = obs.eigvals() - for eig in evals: - # When comparing with the results from a circuit with deferred measurements - # we're not always expected to have the functions used to sample are different - if isinstance(shots, list): - for r in results1: - assert not np.all(np.isclose(r, eig)) - else: - assert not np.all(np.isclose(results1, eig)) + x = qml.math.array(1.5, like=interface) + with patch("pennylane.execute") as mock_execute: + circuit(x) + mock_execute.assert_called() + _, kwargs = mock_execute.call_args + transform_program = kwargs["transform_program"] - results0_2 = func0(*params) - # Same result expected with multiple executions - if isinstance(shots, list): - for r0, r0_2 in zip(results0, results0_2): - assert np.allclose(r0, r0_2) - else: - assert np.allclose(results0, results0_2) + # pylint: disable=protected-access + assert transform_program[0]._transform == qml.gradients.finite_diff.expand_transform # pylint: disable=import-outside-toplevel, not-an-iterable @pytest.mark.jax -@pytest.mark.parametrize("diff_method", [None, "best"]) -@pytest.mark.parametrize("postselect", [None, 1]) -@pytest.mark.parametrize("reset", [False, True]) -def test_jax_jit(diff_method, postselect, reset): - """Tests that DefaultQubit handles a circuit with a single mid-circuit measurement and a - conditional gate. A single measurement of a common observable is performed at the end.""" - import jax - - shots = 10 - - dev = get_device(shots=shots, seed=jax.random.PRNGKey(678)) - params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5] - obs = qml.PauliY(0) - - @qml.qnode(dev, diff_method=diff_method) - def func(x, y, z): - m0, m1 = obs_tape(x, y, z, reset=reset, postselect=postselect) - return ( - qml.probs(wires=[1]), - qml.probs(wires=[0, 1]), - qml.sample(wires=[1]), - qml.sample(wires=[0, 1]), - qml.expval(obs), - qml.probs(obs), - qml.sample(obs), - qml.var(obs), - qml.expval(op=m0 + 2 * m1), - qml.probs(op=m0), - qml.sample(op=m0 + 2 * m1), - qml.var(op=m0 + 2 * m1), - qml.probs(op=[m0, m1]), - ) - - func1 = func - results1 = func1(*params) - - jaxpr = str(jax.make_jaxpr(func)(*params)) - if diff_method == "best": - assert "pure_callback" in jaxpr - pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.") - else: - assert "pure_callback" not in jaxpr - - func2 = jax.jit(func) - results2 = func2(*params) - - measures = [ - qml.probs, - qml.probs, - qml.sample, - qml.sample, - qml.expval, - qml.probs, - qml.sample, - qml.var, - qml.expval, - qml.probs, - qml.sample, - qml.var, - qml.probs, - ] - for measure_f, r1, r2 in zip(measures, results1, results2): - r1, r2 = np.array(r1).ravel(), np.array(r2).ravel() - if measure_f == qml.sample: - r2 = r2[r2 != fill_in_value] - np.allclose(r1, r2) - - -@pytest.mark.torch -@pytest.mark.parametrize("postselect", [None, 1]) -@pytest.mark.parametrize("diff_method", [None, "best"]) -@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var]) -@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"]) -def test_torch_integration(postselect, diff_method, measure_f, meas_obj): - """Test that native MCM circuits are executed correctly with Torch""" - if measure_f in (qml.expval, qml.var) and ( - isinstance(meas_obj, list) or meas_obj == "mcm_list" - ): - pytest.skip("Can't use wires/mcm lists with var or expval") - - import torch - - shots = 7000 - dev = get_device(shots=shots, seed=123456789) - param = torch.tensor(np.pi / 3, dtype=torch.float64) - - @qml.qnode(dev, diff_method=diff_method) - def func(x): - qml.RX(x, 0) - m0 = qml.measure(0) - qml.RX(0.5 * x, 1) - m1 = qml.measure(1, postselect=postselect) - qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) - m2 = qml.measure(0) - - mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2] - measurement_key = "wires" if isinstance(meas_obj, list) else "op" - measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj - return measure_f(**{measurement_key: measurement_value}) - - func1 = func - func2 = qml.defer_measurements(func) - - results1 = func1(param) - results2 = func2(param) +class TestJaxIntegration: + """Integration tests for dynamic_one_shot with jax""" + + @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) + @pytest.mark.parametrize("shots", [100, [100, 101], [100, 100, 101]]) + @pytest.mark.parametrize("postselect", [None, 0, 1]) + def test_sample_with_prng_key(self, mcm_method, shots, postselect): + """Test that setting a PRNGKey gives the expected behaviour. With separate calls + to DefaultQubit.execute, the same results are expected when using a PRNGKey""" + # pylint: disable=import-outside-toplevel + from jax.random import PRNGKey + + dev = get_device(shots=shots, seed=PRNGKey(678)) + params = [np.pi / 4, np.pi / 3] + obs = qml.PauliZ(0) @ qml.PauliZ(1) + + def func(x, y): + obs_tape(x, y, None, postselect=postselect) + return qml.sample(op=obs) + + func0 = qml.QNode(func, dev, mcm_method=mcm_method) + results0 = func0(*params) + results1 = qml.QNode(func, dev, mcm_method="deferred")(*params) + + mcm_utils.validate_measurements(qml.sample, shots, results1, results0, batch_size=None) + + evals = obs.eigvals() + for eig in evals: + # When comparing with the results from a circuit with deferred measurements + # we're not always expected to have the functions used to sample are different + if isinstance(shots, list): + for r in results1: + assert not np.all(np.isclose(r, eig)) + else: + assert not np.all(np.isclose(results1, eig)) - mcm_utils.validate_measurements(measure_f, shots, results1, results2) + results0_2 = func0(*params) + # Same result expected with multiple executions + if isinstance(shots, list): + for r0, r0_2 in zip(results0, results0_2): + assert np.allclose(r0, r0_2) + else: + assert np.allclose(results0, results0_2) + + @pytest.mark.parametrize("diff_method", [None, "best"]) + @pytest.mark.parametrize("postselect", [None, 1]) + @pytest.mark.parametrize("reset", [False, True]) + def test_jax_jit(self, diff_method, postselect, reset): + """Tests that DefaultQubit handles a circuit with a single mid-circuit measurement and a + conditional gate. A single measurement of a common observable is performed at the end.""" + import jax + + shots = 10 + + dev = get_device(shots=shots, seed=jax.random.PRNGKey(678)) + params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5] + obs = qml.PauliY(0) + + @qml.qnode(dev, diff_method=diff_method) + def func(x, y, z): + m0, m1 = obs_tape(x, y, z, reset=reset, postselect=postselect) + return ( + qml.probs(wires=[1]), + qml.probs(wires=[0, 1]), + qml.sample(wires=[1]), + qml.sample(wires=[0, 1]), + qml.expval(obs), + qml.probs(obs), + qml.sample(obs), + qml.var(obs), + qml.expval(op=m0 + 2 * m1), + qml.probs(op=m0), + qml.sample(op=m0 + 2 * m1), + qml.var(op=m0 + 2 * m1), + qml.probs(op=[m0, m1]), + ) + + func1 = func + results1 = func1(*params) + + jaxpr = str(jax.make_jaxpr(func)(*params)) + if diff_method == "best": + assert "pure_callback" in jaxpr + pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.") + else: + assert "pure_callback" not in jaxpr + + func2 = jax.jit(func) + results2 = func2(*params) + + measures = [ + qml.probs, + qml.probs, + qml.sample, + qml.sample, + qml.expval, + qml.probs, + qml.sample, + qml.var, + qml.expval, + qml.probs, + qml.sample, + qml.var, + qml.probs, + ] + for measure_f, r1, r2 in zip(measures, results1, results2): + r1, r2 = np.array(r1).ravel(), np.array(r2).ravel() + if measure_f == qml.sample: + r2 = r2[r2 != fill_in_value] + np.allclose(r1, r2) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 52d9155e448..9f452bfe853 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -16,7 +16,7 @@ # pylint: disable=import-outside-toplevel, protected-access, no-member import warnings -from dataclasses import asdict, replace +from dataclasses import replace from functools import partial import numpy as np @@ -1853,7 +1853,7 @@ def test_execution_does_not_mutate_config(self, mcm_method, postselect_mode): postselect_mode=postselect_mode, mcm_method=mcm_method ) - @qml.qnode(dev, **asdict(original_config)) + @qml.qnode(dev, postselect_mode=postselect_mode, mcm_method=mcm_method) def circuit(x, mp): qml.RX(x, 0) qml.measure(0, postselect=1) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 16ac830cd11..1556e6538dd 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -28,7 +28,10 @@ ProbabilityMP, SampleMP, ) -from pennylane.transforms.dynamic_one_shot import parse_native_mid_circuit_measurements +from pennylane.transforms.dynamic_one_shot import ( + fill_in_value, + parse_native_mid_circuit_measurements, +) @pytest.mark.parametrize( @@ -215,157 +218,245 @@ def assert_results(res, shots, n_mcms): # that samples are generated correctly. -@pytest.mark.jax -@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var)) -@pytest.mark.parametrize("shots", [20, [20, 21]]) -@pytest.mark.parametrize("n_mcms", [1, 3]) -def test_tape_results_jax(shots, n_mcms, measure_f): - """Test that the simulation results of a tape are correct with jax parameters""" - import jax - - dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) - param = jax.numpy.array(np.pi / 2) - - mv = qml.measure(0) - mp = mv.measurements[0] - - tape = qml.tape.QuantumScript( - [qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)], - [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], - shots=shots, - ) - - tapes, _ = qml.dynamic_one_shot(tape) - results = dev.execute(tapes)[0] - - # The transformed tape never has a shot vector - if isinstance(shots, list): - shots = sum(shots) +def generate_dummy_raw_results(measure_f, n_mcms, shots, postselect, interface): + """Helper function for generating dummy raw results. Raw results are the output(s) of + executing the transformed tape(s) that are given to the processing function. For + ``dynamic_one_shot``, the first items in the measurements for the transformed tape will + be all the measurements of the original tape that were applied to wires/observables, + and the rest will be SampleMPs on all the mid-circuit measurements in the original tape. - assert_results(results, shots, n_mcms) + In this unit test suite, the original tape will have one measurement on wires/observables, + so the transformed tape will have one measurement on wires/observables, and ``n_mcms`` + ``SampleMP`` measurements on MCMs. + The raw results will be all 1s with the appropriate shape for tests without postselection. + For tests with postselection, the results for the wires/observable measurement and first + ``SampleMP(mcm)`` will be alternating with valid results at odd indices and invalid results + at even indices. The results for the rest of the ``SampleMP(mcm)`` will be all 1s.""" -@pytest.mark.jax -@pytest.mark.parametrize( - "measure_f, expected1, expected2", - [ - (qml.expval, 1.0, 1.0), - (qml.probs, [1, 0], [0, 1]), - (qml.sample, 1, 1), - (qml.var, 0.0, 0.0), - ], -) -@pytest.mark.parametrize("shots", [20, [20, 21]]) -@pytest.mark.parametrize("n_mcms", [1, 3]) -def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2): - """Test that the results of tapes are processed correctly for tapes with jax parameters""" - import jax.numpy as jnp + if postselect is None: + # First raw result for a single shot, i.e, result of wires/obs measurement + obs_res_single_shot = qml.math.array( + [1.0, 0.0] if measure_f == qml.probs else 1.0, like=interface + ) + # Result of SampleMP on mid-circuit measurements + rest_single_shot = qml.math.array(1, like=interface) + single_shot_res = (obs_res_single_shot,) + (rest_single_shot,) * n_mcms + # Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...) + raw_results = (single_shot_res,) * shots - mv = qml.measure(0) - mp = mv.measurements[0] + else: + # When postselecting, we start by creating results for two shots as alternating indices + # will have valid results. + # Alternating tuple. Only the values at odd indices are valid + obs_res_two_shot = ( + (qml.math.array([1.0, 0.0], like=interface), qml.math.array([0.0, 1.0], like=interface)) + if measure_f == qml.probs + else (qml.math.array(1.0, like=interface), qml.math.array(0.0, like=interface)) + ) + obs_res = obs_res_two_shot * (shots // 2) + # Tuple of alternating 1s and 0s. + postselect_res = ( + qml.math.array(int(postselect), like=interface), + qml.math.array(int(not postselect), like=interface), + ) * (shots // 2) + rest = (qml.math.array(1, like=interface),) * shots + # Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM) + raw_results = tuple(zip(obs_res, postselect_res, rest)) + + # Wrap in 1-tuple as there is a single transformed tape unless broadcasting + return (raw_results,) + + +# pylint: disable=too-many-arguments, import-outside-toplevel +@pytest.mark.all_interfaces +@pytest.mark.parametrize("interface", ["autograd", "jax", "tensorflow", "torch", "numpy", None]) +@pytest.mark.parametrize("use_interface_for_results", [True, False]) +class TestInterfaces: + """Unit tests for ML interfaces with dynamic_one_shot""" + + @pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var)) + @pytest.mark.parametrize("shots", [20, [20, 21]]) + @pytest.mark.parametrize("n_mcms", [1, 3]) + def test_interface_tape_results( + self, shots, n_mcms, measure_f, interface, use_interface_for_results + ): # pylint: disable=unused-argument + """Test that the simulation results of a tape are correct with interface parameters""" + if interface == "jax": + from jax.random import PRNGKey + + seed = PRNGKey(123) + else: + seed = 123 + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=seed) + param = qml.math.array(np.pi / 2, like=interface) + + mv = qml.measure(0) + mcms = [mv.measurements[0]] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)] + + tape = qml.tape.QuantumScript( + [qml.RX(param, 0)] + mcms, + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) - tape = qml.tape.QuantumScript( - [qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1), - [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], - shots=shots, + tapes, _ = qml.dynamic_one_shot(tape) + results = dev.execute(tapes)[0] + + # The transformed tape never has a shot vector + if isinstance(shots, list): + shots = sum(shots) + + assert_results(results, shots, n_mcms) + + @pytest.mark.parametrize( + "measure_f, expected1, expected2", + [ + (qml.expval, 1.0, 1.0), + (qml.probs, [1, 0], [0, 1]), + ( + qml.sample, + # The expected results provided for qml.sample are + # just the result of a single shot + 1, + 1, + ), + (qml.var, 0.0, 0.0), + ], ) - _, fn = qml.dynamic_one_shot(tape) - all_shots = sum(shots) if isinstance(shots, list) else shots - - first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0) - rest = jnp.array(1, dtype=int) - single_shot_res = (first_res,) + (rest,) * n_mcms - # Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...) - raw_results = (single_shot_res,) * all_shots - raw_results = (raw_results,) - res = fn(raw_results) - - if measure_f is qml.sample: - # All samples 1 - expected1 = ( - [[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots + @pytest.mark.parametrize("shots", [20, [20, 21]]) + @pytest.mark.parametrize("n_mcms", [1, 3]) + def test_interface_results_processing( + self, shots, n_mcms, measure_f, expected1, expected2, interface, use_interface_for_results + ): + """Test that the results of tapes are processed correctly for tapes with interface + parameters""" + param = qml.math.array(1.5, like=interface) + mv = qml.measure(0) + mcms = [mv.measurements[0]] + [MidMeasureMP(0)] * (n_mcms - 1) + ops = [qml.RX(param, 0)] + mcms + + tape = qml.tape.QuantumScript( + ops, [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], shots=shots ) - expected2 = ( - [[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots + _, fn = qml.dynamic_one_shot(tape) + total_shots = sum(shots) if isinstance(shots, list) else shots + + raw_results = generate_dummy_raw_results( + measure_f=measure_f, + n_mcms=n_mcms, + shots=total_shots, + postselect=None, + interface=interface if use_interface_for_results else None, ) - else: - expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 - expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 - - if isinstance(shots, list): - assert len(res) == len(shots) - for r, e1, e2 in zip(res, expected1, expected2): + processed_results = fn(raw_results) + + if measure_f is qml.sample: + # All samples 1 + expected1 = ( + [[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots + ) + expected2 = ( + [[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots + ) + else: + expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 + expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 + + if use_interface_for_results: + expected_interface = "numpy" if interface in (None, "autograd") else interface + assert qml.math.get_deep_interface(processed_results) == expected_interface + else: + assert qml.math.get_deep_interface(processed_results) == "numpy" + + if isinstance(shots, list): + assert len(processed_results) == len(shots) + for r, e1, e2 in zip(processed_results, expected1, expected2): + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(r, [e1, e2]) + else: # Expected result is 2-list since we have two measurements in the tape - assert qml.math.allclose(r, [e1, e2]) - else: - # Expected result is 2-list since we have two measurements in the tape - assert qml.math.allclose(res, [expected1, expected2]) - - -@pytest.mark.jax -@pytest.mark.parametrize( - "measure_f, expected1, expected2", - [ - (qml.expval, 1.0, 1.0), - (qml.probs, [1, 0], [0, 1]), - (qml.sample, 1, 1), - (qml.var, 0.0, 0.0), - ], -) -@pytest.mark.parametrize("shots", [20, [20, 22]]) -def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2): - """Test that the results of tapes are processed correctly for tapes with jax parameters - when postselecting""" - import jax.numpy as jnp - - param = jnp.array(np.pi / 2) - fill_value = np.iinfo(np.int32).min - mv = qml.measure(0, postselect=1) - mp = mv.measurements[0] - - tape = qml.tape.QuantumScript( - [qml.RX(param, 0), mp, MidMeasureMP(0)], - [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], - shots=shots, - ) - _, fn = qml.dynamic_one_shot(tape) - all_shots = sum(shots) if isinstance(shots, list) else shots - - # Alternating tuple. Only the values at odd indices are valid - first_res_two_shot = ( - (jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0])) - if measure_f == qml.probs - else (jnp.array(1.0), jnp.array(0.0)) + assert qml.math.allclose(processed_results, [expected1, expected2]) + + @pytest.mark.parametrize( + "measure_f, expected1, expected2", + [ + (qml.expval, 1.0, 1.0), + (qml.probs, [1, 0], [0, 1]), + ( + qml.sample, + # The expected results provided for qml.sample are + # just the result of a single shot + 1, + 1, + ), + (qml.var, 0.0, 0.0), + ], ) - first_res = first_res_two_shot * (all_shots // 2) - # Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1 - postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2) - rest = (jnp.array(1, dtype=int),) * all_shots - # Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM) - raw_results = tuple(zip(first_res, postselect_res, rest)) - raw_results = (raw_results,) - res = fn(raw_results) - - if measure_f is qml.sample: - expected1 = ( - [[expected1, fill_value] * (s // 2) for s in shots] - if isinstance(shots, list) - else [expected1, fill_value] * (shots // 2) + @pytest.mark.parametrize("shots", [20, [20, 22]]) + def test_interface_results_postselection_processing( + self, shots, measure_f, expected1, expected2, interface, use_interface_for_results + ): + """Test that the results of tapes are processed correctly for tapes with interface + parameters when postselecting""" + postselect = 1 + param = qml.math.array(np.pi / 2, like=interface) + mv = qml.measure(0, postselect=postselect) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(param, 0), mp, MidMeasureMP(0)], + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, ) - expected2 = ( - [[expected2, fill_value] * (s // 2) for s in shots] - if isinstance(shots, list) - else [expected2, fill_value] * (shots // 2) + _, fn = qml.dynamic_one_shot( + tape, postselect_mode="pad-invalid-samples" if interface == "jax" else None ) - else: - expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 - expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 - - if isinstance(shots, list): - assert len(res) == len(shots) - for r, e1, e2 in zip(res, expected1, expected2): + total_shots = sum(shots) if isinstance(shots, list) else shots + + raw_results = generate_dummy_raw_results( + measure_f=measure_f, + n_mcms=2, + shots=total_shots, + postselect=postselect, + interface=interface if use_interface_for_results else None, + ) + processed_results = fn(raw_results) + + if measure_f is qml.sample: + if interface == "jax": + expected1 = [expected1, fill_in_value] + expected2 = [expected2, fill_in_value] + else: + expected1 = [expected1] + expected2 = [expected2] + expected1 = ( + [expected1 * (s // 2) for s in shots] + if isinstance(shots, list) + else expected1 * (shots // 2) + ) + expected2 = ( + [expected2 * (s // 2) for s in shots] + if isinstance(shots, list) + else expected2 * (shots // 2) + ) + + else: + expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 + expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 + + if use_interface_for_results: + expected_interface = "numpy" if interface in (None, "autograd") else interface + assert qml.math.get_deep_interface(processed_results) == expected_interface + else: + assert qml.math.get_deep_interface(processed_results) == "numpy" + + if isinstance(shots, list): + assert len(processed_results) == len(shots) + for r, e1, e2 in zip(processed_results, expected1, expected2): + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(r, [e1, e2]) + else: # Expected result is 2-list since we have two measurements in the tape - assert qml.math.allclose(r, [e1, e2]) - else: - # Expected result is 2-list since we have two measurements in the tape - assert qml.math.allclose(res, [expected1, expected2]) + assert qml.math.allclose(processed_results, [expected1, expected2])