Skip to content

Commit

Permalink
Implement XTensor support in core.
Browse files Browse the repository at this point in the history
Adds support for `xt::xtensor`, `xt::xarray` and `xt::xview`, both row
and column major. This works by wrapping the internal row-major with
`xt::adapt`. Therefore, the `T` in `xt::xtensor<T, ...>` must be scalar
(trivial).
  • Loading branch information
1uc committed May 23, 2024
1 parent ce46f86 commit 3192c1d
Show file tree
Hide file tree
Showing 5 changed files with 479 additions and 6 deletions.
212 changes: 212 additions & 0 deletions include/highfive/xtensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#pragma once

#include "bits/H5Inspector_decl.hpp"
#include "H5Exception.hpp"

#include <xtensor/xtensor.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xadapt.hpp>

namespace HighFive {
namespace details {

template <class XTensor>
struct xtensor_get_rank;

template <typename T, size_t N, xt::layout_type L>
struct xtensor_get_rank<xt::xtensor<T, N, L>> {
static constexpr size_t value = N;
};

template <class EC, size_t N, xt::layout_type L, class Tag>
struct xtensor_get_rank<xt::xtensor_adaptor<EC, N, L, Tag>> {
static constexpr size_t value = N;
};

template <class Derived, class XTensorType, xt::layout_type L>
struct xtensor_inspector_base {
using type = XTensorType;
using value_type = typename type::value_type;
using base_type = typename inspector<value_type>::base_type;
using hdf5_type = base_type;

static_assert(std::is_same<value_type, base_type>::value,
"HighFive's XTensor support only works for scalar elements.");

static constexpr bool IsConstExprRowMajor = L == xt::layout_type::row_major;
static constexpr bool is_trivially_copyable = IsConstExprRowMajor &&
std::is_trivially_copyable<value_type>::value &&
inspector<value_type>::is_trivially_copyable;

static constexpr bool is_trivially_nestable = false;

static size_t getRank(const type& val) {
// Non-scalar elements are not supported.
return val.shape().size();
}

static const value_type& getAnyElement(const type& val) {
return val.unchecked(0);
}

static value_type& getAnyElement(type& val) {
return val.unchecked(0);
}

static std::vector<size_t> getDimensions(const type& val) {
auto shape = val.shape();
return {shape.begin(), shape.end()};
}

static void prepare(type& val, const std::vector<size_t>& dims) {
val.resize(Derived::shapeFromDims(dims));
}

static hdf5_type* data(type& val) {
if (!is_trivially_copyable) {
throw DataSetException("Invalid used of `inspector<xarray>::data`.");
}

if (val.size() == 0) {
return nullptr;
}

return inspector<value_type>::data(getAnyElement(val));
}

static const hdf5_type* data(const type& val) {
if (!is_trivially_copyable) {
throw DataSetException("Invalid used of `inspector<xarray>::data`.");
}

if (val.size() == 0) {
return nullptr;
}

return inspector<value_type>::data(getAnyElement(val));
}

static void serialize(const type& val, const std::vector<size_t>& dims, hdf5_type* m) {
// since we only support scalar types we know all dims belong to us.
size_t size = compute_total_size(dims);
xt::adapt(m, size, xt::no_ownership(), dims) = val;
}

static void unserialize(const hdf5_type* vec_align,
const std::vector<size_t>& dims,
type& val) {
// since we only support scalar types we know all dims belong to us.
size_t size = compute_total_size(dims);
val = xt::adapt(vec_align, size, xt::no_ownership(), dims);
}
};

template <class XTensorType, xt::layout_type L>
struct xtensor_inspector
: public xtensor_inspector_base<xtensor_inspector<XTensorType, L>, XTensorType, L> {
private:
using super = xtensor_inspector_base<xtensor_inspector<XTensorType, L>, XTensorType, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;

static constexpr size_t ndim = xtensor_get_rank<XTensorType>::value;
static constexpr size_t min_ndim = ndim + inspector<value_type>::min_ndim;
static constexpr size_t max_ndim = ndim + inspector<value_type>::max_ndim;

static std::array<size_t, ndim> shapeFromDims(const std::vector<size_t>& dims) {
std::array<size_t, ndim> shape;
std::copy(dims.cbegin(), dims.cend(), shape.begin());
return shape;
}
};

template <class XArrayType, xt::layout_type L>
struct xarray_inspector
: public xtensor_inspector_base<xarray_inspector<XArrayType, L>, XArrayType, L> {
private:
using super = xtensor_inspector_base<xarray_inspector<XArrayType, L>, XArrayType, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;

static constexpr size_t min_ndim = 0 + inspector<value_type>::min_ndim;
static constexpr size_t max_ndim = 1024 + inspector<value_type>::max_ndim;

static const std::vector<size_t>& shapeFromDims(const std::vector<size_t>& dims) {
return dims;
}
};

template <typename T, size_t N, xt::layout_type L>
struct inspector<xt::xtensor<T, N, L>>: public xtensor_inspector<xt::xtensor<T, N, L>, L> {
private:
using super = xtensor_inspector<xt::xtensor<T, N, L>, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <typename T, xt::layout_type L>
struct inspector<xt::xarray<T, L>>: public xarray_inspector<xt::xarray<T, L>, L> {
private:
using super = xarray_inspector<xt::xarray<T, L>, L>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <typename CT, class... S>
struct inspector<xt::xview<CT, S...>>
: public xarray_inspector<xt::xview<CT, S...>, xt::layout_type::any> {
private:
using super = xarray_inspector<xt::xview<CT, S...>, xt::layout_type::any>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};


template <class EC, xt::layout_type L, class SC, class Tag>
struct inspector<xt::xarray_adaptor<EC, L, SC, Tag>>
: public xarray_inspector<xt::xarray_adaptor<EC, L, SC, Tag>, xt::layout_type::any> {
private:
using super = xarray_inspector<xt::xarray_adaptor<EC, L, SC, Tag>, xt::layout_type::any>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <class EC, size_t N, xt::layout_type L, class Tag>
struct inspector<xt::xtensor_adaptor<EC, N, L, Tag>>
: public xtensor_inspector<xt::xtensor_adaptor<EC, N, L, Tag>, xt::layout_type::any> {
private:
using super = xtensor_inspector<xt::xtensor_adaptor<EC, N, L, Tag>, xt::layout_type::any>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

} // namespace details
} // namespace HighFive
8 changes: 6 additions & 2 deletions tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if(MSVC)
endif()

## Base tests
foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string)
foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string test_xtensor)
add_executable(${test_name} "${test_name}.cpp")
target_link_libraries(${test_name} HighFive HighFiveWarnings HighFiveFlags Catch2::Catch2WithMain)
target_link_libraries(${test_name} HighFiveOptionalDependencies)
Expand Down Expand Up @@ -47,7 +47,7 @@ endif()
# test succeeds if it compiles.
file(GLOB public_headers LIST_DIRECTORIES false RELATIVE ${PROJECT_SOURCE_DIR}/include CONFIGURE_DEPENDS ${PROJECT_SOURCE_DIR}/include/highfive/*.hpp)
foreach(PUBLIC_HEADER ${public_headers})
if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN)
if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN)
continue()
endif()

Expand All @@ -67,6 +67,10 @@ foreach(PUBLIC_HEADER ${public_headers})
continue()
endif()

if(PUBLIC_HEADER STREQUAL "highfive/xtensor.hpp" AND NOT HIGHFIVE_TEST_XTENSOR)
continue()
endif()

get_filename_component(CLASS_NAME ${PUBLIC_HEADER} NAME_WE)
configure_file(tests_import_public_headers.cpp "tests_${CLASS_NAME}.cpp" @ONLY)
add_executable("tests_include_${CLASS_NAME}" "${CMAKE_CURRENT_BINARY_DIR}/tests_${CLASS_NAME}.cpp")
Expand Down
99 changes: 95 additions & 4 deletions tests/unit/data_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
#include <highfive/span.hpp>
#endif

#ifdef HIGHFIVE_TEST_XTENSOR
#include <highfive/xtensor.hpp>
#endif


namespace HighFive {
namespace testing {

std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
template <class Dims>
std::vector<size_t> lstrip(const Dims& indices, size_t n) {
std::vector<size_t> subindices(indices.size() - n);
for (size_t i = 0; i < subindices.size(); ++i) {
subindices[i] = indices[i + n];
Expand All @@ -34,7 +39,8 @@ std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
return subindices;
}

size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
template <class Dims>
size_t ravel(std::vector<size_t>& indices, const Dims& dims) {
size_t rank = dims.size();
size_t linear_index = 0;
size_t ld = 1;
Expand All @@ -47,7 +53,8 @@ size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
return linear_index;
}

std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
template <class Dims>
std::vector<size_t> unravel(size_t flat_index, const Dims& dims) {
size_t rank = dims.size();
size_t ld = 1;
std::vector<size_t> indices(rank);
Expand All @@ -60,7 +67,8 @@ std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
return indices;
}

static size_t flat_size(const std::vector<size_t>& dims) {
template <class Dims>
static size_t flat_size(const Dims& dims) {
size_t n = 1;
for (auto d: dims) {
n *= d;
Expand Down Expand Up @@ -388,6 +396,7 @@ struct ContainerTraits<boost::numeric::ublas::matrix<T>> {

#endif

// -- Eigen -------------------------------------------------------------------
#if HIGHFIVE_TEST_EIGEN

template <typename EigenType>
Expand Down Expand Up @@ -525,6 +534,88 @@ struct ContainerTraits<Eigen::Map<PlainObjectType, MapOptions>>
};


#endif

// -- XTensor -----------------------------------------------------------------

#if HIGHFIVE_TEST_XTENSOR
template <typename XTensorType, size_t Rank>
struct XTensorContainerTraits {
using container_type = XTensorType;
using value_type = typename container_type::value_type;
using base_type = typename ContainerTraits<value_type>::base_type;

static constexpr size_t rank = Rank;
static constexpr bool is_view = ContainerTraits<value_type>::is_view;

static void set(container_type& array,
const std::vector<size_t>& indices,
const base_type& value) {
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
return ContainerTraits<value_type>::set(array[local_indices], lstrip(indices, rank), value);
}

static base_type get(const container_type& array, const std::vector<size_t>& indices) {
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
return ContainerTraits<value_type>::get(array[local_indices], lstrip(indices, rank));
}

static void assign(container_type& dst, const container_type& src) {
dst = src;
}

static container_type allocate(const std::vector<size_t>& dims) {
const auto& local_dims = details::inspector<XTensorType>::shapeFromDims(dims);
auto array = container_type(local_dims);

size_t n_elements = flat_size(local_dims);
for (size_t i = 0; i < n_elements; ++i) {
auto element = ContainerTraits<value_type>::allocate(lstrip(dims, rank));
set(array, unravel(i, local_dims), element);
}

return array;
}

static void deallocate(container_type& array, const std::vector<size_t>& dims) {
auto local_dims = std::vector<size_t>(dims.begin(), dims.begin() + rank);
size_t n_elements = flat_size(local_dims);
for (size_t i_flat = 0; i_flat < n_elements; ++i_flat) {
auto indices = unravel(i_flat, local_dims);
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
ContainerTraits<value_type>::deallocate(array[local_indices], lstrip(dims, rank));
}
}

static void sanitize_dims(std::vector<size_t>& dims, size_t axis) {
ContainerTraits<value_type>::sanitize_dims(dims, axis + rank);
}
};

template <class T, size_t rank, xt::layout_type layout>
struct ContainerTraits<xt::xtensor<T, rank, layout>>
: public XTensorContainerTraits<xt::xtensor<T, rank, layout>, rank> {
private:
using super = XTensorContainerTraits<xt::xtensor<T, rank, layout>, rank>;

public:
using container_type = typename super::container_type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
};

template <class T, xt::layout_type layout>
struct ContainerTraits<xt::xarray<T, layout>>
: public XTensorContainerTraits<xt::xarray<T, layout>, 2> {
private:
using super = XTensorContainerTraits<xt::xarray<T, layout>, 2>;

public:
using container_type = typename super::container_type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
};

#endif

template <class T, class C>
Expand Down
Loading

0 comments on commit 3192c1d

Please sign in to comment.