diff --git a/huffman/BUILD.bazel b/huffman/BUILD.bazel index bd2610a..287a27d 100644 --- a/huffman/BUILD.bazel +++ b/huffman/BUILD.bazel @@ -5,6 +5,7 @@ cc_library( srcs = [ "src/bit.hpp", "src/code.hpp", + "src/detail/base_view.hpp", "src/detail/static_vector.hpp", "src/detail/table_node.hpp", "src/detail/table_storage.hpp", diff --git a/huffman/src/detail/base_view.hpp b/huffman/src/detail/base_view.hpp new file mode 100644 index 0000000..1e217bf --- /dev/null +++ b/huffman/src/detail/base_view.hpp @@ -0,0 +1,171 @@ +#pragma once + +#include +#include +#include +#include + +namespace gpu_deflate::huffman::detail { + +/// A view of elements cast to a base class +/// @tparam V underlying view +/// @tparam B base class +/// +template + requires std::ranges::view and + std::same_as, + std::ranges::sentinel_t> and + std::derived_from, B> +class base_view : public std::ranges::view_interface> +{ + // This is largely adapted from `transform_view` (or other views), although we + // apply some simplifications: + // * V must model `forward_range` instead of `input_range` + // * sentinel_t is the same as iterator_t + // + // https://eel.is/c++draft/range.transform + + V base_{}; + +public: + class iterator + { + using base_iterator = std::ranges::iterator_t; + base_iterator base_{}; + + public: + using iterator_category = + typename std::iterator_traits::iterator_category; + using value_type = std::ranges::range_value_t; + using reference = B&; + using pointer = B*; + using difference_type = std::ranges::range_difference_t; + + iterator() + requires std::default_initializable + = default; + constexpr iterator(base_iterator current) : base_{std::move(current)} {} + + constexpr auto base() const& noexcept -> const base_iterator& + { + return base_; + } + constexpr auto base() && -> base_iterator { return std::move(base_); } + + constexpr auto operator*() const -> reference + { + return static_cast(*base_); + } + constexpr auto operator->() const -> pointer { return &**this; } + + constexpr auto operator++() -> iterator& + { + ++base_; + return *this; + } + constexpr auto operator++(int) -> iterator + { + auto tmp = *this; + ++*this; + return tmp; + } + + constexpr auto operator--() -> iterator& + requires std::ranges::bidirectional_range + { + --base_; + return *this; + } + constexpr auto operator--(int) -> iterator + requires std::ranges::bidirectional_range + { + auto tmp = *this; + --*this; + return tmp; + } + + constexpr auto operator+=(difference_type n) -> iterator& + requires std::ranges::random_access_range + { + base_ += n; + return *this; + } + constexpr auto operator-=(difference_type n) -> iterator& + requires std::ranges::random_access_range + { + base_ -= n; + return *this; + } + + constexpr auto operator[](difference_type n) const -> reference + requires std::ranges::random_access_range + { + return static_cast(base_[n]); + } + + friend constexpr auto + operator<=>(const iterator&, const iterator&) = default; + + friend constexpr auto operator+(iterator i, difference_type n) -> iterator + requires std::ranges::random_access_range + { + return i += n; + } + friend constexpr auto operator+(difference_type n, iterator i) -> iterator + requires std::ranges::random_access_range + { + return i + n; + } + + friend constexpr auto operator-(iterator i, difference_type n) -> iterator + requires std::ranges::random_access_range + { + return i -= n; + } + friend constexpr auto + operator-(const iterator& x, const iterator& y) -> difference_type + requires std::ranges::random_access_range + { + return x.base() - y.base(); + } + }; + + base_view() + requires std::default_initializable + = default; + constexpr explicit base_view(V base) : base_{std::move(base)} {} + + constexpr auto base() const& -> V + requires std::copy_constructible + { + return base_; + } + constexpr auto base() && -> V { return std::move(base_); } + + constexpr auto begin() const -> iterator { return {base().begin()}; } + constexpr auto end() const -> iterator { return {base().end()}; } +}; + +// Use of range_adaptor_clousure requires GCC 13 or LLVM > 16.0.0 +// While this isn't necessary, it does allow a simpler declaration of +// `base_view_` type in the `table` class template. + +// namespace views { +// +// template +// struct base_raco : std::ranges::range_adaptor_closure> +//{ +// template +// constexpr auto operator()(V v) const +// -> base_view +// { +// return base_view{v} +// } +//}; +// +// template +// inline constexpr auto base = base_raco{}; +// +//} + +} // namespace gpu_deflate::huffman::detail diff --git a/huffman/src/table.hpp b/huffman/src/table.hpp index e0b7f92..5fc9eb3 100644 --- a/huffman/src/table.hpp +++ b/huffman/src/table.hpp @@ -1,5 +1,6 @@ #pragma once +#include "huffman/src/detail/base_view.hpp" #include "huffman/src/detail/table_node.hpp" #include "huffman/src/detail/table_storage.hpp" @@ -41,20 +42,18 @@ class table detail::table_storage table_; - // Create a base view member to prevent member call on the temporary created - // by views::transform + // Create a base view member to prevent a member call on the temporary view + // object // // @{ - static constexpr auto as_const_base(const node_type& node) -> const - typename node_type::encoding_type& - { - return static_cast(node); - } - using base_view_type = decltype(std::views::reverse(std::views::transform( - std::declval(), &as_const_base))); - base_view_type base_view_{ - std::views::reverse(std::views::transform(table_, &as_const_base))}; + using base_view_type = std::ranges::reverse_view>, + const typename node_type::encoding_type>>; + + base_view_type base_view_{detail::base_view< + std::ranges::ref_view>, + const typename node_type::encoding_type>{std::views::all(table_)}}; // @} @@ -107,7 +106,7 @@ class table /// Const iterator type /// - using const_iterator = decltype(std::as_const(base_view_).begin()); + using const_iterator = std::ranges::iterator_t; /// Constructs a `table` from a symbol-frequency mapping /// @tparam R sized-range of symbol-frequency 2-tuples diff --git a/huffman/test/table_find_code_test.cpp b/huffman/test/table_find_code_test.cpp index 31850d9..319f83a 100644 --- a/huffman/test/table_find_code_test.cpp +++ b/huffman/test/table_find_code_test.cpp @@ -26,46 +26,46 @@ auto main() -> int // clang-format on test("finds code in table") = [] { - static_assert('e' == (**table1.find(1_c)).symbol); - static_assert('i' == (**table1.find(01_c)).symbol); - static_assert('n' == (**table1.find(001_c)).symbol); - static_assert('q' == (**table1.find(0001_c)).symbol); - static_assert('x' == (**table1.find(00001_c)).symbol); - static_assert('\4' == (**table1.find(00000_c)).symbol); + static_assert('e' == table1.find(1_c).value()->symbol); + static_assert('i' == table1.find(01_c).value()->symbol); + static_assert('n' == table1.find(001_c).value()->symbol); + static_assert('q' == table1.find(0001_c).value()->symbol); + static_assert('x' == table1.find(00001_c).value()->symbol); + static_assert('\4' == table1.find(00000_c).value()->symbol); }; // bitsize values we compare against are derived from the code // NOLINTBEGIN(readability-magic-numbers) test("code not in table but valid prefix") = [] { - static_assert((*table1.find(0_c).error()).symbol == 'i'); - static_assert((*table1.find(0_c).error()).bitsize() == 2); + static_assert(table1.find(0_c).error()->symbol == 'i'); + static_assert(table1.find(0_c).error()->bitsize() == 2); - static_assert((*table1.find(00_c).error()).symbol == 'n'); - static_assert((*table1.find(00_c).error()).bitsize() == 3); + static_assert(table1.find(00_c).error()->symbol == 'n'); + static_assert(table1.find(00_c).error()->bitsize() == 3); - static_assert((*table1.find(000_c).error()).symbol == 'q'); - static_assert((*table1.find(000_c).error()).bitsize() == 4); + static_assert(table1.find(000_c).error()->symbol == 'q'); + static_assert(table1.find(000_c).error()->bitsize() == 4); // ordering of elements with the same bitsize is unspecified - static_assert((*table1.find(0000_c).error()).bitsize() == 5); + static_assert(table1.find(0000_c).error()->bitsize() == 5); }; test("code not in table but valid prefix, using explicit pos") = [] { constexpr auto pos1 = table1.find(0_c).error(); - static_assert((*pos1).symbol == 'i'); - static_assert((*pos1).bitsize() == 2); + static_assert(pos1->symbol == 'i'); + static_assert(pos1->bitsize() == 2); constexpr auto pos2 = table1.find(00_c, pos1).error(); - static_assert((*pos2).symbol == 'n'); - static_assert((*pos2).bitsize() == 3); + static_assert(pos2->symbol == 'n'); + static_assert(pos2->bitsize() == 3); constexpr auto pos3 = table1.find(000_c, pos2).error(); - static_assert((*pos3).symbol == 'q'); - static_assert((*pos3).bitsize() == 4); + static_assert(pos3->symbol == 'q'); + static_assert(pos3->bitsize() == 4); // ordering of elements with the same bitsize is unspecified - static_assert((*table1.find(0000_c, pos3).error()).bitsize() == 5); + static_assert(table1.find(0000_c, pos3).error()->bitsize() == 5); }; // NOLINTEND(readability-magic-numbers)