Skip to content

Commit

Permalink
Add MCM support in LQ (new device API) (#650)
Browse files Browse the repository at this point in the history
* remove old comment

* some review suggestions

* Auto update version

* remove print

* update state vector class

* add state vector class tests

* adding measurement tests

* update state vector and tests

* move and rename test files, and format

* Auto update version

* skip measurements class for other devices and in the absence of binaries

* format

* add LightningQubit2 to init and format

* update measurements class

* expand measurement class testing

* garbage collection

* typo

* update coverage and StateVector class

* expand measurements class coverage

* Auto update version

* add coverage for n-controlled operations

* add map to standard wires to get_final_state for safety

* update jax config import

* Auto update version

* trigger CI

* update state vector class and tests for improved coverage

* update measurement class tests

* update dev version

* add cpp binary available variable

* remove device definition

* update dev version

* Auto update version

* reduce dependency on DefaultQubit for tests

* update LightningQubit2

* clean test_measurements_class.py

* isort+black

* review suggestion

* fix docs

* Add qml.var support.

* Add probs support.

* increase tolerance

* Auto update version

* isort

* Add double-obs tests.

* Pin pytest version (#624)

* update dev version

* update changelog

* pin pytest version in requirement files

* add a requirements file for tests against Pennylane master

* update wheels' workflows

* Version Bump (#626)

* post release version bump

* trigger CI

---------

Co-authored-by: AmintorDusko <[email protected]>
Co-authored-by: AmintorDusko <[email protected]>

* increase tolerance

* Introduce isort. (#623)

* Introduce isort.

* Auto update version

* Update changelog

* Auto update version

* Update changelog.

* trigger ci

---------

Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>

* Auto update version

* isort

* Add qml.var support.

* Add probs support.

* Add measurement tests with wires.

* review suggestions

* remove unused imports

* Introduce _new_API and fix/skip few tests.

* Fix few more tests.

* Skip shots, adjoint, vjp with new API.

* remove diagonalization gate application from state vector

* pytest.skip tests

* Auto update version

* Fix format

* Fix no-bin interface.

* WIP

* Initial shots support + fix test_measurement tests.

* update

* adding tests from add-simulate branch

* merge conflicts

* create state vector on initialization

* remove import of modifier from lightning

* Update pennylane_lightning/lightning_qubit/lightning_qubit2.py

* minor test updates

* register with setup.py, state vector fixes

* add LightningQubit2 to init and format

* add cpp binary available variable

* reduce dependency on DefaultQubit for tests

* update LightningQubit2

* Fixing rebase artifacts

* Add fewLQ2 tests.

* remove adjoint diff support from supports derivatives

* Remove print from test_apply

* Add expval/var tests.

* Remove duplicate class data.

* Include LQ2 in linux ests.

* Add _group_measurements support.

* --cov-append

* Add mcmc capability + tests.

* Auto update version

* update dev version

* add LightningAdjointJacobian class

* add unit tests for the LightningAdjointJacobian class

* format

* add changelog for PR #613

* [skip ci] Added skeleton file for LQ2 unit tests

* update changelog

* update adjoint Jacobian

* Auto update version

* codefactor

* Add shots tests and fix bugs in LQ, LQ2.

* Lightning qubit2 upgrade api (#628)

* update

* adding tests from add-simulate branch

* merge conflicts

* create state vector on initialization

* remove import of modifier from lightning

* Update pennylane_lightning/lightning_qubit/lightning_qubit2.py

* minor test updates

* register with setup.py, state vector fixes

* add LightningQubit2 to init and format

* add cpp binary available variable

* Auto update version

* reduce dependency on DefaultQubit for tests

* update LightningQubit2

* Introduce _new_API and fix/skip few tests.

* Fix few more tests.

* Skip shots, adjoint, vjp with new API.

* Fix no-bin interface.

* Remove duplicate class data.

* Include LQ2 in linux ests.

* --cov-append

---------

Co-authored-by: albi3ro <[email protected]>
Co-authored-by: AmintorDusko <[email protected]>
Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>

* fix processing_fn_expval

* make a proper new_tape

* Added init tests; Added skeleton tests for helpers

* Fix more bug with shots.

* trigger CI

* Change pennylane branch for CI.

* Update .github/CHANGELOG.md

Co-authored-by: Vincent Michaud-Rioux <[email protected]>

* Update pennylane_lightning/lightning_qubit/_adjoint_jacobian.py

Co-authored-by: Vincent Michaud-Rioux <[email protected]>

* Update pennylane_lightning/lightning_qubit/_adjoint_jacobian.py

Co-authored-by: Vincent Michaud-Rioux <[email protected]>

* Add probs support.

* Add double-obs tests.

* Add qml.var support.

* Add probs support.

* Add measurement tests with wires.

* pytest.skip tests

* Fix format

* update

* adding tests from add-simulate branch

* merge conflicts

* create state vector on initialization

* remove import of modifier from lightning

* Update pennylane_lightning/lightning_qubit/lightning_qubit2.py

* minor test updates

* register with setup.py, state vector fixes

* add LightningQubit2 to init and format

* add cpp binary available variable

* reduce dependency on DefaultQubit for tests

* update LightningQubit2

* Fixing rebase artifacts

* remove adjoint diff support from supports derivatives

* [skip ci] Added skeleton file for LQ2 unit tests

* Lightning qubit2 upgrade api (#628)

* update

* adding tests from add-simulate branch

* merge conflicts

* create state vector on initialization

* remove import of modifier from lightning

* Update pennylane_lightning/lightning_qubit/lightning_qubit2.py

* minor test updates

* register with setup.py, state vector fixes

* add LightningQubit2 to init and format

* add cpp binary available variable

* Auto update version

* reduce dependency on DefaultQubit for tests

* update LightningQubit2

* Introduce _new_API and fix/skip few tests.

* Fix few more tests.

* Skip shots, adjoint, vjp with new API.

* Fix no-bin interface.

* Remove duplicate class data.

* Include LQ2 in linux ests.

* --cov-append

---------

Co-authored-by: albi3ro <[email protected]>
Co-authored-by: AmintorDusko <[email protected]>
Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>

* Added init tests; Added skeleton tests for helpers

* Resolving rebase artifacts

* Refactor shots test.

* Added tests; integrated jacobian

* Update pennylane_lightning/lightning_qubit/lightning_qubit2.py

Co-authored-by: Amintor Dusko <[email protected]>

* Auto update version

* Small update to simulate_and_jacobian

* Auto update version

* Rerun isort.

* Uncomment integration tests.

* Reformat

* Delete symlink

* Fix pylint.

* Run linux tests in parallel (when possible).

* Run double obs tests with shots.

* Revert linux tests

* Fix bg in diag_gates.

* Call isort/black with python -m

* update dev version

* Add docstrings, rm C_DTYPE.

* Auto update version

* comment isort check

* trigger ci

* Update tests/test_expval.py

Co-authored-by: Amintor Dusko <[email protected]>

* Init mcmc params to None in measurements.

* Reformat with python3.11

* Reformat black

* Auto update version

* update QuantumScriptSerializer

* remove LightningQubit2 from init

* update setup.py

* remove lightning.qubit2 from tests configuration

* remove extra tests for lightning.qubit2

* migrate lightning.qubit2 to lightning.qubit on tests

* make lightning.qubit2 the new lightning.qubit

* add device name (necessary for pl-device-test)

* Add _measure_hamiltonian_with_samples _measure_sum_with_samples

* fix tests without binary

* check for jac size before reshaping

* remove obsolete tests

* organize tests

* fix test for Windows wheels

* fix the fix for LightningKokkos

* Add MCM support initial work.

* Update changelog.

* Fix test_preprocess

* Try & parallelize pytest

* Increase timeout to 60 min.

* Limit OMP_NUM_THREADS.

* Do not parallelize Kokkos tests.

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Auto update version

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Amintor Dusko <[email protected]>

* format

* Auto update version

* trigger ci

* Update Makefile

Co-authored-by: Ali Asadi <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Ali Asadi <[email protected]>

* Update tests/test_native_mcm.py

Co-authored-by: Ali Asadi <[email protected]>

* Fix docstrings.

* Auto update version

* Update format.yml

* Update format.yml

---------

Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: AmintorDusko <[email protected]>
Co-authored-by: Amintor Dusko <[email protected]>
Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: AmintorDusko <[email protected]>
Co-authored-by: Mudit Pandey <[email protected]>
Co-authored-by: Ali Asadi <[email protected]>
  • Loading branch information
9 people authored Mar 22, 2024
1 parent 46fe8df commit 5f95a2f
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 49 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

### New features since last release

* `lightning.qubit` supports mid-circuit measurements.
[(#650)](https://github.com/PennyLaneAI/pennylane-lightning/pull/650)

* Add finite shots support in `lightning.qubit2`.
[(#630)](https://github.com/PennyLaneAI/pennylane-lightning/pull/630)

Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ jobs:
with:
python-version: '3.11'

- name: Install dependencies
run:
python -m pip install click==8.0.4 black==23.7.0 isort==5.13.2

- name: Checkout PennyLane-Lightning
uses: actions/checkout@v3

- name: Install dependencies
run:
python -m pip install -r requirements-dev.txt

- name: Run isort & black --check
run: make format-python check=1 verbose=1

Expand Down
27 changes: 15 additions & 12 deletions .github/workflows/tests_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
matrix:
os: [ubuntu-22.04]
pl_backend: ["lightning_qubit"]
timeout-minutes: 30
timeout-minutes: 60
name: C++ tests
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -106,7 +106,7 @@ jobs:
matrix:
os: [ubuntu-22.04]
pl_backend: ["lightning_qubit"]
timeout-minutes: 30
timeout-minutes: 60
name: Python tests
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -184,7 +184,8 @@ jobs:
run: |
cd main/
DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"`
PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS
OMP_NUM_THREADS=1 PL_DEVICE=${DEVICENAME} python -m pytest -n auto tests/ -k "not unitary_correct" $COVERAGE_FLAGS
PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "unitary_correct" $COVERAGE_FLAGS --cov-append
pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }}
Expand All @@ -203,7 +204,7 @@ jobs:
matrix:
os: [ubuntu-22.04]
pl_backend: ["lightning_qubit"]
timeout-minutes: 30
timeout-minutes: 60
name: C++ tests (OpenBLAS)
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -259,7 +260,7 @@ jobs:
matrix:
os: [ubuntu-22.04]
pl_backend: ["lightning_qubit"]
timeout-minutes: 30
timeout-minutes: 60
name: C++ tests (OpenBLAS without LAPACK)
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -314,7 +315,7 @@ jobs:
matrix:
os: [ubuntu-22.04]
pl_backend: ["lightning_qubit"]
timeout-minutes: 30
timeout-minutes: 60
name: Python tests with OpenBLAS
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -392,7 +393,8 @@ jobs:
run: |
cd main/
DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"`
PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS
OMP_NUM_THREADS=1 PL_DEVICE=${DEVICENAME} python -m pytest -n auto tests/ -k "not unitary_correct" $COVERAGE_FLAGS
PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "unitary_correct" $COVERAGE_FLAGS --cov-append
pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }}
Expand All @@ -419,7 +421,7 @@ jobs:
pl_backend: ["lightning_kokkos"]
exec_model: ${{ fromJson(needs.build_and_cache_Kokkos.outputs.exec_model) }}
kokkos_version: ${{ fromJson(needs.build_and_cache_Kokkos.outputs.kokkos_version) }}
timeout-minutes: 30
timeout-minutes: 60
name: C++ tests (Kokkos)
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -494,7 +496,7 @@ jobs:
exclude:
- pl_backend: ["all"]
exec_model: OPENMP
timeout-minutes: 30
timeout-minutes: 60
name: Python tests with Kokkos
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -609,10 +611,11 @@ jobs:
if: ${{ matrix.pl_backend == 'all' }}
run: |
cd main/
PL_DEVICE=lightning.qubit python -m pytest tests/ $COVERAGE_FLAGS
OMP_NUM_THREADS=1 PL_DEVICE=lightning.qubit python -m pytest -n auto tests/ -k "not unitary_correct" $COVERAGE_FLAGS
PL_DEVICE=lightning.qubit python -m pytest tests/ -k "unitary_correct" $COVERAGE_FLAGS --cov-append
pl-device-test --device lightning.qubit --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
pl-device-test --device lightning.qubit --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
PL_DEVICE=lightning.kokkos python -m pytest tests/ $COVERAGE_FLAGS
PL_DEVICE=lightning.kokkos python -m pytest tests/ $COVERAGE_FLAGS --cov-append
pl-device-test --device lightning.kokkos --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
pl-device-test --device lightning.kokkos --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }}
Expand Down Expand Up @@ -682,7 +685,7 @@ jobs:
os: [ubuntu-22.04]
exec_model: ${{ fromJson(needs.build_and_cache_Kokkos.outputs.exec_model) }}
kokkos_version: ${{ fromJson(needs.build_and_cache_Kokkos.outputs.kokkos_version) }}
timeout-minutes: 30
timeout-minutes: 60
name: C++ tests (multiple backends)
runs-on: ${{ matrix.os }}

Expand Down
8 changes: 2 additions & 6 deletions mpitests/test_adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,9 +1396,7 @@ def test_qubit_unitary(dev, n_targets):

np.random.seed(1337)
par = 2 * np.pi * np.random.rand(n_wires)
U = np.random.rand(2**n_targets, 2**n_targets) + 1j * np.random.rand(
2**n_targets, 2**n_targets
)
U = np.random.rand(2**n_targets, 2**n_targets) + 1j * np.random.rand(2**n_targets, 2**n_targets)
U, _ = np.linalg.qr(U)
init_state = np.random.rand(2**n_wires) + 1j * np.random.rand(2**n_wires)
init_state /= np.sqrt(np.dot(np.conj(init_state), init_state))
Expand Down Expand Up @@ -1446,9 +1444,7 @@ def test_diff_qubit_unitary(dev, n_targets):

np.random.seed(1337)
par = 2 * np.pi * np.random.rand(n_wires)
U = np.random.rand(2**n_targets, 2**n_targets) + 1j * np.random.rand(
2**n_targets, 2**n_targets
)
U = np.random.rand(2**n_targets, 2**n_targets) + 1j * np.random.rand(2**n_targets, 2**n_targets)
U, _ = np.linalg.qr(U)
init_state = np.random.rand(2**n_wires) + 1j * np.random.rand(2**n_wires)
init_state /= np.sqrt(np.dot(np.conj(init_state), init_state))
Expand Down
4 changes: 1 addition & 3 deletions pennylane_lightning/core/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,7 @@ def serialize_observables(self, tape: QuantumTape, wires_map: dict = None) -> Li
offset_indices.append(offset_indices[-1] + 1)
return serialized_obs, offset_indices

def serialize_ops(
self, tape: QuantumTape, wires_map: dict = None
) -> Tuple[
def serialize_ops(self, tape: QuantumTape, wires_map: dict = None) -> Tuple[
List[List[str]],
List[np.ndarray],
List[List[int]],
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.36.0-dev16"
__version__ = "0.36.0-dev17"
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
}
},
"Copy StateVector data into a Numpy array.")
.def("collapse", &StateVectorT::collapse,
"Collapse the statevector onto the 0 or 1 branch of a given wire.")
.def("normalize", &StateVectorT::normalize,
"Normalizes the statevector to norm 1.")
.def("applyControlledMatrix", &applyControlledMatrix<StateVectorT>,
"Apply controlled operation")
.def("kernel_map", &svKernelMap<StateVectorT>,
Expand Down
4 changes: 1 addition & 3 deletions pennylane_lightning/lightning_qubit/_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@

from pennylane_lightning.core._serialize import QuantumScriptSerializer

from ._state_vector import LightningStateVector


class LightningMeasurements:
"""Lightning Measurements class
Expand All @@ -69,7 +67,7 @@ class LightningMeasurements:

def __init__(
self,
qubit_state: LightningStateVector,
qubit_state,
mcmc: bool = None,
kernel_name: str = None,
num_burnin: int = None,
Expand Down
50 changes: 41 additions & 9 deletions pennylane_lightning/lightning_qubit/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
import numpy as np
import pennylane as qml
from pennylane import BasisState, DeviceError, StatePrep
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Conditional
from pennylane.ops.op_math import Adjoint
from pennylane.tape import QuantumScript
from pennylane.wires import Wires

from ._measurements import LightningMeasurements


class LightningStateVector:
"""Lightning state-vector class.
Expand Down Expand Up @@ -220,10 +224,10 @@ def _apply_lightning_controlled(self, operation):
"""Apply an arbitrary controlled operation to the state tensor.
Args:
operation (~pennylane.operation.Operation): operation to apply
operation (~pennylane.operation.Operation): controlled operation to apply
Returns:
array[complex]: the output state tensor
None
"""
state = self.state_vector

Expand All @@ -246,14 +250,36 @@ def _apply_lightning_controlled(self, operation):
False,
)

def _apply_lightning(self, operations):
def _apply_lightning_midmeasure(self, operation: MidMeasureMP, mid_measurements: dict):
"""Execute a MidMeasureMP operation and return the sample in mid_measurements.
Args:
operation (~pennylane.operation.Operation): mid-circuit measurement
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
None
"""
wires = self.wires.indices(operation.wires)
wire = list(wires)[0]
circuit = QuantumScript([], [qml.sample(wires=operation.wires)], shots=1)
sample = LightningMeasurements(self).measure_final_state(circuit)
sample = np.squeeze(sample)
if operation.postselect is not None and sample != operation.postselect:
mid_measurements[operation] = -1
return
mid_measurements[operation] = sample
getattr(self.state_vector, "collapse")(wire, bool(sample))
if operation.reset and bool(sample):
self.apply_operations([qml.PauliX(operation.wires)], mid_measurements=mid_measurements)

def _apply_lightning(self, operations, mid_measurements: dict = None):
"""Apply a list of operations to the state tensor.
Args:
operations (list[~pennylane.operation.Operation]): operations to apply
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
array[complex]: the output state tensor
None
"""
state = self.state_vector

Expand All @@ -271,7 +297,12 @@ def _apply_lightning(self, operations):
method = getattr(state, name, None)
wires = list(operation.wires)

if method is not None: # apply specialized gate
if isinstance(operation, Conditional):
if operation.meas_val.concretize(mid_measurements):
self._apply_lightning([operation.then_op])
elif isinstance(operation, MidMeasureMP):
self._apply_lightning_midmeasure(operation, mid_measurements)
elif method is not None: # apply specialized gate
param = operation.parameters
method(wires, invert_param, param)
elif isinstance(operation, qml.ops.Controlled): # apply n-controlled gate
Expand All @@ -286,7 +317,7 @@ def _apply_lightning(self, operations):
# To support older versions of PL
method(operation.matrix, wires, False)

def apply_operations(self, operations):
def apply_operations(self, operations, mid_measurements: dict = None):
"""Applies operations to the state vector."""
# State preparation is currently done in Python
if operations: # make sure operations[0] exists
Expand All @@ -297,21 +328,22 @@ def apply_operations(self, operations):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]

self._apply_lightning(operations)
self._apply_lightning(operations, mid_measurements=mid_measurements)

def get_final_state(self, circuit: QuantumScript):
def get_final_state(self, circuit: QuantumScript, mid_measurements: dict = None):
"""
Get the final state that results from executing the given quantum script.
This is an internal function that will be called by the successor to ``lightning.qubit``.
Args:
circuit (QuantumScript): The single circuit to simulate
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
LightningStateVector: Lightning final state class.
"""
self.apply_operations(circuit.operations)
self.apply_operations(circuit.operations, mid_measurements=mid_measurements)

return self
16 changes: 15 additions & 1 deletion pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from pennylane.devices.modifiers import simulator_tracking, single_tape_support
from pennylane.devices.preprocess import (
decompose,
mid_circuit_measurements,
no_sampling,
validate_adjoint_trainable_params,
validate_device_wires,
validate_measurements,
validate_observables,
)
from pennylane.measurements import MidMeasureMP
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms.core import TransformProgram
from pennylane.typing import Result, ResultBatch
Expand Down Expand Up @@ -71,6 +73,16 @@ def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = N
if mcmc is None:
mcmc = {}
state.reset_state()
has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
mid_measurements = {}
final_state = state.get_final_state(circuit, mid_measurements=mid_measurements)
if any(v == -1 for v in mid_measurements.values()):
return None, mid_measurements
return (
LightningMeasurements(final_state, **mcmc).measure_final_state(circuit),
mid_measurements,
)
final_state = state.get_final_state(circuit)
return LightningMeasurements(final_state, **mcmc).measure_final_state(circuit)

Expand Down Expand Up @@ -200,6 +212,8 @@ def simulate_and_jacobian(circuit: QuantumTape, state: LightningStateVector, bat
"QFT",
"ECR",
"BlockEncode",
"MidMeasureMP",
"Conditional",
}
)
# The set of supported operations.
Expand Down Expand Up @@ -432,7 +446,7 @@ def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig)
program.add_transform(validate_measurements, name=self.name)
program.add_transform(validate_observables, accepted_observables, name=self.name)
program.add_transform(validate_device_wires, self.wires, name=self.name)
program.add_transform(qml.defer_measurements, device=self)
program.add_transform(mid_circuit_measurements, device=self)
program.add_transform(decompose, stopping_condition=stopping_condition, name=self.name)
program.add_transform(qml.transforms.broadcast_expand)

Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def validate_samples(shots, results1, results2):
np.allclose(np.sum(results1), np.sum(results2), rtol=20, atol=0.2)


def validate_expval(shots, results1, results2):
def validate_others(shots, results1, results2):
"""Compares two expval, probs or var.
If the results are ``Sequence``s, validate the average of items.
Expand All @@ -210,7 +210,7 @@ def validate_expval(shots, results1, results2):
assert len(results1) == len(results2)
results1 = reduce(lambda x, y: x + y, results1) / len(results1)
results2 = reduce(lambda x, y: x + y, results2) / len(results2)
validate_expval(shots, results1, results2)
validate_others(shots, results1, results2)
return
if shots is None:
assert np.allclose(results1, results2)
Expand All @@ -228,4 +228,4 @@ def validate_measurements(func, shots, results1, results2):
validate_samples(shots, results1, results2)
return

validate_expval(shots, results1, results2)
validate_others(shots, results1, results2)
Loading

0 comments on commit 5f95a2f

Please sign in to comment.