Skip to content

Commit

Permalink
Support smaller statically sized BitSets
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Feb 15, 2024
1 parent 2a408f7 commit f662145
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 58 deletions.
119 changes: 61 additions & 58 deletions include/Containers/BitSets.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once
#include "Math/Array.hpp"
#include "Utilities/Invariant.hpp"
#include "Utilities/TypePromotion.hpp"
#include <bit>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <cstdio>
Expand All @@ -10,6 +13,7 @@
#include <string>

namespace poly::containers {
using utils::invariant;

struct EndSentinel {
[[nodiscard]] constexpr auto operator-(auto it) -> ptrdiff_t {
Expand All @@ -24,16 +28,15 @@ struct EndSentinel {
template <typename T>
concept CanResize = requires(T t) { t.resize(0); };

class BitSetIterator {
[[no_unique_address]] const uint64_t *it;
[[no_unique_address]] const uint64_t *end;
[[no_unique_address]] uint64_t istate;
template <std::unsigned_integral U> class BitSetIterator {
[[no_unique_address]] const U *it;
[[no_unique_address]] const U *end;
[[no_unique_address]] U istate;
[[no_unique_address]] ptrdiff_t cstate0{-1};
[[no_unique_address]] ptrdiff_t cstate1{0};

public:
constexpr explicit BitSetIterator(const uint64_t *_it, const uint64_t *_end,
uint64_t _istate)
constexpr explicit BitSetIterator(const U *_it, const U *_end, U _istate)
: it{_it}, end{_end}, istate{_istate} {}
using value_type = ptrdiff_t;
using difference_type = ptrdiff_t;
Expand All @@ -43,7 +46,7 @@ class BitSetIterator {
if (++it == end) return *this;
istate = *it;
cstate0 = -1;
cstate1 += 64;
cstate1 += 8 * sizeof(U);
}
ptrdiff_t tzp1 = std::countr_zero(istate);
cstate0 += ++tzp1;
Expand Down Expand Up @@ -84,6 +87,10 @@ concept Collection = requires(T t) {
/// A set of `ptrdiff_t` elements.
/// Initially constructed
template <Collection T = math::Vector<uint64_t, 1>> struct BitSet {
using U = utils::eltype_t<T>;
static constexpr U usize = 8 * sizeof(U);
static constexpr U umask = usize - 1;
static constexpr U ushift = std::countr_zero(usize);
[[no_unique_address]] T data{};
// ptrdiff_t operator[](ptrdiff_t i) const {
// return data[i];
Expand All @@ -92,53 +99,53 @@ template <Collection T = math::Vector<uint64_t, 1>> struct BitSet {
constexpr explicit BitSet(T &&_data) : data{std::move(_data)} {}
constexpr explicit BitSet(const T &_data) : data{_data} {}
static constexpr auto numElementsNeeded(ptrdiff_t N) -> unsigned {
return unsigned(((N + 63) >> 6));
return unsigned(((N + usize - 1) >> ushift));
}
constexpr explicit BitSet(ptrdiff_t N) : data{numElementsNeeded(N), 0} {}
constexpr void resize64(ptrdiff_t N) {
constexpr void resizeData(ptrdiff_t N) {
if constexpr (CanResize<T>) data.resize(N);
else invariant(N <= std::ssize(data));
}
constexpr void resize(ptrdiff_t N) {
if constexpr (CanResize<T>) data.resize(numElementsNeeded(N));
else invariant(N <= std::ssize(data) * 64);
else invariant(N <= std::ssize(data) * usize);
}
constexpr void resize(ptrdiff_t N, uint64_t x) {
constexpr void resize(ptrdiff_t N, U x) {
if constexpr (CanResize<T>) data.resize(numElementsNeeded(N), x);
else {
invariant(N <= std::ssize(data) * 64);
invariant(N <= std::ssize(data) * usize);
std::fill(data.begin(), data.end(), x);
}
}
constexpr void maybeResize(ptrdiff_t N) {
if constexpr (CanResize<T>) {
ptrdiff_t M = numElementsNeeded(N);
if (M > std::ssize(data)) data.resize(M);
} else invariant(N <= std::ssize(data) * 64);
} else invariant(N <= std::ssize(data) * ptrdiff_t(usize));
}
static constexpr auto dense(ptrdiff_t N) -> BitSet {
BitSet b;
ptrdiff_t M = numElementsNeeded(N);
if (!M) return b;
uint64_t maxval = std::numeric_limits<uint64_t>::max();
U maxval = std::numeric_limits<U>::max();
if constexpr (CanResize<T>) b.data.resize(M, maxval);
else
for (ptrdiff_t i = 0; i < M - 1; ++i) b.data[i] = maxval;
if (ptrdiff_t rem = N & 63) b.data[M - 1] = (ptrdiff_t(1) << rem) - 1;
if (ptrdiff_t rem = N & usize) b.data[M - 1] = (ptrdiff_t(1) << rem) - 1;
return b;
}
[[nodiscard]] constexpr auto maxValue() const -> ptrdiff_t {
ptrdiff_t N = std::ssize(data);
return N ? (64 * N - std::countl_zero(data[N - 1])) : 0;
return N ? (usize * N - std::countl_zero(data[N - 1])) : 0;
}
// BitSet::Iterator(std::vector<std::uint64_t> &seta)
// BitSet::Iterator(std::vector<std::U> &seta)
// : set(seta), didx(0), offset(0), state(seta[0]), count(0) {};
[[nodiscard]] constexpr auto begin() const -> BitSetIterator {
[[nodiscard]] constexpr auto begin() const -> BitSetIterator<U> {
auto be = data.begin();
auto de = data.end();
const uint64_t *b{be};
const uint64_t *e{de};
if (b == e) return BitSetIterator{b, e, 0};
const U *b{be};
const U *e{de};
if (b == e) return BitSetIterator<U>{b, e, 0};
BitSetIterator it{b, e, *b};
return ++it;
}
Expand All @@ -147,73 +154,69 @@ template <Collection T = math::Vector<uint64_t, 1>> struct BitSet {
};
[[nodiscard]] constexpr auto front() const -> ptrdiff_t {
for (ptrdiff_t i = 0; i < std::ssize(data); ++i)
if (data[i]) return 64 * i + std::countr_zero(data[i]);
if (data[i]) return usize * i + std::countr_zero(data[i]);
return std::numeric_limits<ptrdiff_t>::max();
}
static constexpr auto contains(math::PtrVector<uint64_t> data, ptrdiff_t x)
-> uint64_t {
static constexpr auto contains(math::PtrVector<U> data, ptrdiff_t x) -> U {
if (data.empty()) return 0;
ptrdiff_t d = x >> ptrdiff_t(6);
uint64_t r = uint64_t(x) & uint64_t(63);
uint64_t mask = uint64_t(1) << r;
ptrdiff_t d = x >> ptrdiff_t(ushift);
U r = U(x) & umask;
U mask = U(1) << r;
return (data[d] & (mask));
}
[[nodiscard]] constexpr auto contains(ptrdiff_t i) const -> uint64_t {
[[nodiscard]] constexpr auto contains(ptrdiff_t i) const -> U {
return contains(data, i);
}
struct Contains {
const T &d;
constexpr auto operator()(ptrdiff_t i) const -> uint64_t {
return contains(d, i);
}
constexpr auto operator()(ptrdiff_t i) const -> U { return contains(d, i); }
};
[[nodiscard]] constexpr auto contains() const -> Contains {
return Contains{data};
}
constexpr auto insert(ptrdiff_t x) -> bool {
ptrdiff_t d = x >> ptrdiff_t(6);
uint64_t r = uint64_t(x) & uint64_t(63);
uint64_t mask = uint64_t(1) << r;
if (d >= std::ssize(data)) resize64(d + 1);
ptrdiff_t d = x >> ptrdiff_t(ushift);
U r = U(x) & umask;
U mask = U(1) << r;
if (d >= std::ssize(data)) resizeData(d + 1);
bool contained = ((data[d] & mask) != 0);
if (!contained) data[d] |= (mask);
return contained;
}
constexpr void uncheckedInsert(ptrdiff_t x) {
ptrdiff_t d = x >> ptrdiff_t(6);
uint64_t r = uint64_t(x) & uint64_t(63);
uint64_t mask = uint64_t(1) << r;
if (d >= std::ssize(data)) resize64(d + 1);
ptrdiff_t d = x >> ushift;
U r = U(x) & umask;
U mask = U(1) << r;
if (d >= std::ssize(data)) resizeData(d + 1);
data[d] |= (mask);
}

constexpr auto remove(ptrdiff_t x) -> bool {
ptrdiff_t d = x >> ptrdiff_t(6);
uint64_t r = uint64_t(x) & uint64_t(63);
uint64_t mask = uint64_t(1) << r;
ptrdiff_t d = x >> ushift;
U r = U(x) & umask;
U mask = U(1) << r;
bool contained = ((data[d] & mask) != 0);
if (contained) data[d] &= (~mask);
return contained;
}
static constexpr void set(uint64_t &d, ptrdiff_t r, bool b) {
uint64_t mask = uint64_t(1) << r;
static constexpr void set(U &d, ptrdiff_t r, bool b) {
U mask = U(1) << r;
if (b == ((d & mask) != 0)) return;
if (b) d |= mask;
else d &= (~mask);
}
static constexpr void set(math::MutPtrVector<uint64_t> data, ptrdiff_t x,
bool b) {
ptrdiff_t d = x >> ptrdiff_t(6);
uint64_t r = uint64_t(x) & uint64_t(63);
static constexpr void set(math::MutPtrVector<U> data, ptrdiff_t x, bool b) {
ptrdiff_t d = x >> ushift;
U r = U(x) & umask;
set(data[d], r, b);
}

class Reference {
[[no_unique_address]] math::MutPtrVector<uint64_t> data;
[[no_unique_address]] math::MutPtrVector<U> data;
[[no_unique_address]] ptrdiff_t i;

public:
constexpr explicit Reference(math::MutPtrVector<uint64_t> dd, ptrdiff_t ii)
constexpr explicit Reference(math::MutPtrVector<U> dd, ptrdiff_t ii)
: data(dd), i(ii) {}
constexpr explicit operator bool() const { return contains(data, i); }
constexpr auto operator=(bool b) -> Reference & {
Expand All @@ -227,7 +230,7 @@ template <Collection T = math::Vector<uint64_t, 1>> struct BitSet {
}
constexpr auto operator[](ptrdiff_t i) -> Reference {
maybeResize(i + 1);
math::MutPtrVector<uint64_t> d{data};
math::MutPtrVector<U> d{data};
return Reference{d, i};
}
[[nodiscard]] constexpr auto size() const -> ptrdiff_t {
Expand All @@ -246,25 +249,25 @@ template <Collection T = math::Vector<uint64_t, 1>> struct BitSet {
}
constexpr void setUnion(const BitSet &bs) {
ptrdiff_t O = std::ssize(bs.data), N = std::ssize(data);
if (O > N) resize64(O);
if (O > N) resizeData(O);
for (ptrdiff_t i = 0; i < O; ++i) {
uint64_t d = data[i] | bs.data[i];
U d = data[i] | bs.data[i];
data[i] = d;
}
}
constexpr auto operator&=(const BitSet &bs) -> BitSet & {
if (std::ssize(bs.data) < std::ssize(data)) resize64(std::ssize(bs.data));
if (std::ssize(bs.data) < std::ssize(data)) resizeData(std::ssize(bs.data));
for (ptrdiff_t i = 0; i < std::ssize(data); ++i) data[i] &= bs.data[i];
return *this;
}
// &!
constexpr auto operator-=(const BitSet &bs) -> BitSet & {
if (std::ssize(bs.data) < std::ssize(data)) resize64(std::ssize(bs.data));
if (std::ssize(bs.data) < std::ssize(data)) resizeData(std::ssize(bs.data));
for (ptrdiff_t i = 0; i < std::ssize(data); ++i) data[i] &= (~bs.data[i]);
return *this;
}
constexpr auto operator|=(const BitSet &bs) -> BitSet & {
if (std::ssize(bs.data) > std::ssize(data)) resize64(std::ssize(bs.data));
if (std::ssize(bs.data) > std::ssize(data)) resizeData(std::ssize(bs.data));
for (ptrdiff_t i = 0; i < std::ssize(bs.data); ++i) data[i] |= bs.data[i];
return *this;
}
Expand Down Expand Up @@ -315,7 +318,7 @@ template <typename T, typename B = BitSet<>> struct BitSliceView {
[[no_unique_address]] const B &i;
struct Iterator {
[[no_unique_address]] math::MutPtrVector<T> a;
[[no_unique_address]] BitSetIterator it;
[[no_unique_address]] BitSetIterator<uint64_t> it;
constexpr auto operator==(EndSentinel) const -> bool {
return it == EndSentinel{};
}
Expand All @@ -336,7 +339,7 @@ template <typename T, typename B = BitSet<>> struct BitSliceView {
constexpr auto begin() -> Iterator { return {a, i.begin()}; }
struct ConstIterator {
[[no_unique_address]] math::PtrVector<T> a;
[[no_unique_address]] BitSetIterator it;
[[no_unique_address]] BitSetIterator<uint64_t> it;
constexpr auto operator==(EndSentinel) const -> bool {
return it == EndSentinel{};
}
Expand Down
17 changes: 17 additions & 0 deletions test/bitset_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,20 @@ TEST(FixedSizeBitSetTest, BasicAssertions) {
EXPECT_EQ(sv[0], 4);
EXPECT_EQ(sv[1], 10);
}
// NOLINTNEXTLINE(modernize-use-trailing-return-type)
TEST(FixedSizeSmallBitSetTest, BasicAssertions) {
static_assert(sizeof(BitSet<std::array<uint16_t, 1>>) == 2);
BitSet<std::array<uint16_t, 1>> bs;
bs[4] = true;
bs[10] = true;
bs[7] = true;
bs.insert(5);
EXPECT_EQ(bs.data[0], 1200);
Vector<size_t> sv;
for (auto i : bs) sv.push_back(i);
EXPECT_EQ(sv.size(), 4);
EXPECT_EQ(sv[0], 4);
EXPECT_EQ(sv[1], 5);
EXPECT_EQ(sv[2], 7);
EXPECT_EQ(sv[3], 10);
}

0 comments on commit f662145

Please sign in to comment.