Skip to content

Commit

Permalink
add open function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 committed Oct 24, 2023
1 parent 86f890e commit 1893583
Showing 1 changed file with 72 additions and 56 deletions.
128 changes: 72 additions & 56 deletions be/src/vec/functions/function_timestamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <cstring>
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -378,6 +379,12 @@ struct MakeDateImpl {
}
};

struct DateTruncState {
using Callback_function =
std::function<void(const ColumnPtr&, ColumnPtr& res, NullMap& null_map, size_t)>;
Callback_function callback_function;
};

template <typename DateType>
struct DateTrunc {
static constexpr auto name = "date_trunc";
Expand All @@ -396,73 +403,71 @@ struct DateTrunc {
return make_nullable(std::make_shared<DateType>());
}

static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) {
DCHECK_EQ(arguments.size(), 2);

auto null_map = ColumnUInt8::create(input_rows_count, 0);
const auto& col0 = block.get_by_position(arguments[0]).column;
bool col_const[2] = {is_column_const(*col0)};
ColumnPtr argument_columns[2] = {
col_const[0] ? static_cast<const ColumnConst&>(*col0).convert_to_full_column()
: col0};

std::tie(argument_columns[1], col_const[1]) =
unpack_if_const(block.get_by_position(arguments[1]).column);

auto datetime_column = static_cast<const ColumnType*>(argument_columns[0].get());
auto str_column = static_cast<const ColumnString*>(argument_columns[1].get());

ColumnPtr res = ColumnType::create();
DCHECK(col_const[1])
<< "the argument[1] must be const string literal, have check function in FE.";
execute_impl_right_const(datetime_column->get_data(), str_column->get_data_at(0),
static_cast<ColumnType*>(res->assume_mutable().get())->get_data(),
null_map->get_data(), input_rows_count);

block.get_by_position(result).column = ColumnNullable::create(res, std::move(null_map));
return Status::OK();
}

private:
static void execute_impl_right_const(const PaddedPODArray<ArgType>& ldata,
const StringRef& rdata, PaddedPODArray<ArgType>& res,
NullMap& null_map, size_t input_rows_count) {
res.resize(input_rows_count);
std::string lower_str(rdata.data, rdata.size);
static Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::THREAD_LOCAL) {
return Status::OK();
}
if (!context->is_col_constant(1)) {
return Status::InvalidArgument(
"date_trunc function of time unit argument must be constant.");
}
const auto& data_str = context->get_constant_col(1)->column_ptr->get_data_at(0);
std::string lower_str(data_str.data, data_str.size);
std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(),
[](unsigned char c) { return std::tolower(c); });

auto _execute_inner_loop = [&]<TimeUnit Unit>() {
for (size_t i = 0; i < input_rows_count; ++i) {
auto dt = binary_cast<ArgType, DateValueType>(ldata[i]);
null_map[i] = !dt.template datetime_trunc<Unit>();
res[i] = binary_cast<DateValueType, ArgType>(dt);
}
};

std::shared_ptr<DateTruncState> state = std::make_shared<DateTruncState>();
if (std::strncmp("year", lower_str.data(), 4) == 0) {
_execute_inner_loop.template operator()<TimeUnit::YEAR>();
state->callback_function = &execute_impl_right_const<TimeUnit::YEAR>;
} else if (std::strncmp("quarter", lower_str.data(), 7) == 0) {
_execute_inner_loop.template operator()<TimeUnit::QUARTER>();
state->callback_function = &execute_impl_right_const<TimeUnit::QUARTER>;
} else if (std::strncmp("month", lower_str.data(), 5) == 0) {
_execute_inner_loop.template operator()<TimeUnit::MONTH>();
state->callback_function = &execute_impl_right_const<TimeUnit::MONTH>;
} else if (std::strncmp("week", lower_str.data(), 4) == 0) {
_execute_inner_loop.template operator()<TimeUnit::WEEK>();
state->callback_function = &execute_impl_right_const<TimeUnit::WEEK>;
} else if (std::strncmp("day", lower_str.data(), 3) == 0) {
_execute_inner_loop.template operator()<TimeUnit::DAY>();
state->callback_function = &execute_impl_right_const<TimeUnit::DAY>;
} else if (std::strncmp("hour", lower_str.data(), 4) == 0) {
_execute_inner_loop.template operator()<TimeUnit::HOUR>();
state->callback_function = &execute_impl_right_const<TimeUnit::HOUR>;
} else if (std::strncmp("minute", lower_str.data(), 6) == 0) {
_execute_inner_loop.template operator()<TimeUnit::MINUTE>();
state->callback_function = &execute_impl_right_const<TimeUnit::MINUTE>;
} else if (std::strncmp("second", lower_str.data(), 6) == 0) {
_execute_inner_loop.template operator()<TimeUnit::SECOND>();
} else { //here maybe unreachable
for (size_t i = 0; i < input_rows_count; ++i) {
null_map[i] = 1;
auto dt = binary_cast<ArgType, DateValueType>(ldata[i]);
res[i] = binary_cast<DateValueType, ArgType>(dt);
}
state->callback_function = &execute_impl_right_const<TimeUnit::SECOND>;
} else {
return Status::RuntimeError(
"Illegal second argument column of function date_trunc. now only support "
"[second,minute,hour,day,week,month,quarter,year]");
}
context->set_function_state(scope, state);
return Status::OK();
}

static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) {
DCHECK_EQ(arguments.size(), 2);

auto null_map = ColumnUInt8::create(input_rows_count, 0);
const auto& datetime_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
ColumnPtr res = ColumnType::create(input_rows_count);
auto* state = reinterpret_cast<DateTruncState*>(
context->get_function_state(FunctionContext::THREAD_LOCAL));
DCHECK(state != nullptr);
state->callback_function(datetime_column, res, null_map->get_data(), input_rows_count);
block.get_by_position(result).column = ColumnNullable::create(res, std::move(null_map));
return Status::OK();
}

private:
template <TimeUnit Unit>
static void execute_impl_right_const(const ColumnPtr& datetime_column, ColumnPtr& result_column,
NullMap& null_map, size_t input_rows_count) {
auto& data = static_cast<const ColumnType*>(datetime_column.get())->get_data();
auto& res = static_cast<ColumnType*>(result_column->assume_mutable().get())->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
auto dt = binary_cast<ArgType, DateValueType>(data[i]);
null_map[i] = !dt.template datetime_trunc<Unit>();
res[i] = binary_cast<DateValueType, ArgType>(dt);
}
}
};
Expand Down Expand Up @@ -1250,6 +1255,17 @@ class FunctionOtherTypesToDateType : public IFunction {
return Impl::get_return_type_impl(arguments);
}

Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
if constexpr (std::is_same_v<Impl, DateTrunc<DataTypeDate>> ||
std::is_same_v<Impl, DateTrunc<DataTypeDateV2>> ||
std::is_same_v<Impl, DateTrunc<DataTypeDateTime>> ||
std::is_same_v<Impl, DateTrunc<DataTypeDateTimeV2>>) {
return Impl::open(context, scope);
} else {
return Status::OK();
}
}

//TODO: add function below when we fixed be-ut.
//ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }

Expand Down

0 comments on commit 1893583

Please sign in to comment.