Skip to content

Commit

Permalink
add controlled matrix support
Browse files Browse the repository at this point in the history
  • Loading branch information
josephleekl committed Oct 21, 2024
1 parent 87b7d0c commit b24f223
Show file tree
Hide file tree
Showing 9 changed files with 650 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,75 @@ class StateVectorKokkos final
PL_ABORT_IF(gate_matrix.empty(),
std::string("Operation does not exist for ") + opName +
std::string(" and no matrix provided."));
PL_ABORT("Controlled matrix operation not yet supported");
return;
return applyNCMultiQubitOp(vector2view(gate_matrix),
controlled_wires, controlled_values,
wires, inverse);
}
}

/**
* @brief Apply a controlled-multi qubit operator to the state vector using
* a matrix
*
* @param matrix Kokkos gate matrix in the device space.
* @param controlled_wires Control wires.
* @param controlled_values Control values (true or false).
* @param wires Wires to apply gate to.
* @param inverse Indicates whether to use adjoint of gate.
*/
void applyNCMultiQubitOp(const KokkosVector matrix,
const std::vector<std::size_t> &controlled_wires,
const std::vector<bool> &controlled_values,
const std::vector<std::size_t> &wires,
bool inverse = false) {

auto &&num_qubits = this->getNumQubits();
std::size_t two2N =
std::exp2(num_qubits - wires.size() - controlled_wires.size());
std::size_t dim = std::exp2(wires.size());
KokkosVector matrix_trans("matrix_trans", matrix.size());

if (inverse) {
Kokkos::MDRangePolicy<DoubleLoopRank> policy_2d({0, 0}, {dim, dim});
Kokkos::parallel_for(
policy_2d,
KOKKOS_LAMBDA(const std::size_t i, const std::size_t j) {
matrix_trans(i + j * dim) = conj(matrix(i * dim + j));
});
} else {
matrix_trans = matrix;
}

switch (wires.size()) {
case 1:
Kokkos::parallel_for(two2N, applyNC1QubitOpFunctor<fp_t>(
*data_, num_qubits, matrix_trans,
controlled_wires, controlled_values,
wires));
break;
case 2:
Kokkos::parallel_for(two2N, applyNC2QubitOpFunctor<fp_t>(
*data_, num_qubits, matrix_trans,
controlled_wires, controlled_values,
wires));
break;
case 3:
Kokkos::parallel_for(two2N, applyNC3QubitOpFunctor<fp_t>(
*data_, num_qubits, matrix_trans,
controlled_wires, controlled_values,
wires));
break;
default:
std::size_t scratch_size = ScratchViewComplex::shmem_size(dim) +
ScratchViewSizeT::shmem_size(dim);
Kokkos::parallel_for(
"multiNCQubitOpFunctor",
TeamPolicy(two2N, Kokkos::AUTO, dim)
.set_scratch_size(0, Kokkos::PerTeam(scratch_size)),
NCMultiQubitOpFunctor<PrecisionT>(
*data_, num_qubits, matrix_trans, controlled_wires,
controlled_values, wires));
break;
}
}

Expand Down Expand Up @@ -546,7 +613,56 @@ class StateVectorKokkos final
"number of wires");
applyMatrix(matrix.data(), wires, inverse);
}
inline void applyControlledMatrix(
ComplexT *matrix, const std::vector<std::size_t> &controlled_wires,
const std::vector<bool> &controlled_values,
const std::vector<std::size_t> &wires, bool inverse = false) {
PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");
std::size_t n = static_cast<std::size_t>(1U) << wires.size();
KokkosVector matrix_(matrix, n * n);
applyNCMultiQubitOp(matrix_, controlled_wires, controlled_values, wires,
inverse);
}

inline void
applyControlledMatrix(const ComplexT *matrix,
const std::vector<std::size_t> &controlled_wires,
const std::vector<bool> &controlled_values,
const std::vector<std::size_t> &wires,
bool inverse = false) {
PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");
std::size_t n = static_cast<std::size_t>(1U) << wires.size();
std::size_t n2 = n * n;
KokkosVector matrix_("matrix_", n2);
Kokkos::deep_copy(matrix_, UnmanagedConstComplexHostView(matrix, n2));
applyNCMultiQubitOp(matrix_, controlled_wires, controlled_values, wires,
inverse);
}

/**
* @brief Apply a given controlled-matrix directly to the statevector.
*
* @param matrix Vector containing the statevector data (in row-major
* format).
* @param controlled_wires Control wires.
* @param controlled_values Control values (false or true).
* @param wires Wires to apply gate to.
* @param inverse Indicate whether inverse should be taken.
*/
inline void
applyControlledMatrix(const std::vector<ComplexT> &matrix,
const std::vector<std::size_t> &controlled_wires,
const std::vector<bool> &controlled_values,
const std::vector<std::size_t> &wires,
bool inverse = false) {
PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");
PL_ABORT_IF(matrix.size() != exp2(2 * wires.size()),
"The size of matrix does not match with the given "
"number of wires");
applyControlledMatrix(matrix.data(), controlled_wires,
controlled_values, wires, inverse);
}

/**
* @brief Apply a single generator to the state vector using the given
* kernel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,26 @@ namespace Pennylane::LightningKokkos {
using StateVectorBackends =
Pennylane::Util::TypeList<StateVectorKokkos<float>,
StateVectorKokkos<double>, void>;
/**
* @brief Register controlled matrix kernel.
*/
template <class StateVectorT>
void applyControlledMatrix(
StateVectorT &st,
const py::array_t<std::complex<typename StateVectorT::PrecisionT>,
py::array::c_style | py::array::forcecast> &matrix,
const std::vector<std::size_t> &controlled_wires,
const std::vector<bool> &controlled_values,
const std::vector<std::size_t> &wires, bool inverse = false) {
using ComplexT = typename StateVectorT::ComplexT;
st.applyControlledMatrix(
static_cast<const ComplexT *>(matrix.request().ptr), controlled_wires,
controlled_values, wires, inverse);
}

/**
* @brief Register controlled gates.
*/
template <class StateVectorT, class PyClass>
void registerControlledGate(PyClass &pyclass) {
using PrecisionT =
Expand Down Expand Up @@ -172,7 +191,9 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
.def("collapse", &StateVectorT::collapse,
"Collapse the statevector onto the 0 or 1 branch of a given wire.")
.def("normalize", &StateVectorT::normalize,
"Normalize the statevector to norm 1.");
"Normalize the statevector to norm 1.")
.def("applyControlledMatrix", &applyControlledMatrix<StateVectorT>,
"Apply controlled operation");
}

/**
Expand Down
Loading

0 comments on commit b24f223

Please sign in to comment.