Skip to content

Commit

Permalink
Fix array_join() for Date type (facebookincubator#11003)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookincubator#11003

Reviewed By: amitkdutta

Differential Revision: D62671141

Pulled By: kewang1024

fbshipit-source-id: e6a405993e3f0d9581f95b1db9c94e2162184ac5
  • Loading branch information
kewang1024 authored and facebook-github-bot committed Sep 16, 2024
1 parent 079ed31 commit 696f7d9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
17 changes: 16 additions & 1 deletion velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,17 @@ struct ArrayJoinFunction {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<velox::Array<T>>* /*arr*/,
const arg_type<Varchar>* /*delimiter*/,
const arg_type<Varchar>* /*nullReplacement*/) {
const exec::PrestoCastHooks hooks{config};
options_ = hooks.timestampToStringOptions();
VELOX_CHECK(
inputTypes[0]->isArray(),
"Array join's first parameter type has to be array");
arrayElementType_ = inputTypes[0]->asArray().elementType();
}

template <typename C>
Expand All @@ -185,6 +189,16 @@ struct ArrayJoinFunction {
result += util::Converter<TypeKind::VARCHAR>::tryCast(value).value();
}

void writeValue(out_type<velox::Varchar>& result, const int32_t& value) {
if (arrayElementType_->isDate()) {
result += util::Converter<TypeKind::VARCHAR>::tryCast(
DateType::get()->toString(value))
.value();
return;
}
result += util::Converter<TypeKind::VARCHAR>::tryCast(value).value();
}

void writeValue(out_type<velox::Varchar>& result, const Timestamp& value) {
Timestamp inputValue{value};
if (options_.timeZone) {
Expand Down Expand Up @@ -244,6 +258,7 @@ struct ArrayJoinFunction {

private:
TimestampToStringOptions options_;
TypePtr arrayElementType_;
};

/// Function Signature: combinations(array(T), n) -> array(array(T))
Expand Down
27 changes: 25 additions & 2 deletions velox/functions/prestosql/tests/ArrayJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ class ArrayJoinTest : public FunctionBaseTest {
void testArrayJoinNoReplacement(
std::vector<std::optional<T>> array,
const StringView& delimiter,
const StringView& expected) {
const StringView& expected,
bool isDate = false) {
auto arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array});
if (isDate) {
arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array}, ARRAY(DATE()));
}
auto delimiterVector = makeFlatVector<StringView>({delimiter});
auto expectedVector = makeFlatVector<StringView>({expected});
testExpr(
Expand All @@ -52,9 +57,14 @@ class ArrayJoinTest : public FunctionBaseTest {
std::vector<std::optional<T>> array,
const StringView& delimiter,
const StringView& replacement,
const StringView& expected) {
const StringView& expected,
bool isDate = false) {
auto arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array});
if (isDate) {
arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array}, ARRAY(DATE()));
}
auto delimiterVector = makeFlatVector<StringView>({delimiter});
auto replacementVector = makeFlatVector<StringView>({replacement});
auto expectedVector = makeFlatVector<StringView>({expected});
Expand Down Expand Up @@ -137,4 +147,17 @@ TEST_F(ArrayJoinTest, timestampTest) {
"1970-01-04 12:33:03.000~<absent>~1970-02-03 12:33:03.000"_sv);
}

TEST_F(ArrayJoinTest, dateTest) {
std::cout << std::nextafter(0.67777777, INFINITY);
setLegacyCast(false);
testArrayJoinNoReplacement<int32_t>(
{-7204, std::nullopt, -7203}, ","_sv, "1950-04-12,1950-04-13"_sv, true);
testArrayJoinReplacement<int32_t>(
{-7204, std::nullopt, -7203},
","_sv,
"."_sv,
"1950-04-12,.,1950-04-13"_sv,
true);
}

} // namespace

0 comments on commit 696f7d9

Please sign in to comment.