diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index deffd4fb50..027277e16c 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -15,6 +15,9 @@ * Optimize gate cache recording for `lightning.tensor` C++ layer. [(#879)](https://github.com/PennyLaneAI/pennylane-lightning/pull/879) +* Smarter defaults for the `split_obs` argument in the serializer. The serializer splits linear combinations into chunks instead of all their terms. + [(#873)](https://github.com/PennyLaneAI/pennylane-lightning/pull/873/) + ### Documentation ### Bug fixes @@ -23,7 +26,7 @@ This release contains contributions from (in alphabetical order): -Amintor Dusko, Luis Alfredo Nuñez Meneses, Shuli Shu +Amintor Dusko, Luis Alfredo Nuñez Meneses, Vincent Michaud-Rioux, Shuli Shu --- diff --git a/.github/workflows/tests_lgpu_cpp.yml b/.github/workflows/tests_lgpu_cpp.yml index cba7cc9c35..d828c23107 100644 --- a/.github/workflows/tests_lgpu_cpp.yml +++ b/.github/workflows/tests_lgpu_cpp.yml @@ -68,7 +68,7 @@ jobs: strategy: matrix: pl_backend: ["lightning_gpu"] - enable_lapack: ["OFF", "ON"] + enable_lapack: ["OFF"] cuda_version: ["12"] name: C++ Tests (${{ matrix.pl_backend }}, cuda-${{ matrix.cuda_version }}, enable_lapack=${{ matrix.enable_lapack }}) diff --git a/pennylane_lightning/core/_serialize.py b/pennylane_lightning/core/_serialize.py index 1ec5e5ddd1..edc60b667c 100644 --- a/pennylane_lightning/core/_serialize.py +++ b/pennylane_lightning/core/_serialize.py @@ -54,7 +54,7 @@ class QuantumScriptSerializer: device_name: device shortname. use_csingle (bool): whether to use np.complex64 instead of np.complex128 use_mpi (bool, optional): If using MPI to accelerate calculation. Defaults to False. - split_obs (bool, optional): If splitting the observables in a list. Defaults to False. + split_obs (Union[bool, int], optional): If splitting the observables in a list. Defaults to False. """ @@ -214,17 +214,40 @@ def _tensor_ob(self, observable, wires_map: dict = None): obs = observable.obs if isinstance(observable, Tensor) else observable.operands return self.tensor_obs([self._ob(o, wires_map) for o in obs]) + def _chunk_ham_terms(self, coeffs, ops, split_num: int = 1) -> List: + "Create split_num sub-Hamiltonians from a single high term-count Hamiltonian" + num_terms = len(coeffs) + iperm = np.argsort(np.array([len(op.get_wires()) for op in ops])) + coeffs = [coeffs[i] for i in iperm] + ops = [ops[i] for i in iperm] + c_coeffs = [ + tuple(coeffs[slice(i, num_terms, split_num)]) for i in range(min(num_terms, split_num)) + ] + c_ops = [ + tuple(ops[slice(i, num_terms, split_num)]) for i in range(min(num_terms, split_num)) + ] + return c_coeffs, c_ops + def _hamiltonian(self, observable, wires_map: dict = None): coeffs, ops = observable.terms() coeffs = np.array(unwrap(coeffs)).astype(self.rtype) + if self.split_obs: + ops_l = [] + for t in ops: + term_cpp = self._ob(t, wires_map) + if isinstance(term_cpp, Sequence): + ops_l.extend(term_cpp) + else: + ops_l.append(term_cpp) + c, o = self._chunk_ham_terms(coeffs, ops_l, self.split_obs) + hams = [self.hamiltonian_obs(c_coeffs, c_obs) for (c_coeffs, c_obs) in zip(c, o)] + return hams + terms = [self._ob(t, wires_map) for t in ops] # TODO: This is in case `_hamiltonian` is called recursively which would cause a list # to be passed where `_ob` expects an observable. terms = [t[0] if isinstance(t, Sequence) and len(t) == 1 else t for t in terms] - if self.split_obs: - return [self.hamiltonian_obs([c], [t]) for (c, t) in zip(coeffs, terms)] - return self.hamiltonian_obs(coeffs, terms) def _sparse_hamiltonian(self, observable, wires_map: dict = None): @@ -282,11 +305,14 @@ def _pauli_sentence(self, observable, wires_map: dict = None): terms = [self._pauli_word(pw, wires_map) for pw in pwords] coeffs = np.array(coeffs).astype(self.rtype) + if self.split_obs: + c, o = self._chunk_ham_terms(coeffs, terms, self.split_obs) + psentences = [self.hamiltonian_obs(c_coeffs, c_obs) for (c_coeffs, c_obs) in zip(c, o)] + return psentences + if len(terms) == 1 and coeffs[0] == 1.0: return terms[0] - if self.split_obs: - return [self.hamiltonian_obs([c], [t]) for (c, t) in zip(coeffs, terms)] return self.hamiltonian_obs(coeffs, terms) # pylint: disable=protected-access, too-many-return-statements @@ -326,17 +352,17 @@ def serialize_observables(self, tape: QuantumTape, wires_map: dict = None) -> Li """ serialized_obs = [] - offset_indices = [0] + obs_indices = [] - for observable in tape.observables: + for i, observable in enumerate(tape.observables): ser_ob = self._ob(observable, wires_map) if isinstance(ser_ob, list): serialized_obs.extend(ser_ob) - offset_indices.append(offset_indices[-1] + len(ser_ob)) + obs_indices.extend([i] * len(ser_ob)) else: serialized_obs.append(ser_ob) - offset_indices.append(offset_indices[-1] + 1) - return serialized_obs, offset_indices + obs_indices.append(i) + return serialized_obs, obs_indices def serialize_ops(self, tape: QuantumTape, wires_map: dict = None) -> Tuple[ List[List[str]], diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index 6eae0c1c89..c4edd41392 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev3" +__version__ = "0.39.0-dev4" diff --git a/pennylane_lightning/core/lightning_base.py b/pennylane_lightning/core/lightning_base.py index 66a632fc73..76b56369ed 100644 --- a/pennylane_lightning/core/lightning_base.py +++ b/pennylane_lightning/core/lightning_base.py @@ -16,8 +16,8 @@ This module contains the base class for all PennyLane Lightning simulator devices, and interfaces with C++ for improved performance. """ -from itertools import islice, product -from typing import List +from itertools import product +from typing import List, Union import numpy as np import pennylane as qml @@ -31,12 +31,6 @@ from ._version import __version__ -def _chunk_iterable(iteration, num_chunks): - "Lazy-evaluated chunking of given iterable from https://stackoverflow.com/a/22045226" - iteration = iter(iteration) - return iter(lambda: tuple(islice(iteration, num_chunks)), ()) - - class LightningBase(QubitDevice): """PennyLane Lightning Base device. @@ -262,11 +256,16 @@ def _get_basis_state_index(self, state, wires): # pylint: disable=too-many-function-args, assignment-from-no-return, too-many-arguments def _process_jacobian_tape( - self, tape, starting_state, use_device_state, use_mpi: bool = False, split_obs: bool = False + self, + tape, + starting_state, + use_device_state, + use_mpi: bool = False, + split_obs: Union[bool, int] = False, ): state_vector = self._init_process_jacobian_tape(tape, starting_state, use_device_state) - obs_serialized, obs_idx_offsets = QuantumScriptSerializer( + obs_serialized, obs_indices = QuantumScriptSerializer( self.short_name, self.use_csingle, use_mpi, split_obs ).serialize_observables(tape, self.wire_map) @@ -309,7 +308,7 @@ def _process_jacobian_tape( "tp_shift": tp_shift, "record_tp_rows": record_tp_rows, "all_params": all_params, - "obs_idx_offsets": obs_idx_offsets, + "obs_indices": obs_indices, } @staticmethod diff --git a/pennylane_lightning/core/src/simulators/lightning_gpu/algorithms/AdjointJacobianGPU.hpp b/pennylane_lightning/core/src/simulators/lightning_gpu/algorithms/AdjointJacobianGPU.hpp index 79f5ee40d4..976bdc71b3 100644 --- a/pennylane_lightning/core/src/simulators/lightning_gpu/algorithms/AdjointJacobianGPU.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_gpu/algorithms/AdjointJacobianGPU.hpp @@ -288,6 +288,7 @@ class AdjointJacobian final H_lambda.emplace_back(lambda.getNumQubits(), dt_local, true, cusvhandle, cublascaller, cusparsehandle); } + BaseType::applyObservables(H_lambda, lambda, obs); StateVectorT mu(lambda.getNumQubits(), dt_local, true, cusvhandle, diff --git a/pennylane_lightning/core/src/simulators/lightning_gpu/observables/ObservablesGPU.hpp b/pennylane_lightning/core/src/simulators/lightning_gpu/observables/ObservablesGPU.hpp index 60ffd6ac4c..102ed27988 100644 --- a/pennylane_lightning/core/src/simulators/lightning_gpu/observables/ObservablesGPU.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_gpu/observables/ObservablesGPU.hpp @@ -194,10 +194,11 @@ class Hamiltonian final : public HamiltonianBase { std::make_unique>(sv.getDataBuffer().getLength(), sv.getDataBuffer().getDevTag()); buffer->zeroInit(); + StateVectorT tmp(sv); for (std::size_t term_idx = 0; term_idx < this->coeffs_.size(); term_idx++) { - StateVectorT tmp(sv); + tmp.updateData(sv); this->obs_[term_idx]->applyInPlace(tmp); scaleAndAddC_CUDA( std::complex{this->coeffs_[term_idx], 0.0}, diff --git a/pennylane_lightning/core/src/simulators/lightning_kokkos/observables/ObservablesKokkos.hpp b/pennylane_lightning/core/src/simulators/lightning_kokkos/observables/ObservablesKokkos.hpp index 8d35653918..8086de9a45 100644 --- a/pennylane_lightning/core/src/simulators/lightning_kokkos/observables/ObservablesKokkos.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_kokkos/observables/ObservablesKokkos.hpp @@ -189,9 +189,10 @@ class Hamiltonian final : public HamiltonianBase { void applyInPlace(StateVectorT &sv) const override { StateVectorT buffer{sv.getNumQubits()}; buffer.initZeros(); + StateVectorT tmp{sv}; for (std::size_t term_idx = 0; term_idx < this->coeffs_.size(); term_idx++) { - StateVectorT tmp{sv}; + tmp.updateData(sv.getView()); this->obs_[term_idx]->applyInPlace(tmp); LightningKokkos::Util::axpy_Kokkos( ComplexT{this->coeffs_[term_idx], 0.0}, tmp.getView(), diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp index 618c21a732..00b7c5d8c7 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp @@ -57,6 +57,7 @@ class StateVectorLQubitManaged final using ComplexT = std::complex; using CFP_t = ComplexT; using MemoryStorageT = Pennylane::Util::MemoryStorageLocation::Internal; + using StateVectorT = StateVectorLQubitManaged; private: using BaseType = diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/observables/ObservablesLQubit.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/observables/ObservablesLQubit.hpp index 1bc1d6156d..178ffeef24 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/observables/ObservablesLQubit.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/observables/ObservablesLQubit.hpp @@ -165,9 +165,10 @@ template struct HamiltonianApplyInPlace { auto allocator = sv.allocator(); std::vector res( sv.getLength(), ComplexT{0.0, 0.0}, allocator); + StateVectorT tmp(sv); for (std::size_t term_idx = 0; term_idx < coeffs.size(); term_idx++) { - StateVectorT tmp(sv); + tmp.updateData(sv.getDataVector()); terms[term_idx]->applyInPlace(tmp); scaleAndAdd(tmp.getLength(), ComplexT{coeffs[term_idx], 0.0}, tmp.getData(), res.data()); diff --git a/pennylane_lightning/lightning_gpu/lightning_gpu.py b/pennylane_lightning/lightning_gpu/lightning_gpu.py index 117e9840de..0e55c5c9e1 100644 --- a/pennylane_lightning/lightning_gpu/lightning_gpu.py +++ b/pennylane_lightning/lightning_gpu/lightning_gpu.py @@ -30,6 +30,7 @@ from pennylane.measurements import Expectation, State from pennylane.ops.op_math import Adjoint from pennylane.wires import Wires +from scipy.sparse import csr_matrix from pennylane_lightning.core._serialize import QuantumScriptSerializer, global_phase_diagonal from pennylane_lightning.core._version import __version__ @@ -633,8 +634,12 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False): # Check adjoint diff support self._check_adjdiff_supported_operations(tape.operations) + if self._mpi: + split_obs = False # with MPI batched means compute Jacobian one observables at a time, no point splitting linear combinations + else: + split_obs = self._dp.getTotalDevices() if self._batch_obs else False processed_data = self._process_jacobian_tape( - tape, starting_state, use_device_state, self._mpi, self._batch_obs + tape, starting_state, use_device_state, self._mpi, split_obs ) if not processed_data: # training_params is empty @@ -653,31 +658,12 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False): adjoint_jacobian = _adj_dtype(self.use_csingle, self._mpi)() if self._batch_obs: # Batching of Measurements - if not self._mpi: # Single-node path, controlled batching over available GPUs - num_obs = len(processed_data["obs_serialized"]) - batch_size = ( - num_obs - if isinstance(self._batch_obs, bool) - else self._batch_obs * self._dp.getTotalDevices() - ) - jac = [] - for chunk in range(0, num_obs, batch_size): - obs_chunk = processed_data["obs_serialized"][chunk : chunk + batch_size] - jac_chunk = adjoint_jacobian.batched( - self._gpu_state, - obs_chunk, - processed_data["ops_serialized"], - trainable_params, - ) - jac.extend(jac_chunk) - else: # MPI path, restrict memory per known GPUs - jac = adjoint_jacobian.batched( - self._gpu_state, - processed_data["obs_serialized"], - processed_data["ops_serialized"], - trainable_params, - ) - + jac = adjoint_jacobian.batched( + self._gpu_state, + processed_data["obs_serialized"], + processed_data["ops_serialized"], + trainable_params, + ) else: jac = adjoint_jacobian( self._gpu_state, @@ -685,36 +671,18 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False): processed_data["ops_serialized"], trainable_params, ) - - jac = np.array(jac) # only for parameters differentiable with the adjoint method - jac = jac.reshape(-1, len(trainable_params)) - jac_r = np.zeros((len(tape.observables), processed_data["all_params"])) - if not self._batch_obs: - jac_r[:, processed_data["record_tp_rows"]] = jac - else: - # Reduce over decomposed expval(H), if required. - for idx in range(len(processed_data["obs_idx_offsets"][0:-1])): - if ( - processed_data["obs_idx_offsets"][idx + 1] - - processed_data["obs_idx_offsets"][idx] - ) > 1: - jac_r[idx, :] = np.sum( - jac[ - processed_data["obs_idx_offsets"][idx] : processed_data[ - "obs_idx_offsets" - ][idx + 1], - :, - ], - axis=0, - ) - else: - jac_r[idx, :] = jac[ - processed_data["obs_idx_offsets"][idx] : processed_data["obs_idx_offsets"][ - idx + 1 - ], - :, - ] - + jac = np.array(jac) + has_shape0 = bool(len(jac)) + + num_obs = len(np.unique(processed_data["obs_indices"])) + rows = processed_data["obs_indices"] + cols = np.arange(len(rows), dtype=int) + data = np.ones(len(rows)) + red_mat = csr_matrix((data, (rows, cols)), shape=(num_obs, len(rows))) + jac = red_mat @ jac.reshape((len(rows), -1)) + jac = jac.reshape(-1, len(trainable_params)) if has_shape0 else jac + jac_r = np.zeros((jac.shape[0], processed_data["all_params"])) + jac_r[:, processed_data["record_tp_rows"]] = jac return self._adjoint_jacobian_processing(jac_r) # pylint: disable=inconsistent-return-statements, line-too-long, missing-function-docstring diff --git a/pennylane_lightning/lightning_kokkos/_adjoint_jacobian.py b/pennylane_lightning/lightning_kokkos/_adjoint_jacobian.py index bc79ade355..f9e9d9ad2b 100644 --- a/pennylane_lightning/lightning_kokkos/_adjoint_jacobian.py +++ b/pennylane_lightning/lightning_kokkos/_adjoint_jacobian.py @@ -37,7 +37,6 @@ # pylint: disable=import-error, no-name-in-module, ungrouped-imports from pennylane_lightning.core._serialize import QuantumScriptSerializer -from pennylane_lightning.core.lightning_base import _chunk_iterable from ._state_vector import LightningKokkosStateVector @@ -243,30 +242,12 @@ def calculate_jacobian(self, tape: QuantumTape): return np.array([], dtype=self._dtype) trainable_params = processed_data["tp_shift"] - - # If requested batching over observables, chunk into OMP_NUM_THREADS sized chunks. - # This will allow use of Lightning with adjoint for large-qubit numbers AND large - # numbers of observables, enabling choice between compute time and memory use. - requested_threads = int(getenv("OMP_NUM_THREADS", "1")) - - if self._batch_obs and requested_threads > 1: - obs_partitions = _chunk_iterable(processed_data["obs_serialized"], requested_threads) - jac = [] - for obs_chunk in obs_partitions: - jac_local = self._jacobian_lightning( - processed_data["state_vector"], - obs_chunk, - processed_data["ops_serialized"], - trainable_params, - ) - jac.extend(jac_local) - else: - jac = self._jacobian_lightning( - processed_data["state_vector"], - processed_data["obs_serialized"], - processed_data["ops_serialized"], - trainable_params, - ) + jac = self._jacobian_lightning( + processed_data["state_vector"], + processed_data["obs_serialized"], + processed_data["ops_serialized"], + trainable_params, + ) jac = np.array(jac) jac = jac.reshape(-1, len(trainable_params)) if len(jac) else jac jac_r = np.zeros((jac.shape[0], processed_data["all_params"])) diff --git a/pennylane_lightning/lightning_qubit/_adjoint_jacobian.py b/pennylane_lightning/lightning_qubit/_adjoint_jacobian.py index 5aecfcbb6c..cf96a47e77 100644 --- a/pennylane_lightning/lightning_qubit/_adjoint_jacobian.py +++ b/pennylane_lightning/lightning_qubit/_adjoint_jacobian.py @@ -23,9 +23,9 @@ from pennylane.measurements import Expectation, MeasurementProcess, State from pennylane.operation import Operation from pennylane.tape import QuantumTape +from scipy.sparse import csr_matrix from pennylane_lightning.core._serialize import QuantumScriptSerializer -from pennylane_lightning.core.lightning_base import _chunk_iterable # pylint: disable=import-error, no-name-in-module, ungrouped-imports try: @@ -112,7 +112,7 @@ def _process_jacobian_tape( """ use_csingle = self._dtype == np.complex64 - obs_serialized, obs_idx_offsets = QuantumScriptSerializer( + obs_serialized, obs_indices = QuantumScriptSerializer( self._qubit_state.device_name, use_csingle, use_mpi, split_obs ).serialize_observables(tape) @@ -155,7 +155,7 @@ def _process_jacobian_tape( "tp_shift": tp_shift, "record_tp_rows": record_tp_rows, "all_params": all_params, - "obs_idx_offsets": obs_idx_offsets, + "obs_indices": obs_indices, } @staticmethod @@ -214,41 +214,39 @@ def calculate_jacobian(self, tape: QuantumTape): "mixed with other return types" ) - processed_data = self._process_jacobian_tape(tape) + split_obs = ( + len(tape.measurements) > 1 + ) # lightning already parallelizes applying a single Hamiltonian + if split_obs: + # split linear combinations into num_threads + # this isn't the best load-balance in general, but well-rounded enough + split_obs = getenv("OMP_NUM_THREADS", None) if self._batch_obs else False + split_obs = int(split_obs) if split_obs else False + processed_data = self._process_jacobian_tape(tape, split_obs=split_obs) if not processed_data: # training_params is empty return np.array([], dtype=self._dtype) trainable_params = processed_data["tp_shift"] - # If requested batching over observables, chunk into OMP_NUM_THREADS sized chunks. - # This will allow use of Lightning with adjoint for large-qubit numbers AND large - # numbers of observables, enabling choice between compute time and memory use. - requested_threads = int(getenv("OMP_NUM_THREADS", "1")) - - if self._batch_obs and requested_threads > 1: - obs_partitions = _chunk_iterable(processed_data["obs_serialized"], requested_threads) - jac = [] - for obs_chunk in obs_partitions: - jac_local = self._jacobian_lightning( - processed_data["state_vector"], - obs_chunk, - processed_data["ops_serialized"], - trainable_params, - ) - jac.extend(jac_local) - else: - jac = self._jacobian_lightning( - processed_data["state_vector"], - processed_data["obs_serialized"], - processed_data["ops_serialized"], - trainable_params, - ) + jac = self._jacobian_lightning( + processed_data["state_vector"], + processed_data["obs_serialized"], + processed_data["ops_serialized"], + trainable_params, + ) jac = np.array(jac) - jac = jac.reshape(-1, len(trainable_params)) if len(jac) else jac + has_shape0 = bool(len(jac)) + + num_obs = len(np.unique(processed_data["obs_indices"])) + rows = processed_data["obs_indices"] + cols = np.arange(len(rows), dtype=int) + data = np.ones(len(rows)) + red_mat = csr_matrix((data, (rows, cols)), shape=(num_obs, len(rows))) + jac = red_mat @ jac.reshape((len(rows), -1)) + jac = jac.reshape(-1, len(trainable_params)) if has_shape0 else jac jac_r = np.zeros((jac.shape[0], processed_data["all_params"])) jac_r[:, processed_data["record_tp_rows"]] = jac - return self._adjoint_jacobian_processing(jac_r) # pylint: disable=inconsistent-return-statements diff --git a/pennylane_lightning/lightning_qubit/_state_vector.py b/pennylane_lightning/lightning_qubit/_state_vector.py index ac2774f204..97a6a85711 100644 --- a/pennylane_lightning/lightning_qubit/_state_vector.py +++ b/pennylane_lightning/lightning_qubit/_state_vector.py @@ -215,7 +215,7 @@ def _apply_lightning_midmeasure( def _apply_lightning( self, operations, mid_measurements: dict = None, postselect_mode: str = None - ): + ): # pylint: disable=protected-access """Apply a list of operations to the state tensor. Args: @@ -255,7 +255,7 @@ def _apply_lightning( method = getattr(state, "applyPauliRot") paulis = operation._hyperparameters["pauli_word"] wires = [i for i, w in zip(wires, paulis) if w != "I"] - word = "".join(p for p in paulis if p != "I") # pylint: disable=protected-access + word = "".join(p for p in paulis if p != "I") method(wires, invert_param, operation.parameters, word) elif method is not None: # apply specialized gate param = operation.parameters diff --git a/tests/test_adjoint_jacobian.py b/tests/test_adjoint_jacobian.py index 841e46ac28..57f007a45f 100644 --- a/tests/test_adjoint_jacobian.py +++ b/tests/test_adjoint_jacobian.py @@ -1133,8 +1133,9 @@ def circuit(params): circuit_ld = qml.QNode(circuit, dev_ld, diff_method="adjoint") circuit_dq = qml.QNode(circuit, dev_dq, diff_method="parameter-shift") - - assert np.allclose(qml.grad(circuit_ld)(params), qml.grad(circuit_dq)(params), tol) + res = qml.grad(circuit_ld)(params) + ref = qml.grad(circuit_dq)(params) + assert np.allclose(res, ref, tol) @pytest.mark.usefixtures("use_legacy_and_new_opmath") @@ -1244,7 +1245,7 @@ def test_integration(returns): def circuit(params): circuit_ansatz(params, wires=range(4)) - return qml.expval(returns), qml.expval(qml.PauliY(1)) + return np.array([qml.expval(returns), qml.expval(qml.PauliY(1))]) n_params = 30 params = np.linspace(0, 10, n_params) @@ -1252,27 +1253,37 @@ def circuit(params): qnode_def = qml.QNode(circuit, dev_def) qnode_lightning = qml.QNode(circuit, dev_lightning, diff_method="adjoint") - def casted_to_array_def(params): - return np.array(qnode_def(params)) - - def casted_to_array_lightning(params): - return np.array(qnode_lightning(params)) - - j_def = qml.jacobian(casted_to_array_def)(params) - j_lightning = qml.jacobian(casted_to_array_lightning)(params) + j_def = qml.jacobian(qnode_def)(params) + j_lightning = qml.jacobian(qnode_lightning)(params) assert np.allclose(j_def, j_lightning) def test_integration_chunk_observables(): """Integration tests that compare to default.qubit for a large circuit with multiple expectation values. Expvals are generated in parallelized chunks.""" - dev_def = qml.device("default.qubit", wires=range(4)) - dev_lightning = qml.device(device_name, wires=range(4)) - dev_lightning_batched = qml.device(device_name, wires=range(4), batch_obs=True) + num_qubits = 4 + + dev_def = qml.device("default.qubit", wires=range(num_qubits)) + dev_lightning = qml.device(device_name, wires=range(num_qubits)) + dev_lightning_batched = qml.device(device_name, wires=range(num_qubits), batch_obs=True) def circuit(params): - circuit_ansatz(params, wires=range(4)) - return [qml.expval(qml.PauliZ(i)) for i in range(4)] + circuit_ansatz(params, wires=range(num_qubits)) + return np.array( + [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)] + + [ + qml.expval( + qml.Hamiltonian( + np.arange(1, num_qubits + 1), + [ + qml.PauliZ(i % num_qubits) @ qml.PauliY((i + 1) % num_qubits) + for i in range(num_qubits) + ], + ) + ) + ] + + [qml.expval(qml.PauliY(i)) for i in range(num_qubits)] + ) n_params = 30 params = np.linspace(0, 10, n_params) @@ -1281,18 +1292,9 @@ def circuit(params): qnode_lightning = qml.QNode(circuit, dev_lightning, diff_method="adjoint") qnode_lightning_batched = qml.QNode(circuit, dev_lightning_batched, diff_method="adjoint") - def casted_to_array_def(params): - return np.array(qnode_def(params)) - - def casted_to_array_lightning(params): - return np.array(qnode_lightning(params)) - - def casted_to_array_batched(params): - return np.array(qnode_lightning_batched(params)) - - j_def = qml.jacobian(casted_to_array_def)(params) - j_lightning = qml.jacobian(casted_to_array_lightning)(params) - j_lightning_batched = qml.jacobian(casted_to_array_batched)(params) + j_def = qml.jacobian(qnode_def)(params) + j_lightning = qml.jacobian(qnode_lightning)(params) + j_lightning_batched = qml.jacobian(qnode_lightning_batched)(params) assert np.allclose(j_def, j_lightning) assert np.allclose(j_def, j_lightning_batched) @@ -1327,7 +1329,7 @@ def test_integration_custom_wires(returns): def circuit(params): circuit_ansatz(params, wires=custom_wires) - return qml.expval(returns), qml.expval(qml.PauliY(custom_wires[1])) + return np.array([qml.expval(returns), qml.expval(qml.PauliY(custom_wires[1]))]) n_params = 30 params = np.linspace(0, 10, n_params) @@ -1335,21 +1337,15 @@ def circuit(params): qnode_def = qml.QNode(circuit, dev_def) qnode_lightning = qml.QNode(circuit, dev_lightning, diff_method="adjoint") - def casted_to_array_def(params): - return np.array(qnode_def(params)) - - def casted_to_array_lightning(params): - return np.array(qnode_lightning(params)) - - j_def = qml.jacobian(casted_to_array_def)(params) - j_lightning = qml.jacobian(casted_to_array_lightning)(params) + j_def = qml.jacobian(qnode_def)(params) + j_lightning = qml.jacobian(qnode_lightning)(params) assert np.allclose(j_def, j_lightning) @pytest.mark.skipif( - device_name != "lightning.gpu", - reason="Tests only for lightning.gpu", + device_name not in ("lightning.qubit", "lightning.gpu"), + reason="Tests only for lightning.qubit and lightning.gpu", ) @pytest.mark.parametrize( "returns", @@ -1375,7 +1371,7 @@ def test_integration_custom_wires_batching(returns): operations and when using custom wire labels""" dev_def = qml.device("default.qubit", wires=custom_wires) - dev_gpu = qml.device("lightning.gpu", wires=custom_wires, batch_obs=True) + dev_gpu = qml.device(device_name, wires=custom_wires, batch_obs=True) def circuit(params): circuit_ansatz(params, wires=custom_wires) @@ -1401,8 +1397,8 @@ def convert_to_array_def(params): @pytest.mark.skipif( - device_name != "lightning.gpu", - reason="Tests only for lightning.gpu", + device_name not in ("lightning.qubit", "lightning.gpu"), + reason="Tests only for lightning.qubit and lightning.gpu", ) @pytest.mark.parametrize( "returns", diff --git a/tests/test_serialize_chunk_obs.py b/tests/test_serialize_chunk_obs.py index 58562657c1..760633551f 100644 --- a/tests/test_serialize_chunk_obs.py +++ b/tests/test_serialize_chunk_obs.py @@ -14,13 +14,11 @@ """ Unit tests for the serialization helper functions. """ -import numpy as np import pennylane as qml import pytest from conftest import LightningDevice as ld from conftest import device_name -import pennylane_lightning from pennylane_lightning.core._serialize import QuantumScriptSerializer if not ld._CPP_BINARY_AVAILABLE: @@ -33,8 +31,8 @@ class TestSerializeObs: wires_dict = {i: i for i in range(10)} @pytest.mark.parametrize("use_csingle", [True, False]) - @pytest.mark.parametrize("obs_chunk", list(range(1, 5))) - def test_chunk_obs(self, use_csingle, obs_chunk): + @pytest.mark.parametrize("obs_chunk, expected", [(1, 5), (2, 6), (3, 7), (7, 7)]) + def test_chunk_obs(self, use_csingle, obs_chunk, expected): """Test chunking of observable array""" with qml.tape.QuantumTape() as tape: qml.expval( @@ -46,9 +44,8 @@ def test_chunk_obs(self, use_csingle, obs_chunk): qml.expval(qml.PauliY(wires=1)) qml.expval(qml.PauliX(0) @ qml.Hermitian([[0, 1], [1, 0]], wires=3) @ qml.Hadamard(2)) qml.expval(qml.Hermitian(qml.PauliZ.compute_matrix(), wires=0) @ qml.Identity(1)) - s, offsets = QuantumScriptSerializer( - device_name, use_csingle, split_obs=True + s, obs_idx = QuantumScriptSerializer( + device_name, use_csingle, split_obs=obs_chunk ).serialize_observables(tape, self.wires_dict) - obtained_chunks = pennylane_lightning.core.lightning_base._chunk_iterable(s, obs_chunk) - assert len(list(obtained_chunks)) == int(np.ceil(len(s) / obs_chunk)) - assert [0, 3, 4, 5, 6, 7] == offsets + assert expected == len(s) + assert [0] * (expected - 4) + [1, 2, 3, 4] == obs_idx