Skip to content

Commit

Permalink
dispatch table: work around gcc bug
Browse files Browse the repository at this point in the history
  • Loading branch information
foolnotion committed Apr 13, 2024
1 parent 0096166 commit 05f7cdc
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions include/operon/interpreter/dispatch_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct Noop {
template<NodeType Type, typename T, std::size_t S>
static inline void DiffOp(Operon::Vector<Node> const& nodes, Backend::View<T const, S> primal, Backend::View<T, S> trace, int i, int j) {
Diff<T, Type, S>{}(nodes, primal, trace, i, j);
};
}

template<NodeType Type, typename T, std::size_t S>
static constexpr auto MakeFunctionCall() -> Dispatch::Callable<T, S>
Expand Down Expand Up @@ -151,21 +151,29 @@ namespace detail {
{ T::size() };
{ std::is_array_v<T> };
};

template<typename Tup, std::size_t N>
struct ExtractTypes {
static auto constexpr Extract() {
return []<auto... Seq>(std::index_sequence<Seq...>){
return std::make_tuple(std::tuple_element_t<Seq, Tup>{}...);
}(std::make_index_sequence<N>{});
}

using Type = decltype(Extract());
};
} // namespace detail

template<typename... Ts>
struct DispatchTable {

private:
using Tup = std::tuple<Ts...>; // make the type parameters into a tuple
static auto constexpr N = std::tuple_size_v<Tup>;

// retrieve the last type in the template parameter pack
using Lst = std::tuple_element_t<sizeof...(Ts)-1, Tup>;

using Typ = std::conditional_t<detail::ExtentsLike<Lst>, decltype([]<auto... Idx>(std::index_sequence<Idx...>){
return std::make_tuple(std::tuple_element_t<Idx, Tup>{}...);
}(std::make_index_sequence<sizeof...(Ts)-1>{})), Tup>;

using Lst = std::tuple_element_t<N-1, Tup>;
using Typ = std::conditional_t<detail::ExtentsLike<Lst>, typename detail::ExtractTypes<Tup, N-1>::Type, Tup>;
using Ext = std::conditional_t<detail::ExtentsLike<Lst>, Lst, std::index_sequence<Dispatch::DefaultBatchSize<Ts>...>>;

template<typename T, auto SZ = std::tuple_size_v<Typ>>
Expand Down Expand Up @@ -193,12 +201,12 @@ struct DispatchTable {
template<NodeType Type, typename T>
static constexpr auto MakeFunction() {
return Dispatch::MakeFunctionCall<Type, T, BatchSize<T>>();
};
}

template<NodeType Type, typename T>
static constexpr auto MakeDerivative() {
return Dispatch::MakeDiffCall<Type, T, BatchSize<T>>();
};
}

template<NodeType Type>
static constexpr auto MakeTuple()
Expand All @@ -209,7 +217,7 @@ struct DispatchTable {
std::make_tuple(MakeDerivative<Type, std::tuple_element_t<Idx, Tup>>()...)
);
}(std::index_sequence_for<Typ>{});
};
}

using TFun = decltype([]<auto... Idx>(std::index_sequence<Idx...>){
return std::make_tuple(Callable<std::tuple_element_t<Idx, Typ>>{}...);
Expand Down

0 comments on commit 05f7cdc

Please sign in to comment.