From 5bd3134a1db519365078a80d9fa6dd7982980a1b Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Tue, 23 Apr 2024 20:50:26 -0400 Subject: [PATCH] Fix native mcm workflow following dynamic_one_shot refactor. (#694) * Fix native mcm workflow following dynamic_one_shot refactor. * Auto update version from '0.36.0-dev35' to '0.36.0-dev36' * Update changelog. --------- Co-authored-by: ringo-but-quantum --- .github/CHANGELOG.md | 3 +++ .../lightning_qubit/_measurements.py | 22 ++++++++++++++++--- .../lightning_qubit/lightning_qubit.py | 7 ++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index eea553de19..2778ea71ba 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -86,6 +86,9 @@ ### Bug fixes +* `dynamic_one_shot` was refactored to use `SampleMP` measurements as a way to return the mid-circuit measurement samples. `LightningQubit`'s `simulate` is modified accordingly. + [(#694)](https://github.com/PennyLaneAI/pennylane/pull/694) + * `LightningQubit` correctly decomposes state prep operations when used in the middle of a circuit. [(#687)](https://github.com/PennyLaneAI/pennylane/pull/687) diff --git a/pennylane_lightning/lightning_qubit/_measurements.py b/pennylane_lightning/lightning_qubit/_measurements.py index 264a4b2bda..52d3c154b4 100644 --- a/pennylane_lightning/lightning_qubit/_measurements.py +++ b/pennylane_lightning/lightning_qubit/_measurements.py @@ -248,7 +248,7 @@ def measurement(self, measurementprocess: MeasurementProcess) -> TensorLike: """ return self.get_measurement_function(measurementprocess)(measurementprocess) - def measure_final_state(self, circuit: QuantumScript) -> Result: + def measure_final_state(self, circuit: QuantumScript, mid_measurements=None) -> Result: """ Perform the measurements required by the circuit on the provided state. @@ -256,6 +256,7 @@ def measure_final_state(self, circuit: QuantumScript) -> Result: Args: circuit (QuantumScript): The single circuit to simulate + mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: Tuple[TensorLike]: The measurement results @@ -272,6 +273,7 @@ def measure_final_state(self, circuit: QuantumScript) -> Result: results = self.measure_with_samples( circuit.measurements, shots=circuit.shots, + mid_measurements=mid_measurements, ) if len(circuit.measurements) == 1: @@ -285,8 +287,9 @@ def measure_final_state(self, circuit: QuantumScript) -> Result: # pylint:disable = too-many-arguments def measure_with_samples( self, - mps: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], + measurements: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], shots: Shots, + mid_measurements=None, ) -> List[TensorLike]: """ Returns the samples of the measurement process performed on the given state. @@ -294,18 +297,27 @@ def measure_with_samples( have already been mapped to integer wires used in the device. Args: - mps (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): + measurements (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): The sample measurements to perform shots (Shots): The number of samples to take + mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: List[TensorLike[Any]]: Sample measurement results """ + # last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode + mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements + skip_measure = ( + any(v == -1 for v in mid_measurements.values()) if mid_measurements else False + ) groups, indices = _group_measurements(mps) all_res = [] for group in groups: + if skip_measure: + all_res.extend([None] * len(group)) + continue if isinstance(group[0], (ExpectationMP, VarianceMP)) and isinstance( group[0].obs, SparseHamiltonian ): @@ -333,6 +345,10 @@ def measure_with_samples( res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]]) ) + # append MCM samples + if mid_measurements: + sorted_res += tuple(mid_measurements.values()) + # put the shot vector axis before the measurement axis if shots.has_partitioned_shots: sorted_res = tuple(zip(*sorted_res)) diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index bfefacb4bd..8350804c18 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -80,11 +80,8 @@ def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = N if circuit.shots and has_mcm: mid_measurements = {} final_state = state.get_final_state(circuit, mid_measurements=mid_measurements) - if any(v == -1 for v in mid_measurements.values()): - return None, mid_measurements - return ( - LightningMeasurements(final_state, **mcmc).measure_final_state(circuit), - mid_measurements, + return LightningMeasurements(final_state, **mcmc).measure_final_state( + circuit, mid_measurements=mid_measurements ) final_state = state.get_final_state(circuit) return LightningMeasurements(final_state, **mcmc).measure_final_state(circuit)