diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index f7e6aee83c..8513928203 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -46,6 +46,9 @@ ### Improvements +* Add joint check for the N-controlled wires support in `lightning.qubit`. + [(#949)](https://github.com/PennyLaneAI/pennylane-lightning/pull/949) + * Optimize `GlobalPhase` and `C(GlobalPhase)` gate implementation for `lightning.gpu`. [(#946)](https://github.com/PennyLaneAI/pennylane-lightning/pull/946) diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp index 78521ebe8b..2901160a2f 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp @@ -369,6 +369,11 @@ class StateVectorLQubit : public StateVectorBase { const std::vector &wires, bool inverse = false, const std::vector ¶ms = {}) { + // Add disjoint check. + PL_ABORT_IF_NOT( + areVecsDisjoint(controlled_wires, wires), + "`controlled_wires` and `target wires` must be disjoint."); + PL_ABORT_IF_NOT(controlled_wires.size() == controlled_values.size(), "`controlled_wires` must have the same size as " "`controlled_values`."); @@ -420,6 +425,11 @@ class StateVectorLQubit : public StateVectorBase { const std::vector &wires, bool inverse, const std::vector ¶ms, const std::vector &matrix) { + // Add disjoint check. + PL_ABORT_IF_NOT( + areVecsDisjoint(controlled_wires, wires), + "`controlled_wires` and `target wires` must be disjoint."); + PL_ABORT_IF_NOT(controlled_wires.size() == controlled_values.size(), "`controlled_wires` must have the same size as " "`controlled_values`."); @@ -569,6 +579,10 @@ class StateVectorLQubit : public StateVectorBase { const std::vector &controlled_values, const std::vector &wires, bool inverse = false) { + // Add disjoint check. + PL_ABORT_IF_NOT( + areVecsDisjoint(controlled_wires, wires), + "`controlled_wires` and `target wires` must be disjoint."); applyControlledMatrix(matrix.data(), controlled_wires, controlled_values, wires, inverse); } diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/gates/tests/Test_GateImplementations_Nonparam.cpp b/pennylane_lightning/core/src/simulators/lightning_qubit/gates/tests/Test_GateImplementations_Nonparam.cpp index 305d161f2e..afd371f7aa 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/gates/tests/Test_GateImplementations_Nonparam.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/gates/tests/Test_GateImplementations_Nonparam.cpp @@ -674,6 +674,14 @@ TEMPLATE_TEST_CASE("StateVectorLQubitManaged::applyOperation non-param " << "controls = {" << control << "} " << ", wires = {" << wire << "} - " << PrecisionToName::value) { + if (control == wire) { + REQUIRE_THROWS_AS( + sv0.applyOperation("PauliX", std::vector{control}, + std::vector{true}, + std::vector{wire}), + LightningException); + } + if (control != wire) { auto st0 = createRandomStateVectorData(re, num_qubits); sv0.updateData(st0); @@ -687,7 +695,7 @@ TEMPLATE_TEST_CASE("StateVectorLQubitManaged::applyOperation non-param " approx(sv1.getDataVector()).margin(margin)); } - if (control != 0 && wire != 0) { + if (control != 0 && wire != 0 && control != wire) { sv0.applyOperation("Toffoli", {0, control, wire}); sv1.applyOperation("PauliX", std::vector{0, control}, std::vector{true, true}, @@ -760,6 +768,15 @@ TEMPLATE_TEST_CASE("StateVectorLQubitManaged::applyOperation non-param " REQUIRE(sv0.getDataVector() == approx(sv1.getDataVector()).margin(margin)); } + + if (control == wire) { + const auto matrix = getHadamard(); + REQUIRE_THROWS_AS(sv0.applyControlledMatrix( + matrix, std::vector{control}, + std::vector{true}, + std::vector{wire}), + LightningException); + } } DYNAMIC_SECTION("N-controlled S - " << "controls = {" << control << "} " diff --git a/pennylane_lightning/core/src/utils/Util.hpp b/pennylane_lightning/core/src/utils/Util.hpp index 5478cdbdcb..22d0a3b8de 100644 --- a/pennylane_lightning/core/src/utils/Util.hpp +++ b/pennylane_lightning/core/src/utils/Util.hpp @@ -574,4 +574,18 @@ std::vector cast_vector(const std::vector &vec) { return result; } +/** + * @brief Check if two vectors are disjoint. + * @tparam T Data type. + * @param v1 First vector. + * @param v2 Second vector. + * + * @return bool True if the vectors are disjoint, false otherwise. + */ +template +bool areVecsDisjoint(const std::vector &v1, const std::vector &v2) { + std::set s0(v1.begin(), v1.end()); + s0.insert(v2.begin(), v2.end()); + return s0.size() == v1.size() + v2.size(); +} } // namespace Pennylane::Util diff --git a/pennylane_lightning/core/src/utils/tests/Test_Util.cpp b/pennylane_lightning/core/src/utils/tests/Test_Util.cpp index 4f52318fdd..2ba66c779f 100644 --- a/pennylane_lightning/core/src/utils/tests/Test_Util.cpp +++ b/pennylane_lightning/core/src/utils/tests/Test_Util.cpp @@ -217,3 +217,19 @@ TEMPLATE_TEST_CASE("Util::kronProd", "[Util][LinearAlgebra]", float, double) { CHECK(vec == expected); } } + +TEST_CASE("Util::areVecsDisjoint", "[Util][LinearAlgebra]") { + SECTION("Test for disjoint vectors") { + std::vector vec0{0, 1, 2}; + std::vector vec1{3, 4, 5}; + + REQUIRE(areVecsDisjoint(vec0, vec1) == true); + } + + SECTION("Test for joint vectors") { + std::vector vec0{0, 1, 2}; + std::vector vec1{2, 4, 5}; + + REQUIRE(areVecsDisjoint(vec0, vec1) == false); + } +}