From f6621451a74fc0cb40be031cff0a7ed924e588a1 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 15 Feb 2024 01:58:26 -0500 Subject: [PATCH] Support smaller statically sized `BitSet`s --- include/Containers/BitSets.hpp | 119 +++++++++++++++++---------------- test/bitset_test.cpp | 17 +++++ 2 files changed, 78 insertions(+), 58 deletions(-) diff --git a/include/Containers/BitSets.hpp b/include/Containers/BitSets.hpp index 7cc81be..8c941c9 100644 --- a/include/Containers/BitSets.hpp +++ b/include/Containers/BitSets.hpp @@ -1,6 +1,9 @@ #pragma once #include "Math/Array.hpp" #include "Utilities/Invariant.hpp" +#include "Utilities/TypePromotion.hpp" +#include +#include #include #include #include @@ -10,6 +13,7 @@ #include namespace poly::containers { +using utils::invariant; struct EndSentinel { [[nodiscard]] constexpr auto operator-(auto it) -> ptrdiff_t { @@ -24,16 +28,15 @@ struct EndSentinel { template concept CanResize = requires(T t) { t.resize(0); }; -class BitSetIterator { - [[no_unique_address]] const uint64_t *it; - [[no_unique_address]] const uint64_t *end; - [[no_unique_address]] uint64_t istate; +template class BitSetIterator { + [[no_unique_address]] const U *it; + [[no_unique_address]] const U *end; + [[no_unique_address]] U istate; [[no_unique_address]] ptrdiff_t cstate0{-1}; [[no_unique_address]] ptrdiff_t cstate1{0}; public: - constexpr explicit BitSetIterator(const uint64_t *_it, const uint64_t *_end, - uint64_t _istate) + constexpr explicit BitSetIterator(const U *_it, const U *_end, U _istate) : it{_it}, end{_end}, istate{_istate} {} using value_type = ptrdiff_t; using difference_type = ptrdiff_t; @@ -43,7 +46,7 @@ class BitSetIterator { if (++it == end) return *this; istate = *it; cstate0 = -1; - cstate1 += 64; + cstate1 += 8 * sizeof(U); } ptrdiff_t tzp1 = std::countr_zero(istate); cstate0 += ++tzp1; @@ -84,6 +87,10 @@ concept Collection = requires(T t) { /// A set of `ptrdiff_t` elements. /// Initially constructed template > struct BitSet { + using U = utils::eltype_t; + static constexpr U usize = 8 * sizeof(U); + static constexpr U umask = usize - 1; + static constexpr U ushift = std::countr_zero(usize); [[no_unique_address]] T data{}; // ptrdiff_t operator[](ptrdiff_t i) const { // return data[i]; @@ -92,21 +99,21 @@ template > struct BitSet { constexpr explicit BitSet(T &&_data) : data{std::move(_data)} {} constexpr explicit BitSet(const T &_data) : data{_data} {} static constexpr auto numElementsNeeded(ptrdiff_t N) -> unsigned { - return unsigned(((N + 63) >> 6)); + return unsigned(((N + usize - 1) >> ushift)); } constexpr explicit BitSet(ptrdiff_t N) : data{numElementsNeeded(N), 0} {} - constexpr void resize64(ptrdiff_t N) { + constexpr void resizeData(ptrdiff_t N) { if constexpr (CanResize) data.resize(N); else invariant(N <= std::ssize(data)); } constexpr void resize(ptrdiff_t N) { if constexpr (CanResize) data.resize(numElementsNeeded(N)); - else invariant(N <= std::ssize(data) * 64); + else invariant(N <= std::ssize(data) * usize); } - constexpr void resize(ptrdiff_t N, uint64_t x) { + constexpr void resize(ptrdiff_t N, U x) { if constexpr (CanResize) data.resize(numElementsNeeded(N), x); else { - invariant(N <= std::ssize(data) * 64); + invariant(N <= std::ssize(data) * usize); std::fill(data.begin(), data.end(), x); } } @@ -114,31 +121,31 @@ template > struct BitSet { if constexpr (CanResize) { ptrdiff_t M = numElementsNeeded(N); if (M > std::ssize(data)) data.resize(M); - } else invariant(N <= std::ssize(data) * 64); + } else invariant(N <= std::ssize(data) * ptrdiff_t(usize)); } static constexpr auto dense(ptrdiff_t N) -> BitSet { BitSet b; ptrdiff_t M = numElementsNeeded(N); if (!M) return b; - uint64_t maxval = std::numeric_limits::max(); + U maxval = std::numeric_limits::max(); if constexpr (CanResize) b.data.resize(M, maxval); else for (ptrdiff_t i = 0; i < M - 1; ++i) b.data[i] = maxval; - if (ptrdiff_t rem = N & 63) b.data[M - 1] = (ptrdiff_t(1) << rem) - 1; + if (ptrdiff_t rem = N & usize) b.data[M - 1] = (ptrdiff_t(1) << rem) - 1; return b; } [[nodiscard]] constexpr auto maxValue() const -> ptrdiff_t { ptrdiff_t N = std::ssize(data); - return N ? (64 * N - std::countl_zero(data[N - 1])) : 0; + return N ? (usize * N - std::countl_zero(data[N - 1])) : 0; } - // BitSet::Iterator(std::vector &seta) + // BitSet::Iterator(std::vector &seta) // : set(seta), didx(0), offset(0), state(seta[0]), count(0) {}; - [[nodiscard]] constexpr auto begin() const -> BitSetIterator { + [[nodiscard]] constexpr auto begin() const -> BitSetIterator { auto be = data.begin(); auto de = data.end(); - const uint64_t *b{be}; - const uint64_t *e{de}; - if (b == e) return BitSetIterator{b, e, 0}; + const U *b{be}; + const U *e{de}; + if (b == e) return BitSetIterator{b, e, 0}; BitSetIterator it{b, e, *b}; return ++it; } @@ -147,73 +154,69 @@ template > struct BitSet { }; [[nodiscard]] constexpr auto front() const -> ptrdiff_t { for (ptrdiff_t i = 0; i < std::ssize(data); ++i) - if (data[i]) return 64 * i + std::countr_zero(data[i]); + if (data[i]) return usize * i + std::countr_zero(data[i]); return std::numeric_limits::max(); } - static constexpr auto contains(math::PtrVector data, ptrdiff_t x) - -> uint64_t { + static constexpr auto contains(math::PtrVector data, ptrdiff_t x) -> U { if (data.empty()) return 0; - ptrdiff_t d = x >> ptrdiff_t(6); - uint64_t r = uint64_t(x) & uint64_t(63); - uint64_t mask = uint64_t(1) << r; + ptrdiff_t d = x >> ptrdiff_t(ushift); + U r = U(x) & umask; + U mask = U(1) << r; return (data[d] & (mask)); } - [[nodiscard]] constexpr auto contains(ptrdiff_t i) const -> uint64_t { + [[nodiscard]] constexpr auto contains(ptrdiff_t i) const -> U { return contains(data, i); } struct Contains { const T &d; - constexpr auto operator()(ptrdiff_t i) const -> uint64_t { - return contains(d, i); - } + constexpr auto operator()(ptrdiff_t i) const -> U { return contains(d, i); } }; [[nodiscard]] constexpr auto contains() const -> Contains { return Contains{data}; } constexpr auto insert(ptrdiff_t x) -> bool { - ptrdiff_t d = x >> ptrdiff_t(6); - uint64_t r = uint64_t(x) & uint64_t(63); - uint64_t mask = uint64_t(1) << r; - if (d >= std::ssize(data)) resize64(d + 1); + ptrdiff_t d = x >> ptrdiff_t(ushift); + U r = U(x) & umask; + U mask = U(1) << r; + if (d >= std::ssize(data)) resizeData(d + 1); bool contained = ((data[d] & mask) != 0); if (!contained) data[d] |= (mask); return contained; } constexpr void uncheckedInsert(ptrdiff_t x) { - ptrdiff_t d = x >> ptrdiff_t(6); - uint64_t r = uint64_t(x) & uint64_t(63); - uint64_t mask = uint64_t(1) << r; - if (d >= std::ssize(data)) resize64(d + 1); + ptrdiff_t d = x >> ushift; + U r = U(x) & umask; + U mask = U(1) << r; + if (d >= std::ssize(data)) resizeData(d + 1); data[d] |= (mask); } constexpr auto remove(ptrdiff_t x) -> bool { - ptrdiff_t d = x >> ptrdiff_t(6); - uint64_t r = uint64_t(x) & uint64_t(63); - uint64_t mask = uint64_t(1) << r; + ptrdiff_t d = x >> ushift; + U r = U(x) & umask; + U mask = U(1) << r; bool contained = ((data[d] & mask) != 0); if (contained) data[d] &= (~mask); return contained; } - static constexpr void set(uint64_t &d, ptrdiff_t r, bool b) { - uint64_t mask = uint64_t(1) << r; + static constexpr void set(U &d, ptrdiff_t r, bool b) { + U mask = U(1) << r; if (b == ((d & mask) != 0)) return; if (b) d |= mask; else d &= (~mask); } - static constexpr void set(math::MutPtrVector data, ptrdiff_t x, - bool b) { - ptrdiff_t d = x >> ptrdiff_t(6); - uint64_t r = uint64_t(x) & uint64_t(63); + static constexpr void set(math::MutPtrVector data, ptrdiff_t x, bool b) { + ptrdiff_t d = x >> ushift; + U r = U(x) & umask; set(data[d], r, b); } class Reference { - [[no_unique_address]] math::MutPtrVector data; + [[no_unique_address]] math::MutPtrVector data; [[no_unique_address]] ptrdiff_t i; public: - constexpr explicit Reference(math::MutPtrVector dd, ptrdiff_t ii) + constexpr explicit Reference(math::MutPtrVector dd, ptrdiff_t ii) : data(dd), i(ii) {} constexpr explicit operator bool() const { return contains(data, i); } constexpr auto operator=(bool b) -> Reference & { @@ -227,7 +230,7 @@ template > struct BitSet { } constexpr auto operator[](ptrdiff_t i) -> Reference { maybeResize(i + 1); - math::MutPtrVector d{data}; + math::MutPtrVector d{data}; return Reference{d, i}; } [[nodiscard]] constexpr auto size() const -> ptrdiff_t { @@ -246,25 +249,25 @@ template > struct BitSet { } constexpr void setUnion(const BitSet &bs) { ptrdiff_t O = std::ssize(bs.data), N = std::ssize(data); - if (O > N) resize64(O); + if (O > N) resizeData(O); for (ptrdiff_t i = 0; i < O; ++i) { - uint64_t d = data[i] | bs.data[i]; + U d = data[i] | bs.data[i]; data[i] = d; } } constexpr auto operator&=(const BitSet &bs) -> BitSet & { - if (std::ssize(bs.data) < std::ssize(data)) resize64(std::ssize(bs.data)); + if (std::ssize(bs.data) < std::ssize(data)) resizeData(std::ssize(bs.data)); for (ptrdiff_t i = 0; i < std::ssize(data); ++i) data[i] &= bs.data[i]; return *this; } // &! constexpr auto operator-=(const BitSet &bs) -> BitSet & { - if (std::ssize(bs.data) < std::ssize(data)) resize64(std::ssize(bs.data)); + if (std::ssize(bs.data) < std::ssize(data)) resizeData(std::ssize(bs.data)); for (ptrdiff_t i = 0; i < std::ssize(data); ++i) data[i] &= (~bs.data[i]); return *this; } constexpr auto operator|=(const BitSet &bs) -> BitSet & { - if (std::ssize(bs.data) > std::ssize(data)) resize64(std::ssize(bs.data)); + if (std::ssize(bs.data) > std::ssize(data)) resizeData(std::ssize(bs.data)); for (ptrdiff_t i = 0; i < std::ssize(bs.data); ++i) data[i] |= bs.data[i]; return *this; } @@ -315,7 +318,7 @@ template > struct BitSliceView { [[no_unique_address]] const B &i; struct Iterator { [[no_unique_address]] math::MutPtrVector a; - [[no_unique_address]] BitSetIterator it; + [[no_unique_address]] BitSetIterator it; constexpr auto operator==(EndSentinel) const -> bool { return it == EndSentinel{}; } @@ -336,7 +339,7 @@ template > struct BitSliceView { constexpr auto begin() -> Iterator { return {a, i.begin()}; } struct ConstIterator { [[no_unique_address]] math::PtrVector a; - [[no_unique_address]] BitSetIterator it; + [[no_unique_address]] BitSetIterator it; constexpr auto operator==(EndSentinel) const -> bool { return it == EndSentinel{}; } diff --git a/test/bitset_test.cpp b/test/bitset_test.cpp index 4d29902..0e28a3a 100644 --- a/test/bitset_test.cpp +++ b/test/bitset_test.cpp @@ -82,3 +82,20 @@ TEST(FixedSizeBitSetTest, BasicAssertions) { EXPECT_EQ(sv[0], 4); EXPECT_EQ(sv[1], 10); } +// NOLINTNEXTLINE(modernize-use-trailing-return-type) +TEST(FixedSizeSmallBitSetTest, BasicAssertions) { + static_assert(sizeof(BitSet>) == 2); + BitSet> bs; + bs[4] = true; + bs[10] = true; + bs[7] = true; + bs.insert(5); + EXPECT_EQ(bs.data[0], 1200); + Vector sv; + for (auto i : bs) sv.push_back(i); + EXPECT_EQ(sv.size(), 4); + EXPECT_EQ(sv[0], 4); + EXPECT_EQ(sv[1], 5); + EXPECT_EQ(sv[2], 7); + EXPECT_EQ(sv[3], 10); +}