Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
paul0403 committed Oct 1, 2024
1 parent 0b943f1 commit f496df8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ template <class StateVectorT, class Derived> class MeasurementsBase {
/**
* @brief Set the internal random generator to an already existing instance
*
* @param catalyst_rng Seed
* @param rng An already existing instance of a random number generator
*/
void setRNG(std::mt19937 rng) { this->rng = std::move(rng); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ using Pennylane::Util::isApproxEqual;
} // namespace
/// @endcond
#include <algorithm>
#include <optional>
#include <random>
#include <string>

#ifdef _ENABLE_PLQUBIT
Expand Down Expand Up @@ -1251,7 +1253,9 @@ TEST_CASE("Var Shot- TensorProdObs", "[MeasurementsBase][Observables]") {
testTensorProdObsVarShot<TestStateVectorBackends>();
}
}
template <typename TypeList> void testSamples() {

template <typename TypeList> void testSamples(
const std::optional<std::mt19937>& rng = std::nullopt) {
if constexpr (!std::is_same_v<TypeList, void>) {
using StateVectorT = typename TypeList::Type;
using PrecisionT = typename StateVectorT::PrecisionT;
Expand Down Expand Up @@ -1281,7 +1285,8 @@ template <typename TypeList> void testSamples() {
std::size_t num_qubits = 3;
std::size_t N = std::pow(2, num_qubits);
std::size_t num_samples = 100000;
auto &&samples = Measurer.generate_samples(num_samples);
auto &&samples = rng.has_value() ?
Measurer.generate_samples(num_samples, rng.value()) : Measurer.generate_samples(num_samples);

std::vector<std::size_t> counts(N, 0);
std::vector<std::size_t> samples_decimal(num_samples, 0);
Expand All @@ -1307,7 +1312,7 @@ template <typename TypeList> void testSamples() {
REQUIRE_THAT(probabilities,
Catch::Approx(expected_probabilities).margin(.05));
}
testSamples<typename TypeList::Next>();
testSamples<typename TypeList::Next>(rng);
}
}

Expand All @@ -1317,6 +1322,13 @@ TEST_CASE("Samples", "[MeasurementsBase]") {
}
}

TEST_CASE("Seeded samples", "[MeasurementsBase]") {
if constexpr (BACKEND_FOUND) {
std::mt19937 rng(37.42);
testSamples<TestStateVectorBackends>(rng);
}
}

template <typename TypeList> void testSamplesCountsObs() {
if constexpr (!std::is_same_v<TypeList, void>) {
using StateVectorT = typename TypeList::Type;
Expand Down Expand Up @@ -1729,4 +1741,4 @@ TEST_CASE("Measure Shot - SparseHObs ", "[MeasurementsBase][Observables]") {
if constexpr (BACKEND_FOUND) {
testSparseHObsMeasureShot<TestStateVectorBackends>();
}
}
}

0 comments on commit f496df8

Please sign in to comment.