Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement XTensor support in core. #976

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions include/highfive/xtensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#pragma once

#include "bits/H5Inspector_decl.hpp"
#include "H5Exception.hpp"

#include <xtensor/xtensor.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xadapt.hpp>

namespace HighFive {
namespace details {

template <class XTensor>
struct xtensor_get_rank;

template <typename T, size_t N, xt::layout_type L>
struct xtensor_get_rank<xt::xtensor<T, N, L>> {
static constexpr size_t value = N;
};

template <class EC, size_t N, xt::layout_type L, class Tag>
struct xtensor_get_rank<xt::xtensor_adaptor<EC, N, L, Tag>> {
static constexpr size_t value = N;
};

template <class Derived, class XTensorType, xt::layout_type L>
struct xtensor_inspector_base {
using type = XTensorType;
using value_type = typename type::value_type;
using base_type = typename inspector<value_type>::base_type;
using hdf5_type = base_type;

static_assert(std::is_same<value_type, base_type>::value,
"HighFive's XTensor support only works for scalar elements.");

static constexpr bool IsConstExprRowMajor = L == xt::layout_type::row_major;
static constexpr bool is_trivially_copyable = IsConstExprRowMajor &&
std::is_trivially_copyable<value_type>::value &&
inspector<value_type>::is_trivially_copyable;

static constexpr bool is_trivially_nestable = false;

static size_t getRank(const type& val) {
// Non-scalar elements are not supported.
return val.shape().size();
}

static const value_type& getAnyElement(const type& val) {
return val.unchecked(0);
}

static value_type& getAnyElement(type& val) {
return val.unchecked(0);
}

static std::vector<size_t> getDimensions(const type& val) {
auto shape = val.shape();
return {shape.begin(), shape.end()};
}

static void prepare(type& val, const std::vector<size_t>& dims) {
val.resize(Derived::shapeFromDims(dims));
}

static hdf5_type* data(type& val) {
if (!is_trivially_copyable) {
throw DataSetException("Invalid used of `inspector<xarray>::data`.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really xarray here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, there used to be separate inspectors for xarray and xtensor. Looks like I copied these lines from the xarray variation.

}

if (val.size() == 0) {
return nullptr;
}

return inspector<value_type>::data(getAnyElement(val));
}

static const hdf5_type* data(const type& val) {
if (!is_trivially_copyable) {
throw DataSetException("Invalid used of `inspector<xarray>::data`.");
1uc marked this conversation as resolved.
Show resolved Hide resolved
}

if (val.size() == 0) {
return nullptr;
}

return inspector<value_type>::data(getAnyElement(val));
}

static void serialize(const type& val, const std::vector<size_t>& dims, hdf5_type* m) {
// since we only support scalar types we know all dims belong to us.
size_t size = compute_total_size(dims);
xt::adapt(m, size, xt::no_ownership(), dims) = val;
}

static void unserialize(const hdf5_type* vec_align,
const std::vector<size_t>& dims,
type& val) {
// since we only support scalar types we know all dims belong to us.
size_t size = compute_total_size(dims);
val = xt::adapt(vec_align, size, xt::no_ownership(), dims);
}
};

template <class XTensorType, xt::layout_type L>
struct xtensor_inspector
: public xtensor_inspector_base<xtensor_inspector<XTensorType, L>, XTensorType, L> {
private:
using super = xtensor_inspector_base<xtensor_inspector<XTensorType, L>, XTensorType, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;

static constexpr size_t ndim = xtensor_get_rank<XTensorType>::value;
static constexpr size_t min_ndim = ndim + inspector<value_type>::min_ndim;
static constexpr size_t max_ndim = ndim + inspector<value_type>::max_ndim;

static std::array<size_t, ndim> shapeFromDims(const std::vector<size_t>& dims) {
std::array<size_t, ndim> shape;
std::copy(dims.cbegin(), dims.cend(), shape.begin());
return shape;
}
};

template <class XArrayType, xt::layout_type L>
struct xarray_inspector
: public xtensor_inspector_base<xarray_inspector<XArrayType, L>, XArrayType, L> {
private:
using super = xtensor_inspector_base<xarray_inspector<XArrayType, L>, XArrayType, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;

static constexpr size_t min_ndim = 0 + inspector<value_type>::min_ndim;
static constexpr size_t max_ndim = 1024 + inspector<value_type>::max_ndim;

static const std::vector<size_t>& shapeFromDims(const std::vector<size_t>& dims) {
return dims;
}
};

template <typename T, size_t N, xt::layout_type L>
struct inspector<xt::xtensor<T, N, L>>: public xtensor_inspector<xt::xtensor<T, N, L>, L> {
private:
using super = xtensor_inspector<xt::xtensor<T, N, L>, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <typename T, xt::layout_type L>
struct inspector<xt::xarray<T, L>>: public xarray_inspector<xt::xarray<T, L>, L> {
private:
using super = xarray_inspector<xt::xarray<T, L>, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <typename CT, class... S>
struct inspector<xt::xview<CT, S...>>
: public xarray_inspector<xt::xview<CT, S...>, xt::layout_type::any> {
private:
using super = xarray_inspector<xt::xview<CT, S...>, xt::layout_type::any>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};


template <class EC, xt::layout_type L, class SC, class Tag>
struct inspector<xt::xarray_adaptor<EC, L, SC, Tag>>
: public xarray_inspector<xt::xarray_adaptor<EC, L, SC, Tag>, xt::layout_type::any> {
private:
using super = xarray_inspector<xt::xarray_adaptor<EC, L, SC, Tag>, xt::layout_type::any>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <class EC, size_t N, xt::layout_type L, class Tag>
struct inspector<xt::xtensor_adaptor<EC, N, L, Tag>>
: public xtensor_inspector<xt::xtensor_adaptor<EC, N, L, Tag>, xt::layout_type::any> {
private:
using super = xtensor_inspector<xt::xtensor_adaptor<EC, N, L, Tag>, xt::layout_type::any>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

} // namespace details
} // namespace HighFive
8 changes: 6 additions & 2 deletions tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if(MSVC)
endif()

## Base tests
foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string)
foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string test_xtensor)
add_executable(${test_name} "${test_name}.cpp")
target_link_libraries(${test_name} HighFive HighFiveWarnings HighFiveFlags Catch2::Catch2WithMain)
target_link_libraries(${test_name} HighFiveOptionalDependencies)
Expand Down Expand Up @@ -47,7 +47,7 @@ endif()
# test succeeds if it compiles.
file(GLOB public_headers LIST_DIRECTORIES false RELATIVE ${PROJECT_SOURCE_DIR}/include CONFIGURE_DEPENDS ${PROJECT_SOURCE_DIR}/include/highfive/*.hpp)
foreach(PUBLIC_HEADER ${public_headers})
if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN)
if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN)
continue()
endif()

Expand All @@ -67,6 +67,10 @@ foreach(PUBLIC_HEADER ${public_headers})
continue()
endif()

if(PUBLIC_HEADER STREQUAL "highfive/xtensor.hpp" AND NOT HIGHFIVE_TEST_XTENSOR)
continue()
endif()

get_filename_component(CLASS_NAME ${PUBLIC_HEADER} NAME_WE)
configure_file(tests_import_public_headers.cpp "tests_${CLASS_NAME}.cpp" @ONLY)
add_executable("tests_include_${CLASS_NAME}" "${CMAKE_CURRENT_BINARY_DIR}/tests_${CLASS_NAME}.cpp")
Expand Down
99 changes: 95 additions & 4 deletions tests/unit/data_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
#include <highfive/span.hpp>
#endif

#ifdef HIGHFIVE_TEST_XTENSOR
#include <highfive/xtensor.hpp>
#endif


namespace HighFive {
namespace testing {

std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
template <class Dims>
std::vector<size_t> lstrip(const Dims& indices, size_t n) {
std::vector<size_t> subindices(indices.size() - n);
for (size_t i = 0; i < subindices.size(); ++i) {
subindices[i] = indices[i + n];
Expand All @@ -34,7 +39,8 @@ std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
return subindices;
}

size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
template <class Dims>
size_t ravel(std::vector<size_t>& indices, const Dims& dims) {
size_t rank = dims.size();
size_t linear_index = 0;
size_t ld = 1;
Expand All @@ -47,7 +53,8 @@ size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
return linear_index;
}

std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
template <class Dims>
std::vector<size_t> unravel(size_t flat_index, const Dims& dims) {
size_t rank = dims.size();
size_t ld = 1;
std::vector<size_t> indices(rank);
Expand All @@ -60,7 +67,8 @@ std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
return indices;
}

static size_t flat_size(const std::vector<size_t>& dims) {
template <class Dims>
static size_t flat_size(const Dims& dims) {
size_t n = 1;
for (auto d: dims) {
n *= d;
Expand Down Expand Up @@ -388,6 +396,7 @@ struct ContainerTraits<boost::numeric::ublas::matrix<T>> {

#endif

// -- Eigen -------------------------------------------------------------------
#if HIGHFIVE_TEST_EIGEN

template <typename EigenType>
Expand Down Expand Up @@ -525,6 +534,88 @@ struct ContainerTraits<Eigen::Map<PlainObjectType, MapOptions>>
};


#endif

// -- XTensor -----------------------------------------------------------------

#if HIGHFIVE_TEST_XTENSOR
template <typename XTensorType, size_t Rank>
struct XTensorContainerTraits {
using container_type = XTensorType;
using value_type = typename container_type::value_type;
using base_type = typename ContainerTraits<value_type>::base_type;

static constexpr size_t rank = Rank;
static constexpr bool is_view = ContainerTraits<value_type>::is_view;

static void set(container_type& array,
const std::vector<size_t>& indices,
const base_type& value) {
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
return ContainerTraits<value_type>::set(array[local_indices], lstrip(indices, rank), value);
}

static base_type get(const container_type& array, const std::vector<size_t>& indices) {
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
return ContainerTraits<value_type>::get(array[local_indices], lstrip(indices, rank));
}

static void assign(container_type& dst, const container_type& src) {
dst = src;
}

static container_type allocate(const std::vector<size_t>& dims) {
const auto& local_dims = details::inspector<XTensorType>::shapeFromDims(dims);
auto array = container_type(local_dims);

size_t n_elements = flat_size(local_dims);
for (size_t i = 0; i < n_elements; ++i) {
auto element = ContainerTraits<value_type>::allocate(lstrip(dims, rank));
set(array, unravel(i, local_dims), element);
}

return array;
}

static void deallocate(container_type& array, const std::vector<size_t>& dims) {
auto local_dims = std::vector<size_t>(dims.begin(), dims.begin() + rank);
size_t n_elements = flat_size(local_dims);
for (size_t i_flat = 0; i_flat < n_elements; ++i_flat) {
auto indices = unravel(i_flat, local_dims);
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
ContainerTraits<value_type>::deallocate(array[local_indices], lstrip(dims, rank));
}
}

static void sanitize_dims(std::vector<size_t>& dims, size_t axis) {
ContainerTraits<value_type>::sanitize_dims(dims, axis + rank);
}
};

template <class T, size_t rank, xt::layout_type layout>
struct ContainerTraits<xt::xtensor<T, rank, layout>>
: public XTensorContainerTraits<xt::xtensor<T, rank, layout>, rank> {
private:
using super = XTensorContainerTraits<xt::xtensor<T, rank, layout>, rank>;

public:
using container_type = typename super::container_type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
};

template <class T, xt::layout_type layout>
struct ContainerTraits<xt::xarray<T, layout>>
: public XTensorContainerTraits<xt::xarray<T, layout>, 2> {
private:
using super = XTensorContainerTraits<xt::xarray<T, layout>, 2>;

public:
using container_type = typename super::container_type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
};

#endif

template <class T, class C>
Expand Down
Loading
Loading