Skip to content

Commit

Permalink
Fewer decltype(auto)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jul 16, 2023
1 parent 4bb131d commit f360e58
Showing 1 changed file with 44 additions and 85 deletions.
129 changes: 44 additions & 85 deletions include/Math/Math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ template <typename Op, typename A> struct ElementwiseUnaryOp {
using value_type = typename A::value_type;
[[no_unique_address]] Op op;
[[no_unique_address]] A a;
constexpr auto operator[](ptrdiff_t i) const -> decltype(auto) {
return op(a[i]);
}
constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const -> decltype(auto) {
constexpr auto operator[](ptrdiff_t i) const { return op(a[i]); }
constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const {
return op(a(i, j));
}

Expand All @@ -65,16 +63,10 @@ template <typename Op, typename A> struct ElementwiseUnaryOp {
[[nodiscard]] constexpr auto view() const { return *this; };
};
// scalars broadcast
constexpr auto get(const auto &A, ptrdiff_t) -> decltype(auto) { return A; }
constexpr auto get(const auto &A, ptrdiff_t, ptrdiff_t) -> decltype(auto) {
return A;
}
constexpr auto get(const AbstractVector auto &A, ptrdiff_t i)
-> decltype(auto) {
return A[i];
}
constexpr auto get(const AbstractMatrix auto &A, ptrdiff_t i, ptrdiff_t j)
-> decltype(auto) {
constexpr auto get(const auto &A, ptrdiff_t) { return A; }
constexpr auto get(const auto &A, ptrdiff_t, ptrdiff_t) { return A; }
constexpr auto get(const AbstractVector auto &A, ptrdiff_t i) { return A[i]; }
constexpr auto get(const AbstractMatrix auto &A, ptrdiff_t i, ptrdiff_t j) {
return A(i, j);
}

Expand Down Expand Up @@ -118,7 +110,7 @@ template <typename Op, Trivial A, Trivial B> struct ElementwiseVectorBinaryOp {
[[no_unique_address]] B b;
constexpr ElementwiseVectorBinaryOp(Op _op, A _a, B _b)
: op(_op), a(_a), b(_b) {}
constexpr auto operator[](ptrdiff_t i) const -> decltype(auto) {
constexpr auto operator[](ptrdiff_t i) const {
return op(get(a, i), get(b, i));
}
[[nodiscard]] constexpr auto size() const -> ptrdiff_t {
Expand All @@ -142,7 +134,7 @@ template <typename Op, Trivial A, Trivial B> struct ElementwiseMatrixBinaryOp {
[[no_unique_address]] B b;
constexpr ElementwiseMatrixBinaryOp(Op _op, A _a, B _b)
: op(_op), a(_a), b(_b) {}
constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const -> decltype(auto) {
constexpr auto operator()(ptrdiff_t i, ptrdiff_t j) const {
return op(get(a, i, j), get(b, i, j));
}
[[nodiscard]] constexpr auto numRow() const -> Row {
Expand Down Expand Up @@ -259,49 +251,11 @@ template <class T, class S> constexpr auto view(const Array<T, S> &x) {
}
static_assert(!AbstractMatrix<StridedVector<int64_t>>);

struct Negate {
constexpr auto operator()(const auto &x) const -> decltype(-x) { return -x; }
};
struct Abs {
constexpr auto operator()(const auto &x) const -> decltype(abs(x)) {
return abs(x);
}
};
struct Plus {
constexpr auto operator()(const auto &x, const auto &y) const
-> decltype(x + y) {
return x + y;
}
};
struct Minus {
constexpr auto operator()(const auto &x, const auto &y) const
-> decltype(x - y) {
return x - y;
}
};
struct Mul {
constexpr auto operator()(const auto &x, const auto &y) const
-> decltype(x * y) {
return x * y;
}
};
struct Div {
constexpr auto operator()(const auto &x, const auto &y) const
-> decltype(x / y) {
return x / y;
}
};
struct Modulus {
constexpr auto operator()(const auto &x, const auto &y) const
-> decltype(x % y) {
return x % y;
}
};

// static_assert(std::is_trivially_copyable_v<MutStridedVector<int64_t>>);
static_assert(std::is_trivially_copyable_v<
ElementwiseUnaryOp<Negate, StridedVector<int64_t>>>);
static_assert(Trivial<ElementwiseUnaryOp<Negate, StridedVector<int64_t>>>);
ElementwiseUnaryOp<std::negate<>, StridedVector<int64_t>>>);
static_assert(
Trivial<ElementwiseUnaryOp<std::negate<>, StridedVector<int64_t>>>);

constexpr auto allMatch(const AbstractVector auto &x0,
const AbstractVector auto &x1) -> bool {
Expand Down Expand Up @@ -395,8 +349,9 @@ static_assert(std::copy_constructible<PtrMatrix<int64_t>>);
static_assert(std::is_trivially_copyable_v<PtrMatrix<int64_t>>);
static_assert(Trivial<PtrMatrix<int64_t>>);
static_assert(Trivial<int>);
static_assert(TriviallyCopyable<Mul>);
static_assert(Trivial<ElementwiseMatrixBinaryOp<Mul, PtrMatrix<int64_t>, int>>);
static_assert(TriviallyCopyable<std::multiplies<>>);
static_assert(Trivial<ElementwiseMatrixBinaryOp<std::multiplies<>,
PtrMatrix<int64_t>, int>>);
static_assert(Trivial<MatMatMul<PtrMatrix<int64_t>, PtrMatrix<int64_t>>>);

template <TriviallyCopyable OP, Trivial A, Trivial B>
Expand Down Expand Up @@ -443,13 +398,16 @@ inline auto operator<<(std::ostream &os, const T &A) -> std::ostream & {

constexpr auto operator-(const AbstractVector auto &a) {
auto AA{a.view()};
return ElementwiseUnaryOp<Negate, decltype(AA)>{.op = Negate{}, .a = AA};
return ElementwiseUnaryOp<std::negate<>, decltype(AA)>{.op = std::negate<>{},
.a = AA};
}
constexpr auto operator-(const AbstractMatrix auto &a) {
auto AA{a.view()};
return ElementwiseUnaryOp<Negate, decltype(AA)>{.op = Negate{}, .a = AA};
return ElementwiseUnaryOp<std::negate<>, decltype(AA)>{.op = std::negate<>{},
.a = AA};
}
static_assert(AbstractMatrix<ElementwiseUnaryOp<Negate, PtrMatrix<int64_t>>>);
static_assert(
AbstractMatrix<ElementwiseUnaryOp<std::negate<>, PtrMatrix<int64_t>>>);
static_assert(AbstractMatrix<Array<int64_t, SquareDims>>);
static_assert(AbstractMatrix<ManagedArray<int64_t, SquareDims>>);

Expand All @@ -469,97 +427,97 @@ constexpr auto operator*(const AbstractMatrix auto &a,
}
constexpr auto operator*(const AbstractVector auto &a,
const AbstractVector auto &b) {
return ElementwiseVectorBinaryOp(Mul{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::multiplies<>{}, view(a), view(b));
}

template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator+(S a, const M &b) {
return ElementwiseVectorBinaryOp(Plus{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::plus<>{}, view(a), view(b));
}
template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator+(const M &b, S a) {
return ElementwiseVectorBinaryOp(Plus{}, view(b), view(a));
return ElementwiseVectorBinaryOp(std::plus<>{}, view(b), view(a));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator+(S a, const M &b) {
return ElementwiseMatrixBinaryOp(Plus{}, view(a), view(b));
return ElementwiseMatrixBinaryOp(std::plus<>{}, view(a), view(b));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator+(const M &b, S a) {
return ElementwiseMatrixBinaryOp(Plus{}, view(b), view(a));
return ElementwiseMatrixBinaryOp(std::plus<>{}, view(b), view(a));
}

template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator-(S a, const M &b) {
return ElementwiseVectorBinaryOp(Minus{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::minus<>{}, view(a), view(b));
}
template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator-(const M &b, S a) {
return ElementwiseVectorBinaryOp(Minus{}, view(b), view(a));
return ElementwiseVectorBinaryOp(std::minus<>{}, view(b), view(a));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator-(S a, const M &b) {
return ElementwiseMatrixBinaryOp(Minus{}, view(a), view(b));
return ElementwiseMatrixBinaryOp(std::minus<>{}, view(a), view(b));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator-(const M &b, S a) {
return ElementwiseMatrixBinaryOp(Minus{}, view(b), view(a));
return ElementwiseMatrixBinaryOp(std::minus<>{}, view(b), view(a));
}

template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator*(S a, const M &b) {
return ElementwiseVectorBinaryOp(Mul{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::multiplies<>{}, view(a), view(b));
}
template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator*(const M &b, S a) {
return ElementwiseVectorBinaryOp(Mul{}, view(b), view(a));
return ElementwiseVectorBinaryOp(std::multiplies<>{}, view(b), view(a));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator*(S a, const M &b) {
return ElementwiseMatrixBinaryOp(Mul{}, view(a), view(b));
return ElementwiseMatrixBinaryOp(std::multiplies<>{}, view(a), view(b));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator*(const M &b, S a) {
return ElementwiseMatrixBinaryOp(Mul{}, view(b), view(a));
return ElementwiseMatrixBinaryOp(std::multiplies<>{}, view(b), view(a));
}

template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator/(S a, const M &b) {
return ElementwiseVectorBinaryOp(Div{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::divides<>{}, view(a), view(b));
}
template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator/(const M &b, S a) {
return ElementwiseVectorBinaryOp(Div{}, view(b), view(a));
return ElementwiseVectorBinaryOp(std::divides<>{}, view(b), view(a));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator/(S a, const M &b) {
return ElementwiseMatrixBinaryOp(Div{}, view(a), view(b));
return ElementwiseMatrixBinaryOp(std::divides<>{}, view(a), view(b));
}
template <AbstractMatrix M, utils::ElementOf<M> S>
constexpr auto operator/(const M &b, S a) {
return ElementwiseMatrixBinaryOp(Div{}, view(b), view(a));
return ElementwiseMatrixBinaryOp(std::divides<>{}, view(b), view(a));
}

constexpr auto operator+(const AbstractVector auto &a,
const AbstractVector auto &b) {
return ElementwiseVectorBinaryOp(Plus{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::plus<>{}, view(a), view(b));
}
constexpr auto operator+(const AbstractMatrix auto &a,
const AbstractMatrix auto &b) {
return ElementwiseMatrixBinaryOp(Plus{}, view(a), view(b));
return ElementwiseMatrixBinaryOp(std::plus<>{}, view(a), view(b));
}
constexpr auto operator-(const AbstractVector auto &a,
const AbstractVector auto &b) {
return ElementwiseVectorBinaryOp(Minus{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::minus<>{}, view(a), view(b));
}
constexpr auto operator-(const AbstractMatrix auto &a,
const AbstractMatrix auto &b) {
return ElementwiseMatrixBinaryOp(Minus{}, view(a), view(b));
return ElementwiseMatrixBinaryOp(std::minus<>{}, view(a), view(b));
}

constexpr auto operator/(const AbstractVector auto &a,
const AbstractVector auto &b) {
return ElementwiseVectorBinaryOp(Div{}, view(a), view(b));
return ElementwiseVectorBinaryOp(std::divides<>{}, view(a), view(b));
}

// constexpr auto operator*(AbstractMatrix auto &A, AbstractVector auto &x) {
Expand All @@ -568,7 +526,8 @@ constexpr auto operator/(const AbstractVector auto &a,
// return MatMul<decltype(AA), decltype(xx)>{.a = AA, .b = xx};
// }
static_assert(
AbstractMatrix<ElementwiseMatrixBinaryOp<Mul, PtrMatrix<int64_t>, int>>,
AbstractMatrix<
ElementwiseMatrixBinaryOp<std::multiplies<>, PtrMatrix<int64_t>, int>>,
"ElementwiseBinaryOp isa AbstractMatrix failed");

static_assert(
Expand Down

0 comments on commit f360e58

Please sign in to comment.