Skip to content

Commit

Permalink
Resolved unit test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
divya2108 committed Oct 17, 2024
1 parent 785c900 commit 9b3a0d9
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,32 @@ class GHistBuildingManager {
};

#ifdef XGBOOST_SVE_COMPILER_SUPPORT
template <typename BinIdxType>
__attribute__((target("arch=armv8-a+sve")))
inline svuint32_t load_index_vec(svbool_t pg, BinIdxType *d) {
std::cout << "Missing template for type " << typeid(BinIdxType).name() << std::endl;
assert(0);
return svindex_u32(0, 2); // dummy
}

template <>
__attribute__((target("arch=armv8-a+sve")))
inline svuint32_t load_index_vec(svbool_t pg, const uint32_t *d) {
return svld1(pg, d);
}

template <>
__attribute__((target("arch=armv8-a+sve")))
inline svuint32_t load_index_vec(svbool_t pg, const uint16_t *d) {
return svld1uh_u32(pg, d);
}

template <>
__attribute__((target("arch=armv8-a+sve")))
inline svuint32_t load_index_vec(svbool_t pg, const uint8_t *d) {
return svld1ub_u32(pg, d);
}

template <typename BinIdxType>
__attribute__((target("arch=armv8-a+sve")))
inline void UpdateHistogramWithSVE(size_t row_size, const BinIdxType *gr_index_local,
Expand All @@ -206,14 +232,12 @@ inline void UpdateHistogramWithSVE(size_t row_size, const BinIdxType *gr_index_l
svbool_t pg64_upper = svwhilelt_b64(j+svcntd(), row_size);

// Load the gradient index values and offsets for the current chunk of the row
svuint32_t gr_index_vec =
svld1ub_u32(pg32, reinterpret_cast<const uint8_t *>(&gr_index_local[j]));
svuint32_t offsets_vec = svld1(pg32, &offsets[j]);

svuint32_t gr_index_vec = load_index_vec(pg32, &gr_index_local[j]);
svuint32_t idx_bin_vec;
if (kAnyMissing) {
idx_bin_vec = svmul_n_u32_x(pg32, gr_index_vec, two);
} else {
svuint32_t offsets_vec = svld1(pg32, &offsets[j]);
svuint32_t temp = svadd_u32_m(pg32, gr_index_vec, offsets_vec);
idx_bin_vec = svmul_n_u32_x(pg32, temp, two);
}
Expand Down Expand Up @@ -341,7 +365,7 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair, Span<bst_idx_t cons
*(hist_local + 1) += pgh_t[1];
}
#ifdef XGBOOST_SVE_COMPILER_SUPPORT
}
}
#endif
}
}
Expand Down

0 comments on commit 9b3a0d9

Please sign in to comment.