From f496df8943efbe80b738ce7c5e6e1da5ad0b110a Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Tue, 1 Oct 2024 11:43:47 -0400 Subject: [PATCH] test --- .../src/measurements/MeasurementsBase.hpp | 2 +- .../tests/Test_MeasurementsBase.cpp | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp b/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp index 9c8912d6de..7cd8110f6b 100644 --- a/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp +++ b/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp @@ -86,7 +86,7 @@ template class MeasurementsBase { /** * @brief Set the internal random generator to an already existing instance * - * @param catalyst_rng Seed + * @param rng An already existing instance of a random number generator */ void setRNG(std::mt19937 rng) { this->rng = std::move(rng); } diff --git a/pennylane_lightning/core/src/measurements/tests/Test_MeasurementsBase.cpp b/pennylane_lightning/core/src/measurements/tests/Test_MeasurementsBase.cpp index 674659a9cc..389f16fb36 100644 --- a/pennylane_lightning/core/src/measurements/tests/Test_MeasurementsBase.cpp +++ b/pennylane_lightning/core/src/measurements/tests/Test_MeasurementsBase.cpp @@ -20,6 +20,8 @@ using Pennylane::Util::isApproxEqual; } // namespace /// @endcond #include +#include +#include #include #ifdef _ENABLE_PLQUBIT @@ -1251,7 +1253,9 @@ TEST_CASE("Var Shot- TensorProdObs", "[MeasurementsBase][Observables]") { testTensorProdObsVarShot(); } } -template void testSamples() { + +template void testSamples( + const std::optional& rng = std::nullopt) { if constexpr (!std::is_same_v) { using StateVectorT = typename TypeList::Type; using PrecisionT = typename StateVectorT::PrecisionT; @@ -1281,7 +1285,8 @@ template void testSamples() { std::size_t num_qubits = 3; std::size_t N = std::pow(2, num_qubits); std::size_t num_samples = 100000; - auto &&samples = Measurer.generate_samples(num_samples); + auto &&samples = rng.has_value() ? + Measurer.generate_samples(num_samples, rng.value()) : Measurer.generate_samples(num_samples); std::vector counts(N, 0); std::vector samples_decimal(num_samples, 0); @@ -1307,7 +1312,7 @@ template void testSamples() { REQUIRE_THAT(probabilities, Catch::Approx(expected_probabilities).margin(.05)); } - testSamples(); + testSamples(rng); } } @@ -1317,6 +1322,13 @@ TEST_CASE("Samples", "[MeasurementsBase]") { } } +TEST_CASE("Seeded samples", "[MeasurementsBase]") { + if constexpr (BACKEND_FOUND) { + std::mt19937 rng(37.42); + testSamples(rng); + } +} + template void testSamplesCountsObs() { if constexpr (!std::is_same_v) { using StateVectorT = typename TypeList::Type; @@ -1729,4 +1741,4 @@ TEST_CASE("Measure Shot - SparseHObs ", "[MeasurementsBase][Observables]") { if constexpr (BACKEND_FOUND) { testSparseHObsMeasureShot(); } -} \ No newline at end of file +}