From 1d78377d709c5067a59f8568460ae4785bcf4ed5 Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Wed, 24 Apr 2024 20:25:46 -0300 Subject: [PATCH] Add pure xtensor fft implementation --- .gitignore | 1 + CMakeLists.txt | 6 ++ docs/source/api/container_index.rst | 1 + include/xtensor/xbroadcast.hpp | 23 +++++ include/xtensor/xfunction.hpp | 36 +++++++ include/xtensor/xgenerator.hpp | 15 +++ include/xtensor/xmath.hpp | 1 + include/xtensor/xsemantic.hpp | 37 +++++++ include/xtensor/xutils.hpp | 147 ++++++++++++++++++++++++++++ test/CMakeLists.txt | 1 + 10 files changed, 268 insertions(+) diff --git a/.gitignore b/.gitignore index 80fa14348..0f4ce3a73 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,4 @@ __pycache__ # Generated files *.pc +.vscode/settings.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 92c886dfc..67c58f083 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,6 +140,7 @@ set(XTENSOR_HEADERS ${XTENSOR_INCLUDE_DIR}/xtensor/xfixed.hpp ${XTENSOR_INCLUDE_DIR}/xtensor/xfunction.hpp ${XTENSOR_INCLUDE_DIR}/xtensor/xfunctor_view.hpp + ${XTENSOR_INCLUDE_DIR}/xtensor/xfft.hpp ${XTENSOR_INCLUDE_DIR}/xtensor/xgenerator.hpp ${XTENSOR_INCLUDE_DIR}/xtensor/xhistogram.hpp ${XTENSOR_INCLUDE_DIR}/xtensor/xindex_view.hpp @@ -199,6 +200,7 @@ target_link_libraries(xtensor INTERFACE xtl) OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF) OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF) +OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON) OPTION(BUILD_TESTS "xtensor test suite" OFF) OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF) OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF) @@ -219,6 +221,10 @@ if(XTENSOR_CHECK_DIMENSION) add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION) endif() +if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS) + add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS) +endif() + if(DEFAULT_COLUMN_MAJOR) add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major) endif() diff --git a/docs/source/api/container_index.rst b/docs/source/api/container_index.rst index bb3e2d724..ef8c0ba84 100644 --- a/docs/source/api/container_index.rst +++ b/docs/source/api/container_index.rst @@ -33,3 +33,4 @@ xexpression API is actually implemented in ``xstrided_container`` and ``xcontain xindex_view xfunctor_view xrepeat + xfft diff --git a/include/xtensor/xbroadcast.hpp b/include/xtensor/xbroadcast.hpp index 798b9cc9d..20b04edab 100644 --- a/include/xtensor/xbroadcast.hpp +++ b/include/xtensor/xbroadcast.hpp @@ -118,6 +118,29 @@ namespace xt return linear_end(c.expression()); } + /************************************* + * overlapping_memory_checker_traits * + *************************************/ + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_specialization_of::value>> + { + static bool check_overlap(const E& expr, const memory_range& dst_range) + { + if (expr.size() == 0) + { + return false; + } + else + { + using ChildE = std::decay_t; + return overlapping_memory_checker_traits::check_overlap(expr.expression(), dst_range); + } + } + }; + /** * @class xbroadcast * @brief Broadcasted xexpression to a specified shape. diff --git a/include/xtensor/xfunction.hpp b/include/xtensor/xfunction.hpp index 08a3dc1c1..f11362cdb 100644 --- a/include/xtensor/xfunction.hpp +++ b/include/xtensor/xfunction.hpp @@ -162,6 +162,42 @@ namespace xt { }; + /************************************* + * overlapping_memory_checker_traits * + *************************************/ + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_specialization_of::value>> + { + template = 0> + static bool check_tuple(const std::tuple&, const memory_range&) + { + return false; + } + + template = 0> + static bool check_tuple(const std::tuple& t, const memory_range& dst_range) + { + using ChildE = std::decay_t(t))>; + return overlapping_memory_checker_traits::check_overlap(std::get(t), dst_range) + || check_tuple(t, dst_range); + } + + static bool check_overlap(const E& expr, const memory_range& dst_range) + { + if (expr.size() == 0) + { + return false; + } + else + { + return check_tuple(expr.arguments(), dst_range); + } + } + }; + /************* * xfunction * *************/ diff --git a/include/xtensor/xgenerator.hpp b/include/xtensor/xgenerator.hpp index 551bb7e24..03433adca 100644 --- a/include/xtensor/xgenerator.hpp +++ b/include/xtensor/xgenerator.hpp @@ -76,6 +76,21 @@ namespace xt using size_type = std::size_t; }; + /************************************* + * overlapping_memory_checker_traits * + *************************************/ + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_specialization_of::value>> + { + static bool check_overlap(const E&, const memory_range&) + { + return false; + } + }; + /** * @class xgenerator * @brief Multidimensional function operating on indices. diff --git a/include/xtensor/xmath.hpp b/include/xtensor/xmath.hpp index 6e32df15b..3035502a5 100644 --- a/include/xtensor/xmath.hpp +++ b/include/xtensor/xmath.hpp @@ -338,6 +338,7 @@ namespace xt XTENSOR_UNARY_MATH_FUNCTOR(isfinite); XTENSOR_UNARY_MATH_FUNCTOR(isinf); XTENSOR_UNARY_MATH_FUNCTOR(isnan); + XTENSOR_UNARY_MATH_FUNCTOR(conj); } #undef XTENSOR_UNARY_MATH_FUNCTOR diff --git a/include/xtensor/xsemantic.hpp b/include/xtensor/xsemantic.hpp index 41f14951c..8aa76cfc9 100644 --- a/include/xtensor/xsemantic.hpp +++ b/include/xtensor/xsemantic.hpp @@ -217,6 +217,29 @@ namespace xt template using disable_xcontainer_semantics = typename std::enable_if::value, R>::type; + + template + class xview_semantic; + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_crtp_base_of::value>> + { + static bool check_overlap(const E& expr, const memory_range& dst_range) + { + if (expr.size() == 0) + { + return false; + } + else + { + using ChildE = std::decay_t; + return overlapping_memory_checker_traits::check_overlap(expr.expression(), dst_range); + } + } + }; + /** * @class xview_semantic * @brief Implementation of the xsemantic_base interface for @@ -598,8 +621,22 @@ namespace xt template inline auto xsemantic_base::operator=(const xexpression& e) -> derived_type& { +#ifdef XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS temporary_type tmp(e); return this->derived_cast().assign_temporary(std::move(tmp)); +#else + auto&& this_derived = this->derived_cast(); + auto memory_checker = make_overlapping_memory_checker(this_derived); + if (memory_checker.check_overlap(e.derived_cast())) + { + temporary_type tmp(e); + return this_derived.assign_temporary(std::move(tmp)); + } + else + { + return this->assign(e); + } +#endif } /************************************** diff --git a/include/xtensor/xutils.hpp b/include/xtensor/xutils.hpp index 137d0e70e..21c452489 100644 --- a/include/xtensor/xutils.hpp +++ b/include/xtensor/xutils.hpp @@ -119,6 +119,20 @@ namespace xt using type = T; }; + /*************************************** + * is_specialization_of implementation * + ***************************************/ + + template