Skip to content

Commit

Permalink
gather_internal.h: Use int64_t instead of size_t for size factors
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecrv committed May 16, 2024
1 parent ed38b80 commit 77d292b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ struct FixedWidthTakeImpl {
static constexpr int kValueWidthInBits = ValueBitWidthConstant::value;

static void Exec(KernelContext* ctx, const ArraySpan& values, const ArraySpan& indices,
ArrayData* out_arr, size_t factor) {
ArrayData* out_arr, int64_t factor) {
#ifndef NDEBUG
int64_t bit_width = util::FixedWidthInBits(*values.type);
DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor == 1));
Expand Down Expand Up @@ -394,7 +394,7 @@ struct FixedWidthTakeImpl {

template <template <typename...> class TakeImpl, typename... Args>
void TakeIndexDispatch(KernelContext* ctx, const ArraySpan& values,
const ArraySpan& indices, ArrayData* out, size_t factor = 1) {
const ArraySpan& indices, ArrayData* out, int64_t factor = 1) {
// With the simplifying assumption that boundschecking has taken place
// already at a higher level, we can now assume that the index values are all
// non-negative. Thus, we can interpret signed integers as unsigned and avoid
Expand Down Expand Up @@ -482,9 +482,8 @@ Status FixedWidthTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
TakeIndexDispatch<FixedWidthTakeImpl,
/*ValueBitWidth=*/std::integral_constant<int, 8>,
/*OutputIsZeroInitialized=*/std::false_type,
/*WithFactor=*/std::true_type>(
ctx, values, indices, out_arr,
/*factor=*/static_cast<size_t>(byte_width));
/*WithFactor=*/std::true_type>(ctx, values, indices, out_arr,
/*factor=*/byte_width);
} else {
return Status::NotImplemented("Unsupported primitive type for take: ",
*values.type);
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/arrow/util/gather_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class GatherBaseCRTP {
template <int kValueWidthInBits, typename IndexCType, bool kWithFactor>
class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, kWithFactor>> {
public:
static_assert(kValueWidthInBits % 8 == 0);
static_assert(kValueWidthInBits >= 0 && kValueWidthInBits % 8 == 0);
static constexpr int kValueWidth = kValueWidthInBits / 8;

private:
Expand All @@ -158,12 +158,12 @@ class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, kWith
const int64_t idx_length_; // number IndexCType elements in idx_
const IndexCType* idx_;
uint8_t* out_;
size_t factor_;
int64_t factor_;

public:
void WriteValue(int64_t position) {
if constexpr (kWithFactor) {
const size_t scaled_factor = kValueWidth * factor_;
const int64_t scaled_factor = kValueWidth * factor_;
memcpy(out_ + position * scaled_factor, src_ + idx_[position] * scaled_factor,
scaled_factor);
} else {
Expand All @@ -174,7 +174,7 @@ class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, kWith

void WriteZero(int64_t position) {
if constexpr (kWithFactor) {
const size_t scaled_factor = kValueWidth * factor_;
const int64_t scaled_factor = kValueWidth * factor_;
memset(out_ + position * scaled_factor, 0, scaled_factor);
} else {
memset(out_ + position * kValueWidth, 0, kValueWidth);
Expand All @@ -183,7 +183,7 @@ class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, kWith

void WriteZeroSegment(int64_t position, int64_t length) {
if constexpr (kWithFactor) {
const size_t scaled_factor = kValueWidth * factor_;
const int64_t scaled_factor = kValueWidth * factor_;
memset(out_ + position * scaled_factor, 0, length * scaled_factor);
} else {
memset(out_ + position * kValueWidth, 0, length * kValueWidth);
Expand All @@ -192,7 +192,7 @@ class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, kWith

public:
Gather(int64_t src_length, const uint8_t* src, int64_t zero_src_offset,
int64_t idx_length, const IndexCType* idx, uint8_t* out, size_t factor)
int64_t idx_length, const IndexCType* idx, uint8_t* out, int64_t factor)
: src_length_(src_length),
src_(src),
idx_length_(idx_length),
Expand Down Expand Up @@ -239,7 +239,7 @@ class Gather</*kValueWidthInBits=*/1, IndexCType, /*kWithFactor=*/false>

public:
Gather(int64_t src_length, const uint8_t* src, int64_t src_offset, int64_t idx_length,
const IndexCType* idx, uint8_t* out, size_t factor)
const IndexCType* idx, uint8_t* out, int64_t factor)
: src_length_(src_length),
src_(src),
src_offset_(src_offset),
Expand Down

0 comments on commit 77d292b

Please sign in to comment.