Skip to content

Commit

Permalink
Merge branch 'ttg-device-support-master-coro' of github.com:devreal/t…
Browse files Browse the repository at this point in the history
…tg into ttg-device-support-master-coro
  • Loading branch information
therault committed Mar 8, 2023
2 parents 61aff73 + e07ebe5 commit dffa4a2
Show file tree
Hide file tree
Showing 21 changed files with 2,410 additions and 456 deletions.
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ if (TARGET MADworld)
message(STATUS "MADNESS_FOUND=1")
endif(TARGET MADworld)


##########################
#### CUDA
##########################
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
endif(CMAKE_CUDA_COMPILER)
set(TTG_HAVE_CUDA ${CMAKE_CUDA_COMPILER} CACHE BOOL "True if TTG supports compiling .cu files")


##########################
#### Examples
##########################
Expand Down
40 changes: 40 additions & 0 deletions examples/madness/mrattg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,46 @@ auto make_project(functorT& f,
return ttg::make_tt(F, edges(fuse(refine, ctl)), edges(refine, result), name, {"control"}, {"refine", "result"});
}


/// Returns an std::unique_ptr to the object
template <typename functorT, typename T, size_t K, Dimension NDIM>
auto make_project_device(functorT& f,
const T thresh, /// should be scalar value not complex
ctlEdge<NDIM>& ctl, rnodeEdge<T, K, NDIM>& result, const std::string& name = "project") {
auto F = [f, thresh](const Key<NDIM>& key, std::tuple<ctlOut<NDIM>, rnodeOut<T, K, NDIM>>& out) {
FunctionReconstructedNode<T, K, NDIM> node(key); // Our eventual result
auto& coeffs = node.coeffs; // Need to clean up OO design
bool is_leaf;

if (key.level() < initial_level(f)) {
for (auto child : children(key)) ttg::sendk<0>(child, out);
coeffs = T(1e7); // set to obviously bad value to detect incorrect use
is_leaf = false;
} else if (is_negligible<functorT, T, NDIM>(f, Domain<NDIM>::template bounding_box<T>(key),
truncate_tol(key, thresh))) {
coeffs = T(0.0);
is_leaf = true;
} else {
auto node_view = ttg::make_view(node, ttg::ViewScope::Out); // no need to move node onto the device
auto is_leaf_view = ttg::make_view(is_leaf, ttg::ViewScope::Out);
co_await ttg::device::wait_views{};
fcoeffs<functorT, T, K>(f, key, thresh,
node_view.get_device_ptr<0>(),
is_leaf_view.get_device_ptr<0>()); // cannot deduce K
co_await ttg::device::wait_kernel{};
if (!is_leaf) {
for (auto child : children(key)) ttg::sendk<0>(child, out); // should be broadcast ?
}
}
node.is_leaf = is_leaf;
ttg::send<1>(key, node, out); // always produce a result
};
ctlEdge<NDIM> refine("refine");
return ttg::make_tt(F, edges(fuse(refine, ctl)), edges(refine, result), name, {"control"}, {"refine", "result"});
}



namespace detail {
template <typename T, size_t K, Dimension NDIM>
struct tree_types {};
Expand Down
18 changes: 10 additions & 8 deletions tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ include(AddTTGExecutable)

# TT unit test: core TTG ops
set(ut_src
fibonacci.cc
device_coro.cc
ranges.cc
tt.cc
unit_main.cpp)
#fibonacci.cc
#ranges.cc
#tt.cc
unit_main.cpp
)
set(ut_libs Catch2::Catch2)
if (TARGET std::coroutine)
list(APPEND ut_src fibonacci-coro.cc)
#if (TARGET std::coroutine)
#list(APPEND ut_src fibonacci-coro.cc)
list(APPEND ut_src device_coro.cc)
list(APPEND ut_src cuda_kernel.cu)
list(APPEND ut_libs std::coroutine)
endif()
#endif()
add_ttg_executable(core-unittests-ttg "${ut_src}" LINK_LIBRARIES "${ut_libs}")


Expand Down
Loading

0 comments on commit dffa4a2

Please sign in to comment.