From 9b3a0d97d7507b7ab452cc9e66cc0fcf6cd8c8f0 Mon Sep 17 00:00:00 2001 From: divya2108 Date: Thu, 17 Oct 2024 14:25:27 +0530 Subject: [PATCH] Resolved unit test failures --- src/common/hist_util.cc | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 1986a7e4277f..a3b46c6eac68 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -189,6 +189,32 @@ class GHistBuildingManager { }; #ifdef XGBOOST_SVE_COMPILER_SUPPORT +template +__attribute__((target("arch=armv8-a+sve"))) +inline svuint32_t load_index_vec(svbool_t pg, BinIdxType *d) { + std::cout << "Missing template for type " << typeid(BinIdxType).name() << std::endl; + assert(0); + return svindex_u32(0, 2); // dummy +} + +template <> +__attribute__((target("arch=armv8-a+sve"))) +inline svuint32_t load_index_vec(svbool_t pg, const uint32_t *d) { + return svld1(pg, d); +} + +template <> +__attribute__((target("arch=armv8-a+sve"))) +inline svuint32_t load_index_vec(svbool_t pg, const uint16_t *d) { + return svld1uh_u32(pg, d); +} + +template <> +__attribute__((target("arch=armv8-a+sve"))) +inline svuint32_t load_index_vec(svbool_t pg, const uint8_t *d) { + return svld1ub_u32(pg, d); +} + template __attribute__((target("arch=armv8-a+sve"))) inline void UpdateHistogramWithSVE(size_t row_size, const BinIdxType *gr_index_local, @@ -206,14 +232,12 @@ inline void UpdateHistogramWithSVE(size_t row_size, const BinIdxType *gr_index_l svbool_t pg64_upper = svwhilelt_b64(j+svcntd(), row_size); // Load the gradient index values and offsets for the current chunk of the row - svuint32_t gr_index_vec = - svld1ub_u32(pg32, reinterpret_cast(&gr_index_local[j])); - svuint32_t offsets_vec = svld1(pg32, &offsets[j]); - + svuint32_t gr_index_vec = load_index_vec(pg32, &gr_index_local[j]); svuint32_t idx_bin_vec; if (kAnyMissing) { idx_bin_vec = svmul_n_u32_x(pg32, gr_index_vec, two); } else { + svuint32_t offsets_vec = svld1(pg32, &offsets[j]); svuint32_t temp = svadd_u32_m(pg32, gr_index_vec, offsets_vec); idx_bin_vec = svmul_n_u32_x(pg32, temp, two); } @@ -341,7 +365,7 @@ void RowsWiseBuildHistKernel(Span gpair, Span