Skip to content

Commit

Permalink
Also absorb eltwise prod
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Sep 19, 2023
1 parent efdfe76 commit a5141d9
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions include/Math/Math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,21 +533,21 @@ static_assert(AbstractMatrix<Elementwise<std::negate<>, PtrMatrix<int64_t>>>);
static_assert(AbstractMatrix<Array<int64_t, SquareDims>>);
static_assert(AbstractMatrix<ManagedArray<int64_t, SquareDims>>);

constexpr auto operator*(const AbstractMatrix auto &a, const VecOrMat auto &b) {
constexpr auto operator*(const VecOrMat auto &a, const VecOrMat auto &b) {
auto AA{a.view()};
auto BB{b.view()};
if constexpr (AbstractVector<decltype(BB)>) {
invariant(ptrdiff_t(AA.numCol()) == BB.size());
return MatVecMul<decltype(AA), decltype(BB)>{.a = AA, .b = BB};
if constexpr (AbstractVector<decltype(AA)>) {
ElementwiseBinaryOp(std::multiplies<>{}, AA, BB);
} else {
invariant(ptrdiff_t(AA.numCol()) == BB.size());
return MatVecMul<decltype(AA), decltype(BB)>{.a = AA, .b = BB};
}
} else {
invariant(ptrdiff_t(AA.numCol()) == ptrdiff_t(BB.numRow()));
return MatMatMul<decltype(AA), decltype(BB)>{.a = AA, .b = BB};
}
}
constexpr auto operator*(const AbstractVector auto &a,
const AbstractVector auto &b) {
return ElementwiseBinaryOp(std::multiplies<>{}, view(a), view(b));
}
template <AbstractVector M, utils::ElementOf<M> S>
constexpr auto operator*(const M &b, S a) {
return ElementwiseBinaryOp(std::multiplies<>{}, view(b), view(a));
Expand Down

0 comments on commit a5141d9

Please sign in to comment.