From f360e58603836deaba6b3a7de973fa7432eac3a4 Mon Sep 17 00:00:00 2001 From: chriselrod Date: Sun, 16 Jul 2023 01:24:59 -0400 Subject: [PATCH] Fewer decltype(auto) --- include/Math/Math.hpp | 129 ++++++++++++++---------------------------- 1 file changed, 44 insertions(+), 85 deletions(-) diff --git a/include/Math/Math.hpp b/include/Math/Math.hpp index 100a310..58b6044 100644 --- a/include/Math/Math.hpp +++ b/include/Math/Math.hpp @@ -51,10 +51,8 @@ template struct ElementwiseUnaryOp { using value_type = typename A::value_type; [[no_unique_address]] Op op; [[no_unique_address]] A a; - constexpr auto operator[](ptrdiff_t i) const -> decltype(auto) { - return op(a[i]); - } - constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const -> decltype(auto) { + constexpr auto operator[](ptrdiff_t i) const { return op(a[i]); } + constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const { return op(a(i, j)); } @@ -65,16 +63,10 @@ template struct ElementwiseUnaryOp { [[nodiscard]] constexpr auto view() const { return *this; }; }; // scalars broadcast -constexpr auto get(const auto &A, ptrdiff_t) -> decltype(auto) { return A; } -constexpr auto get(const auto &A, ptrdiff_t, ptrdiff_t) -> decltype(auto) { - return A; -} -constexpr auto get(const AbstractVector auto &A, ptrdiff_t i) - -> decltype(auto) { - return A[i]; -} -constexpr auto get(const AbstractMatrix auto &A, ptrdiff_t i, ptrdiff_t j) - -> decltype(auto) { +constexpr auto get(const auto &A, ptrdiff_t) { return A; } +constexpr auto get(const auto &A, ptrdiff_t, ptrdiff_t) { return A; } +constexpr auto get(const AbstractVector auto &A, ptrdiff_t i) { return A[i]; } +constexpr auto get(const AbstractMatrix auto &A, ptrdiff_t i, ptrdiff_t j) { return A(i, j); } @@ -118,7 +110,7 @@ template struct ElementwiseVectorBinaryOp { [[no_unique_address]] B b; constexpr ElementwiseVectorBinaryOp(Op _op, A _a, B _b) : op(_op), a(_a), b(_b) {} - constexpr auto operator[](ptrdiff_t i) const -> decltype(auto) { + constexpr auto operator[](ptrdiff_t i) const { return op(get(a, i), get(b, i)); } [[nodiscard]] constexpr auto size() const -> ptrdiff_t { @@ -142,7 +134,7 @@ template struct ElementwiseMatrixBinaryOp { [[no_unique_address]] B b; constexpr ElementwiseMatrixBinaryOp(Op _op, A _a, B _b) : op(_op), a(_a), b(_b) {} - constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const -> decltype(auto) { + constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const { return op(get(a, i, j), get(b, i, j)); } [[nodiscard]] constexpr auto numRow() const -> Row { @@ -259,49 +251,11 @@ template constexpr auto view(const Array &x) { } static_assert(!AbstractMatrix>); -struct Negate { - constexpr auto operator()(const auto &x) const -> decltype(-x) { return -x; } -}; -struct Abs { - constexpr auto operator()(const auto &x) const -> decltype(abs(x)) { - return abs(x); - } -}; -struct Plus { - constexpr auto operator()(const auto &x, const auto &y) const - -> decltype(x + y) { - return x + y; - } -}; -struct Minus { - constexpr auto operator()(const auto &x, const auto &y) const - -> decltype(x - y) { - return x - y; - } -}; -struct Mul { - constexpr auto operator()(const auto &x, const auto &y) const - -> decltype(x * y) { - return x * y; - } -}; -struct Div { - constexpr auto operator()(const auto &x, const auto &y) const - -> decltype(x / y) { - return x / y; - } -}; -struct Modulus { - constexpr auto operator()(const auto &x, const auto &y) const - -> decltype(x % y) { - return x % y; - } -}; - // static_assert(std::is_trivially_copyable_v>); static_assert(std::is_trivially_copyable_v< - ElementwiseUnaryOp>>); -static_assert(Trivial>>); + ElementwiseUnaryOp, StridedVector>>); +static_assert( + Trivial, StridedVector>>); constexpr auto allMatch(const AbstractVector auto &x0, const AbstractVector auto &x1) -> bool { @@ -395,8 +349,9 @@ static_assert(std::copy_constructible>); static_assert(std::is_trivially_copyable_v>); static_assert(Trivial>); static_assert(Trivial); -static_assert(TriviallyCopyable); -static_assert(Trivial, int>>); +static_assert(TriviallyCopyable>); +static_assert(Trivial, + PtrMatrix, int>>); static_assert(Trivial, PtrMatrix>>); template @@ -443,13 +398,16 @@ inline auto operator<<(std::ostream &os, const T &A) -> std::ostream & { constexpr auto operator-(const AbstractVector auto &a) { auto AA{a.view()}; - return ElementwiseUnaryOp{.op = Negate{}, .a = AA}; + return ElementwiseUnaryOp, decltype(AA)>{.op = std::negate<>{}, + .a = AA}; } constexpr auto operator-(const AbstractMatrix auto &a) { auto AA{a.view()}; - return ElementwiseUnaryOp{.op = Negate{}, .a = AA}; + return ElementwiseUnaryOp, decltype(AA)>{.op = std::negate<>{}, + .a = AA}; } -static_assert(AbstractMatrix>>); +static_assert( + AbstractMatrix, PtrMatrix>>); static_assert(AbstractMatrix>); static_assert(AbstractMatrix>); @@ -469,97 +427,97 @@ constexpr auto operator*(const AbstractMatrix auto &a, } constexpr auto operator*(const AbstractVector auto &a, const AbstractVector auto &b) { - return ElementwiseVectorBinaryOp(Mul{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::multiplies<>{}, view(a), view(b)); } template S> constexpr auto operator+(S a, const M &b) { - return ElementwiseVectorBinaryOp(Plus{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::plus<>{}, view(a), view(b)); } template S> constexpr auto operator+(const M &b, S a) { - return ElementwiseVectorBinaryOp(Plus{}, view(b), view(a)); + return ElementwiseVectorBinaryOp(std::plus<>{}, view(b), view(a)); } template S> constexpr auto operator+(S a, const M &b) { - return ElementwiseMatrixBinaryOp(Plus{}, view(a), view(b)); + return ElementwiseMatrixBinaryOp(std::plus<>{}, view(a), view(b)); } template S> constexpr auto operator+(const M &b, S a) { - return ElementwiseMatrixBinaryOp(Plus{}, view(b), view(a)); + return ElementwiseMatrixBinaryOp(std::plus<>{}, view(b), view(a)); } template S> constexpr auto operator-(S a, const M &b) { - return ElementwiseVectorBinaryOp(Minus{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::minus<>{}, view(a), view(b)); } template S> constexpr auto operator-(const M &b, S a) { - return ElementwiseVectorBinaryOp(Minus{}, view(b), view(a)); + return ElementwiseVectorBinaryOp(std::minus<>{}, view(b), view(a)); } template S> constexpr auto operator-(S a, const M &b) { - return ElementwiseMatrixBinaryOp(Minus{}, view(a), view(b)); + return ElementwiseMatrixBinaryOp(std::minus<>{}, view(a), view(b)); } template S> constexpr auto operator-(const M &b, S a) { - return ElementwiseMatrixBinaryOp(Minus{}, view(b), view(a)); + return ElementwiseMatrixBinaryOp(std::minus<>{}, view(b), view(a)); } template S> constexpr auto operator*(S a, const M &b) { - return ElementwiseVectorBinaryOp(Mul{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::multiplies<>{}, view(a), view(b)); } template S> constexpr auto operator*(const M &b, S a) { - return ElementwiseVectorBinaryOp(Mul{}, view(b), view(a)); + return ElementwiseVectorBinaryOp(std::multiplies<>{}, view(b), view(a)); } template S> constexpr auto operator*(S a, const M &b) { - return ElementwiseMatrixBinaryOp(Mul{}, view(a), view(b)); + return ElementwiseMatrixBinaryOp(std::multiplies<>{}, view(a), view(b)); } template S> constexpr auto operator*(const M &b, S a) { - return ElementwiseMatrixBinaryOp(Mul{}, view(b), view(a)); + return ElementwiseMatrixBinaryOp(std::multiplies<>{}, view(b), view(a)); } template S> constexpr auto operator/(S a, const M &b) { - return ElementwiseVectorBinaryOp(Div{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::divides<>{}, view(a), view(b)); } template S> constexpr auto operator/(const M &b, S a) { - return ElementwiseVectorBinaryOp(Div{}, view(b), view(a)); + return ElementwiseVectorBinaryOp(std::divides<>{}, view(b), view(a)); } template S> constexpr auto operator/(S a, const M &b) { - return ElementwiseMatrixBinaryOp(Div{}, view(a), view(b)); + return ElementwiseMatrixBinaryOp(std::divides<>{}, view(a), view(b)); } template S> constexpr auto operator/(const M &b, S a) { - return ElementwiseMatrixBinaryOp(Div{}, view(b), view(a)); + return ElementwiseMatrixBinaryOp(std::divides<>{}, view(b), view(a)); } constexpr auto operator+(const AbstractVector auto &a, const AbstractVector auto &b) { - return ElementwiseVectorBinaryOp(Plus{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::plus<>{}, view(a), view(b)); } constexpr auto operator+(const AbstractMatrix auto &a, const AbstractMatrix auto &b) { - return ElementwiseMatrixBinaryOp(Plus{}, view(a), view(b)); + return ElementwiseMatrixBinaryOp(std::plus<>{}, view(a), view(b)); } constexpr auto operator-(const AbstractVector auto &a, const AbstractVector auto &b) { - return ElementwiseVectorBinaryOp(Minus{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::minus<>{}, view(a), view(b)); } constexpr auto operator-(const AbstractMatrix auto &a, const AbstractMatrix auto &b) { - return ElementwiseMatrixBinaryOp(Minus{}, view(a), view(b)); + return ElementwiseMatrixBinaryOp(std::minus<>{}, view(a), view(b)); } constexpr auto operator/(const AbstractVector auto &a, const AbstractVector auto &b) { - return ElementwiseVectorBinaryOp(Div{}, view(a), view(b)); + return ElementwiseVectorBinaryOp(std::divides<>{}, view(a), view(b)); } // constexpr auto operator*(AbstractMatrix auto &A, AbstractVector auto &x) { @@ -568,7 +526,8 @@ constexpr auto operator/(const AbstractVector auto &a, // return MatMul{.a = AA, .b = xx}; // } static_assert( - AbstractMatrix, int>>, + AbstractMatrix< + ElementwiseMatrixBinaryOp, PtrMatrix, int>>, "ElementwiseBinaryOp isa AbstractMatrix failed"); static_assert(