Skip to content

Commit

Permalink
Chunk Hamiltonian, PauliSentence, LinearCombination [sc-65680] (#873)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      [`tests`](../tests) directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `.github/CHANGELOG.md` file, summarizing
the
      change and including a link back to the PR.

- [x] Ensure that code is properly formatted by running `make format`. 

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
Parallelizing over observables can accelerate adjoint Jacobian
calculations' backward pass. This PR revisits our implementation for
L-Qubit and L-GPU which are the two devices that support it. Certain
observables like Hamiltonian, PauliSentence, and LinearCombination can
be split into many observables, enabling the distribution of the cost of
expectation value computation. This strategy is initiated by the
serializer which partitions the observables if `split_obs` is not
`False`. The serializer proceeds to a complete partitioning, meaning a
1000-PauliWord PauliSentence is partitioned into a 1000 PauliWords. We
note in passing that L-Qubit does not split observables since it does
not pass a `split_obs` value to `_process_jacobian_tape`. This is
wasteful because we end up with either of two situations:

- The Jacobian is computed N processes (threads, devices, etc.) at a
time which results in a lot of duplicate computation (forward/backward
passes are repeated and the results combined);
- The Jacobian is parallelized over all observables, each of which
requires a state vector copy which increases the memory requirements by
as much.

We explore chunking instead of full partitioning for
LinearCombination-like objects, meaning a 1000-PauliWord PauliSentence
is partitioned into four 250-PauliWords PauliSentences if we parallelize
over 4 processes.

**Description of the Change:**
Modify the serializer to chunk LinearCombination-like objects if
`self.split_obs` is truthy.
Correctly route `_batch_obs` such that L-Qubit splits observables.
Enhance/adapt tests.

**Analysis:**
**Lightning-Qubit**

`applyObservable` is a bottleneck for somewhat large linear combinations
(say 100s or 1000s of terms). Chunking isn't helpful for a circuit like
```   
    @qml.qnode(dev, diff_method="adjoint")
    def c(weights):
        qml.templates.AllSinglesDoubles(weights, wires, hf_state, singles, doubles)
        return qml.expval(ham)
```
because L-Qubit's `applyObservable` method is parallelized over terms
for a single `Hamiltonian` observable. Chunking in this case is
counter-productive because it requires extra state vectors, extra
backward passes, etc.

For a circuit like however
```    
    @qml.qnode(dev, diff_method="adjoint")
    def c(weights):
        qml.templates.AllSinglesDoubles(weights, wires, hf_state, singles, doubles)
        return np.array([qml.expval(ham), qml.expval(qml.PauliZ(0))])
```
`applyObservable` is parallelized over observables, which only scales up
to 2 threads, and with poor load-balance. In this case, it is better to
split the observable, which is what the current changes do.

| mol | master-serial | master-batched | chunk-serial | chunk-batched |
| --- | ------------- | -------------- | ------------ | ------------- |
| CH4 | 1.793e+01     | 1.330e+01      | 1.819e+01    | 8.040e+00     |
| Li2 | 5.333e+01     | 3.354e+01      | 5.289e+01    | 1.839e+01     |
| CO  | 9.817e+01     | 5.945e+01      | 9.619e+01    | 2.559e+01     |
| H10 | 1.220e+02     | 7.317e+01      | 1.182e+02    | 3.305e+01     |

So for this circuit the current PR yields speeds-up ranging from 1.5x to
>2x by using obs-batching + chunking (compared with the previous
obs-batching).

**Lightning-GPU** 

Lightning-GPU splits the observables as soon as `batch_obs` is true. The
current code splits a Hamiltonian into all its individual terms, which
is quite inefficient and induces a lot of redundant backward passes.
This is visible benchmarking the circuit
```   
    @qml.qnode(dev, diff_method="adjoint")
    def c(weights):
        qml.templates.AllSinglesDoubles(weights, wires, hf_state, singles, doubles)
        return qml.expval(ham)
```

| mol | master-serial | master-batched | chunk-serial | chunk-batched |
| --- | ------------- | -------------- | ------------ | ------------- |
| CH4 | 1.463e+01     | forever        | 5.583e+00    | 3.405e+00     |
| Li2 | 1.201e+01     | forever        | 5.284e+00    | 2.658e+00     |
| CO  | 2.357e+01     | forever        | 4.716e+00    | 4.577e+00     |
| H10 | 2.992e+01     | forever        | 5.476e+00    | 5.469e+00     |
| HCN | 8.622e+01     | forever        | 3.144e+01    | 2.452e+01     |

The batched L-GPU runs are using 2 x A100 GPUs on ISAIC. The speed-ups
for batched versus serial are OK, but most important is the optimization
of `Hamiltonian::applyInPlace` which brings about nice speed-ups between
master and this PR.
 
**Related GitHub Issues:**

---------

Co-authored-by: ringo-but-quantum <[email protected]>
Co-authored-by: Amintor Dusko <[email protected]>
Co-authored-by: AmintorDusko <[email protected]>
  • Loading branch information
4 people authored Sep 5, 2024
1 parent 00ebcdf commit f4a6114
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 191 deletions.
5 changes: 4 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
* Optimize gate cache recording for `lightning.tensor` C++ layer.
[(#879)](https://github.com/PennyLaneAI/pennylane-lightning/pull/879)

* Smarter defaults for the `split_obs` argument in the serializer. The serializer splits linear combinations into chunks instead of all their terms.
[(#873)](https://github.com/PennyLaneAI/pennylane-lightning/pull/873/)

### Documentation

### Bug fixes
Expand All @@ -23,7 +26,7 @@

This release contains contributions from (in alphabetical order):

Amintor Dusko, Luis Alfredo Nuñez Meneses, Shuli Shu
Amintor Dusko, Luis Alfredo Nuñez Meneses, Vincent Michaud-Rioux, Shuli Shu

---

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests_lgpu_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
strategy:
matrix:
pl_backend: ["lightning_gpu"]
enable_lapack: ["OFF", "ON"]
enable_lapack: ["OFF"]
cuda_version: ["12"]

name: C++ Tests (${{ matrix.pl_backend }}, cuda-${{ matrix.cuda_version }}, enable_lapack=${{ matrix.enable_lapack }})
Expand Down
48 changes: 37 additions & 11 deletions pennylane_lightning/core/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class QuantumScriptSerializer:
device_name: device shortname.
use_csingle (bool): whether to use np.complex64 instead of np.complex128
use_mpi (bool, optional): If using MPI to accelerate calculation. Defaults to False.
split_obs (bool, optional): If splitting the observables in a list. Defaults to False.
split_obs (Union[bool, int], optional): If splitting the observables in a list. Defaults to False.
"""

Expand Down Expand Up @@ -214,17 +214,40 @@ def _tensor_ob(self, observable, wires_map: dict = None):
obs = observable.obs if isinstance(observable, Tensor) else observable.operands
return self.tensor_obs([self._ob(o, wires_map) for o in obs])

def _chunk_ham_terms(self, coeffs, ops, split_num: int = 1) -> List:
"Create split_num sub-Hamiltonians from a single high term-count Hamiltonian"
num_terms = len(coeffs)
iperm = np.argsort(np.array([len(op.get_wires()) for op in ops]))
coeffs = [coeffs[i] for i in iperm]
ops = [ops[i] for i in iperm]
c_coeffs = [
tuple(coeffs[slice(i, num_terms, split_num)]) for i in range(min(num_terms, split_num))
]
c_ops = [
tuple(ops[slice(i, num_terms, split_num)]) for i in range(min(num_terms, split_num))
]
return c_coeffs, c_ops

def _hamiltonian(self, observable, wires_map: dict = None):
coeffs, ops = observable.terms()
coeffs = np.array(unwrap(coeffs)).astype(self.rtype)
if self.split_obs:
ops_l = []
for t in ops:
term_cpp = self._ob(t, wires_map)
if isinstance(term_cpp, Sequence):
ops_l.extend(term_cpp)
else:
ops_l.append(term_cpp)
c, o = self._chunk_ham_terms(coeffs, ops_l, self.split_obs)
hams = [self.hamiltonian_obs(c_coeffs, c_obs) for (c_coeffs, c_obs) in zip(c, o)]
return hams

terms = [self._ob(t, wires_map) for t in ops]
# TODO: This is in case `_hamiltonian` is called recursively which would cause a list
# to be passed where `_ob` expects an observable.
terms = [t[0] if isinstance(t, Sequence) and len(t) == 1 else t for t in terms]

if self.split_obs:
return [self.hamiltonian_obs([c], [t]) for (c, t) in zip(coeffs, terms)]

return self.hamiltonian_obs(coeffs, terms)

def _sparse_hamiltonian(self, observable, wires_map: dict = None):
Expand Down Expand Up @@ -282,11 +305,14 @@ def _pauli_sentence(self, observable, wires_map: dict = None):
terms = [self._pauli_word(pw, wires_map) for pw in pwords]
coeffs = np.array(coeffs).astype(self.rtype)

if self.split_obs:
c, o = self._chunk_ham_terms(coeffs, terms, self.split_obs)
psentences = [self.hamiltonian_obs(c_coeffs, c_obs) for (c_coeffs, c_obs) in zip(c, o)]
return psentences

if len(terms) == 1 and coeffs[0] == 1.0:
return terms[0]

if self.split_obs:
return [self.hamiltonian_obs([c], [t]) for (c, t) in zip(coeffs, terms)]
return self.hamiltonian_obs(coeffs, terms)

# pylint: disable=protected-access, too-many-return-statements
Expand Down Expand Up @@ -326,17 +352,17 @@ def serialize_observables(self, tape: QuantumTape, wires_map: dict = None) -> Li
"""

serialized_obs = []
offset_indices = [0]
obs_indices = []

for observable in tape.observables:
for i, observable in enumerate(tape.observables):
ser_ob = self._ob(observable, wires_map)
if isinstance(ser_ob, list):
serialized_obs.extend(ser_ob)
offset_indices.append(offset_indices[-1] + len(ser_ob))
obs_indices.extend([i] * len(ser_ob))
else:
serialized_obs.append(ser_ob)
offset_indices.append(offset_indices[-1] + 1)
return serialized_obs, offset_indices
obs_indices.append(i)
return serialized_obs, obs_indices

def serialize_ops(self, tape: QuantumTape, wires_map: dict = None) -> Tuple[
List[List[str]],
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev3"
__version__ = "0.39.0-dev4"
21 changes: 10 additions & 11 deletions pennylane_lightning/core/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
This module contains the base class for all PennyLane Lightning simulator devices,
and interfaces with C++ for improved performance.
"""
from itertools import islice, product
from typing import List
from itertools import product
from typing import List, Union

import numpy as np
import pennylane as qml
Expand All @@ -31,12 +31,6 @@
from ._version import __version__


def _chunk_iterable(iteration, num_chunks):
"Lazy-evaluated chunking of given iterable from https://stackoverflow.com/a/22045226"
iteration = iter(iteration)
return iter(lambda: tuple(islice(iteration, num_chunks)), ())


class LightningBase(QubitDevice):
"""PennyLane Lightning Base device.
Expand Down Expand Up @@ -262,11 +256,16 @@ def _get_basis_state_index(self, state, wires):

# pylint: disable=too-many-function-args, assignment-from-no-return, too-many-arguments
def _process_jacobian_tape(
self, tape, starting_state, use_device_state, use_mpi: bool = False, split_obs: bool = False
self,
tape,
starting_state,
use_device_state,
use_mpi: bool = False,
split_obs: Union[bool, int] = False,
):
state_vector = self._init_process_jacobian_tape(tape, starting_state, use_device_state)

obs_serialized, obs_idx_offsets = QuantumScriptSerializer(
obs_serialized, obs_indices = QuantumScriptSerializer(
self.short_name, self.use_csingle, use_mpi, split_obs
).serialize_observables(tape, self.wire_map)

Expand Down Expand Up @@ -309,7 +308,7 @@ def _process_jacobian_tape(
"tp_shift": tp_shift,
"record_tp_rows": record_tp_rows,
"all_params": all_params,
"obs_idx_offsets": obs_idx_offsets,
"obs_indices": obs_indices,
}

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ class AdjointJacobian final
H_lambda.emplace_back(lambda.getNumQubits(), dt_local, true,
cusvhandle, cublascaller, cusparsehandle);
}

BaseType::applyObservables(H_lambda, lambda, obs);

StateVectorT mu(lambda.getNumQubits(), dt_local, true, cusvhandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ class Hamiltonian final : public HamiltonianBase<StateVectorT> {
std::make_unique<DataBuffer<CFP_t>>(sv.getDataBuffer().getLength(),
sv.getDataBuffer().getDevTag());
buffer->zeroInit();
StateVectorT tmp(sv);

for (std::size_t term_idx = 0; term_idx < this->coeffs_.size();
term_idx++) {
StateVectorT tmp(sv);
tmp.updateData(sv);
this->obs_[term_idx]->applyInPlace(tmp);
scaleAndAddC_CUDA(
std::complex<PrecisionT>{this->coeffs_[term_idx], 0.0},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ class Hamiltonian final : public HamiltonianBase<StateVectorT> {
void applyInPlace(StateVectorT &sv) const override {
StateVectorT buffer{sv.getNumQubits()};
buffer.initZeros();
StateVectorT tmp{sv};
for (std::size_t term_idx = 0; term_idx < this->coeffs_.size();
term_idx++) {
StateVectorT tmp{sv};
tmp.updateData(sv.getView());
this->obs_[term_idx]->applyInPlace(tmp);
LightningKokkos::Util::axpy_Kokkos<PrecisionT>(
ComplexT{this->coeffs_[term_idx], 0.0}, tmp.getView(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class StateVectorLQubitManaged final
using ComplexT = std::complex<PrecisionT>;
using CFP_t = ComplexT;
using MemoryStorageT = Pennylane::Util::MemoryStorageLocation::Internal;
using StateVectorT = StateVectorLQubitManaged<PrecisionT>;

private:
using BaseType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ template <class StateVectorT, bool use_openmp> struct HamiltonianApplyInPlace {
auto allocator = sv.allocator();
std::vector<ComplexT, decltype(allocator)> res(
sv.getLength(), ComplexT{0.0, 0.0}, allocator);
StateVectorT tmp(sv);
for (std::size_t term_idx = 0; term_idx < coeffs.size();
term_idx++) {
StateVectorT tmp(sv);
tmp.updateData(sv.getDataVector());
terms[term_idx]->applyInPlace(tmp);
scaleAndAdd(tmp.getLength(), ComplexT{coeffs[term_idx], 0.0},
tmp.getData(), res.data());
Expand Down
80 changes: 24 additions & 56 deletions pennylane_lightning/lightning_gpu/lightning_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pennylane.measurements import Expectation, State
from pennylane.ops.op_math import Adjoint
from pennylane.wires import Wires
from scipy.sparse import csr_matrix

from pennylane_lightning.core._serialize import QuantumScriptSerializer, global_phase_diagonal
from pennylane_lightning.core._version import __version__
Expand Down Expand Up @@ -633,8 +634,12 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False):
# Check adjoint diff support
self._check_adjdiff_supported_operations(tape.operations)

if self._mpi:
split_obs = False # with MPI batched means compute Jacobian one observables at a time, no point splitting linear combinations
else:
split_obs = self._dp.getTotalDevices() if self._batch_obs else False
processed_data = self._process_jacobian_tape(
tape, starting_state, use_device_state, self._mpi, self._batch_obs
tape, starting_state, use_device_state, self._mpi, split_obs
)

if not processed_data: # training_params is empty
Expand All @@ -653,68 +658,31 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False):
adjoint_jacobian = _adj_dtype(self.use_csingle, self._mpi)()

if self._batch_obs: # Batching of Measurements
if not self._mpi: # Single-node path, controlled batching over available GPUs
num_obs = len(processed_data["obs_serialized"])
batch_size = (
num_obs
if isinstance(self._batch_obs, bool)
else self._batch_obs * self._dp.getTotalDevices()
)
jac = []
for chunk in range(0, num_obs, batch_size):
obs_chunk = processed_data["obs_serialized"][chunk : chunk + batch_size]
jac_chunk = adjoint_jacobian.batched(
self._gpu_state,
obs_chunk,
processed_data["ops_serialized"],
trainable_params,
)
jac.extend(jac_chunk)
else: # MPI path, restrict memory per known GPUs
jac = adjoint_jacobian.batched(
self._gpu_state,
processed_data["obs_serialized"],
processed_data["ops_serialized"],
trainable_params,
)

jac = adjoint_jacobian.batched(
self._gpu_state,
processed_data["obs_serialized"],
processed_data["ops_serialized"],
trainable_params,
)
else:
jac = adjoint_jacobian(
self._gpu_state,
processed_data["obs_serialized"],
processed_data["ops_serialized"],
trainable_params,
)

jac = np.array(jac) # only for parameters differentiable with the adjoint method
jac = jac.reshape(-1, len(trainable_params))
jac_r = np.zeros((len(tape.observables), processed_data["all_params"]))
if not self._batch_obs:
jac_r[:, processed_data["record_tp_rows"]] = jac
else:
# Reduce over decomposed expval(H), if required.
for idx in range(len(processed_data["obs_idx_offsets"][0:-1])):
if (
processed_data["obs_idx_offsets"][idx + 1]
- processed_data["obs_idx_offsets"][idx]
) > 1:
jac_r[idx, :] = np.sum(
jac[
processed_data["obs_idx_offsets"][idx] : processed_data[
"obs_idx_offsets"
][idx + 1],
:,
],
axis=0,
)
else:
jac_r[idx, :] = jac[
processed_data["obs_idx_offsets"][idx] : processed_data["obs_idx_offsets"][
idx + 1
],
:,
]

jac = np.array(jac)
has_shape0 = bool(len(jac))

num_obs = len(np.unique(processed_data["obs_indices"]))
rows = processed_data["obs_indices"]
cols = np.arange(len(rows), dtype=int)
data = np.ones(len(rows))
red_mat = csr_matrix((data, (rows, cols)), shape=(num_obs, len(rows)))
jac = red_mat @ jac.reshape((len(rows), -1))
jac = jac.reshape(-1, len(trainable_params)) if has_shape0 else jac
jac_r = np.zeros((jac.shape[0], processed_data["all_params"]))
jac_r[:, processed_data["record_tp_rows"]] = jac
return self._adjoint_jacobian_processing(jac_r)

# pylint: disable=inconsistent-return-statements, line-too-long, missing-function-docstring
Expand Down
31 changes: 6 additions & 25 deletions pennylane_lightning/lightning_kokkos/_adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

# pylint: disable=import-error, no-name-in-module, ungrouped-imports
from pennylane_lightning.core._serialize import QuantumScriptSerializer
from pennylane_lightning.core.lightning_base import _chunk_iterable

from ._state_vector import LightningKokkosStateVector

Expand Down Expand Up @@ -243,30 +242,12 @@ def calculate_jacobian(self, tape: QuantumTape):
return np.array([], dtype=self._dtype)

trainable_params = processed_data["tp_shift"]

# If requested batching over observables, chunk into OMP_NUM_THREADS sized chunks.
# This will allow use of Lightning with adjoint for large-qubit numbers AND large
# numbers of observables, enabling choice between compute time and memory use.
requested_threads = int(getenv("OMP_NUM_THREADS", "1"))

if self._batch_obs and requested_threads > 1:
obs_partitions = _chunk_iterable(processed_data["obs_serialized"], requested_threads)
jac = []
for obs_chunk in obs_partitions:
jac_local = self._jacobian_lightning(
processed_data["state_vector"],
obs_chunk,
processed_data["ops_serialized"],
trainable_params,
)
jac.extend(jac_local)
else:
jac = self._jacobian_lightning(
processed_data["state_vector"],
processed_data["obs_serialized"],
processed_data["ops_serialized"],
trainable_params,
)
jac = self._jacobian_lightning(
processed_data["state_vector"],
processed_data["obs_serialized"],
processed_data["ops_serialized"],
trainable_params,
)
jac = np.array(jac)
jac = jac.reshape(-1, len(trainable_params)) if len(jac) else jac
jac_r = np.zeros((jac.shape[0], processed_data["all_params"]))
Expand Down
Loading

0 comments on commit f4a6114

Please sign in to comment.