From ce72cfefb1af1929a760cb831a10602ab22b1213 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Fri, 16 Feb 2024 01:35:50 -0500 Subject: [PATCH] add BitSet::fromMask constructor, and have Printable concept also require countDigits be implemented (special casing `double`, where we have a print method that does not use it) --- include/Containers/BitSets.hpp | 1 + include/Math/Array.hpp | 83 +++++++++++++++++----------------- include/Math/ExpDual.hpp | 11 ++++- test/bitset_test.cpp | 6 ++- 4 files changed, 56 insertions(+), 45 deletions(-) diff --git a/include/Containers/BitSets.hpp b/include/Containers/BitSets.hpp index 8c941c9..61749da 100644 --- a/include/Containers/BitSets.hpp +++ b/include/Containers/BitSets.hpp @@ -102,6 +102,7 @@ template > struct BitSet { return unsigned(((N + usize - 1) >> ushift)); } constexpr explicit BitSet(ptrdiff_t N) : data{numElementsNeeded(N), 0} {} + static constexpr auto fromMask(U u) -> BitSet { return BitSet{T{u}}; } constexpr void resizeData(ptrdiff_t N) { if constexpr (CanResize) data.resize(N); else invariant(N <= std::ssize(data)); diff --git a/include/Math/Array.hpp b/include/Math/Array.hpp index 16d8946..f14e1ea 100644 --- a/include/Math/Array.hpp +++ b/include/Math/Array.hpp @@ -68,9 +68,50 @@ template >> struct ManagedArray; +template static constexpr auto maxPow10() -> size_t { + if constexpr (sizeof(T) == 1) return 3; + else if constexpr (sizeof(T) == 2) return 5; + else if constexpr (sizeof(T) == 4) return 10; + else if constexpr (std::signed_integral) return 19; + else return 20; +} + +template constexpr auto countDigits(T x) { + std::array() + 1> powers; + powers[0] = 0; + powers[1] = 10; + for (ptrdiff_t i = 2; i < std::ssize(powers); i++) + powers[i] = powers[i - 1] * 10; + std::array bits; + if constexpr (sizeof(T) == 8) { + bits = {1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 10, + 11, 11, 11, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, + 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 19, 20}; + } else if constexpr (sizeof(T) == 4) { + bits = {1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10}; + } else if constexpr (sizeof(T) == 2) { + bits = {1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5}; + } else if constexpr (sizeof(T) == 1) { + bits = {1, 1, 1, 1, 2, 2, 2, 3, 3}; + } + T digits = bits[8 * sizeof(T) - std::countl_zero(x)]; + return std::make_signed_t(digits - (x < powers[digits - 1])); +} +template constexpr auto countDigits(T x) -> T { + using U = std::make_unsigned_t; + if (x == std::numeric_limits::min()) return T(sizeof(T) == 8 ? 20 : 11); + return countDigits(U(std::abs(x))) + T{x < 0}; +} +constexpr auto countDigits(Rational x) -> ptrdiff_t { + ptrdiff_t num = countDigits(x.numerator); + return (x.denominator == 1) ? num : num + countDigits(x.denominator) + 2; +} template -concept Printable = requires(std::ostream &os, T x) { +concept Printable = std::same_as || requires(std::ostream &os, T x) { { os << x } -> std::convertible_to; + { countDigits(x) }; }; static_assert(Printable); void print_obj(std::ostream &os, Printable auto x) { os << x; }; @@ -1816,46 +1857,6 @@ inline auto operator<<(std::ostream &os, const T &A) -> std::ostream & { else B << A.t(); return printVector(os, B); } -template static constexpr auto maxPow10() -> size_t { - if constexpr (sizeof(T) == 1) return 3; - else if constexpr (sizeof(T) == 2) return 5; - else if constexpr (sizeof(T) == 4) return 10; - else if constexpr (std::signed_integral) return 19; - else return 20; -} - -template constexpr auto countDigits(T x) { - std::array() + 1> powers; - powers[0] = 0; - powers[1] = 10; - for (ptrdiff_t i = 2; i < std::ssize(powers); i++) - powers[i] = powers[i - 1] * 10; - std::array bits; - if constexpr (sizeof(T) == 8) { - bits = {1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, - 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 10, - 11, 11, 11, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, - 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 19, 20}; - } else if constexpr (sizeof(T) == 4) { - bits = {1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, - 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10}; - } else if constexpr (sizeof(T) == 2) { - bits = {1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5}; - } else if constexpr (sizeof(T) == 1) { - bits = {1, 1, 1, 1, 2, 2, 2, 3, 3}; - } - T digits = bits[8 * sizeof(T) - std::countl_zero(x)]; - return std::make_signed_t(digits - (x < powers[digits - 1])); -} -template constexpr auto countDigits(T x) { - using U = std::make_unsigned_t; - if (x == std::numeric_limits::min()) return T(sizeof(T) == 8 ? 20 : 11); - return countDigits(U(std::abs(x))) + T{x < 0}; -} -constexpr auto countDigits(Rational x) -> ptrdiff_t { - ptrdiff_t num = countDigits(x.numerator); - return (x.denominator == 1) ? num : num + countDigits(x.denominator) + 2; -} /// \brief Returns the maximum number of digits per column of a matrix. constexpr auto getMaxDigits(PtrMatrix A) -> Vector { ptrdiff_t M = ptrdiff_t(A.numRow()); diff --git a/include/Math/ExpDual.hpp b/include/Math/ExpDual.hpp index 39ef494..27dad0d 100644 --- a/include/Math/ExpDual.hpp +++ b/include/Math/ExpDual.hpp @@ -5,8 +5,15 @@ namespace poly::math { template constexpr auto smax(auto x, auto y, auto z) { double m = std::max(std::max(value(x), value(y)), value(z)); - constexpr double f = l, i = 1 / f; - return m + log(exp(f * (x - m)) + exp(f * (y - m)) + exp(f * (z - m))) * i; + static constexpr double f = l, i = 1 / f; + return m + i * log(exp(f * (x - m)) + exp(f * (y - m)) + exp(f * (z - m))); +} +template constexpr auto smax(auto w, auto x, auto y, auto z) { + double m = + std::max(std::max(value(w), value(y)), std::max(value(x), value(z))); + static constexpr double f = l, i = 1 / f; + return m + i * log(exp(f * (w - m)) + exp(f * (x - m)) + exp(f * (y - m)) + + exp(f * (z - m))); } } // namespace poly::math diff --git a/test/bitset_test.cpp b/test/bitset_test.cpp index 0e28a3a..040b627 100644 --- a/test/bitset_test.cpp +++ b/test/bitset_test.cpp @@ -84,8 +84,9 @@ TEST(FixedSizeBitSetTest, BasicAssertions) { } // NOLINTNEXTLINE(modernize-use-trailing-return-type) TEST(FixedSizeSmallBitSetTest, BasicAssertions) { - static_assert(sizeof(BitSet>) == 2); - BitSet> bs; + using SB = BitSet>; + static_assert(sizeof(SB) == 2); + SB bs; bs[4] = true; bs[10] = true; bs[7] = true; @@ -98,4 +99,5 @@ TEST(FixedSizeSmallBitSetTest, BasicAssertions) { EXPECT_EQ(sv[1], 5); EXPECT_EQ(sv[2], 7); EXPECT_EQ(sv[3], 10); + EXPECT_EQ(SB::fromMask(1200).data[0], 1200); }