Skip to content

Commit

Permalink
Add native PauliRot implementation in LightningKokkos [sc-71642] (#855)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      [`tests`](../tests) directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `.github/CHANGELOG.md` file, summarizing
the
      change, and including a link back to the PR.

- [x] Ensure that code is properly formatted by running `make format`. 

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
Pauli rotations come up in many places, and importantly in the time
evolution of qchem Hamiltonians. It is therefore worth considering ways
to accelerate their execution.

**Description of the Change:**
Implement `applyPauliRot`. Invoke `applyPauliRot` directly from the SV
class and add bindings to the Python layer.

**Benefits:**
Faster Pauli rotations. I performed a benchmark on random
`PauliRotation`s (runtime > 1.0 sec and at least 5 of them) through the
Python layer. The data remains noisy with 5 samples because the
performance varies depending on the specific "XYZ" sequence (which
translates into more or less predictable memory access patterns).
Overall, we see an advantage for 3+ qubits and up.


![speedup_vs_ntargets_lk_omp16](https://github.com/user-attachments/assets/0fe2bc86-dae2-48c2-9af1-647efe753bae)

I performed the same benchmark on an A100 card with the Kokkos-CUDA
backend, but using at least 500 samples since the absolute timings quite
small and get the following speed-ups.


![speedup_vs_ntargets_lk_cuda](https://github.com/user-attachments/assets/ec842b49-cd93-478b-946e-6f208f97d4de)

Using a full workflow such as 
```
    @qml.qnode(dev, diff_method=None)
    def circuit():
        qml.TrotterProduct(ham, time=1.0, n=1, order=2)
        return qml.state()
```
to benchmark, we obtain timings as follows


![time_vs_mol](https://github.com/user-attachments/assets/7ed7a3db-e71d-42b3-84bb-e0325dceea68)

For large enough molecules (>= 20 qubits, >= 1000 terms), the new
PauliRot kernels have a clear advantage which only grows with molecular
size. It is worth noting that with L-Kokkos-CUDA, even at the (24/10k)
scale, evaluating the circuit is not the main bottleneck which is why it
takes about the same time simulating HCN (2.64 sec. `apply_lightning` vs
32.5 sec. `QNode`) and N2N2 (7.51 sec. `apply_lightning` vs 36.4 sec.
`QNode`).

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-69801]

---------

Co-authored-by: ringo-but-quantum <[email protected]>
Co-authored-by: Luis Alfredo Nuñez Meneses <[email protected]>
  • Loading branch information
3 people authored Sep 10, 2024
1 parent d5ffb0c commit 43374cc
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
* Unify Lightning-Kokkos device and Lightning-Qubit device under a Lightning Base device.
[(#876)](https://github.com/PennyLaneAI/pennylane-lightning/pull/876)

* LightningKokkos gains native support for the `PauliRot` gate.
[(#855)](https://github.com/PennyLaneAI/pennylane-lightning/pull/855)

### Documentation

### Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev14"
__version__ = "0.39.0-dev15"
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ class StateVectorKokkos final
}
}

/**
* @brief Apply a PauliRot gate to the state-vector.
*
* @param wires Wires to apply gate to.
* @param inverse Indicates whether to use inverse of gate.
* @param params Rotation angle.
* @param word A Pauli word (e.g. "XYYX").
*/
void applyPauliRot(const std::vector<std::size_t> &wires,
const bool inverse,
const std::vector<PrecisionT> &params,
const std::string &word) {
PL_ABORT_IF_NOT(wires.size() == word.size(),
"wires and word have incompatible dimensions.");
Pennylane::LightningKokkos::Functors::applyPauliRot<KokkosExecSpace,
PrecisionT>(
getView(), this->getNumQubits(), wires, inverse, params[0], word);
}

template <bool inverse = false>
void applyControlledGlobalPhase(const std::vector<ComplexT> &diagonal) {
auto diagonal_ = vector2view(diagonal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {

registerGatesForStateVector<StateVectorT>(pyclass);

pyclass.def(
"applyPauliRot",
[](StateVectorT &sv, const std::vector<std::size_t> &wires,
const bool inverse, const std::vector<ParamT> &params,
const std::string &word) {
sv.applyPauliRot(wires, inverse, params, word);
},
"Apply a Pauli rotation.");
pyclass
.def(py::init([](std::size_t num_qubits) {
return new StateVectorT(num_qubits);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
#include "BitUtil.hpp"
#include "GateOperation.hpp"
#include "Gates.hpp"
#include "Util.hpp" // exp2, INVSQRT2
#include "UtilKokkos.hpp"

/// @cond DEV
namespace {
using namespace Pennylane::Util;
using Kokkos::kokkos_swap;
using Pennylane::Gates::GateOperation;
using Pennylane::LightningKokkos::Util::vector2view;
} // namespace
/// @endcond

Expand Down Expand Up @@ -1092,6 +1095,59 @@ void applyMultiRZ(Kokkos::View<Kokkos::complex<PrecisionT> *> arr_,
});
}

template <class ExecutionSpace, class PrecisionT>
void applyPauliRot(Kokkos::View<Kokkos::complex<PrecisionT> *> arr_,
const std::size_t num_qubits,
const std::vector<std::size_t> &wires, const bool inverse,
const PrecisionT angle, const std::string &word) {
using ComplexT = Kokkos::complex<PrecisionT>;
constexpr auto IMAG = Pennylane::Util::IMAG<PrecisionT>();
PL_ABORT_IF_NOT(wires.size() == word.size(),
"wires and word have incompatible dimensions.")
if (std::find_if_not(word.begin(), word.end(),
[](const int w) { return w == 'Z'; }) == word.end()) {
applyMultiRZ<ExecutionSpace>(arr_, num_qubits, wires, inverse,
std::vector<PrecisionT>{angle});
return;
}
const PrecisionT c = std::cos(angle / 2);
const ComplexT s = ((inverse) ? IMAG : -IMAG) * std::sin(angle / 2);
const std::vector<ComplexT> sines = {s, IMAG * s, -s, -IMAG * s};
auto d_sines = vector2view(sines);
auto get_mask =
[num_qubits, &wires](
[[maybe_unused]] const std::function<bool(const int)> &condition) {
std::size_t mask{0U};
for (std::size_t iw = 0; iw < wires.size(); iw++) {
const auto bit = static_cast<std::size_t>(condition(iw));
mask |= bit << (num_qubits - 1 - wires[iw]);
}
return mask;
};
const std::size_t mask_xy =
get_mask([&word](const int a) { return word[a] != 'Z'; });
const std::size_t mask_y =
get_mask([&word](const int a) { return word[a] == 'Y'; });
const std::size_t mask_z =
get_mask([&word](const int a) { return word[a] == 'Z'; });
const auto count_mask_y = std::popcount(mask_y);
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecutionSpace>(0, exp2(num_qubits)),
KOKKOS_LAMBDA(const std::size_t i0) {
const std::size_t i1 = i0 ^ mask_xy;
if (i0 <= i1) {
const auto count_y = Kokkos::Impl::bit_count(i0 & mask_y) * 2;
const auto count_z = Kokkos::Impl::bit_count(i0 & mask_z) * 2;
const auto sign_i0 = count_z + count_mask_y * 3 - count_y;
const auto sign_i1 = count_z + count_mask_y + count_y;
const ComplexT v0 = arr_(i0);
const ComplexT v1 = arr_(i1);
arr_(i0) = c * v0 + d_sines(sign_i0 % 4) * v1;
arr_(i1) = c * v1 + d_sines(sign_i1 % 4) * v0;
}
});
}

template <class ExecutionSpace, class PrecisionT>
void applyGlobalPhase(Kokkos::View<Kokkos::complex<PrecisionT> *> arr_,
const std::size_t num_qubits,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,22 @@ TEMPLATE_TEST_CASE("StateVectorKokkosManaged::applyIsingXX",
}
}
}
SECTION("PauliRotXX 0,1") {
for (std::size_t index = 0; index < angles.size(); index++) {
StateVectorKokkos<TestType> kokkos_sv{num_qubits};
kokkos_sv.applyPauliRot({0, 1}, adjoint, {angles[index]}, "XX");
std::vector<ComplexT> result_sv(kokkos_sv.getLength(), {0, 0});
kokkos_sv.DeviceToHost(result_sv.data(), kokkos_sv.getLength());

for (std::size_t j = 0; j < exp2(num_qubits); j++) {
CHECK((real(result_sv[j])) ==
Approx(real(expected_results[index][j])));
CHECK((imag(result_sv[j])) ==
Approx(((adjoint) ? -1.0 : 1.0) *
imag(expected_results[index][j])));
}
}
}
SECTION("IsingXX 0,2") {
for (std::size_t index = 0; index < angles.size(); index++) {
StateVectorKokkos<TestType> kokkos_sv{num_qubits};
Expand Down Expand Up @@ -645,6 +661,22 @@ TEMPLATE_TEST_CASE("StateVectorKokkosManaged::applyIsingYY",
}
}
}
SECTION("PauliRotYY 0,1") {
for (std::size_t index = 0; index < angles.size(); index++) {
StateVectorKokkos<TestType> kokkos_sv{num_qubits};
kokkos_sv.applyPauliRot({0, 1}, adjoint, {angles[index]}, "YY");
std::vector<ComplexT> result_sv(kokkos_sv.getLength(), {0, 0});
kokkos_sv.DeviceToHost(result_sv.data(), kokkos_sv.getLength());

for (std::size_t j = 0; j < exp2(num_qubits); j++) {
CHECK((real(result_sv[j])) ==
Approx(real(expected_results[index][j])));
CHECK((imag(result_sv[j])) ==
Approx(((adjoint) ? -1.0 : 1.0) *
imag(expected_results[index][j])));
}
}
}
SECTION("IsingYY 0,2") {
for (std::size_t index = 0; index < angles.size(); index++) {
StateVectorKokkos<TestType> kokkos_sv{num_qubits};
Expand Down Expand Up @@ -700,6 +732,22 @@ TEMPLATE_TEST_CASE("StateVectorKokkosManaged::applyIsingZZ",
}
}
}
SECTION("PauliRotZZ 0,1") {
for (std::size_t index = 0; index < angles.size(); index++) {
StateVectorKokkos<TestType> kokkos_sv{num_qubits};
kokkos_sv.applyPauliRot({0, 1}, adjoint, {angles[index]}, "ZZ");
std::vector<ComplexT> result_sv(kokkos_sv.getLength(), {0, 0});
kokkos_sv.DeviceToHost(result_sv.data(), kokkos_sv.getLength());

for (std::size_t j = 0; j < exp2(num_qubits); j++) {
CHECK((real(result_sv[j])) ==
Approx(real(expected_results[index][j])));
CHECK((imag(result_sv[j])) ==
Approx(((adjoint) ? -1.0 : 1.0) *
imag(expected_results[index][j])));
}
}
}
SECTION("IsingZZ 0,2") {
for (std::size_t index = 0; index < angles.size(); index++) {
StateVectorKokkos<TestType> kokkos_sv{num_qubits};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ class GateImplementationsLM : public PauliGenerator<GateImplementationsLM> {
constexpr auto IMAG = Pennylane::Util::IMAG<PrecisionT>();
PL_ABORT_IF_NOT(wires.size() == word.size(),
"wires and word have incompatible dimensions.")
if (std::find_if_not(word.begin(), word.end(), [](const int w) {
return w == 'Z';
}) == word.end()) {
applyMultiRZ(arr, num_qubits, wires, inverse, angle);
return;
}
const PrecisionT c = std::cos(angle / 2);
const ComplexT s = ((inverse) ? IMAG : -IMAG) * std::sin(angle / 2);
const std::array<ComplexT, 4> sines{s, IMAG * s, -s, -IMAG * s};
Expand Down Expand Up @@ -615,15 +621,11 @@ class GateImplementationsLM : public PauliGenerator<GateImplementationsLM> {
const auto count_y = std::popcount(i0 & mask_y) * 2;
const auto count_z = std::popcount(i0 & mask_z) * 2;
const auto sign_i0 = count_z + count_mask_y * 3 - count_y;
if (mask_xy) [[likely]] {
const auto sign_i1 = count_z + count_mask_y + count_y;
const ComplexT v0 = arr[i0];
const ComplexT v1 = arr[i1];
arr[i0] = c * v0 + sines[sign_i0 % 4] * v1;
arr[i1] = c * v1 + sines[sign_i1 % 4] * v0;
} else [[unlikely]] {
arr[i0] *= c + sines[sign_i0 % 4];
}
const auto sign_i1 = count_z + count_mask_y + count_y;
const ComplexT v0 = arr[i0];
const ComplexT v1 = arr[i1];
arr[i0] = c * v0 + sines[sign_i0 % 4] * v1;
arr[i1] = c * v1 + sines[sign_i1 % 4] * v0;
}
}

Expand Down
6 changes: 6 additions & 0 deletions pennylane_lightning/lightning_kokkos/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ def _apply_lightning(
self._apply_lightning_midmeasure(
operation, mid_measurements, postselect_mode=postselect_mode
)
elif isinstance(operation, qml.PauliRot):
method = getattr(state, "applyPauliRot")
paulis = operation._hyperparameters["pauli_word"]
wires = [i for i, w in zip(wires, paulis) if w != "I"]
word = "".join(p for p in paulis if p != "I") # pylint: disable=protected-access
method(wires, invert_param, operation.parameters, word)
elif method is not None: # apply specialized gate
param = operation.parameters
method(wires, invert_param, param)
Expand Down
8 changes: 6 additions & 2 deletions pennylane_lightning/lightning_kokkos/lightning_kokkos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import sys
from dataclasses import replace
from functools import reduce
from pathlib import Path
from typing import Optional
from warnings import warn
Expand Down Expand Up @@ -156,7 +157,10 @@ def stopping_condition(op: Operator) -> bool:
return len(op.wires) < 10
if isinstance(op, qml.GroverOperator):
return len(op.wires) < 13

if isinstance(op, qml.PauliRot):
word = op._hyperparameters["pauli_word"] # pylint: disable=protected-access
# decomposes to IsingXX, etc. for n <= 2
return reduce(lambda x, y: x + (y != "I"), word, 0) > 2
return op.name in _operations


Expand Down Expand Up @@ -212,7 +216,7 @@ def _supports_adjoint(circuit):

def _adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
return adjoint_ops(op)
return not isinstance(op, qml.PauliRot) and adjoint_ops(op)


def _add_adjoint_transforms(program: TransformProgram) -> None:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ def circuit(x):


@pytest.mark.skipif(
device_name != "lightning.qubit",
reason="PauliRot operations only implemented in lightning.qubit.",
device_name not in ("lightning.qubit", "lightning.kokkos"),
reason="PauliRot operations only implemented in lightning.qubit and lightning.kokkos.",
)
@pytest.mark.parametrize("n_wires", [1, 2, 3, 4, 5, 10, 15])
@pytest.mark.parametrize("n_targets", [1, 2, 3, 4, 5, 10, 15])
Expand All @@ -540,8 +540,12 @@ def test_paulirot(n_wires, n_targets, tol):
init_state /= np.linalg.norm(init_state)
theta = 0.3

for _ in range(10):
word = "".join(pws[w] for w in np.random.randint(0, 3, n_targets))
for i in range(10):
word = (
"Z" * n_targets
if i == 0
else "".join(pws[w] for w in np.random.randint(0, 3, n_targets))
)
wires = np.random.permutation(n_wires)[0:n_targets]
stateprep = qml.StatePrep(init_state, wires=range(n_wires))
op = qml.PauliRot(theta, word, wires=wires)
Expand Down

0 comments on commit 43374cc

Please sign in to comment.