Skip to content

Commit

Permalink
columnar calculating partition id
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Dec 6, 2023
1 parent b2ae9b4 commit 4ee3179
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 29 deletions.
73 changes: 45 additions & 28 deletions cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace ErrorCodes
}
namespace local_engine
{
PartitionInfo PartitionInfo::fromSelector(DB::IColumn::Selector selector, size_t partition_num)
PartitionInfo PartitionInfo::fromSelector(const DB::IColumn::Selector & selector, size_t partition_num)
{
auto rows = selector.size();
std::vector<size_t> partition_row_idx_start_points(partition_num + 1, 0);
Expand Down Expand Up @@ -81,7 +81,7 @@ PartitionInfo RoundRobinSelectorBuilder::build(DB::Block & block)
pid = pid_selection;
pid_selection = (pid_selection + 1) % parts_num;
}
return PartitionInfo::fromSelector(std::move(result), parts_num);
return PartitionInfo::fromSelector(result, parts_num);
}

HashSelectorBuilder::HashSelectorBuilder(
Expand All @@ -100,51 +100,68 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block)
auto flatten_block = BlockUtil::flattenBlock(DB::Block(args), BlockUtil::FLAT_STRUCT_FORCE | BlockUtil::FLAT_NESTED_TABLE, true);
args = flatten_block.getColumnsWithTypeAndName();

auto & factory = DB::FunctionFactory::instance();
if (!hash_function) [[unlikely]]
{
auto & factory = DB::FunctionFactory::instance();
auto function = factory.get(hash_function_name, local_engine::SerializedPlanParser::global_context);

hash_function = function->build(args);
}

auto rows = block.rows();
DB::IColumn::Selector partition_ids;
partition_ids.reserve(rows);
auto result_type = hash_function->getResultType();
auto hash_column = hash_function->execute(args, result_type, rows, false);

if (isNothing(removeNullable(result_type)))
ColumnPtr selector;
auto hash_result_type = hash_function->getResultType();
if (isNothing(removeNullable(hash_result_type)))
{
/// TODO: implement new hash function sparkCityHash64 like sparkXxHash64 to process null literal as column more gracefully.
/// Current implementation may cause partition skew.
for (size_t i = 0; i < rows; i++)
partition_ids.emplace_back(0);
auto tmp = DataTypeUInt64().createColumn();
tmp->insertManyDefaults(rows);
selector = std::move(tmp);
}
else
{
/// UInt64 partition_id = positive_modulo(hash(args)::Int32, parts_num::UInt64)
const auto & global_context = local_engine::SerializedPlanParser::global_context;
auto hash_column = hash_function->execute(args, hash_result_type, rows);
if (hash_function_name == "sparkMurmurHash3_32")
{
auto parts_num_int32 = static_cast<Int32>(parts_num);
for (size_t i = 0; i < rows; i++)
{
// cast to int32 to be the same as the data type of the vanilla spark
auto hash_int32 = static_cast<Int32>(hash_column->get64(i));
auto res = hash_int32 % parts_num_int32;
if (res < 0)
{
res += parts_num_int32;
}
partition_ids.emplace_back(static_cast<UInt64>(res));
}
ColumnsWithTypeAndName cast_args
= {{hash_column, hash_result_type, ""},
{DataTypeString().createColumnConst(rows, "Int32"), std::make_shared<DataTypeString>(), ""}};
if (!cast_function)
cast_function = factory.get("CAST", global_context)->build(cast_args);
auto cast_column = cast_function->execute(cast_args, cast_function->getResultType(), rows);

ColumnsWithTypeAndName pmod_args
= {{cast_column, cast_function->getResultType(), ""},
{DataTypeUInt64().createColumnConst(rows, static_cast<UInt64>(parts_num)), std::make_shared<DataTypeUInt64>(), ""}};
if (!pmod_function)
pmod_function = factory.get("positiveModulo", global_context)->build(pmod_args);
selector = pmod_function->execute(pmod_args, pmod_function->getResultType(), rows);
}
else
{
for (size_t i = 0; i < rows; i++)
partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num));
/// UInt64 partition_id = assumeNotNull(modulo(hash(args), parts_num::UInt64)), assumeNotNull is used because cityHash64 may returns Nullable(UInt64)
ColumnsWithTypeAndName modulo_args
= {{hash_column, hash_result_type, ""},
{DataTypeUInt64().createColumnConst(rows, static_cast<UInt64>(parts_num)), std::make_shared<DataTypeUInt64>(), ""}};
if (!modulo_function)
modulo_function = factory.get("modulo", global_context)->build(modulo_args);
auto modulo_column = modulo_function->execute(modulo_args, modulo_function->getResultType(), rows);

ColumnsWithTypeAndName assume_notnull_args = {{modulo_column, modulo_function->getResultType(), ""}};
if (!assume_notnull_function)
assume_notnull_function = factory.get("assumeNotNull", global_context)->build(assume_notnull_args);
selector = assume_notnull_function->execute(assume_notnull_args, assume_notnull_function->getResultType(), rows);
}
}
return PartitionInfo::fromSelector(std::move(partition_ids), parts_num);

const auto * selector_col = checkAndGetColumn<ColumnUInt64>(selector.get());
if (!selector_col)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Wrong type of selector column:{} expect ColumnUInt64", selector->getName());

const DB::IColumn::Selector & partition_ids = selector_col->getData();
return PartitionInfo::fromSelector(partition_ids, parts_num);
}


Expand All @@ -164,7 +181,7 @@ PartitionInfo RangeSelectorBuilder::build(DB::Block & block)
{
DB::IColumn::Selector result;
computePartitionIdByBinarySearch(block, result);
return PartitionInfo::fromSelector(std::move(result), partition_num);
return PartitionInfo::fromSelector(result, partition_num);
}

void RangeSelectorBuilder::initSortInformation(Poco::JSON::Array::Ptr orderings)
Expand Down
10 changes: 9 additions & 1 deletion cpp-ch/local-engine/Shuffle/SelectorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct PartitionInfo
std::vector<size_t> partition_start_points;
size_t partition_num;

static PartitionInfo fromSelector(DB::IColumn::Selector selector, size_t partition_num);
static PartitionInfo fromSelector(const DB::IColumn::Selector & selector, size_t partition_num);
};

class SelectorBuilder
Expand Down Expand Up @@ -71,6 +71,14 @@ class HashSelectorBuilder : public SelectorBuilder
std::vector<size_t> exprs_index;
std::string hash_function_name;
DB::FunctionBasePtr hash_function;

/// Only used when hash function is sparkMurmurHash3_32
DB::FunctionBasePtr cast_function;
DB::FunctionBasePtr pmod_function;

/// Only used when hash function is cityHash64
DB::FunctionBasePtr modulo_function;
DB::FunctionBasePtr assume_notnull_function;
};

class RangeSelectorBuilder : public SelectorBuilder
Expand Down

0 comments on commit 4ee3179

Please sign in to comment.