Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix native mcm workflow following dynamic_one_shot refactor. #694

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@

### Bug fixes

* `dynamic_one_shot` was refactored to use `SampleMP` measurements as a way to return the mid-circuit measurement samples. `LightningQubit`'s `simulate` is modified accordingly.
[(#694)](https://github.com/PennyLaneAI/pennylane/pull/694)

* `LightningQubit` correctly decomposes state prep operations when used in the middle of a circuit.
[(#687)](https://github.com/PennyLaneAI/pennylane/pull/687)

Expand Down
22 changes: 19 additions & 3 deletions pennylane_lightning/lightning_qubit/_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,15 @@
"""
return self.get_measurement_function(measurementprocess)(measurementprocess)

def measure_final_state(self, circuit: QuantumScript) -> Result:
def measure_final_state(self, circuit: QuantumScript, mid_measurements=None) -> Result:
"""
Perform the measurements required by the circuit on the provided state.

This is an internal function that will be called by the successor to ``lightning.qubit``.

Args:
circuit (QuantumScript): The single circuit to simulate
mid_measurements (None, dict): Dictionary of mid-circuit measurements

Returns:
Tuple[TensorLike]: The measurement results
Expand All @@ -272,6 +273,7 @@
results = self.measure_with_samples(
circuit.measurements,
shots=circuit.shots,
mid_measurements=mid_measurements,
)

if len(circuit.measurements) == 1:
Expand All @@ -285,27 +287,37 @@
# pylint:disable = too-many-arguments
def measure_with_samples(
self,
mps: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]],
measurements: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]],
shots: Shots,
mid_measurements=None,
) -> List[TensorLike]:
"""
Returns the samples of the measurement process performed on the given state.
This function assumes that the user-defined wire labels in the measurement process
have already been mapped to integer wires used in the device.

Args:
mps (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]):
measurements (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]):
The sample measurements to perform
shots (Shots): The number of samples to take
mid_measurements (None, dict): Dictionary of mid-circuit measurements

Returns:
List[TensorLike[Any]]: Sample measurement results
"""
# last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode
mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
skip_measure = (

Check warning on line 310 in pennylane_lightning/lightning_qubit/_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/_measurements.py#L309-L310

Added lines #L309 - L310 were not covered by tests
any(v == -1 for v in mid_measurements.values()) if mid_measurements else False
)

groups, indices = _group_measurements(mps)

all_res = []
for group in groups:
if skip_measure:
all_res.extend([None] * len(group))
continue

Check warning on line 320 in pennylane_lightning/lightning_qubit/_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/_measurements.py#L318-L320

Added lines #L318 - L320 were not covered by tests
maliasadi marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(group[0], (ExpectationMP, VarianceMP)) and isinstance(
group[0].obs, SparseHamiltonian
):
Expand Down Expand Up @@ -333,6 +345,10 @@
res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]])
)

# append MCM samples
if mid_measurements:
sorted_res += tuple(mid_measurements.values())

Check warning on line 350 in pennylane_lightning/lightning_qubit/_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/_measurements.py#L349-L350

Added lines #L349 - L350 were not covered by tests

# put the shot vector axis before the measurement axis
if shots.has_partitioned_shots:
sorted_res = tuple(zip(*sorted_res))
Expand Down
7 changes: 2 additions & 5 deletions pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@
if circuit.shots and has_mcm:
mid_measurements = {}
final_state = state.get_final_state(circuit, mid_measurements=mid_measurements)
if any(v == -1 for v in mid_measurements.values()):
return None, mid_measurements
return (
LightningMeasurements(final_state, **mcmc).measure_final_state(circuit),
mid_measurements,
return LightningMeasurements(final_state, **mcmc).measure_final_state(

Check warning on line 83 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L83

Added line #L83 was not covered by tests
circuit, mid_measurements=mid_measurements
)
final_state = state.get_final_state(circuit)
return LightningMeasurements(final_state, **mcmc).measure_final_state(circuit)
Expand Down
Loading