From 1c90fc11bad74c9fc8429226f4acc11d53be10e9 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Wed, 3 Jul 2024 08:57:08 +0800 Subject: [PATCH] Optimize get_json_object Spark function using simdjson (5179) --- velox/docs/functions/spark/json.rst | 35 +++ velox/functions/sparksql/CMakeLists.txt | 3 +- velox/functions/sparksql/JsonFunctions.h | 205 ++++++++++++++++++ velox/functions/sparksql/Register.cpp | 4 + velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/JsonFunctionsTest.cpp | 123 +++++++++++ 6 files changed, 370 insertions(+), 1 deletion(-) create mode 100644 velox/docs/functions/spark/json.rst create mode 100644 velox/functions/sparksql/JsonFunctions.h create mode 100644 velox/functions/sparksql/tests/JsonFunctionsTest.cpp diff --git a/velox/docs/functions/spark/json.rst b/velox/docs/functions/spark/json.rst new file mode 100644 index 000000000000..3b90560495ab --- /dev/null +++ b/velox/docs/functions/spark/json.rst @@ -0,0 +1,35 @@ +============== +JSON Functions +============== + +JSON Format +----------- + +JSON is a language-independent data format that represents data as +human-readable text. A JSON text can represent a number, a boolean, a +string, an array, an object, or a null, with slightly different grammar. +For instance, a JSON text representing a string must escape all characters +and enclose the string in double quotes, such as ``"123\n"``, whereas a JSON +text representing a number does not need to, such as ``123``. A JSON text +representing an array must enclose the array elements in square brackets, +such as ``[1,2,3]``. More detailed grammar can be found in +`this JSON introduction`_. + +.. _this JSON introduction: https://www.json.org + +JSON Functions +-------------- + +.. spark:function:: get_json_object(jsonString, path) -> varchar + + Returns a json object, represented by VARCHAR, from ``jsonString`` by searching ``path``. + Valid ``path`` should start with '$' and then contain "[index]", "['field']" or ".field" + to define a JSON path. Here are some examples: "$.a" "$.a.b", "$[0]['a'].b". Returns + ``jsonString`` if ``path`` is "$". Returns NULL if ``jsonString`` or ``path`` is malformed. + Also returns NULL if ``path`` doesn't exist. :: + + SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b' + SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}' + SELECT get_json_object('{"a":3}', '$.b'); -- NULL (unexisting field) + SELECT get_json_object('{"a"-3}'', '$.a'); -- NULL (malformed JSON string) + SELECT get_json_object('{"a":3}'', '.a'); -- NULL (malformed JSON path) diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index eb49abd92435..585ef0f6bbb5 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -50,7 +50,8 @@ target_link_libraries( velox_functions_spark_specialforms velox_is_null_functions velox_functions_util - Folly::folly) + Folly::folly + simdjson) set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE high_memory_pool) diff --git a/velox/functions/sparksql/JsonFunctions.h b/velox/functions/sparksql/JsonFunctions.h new file mode 100644 index 000000000000..8b8c0c0790f7 --- /dev/null +++ b/velox/functions/sparksql/JsonFunctions.h @@ -0,0 +1,205 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/prestosql/json/SIMDJsonUtil.h" + +namespace facebook::velox::functions::sparksql { + +template +struct GetJsonObjectFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // ASCII input always produces ASCII result. + static constexpr bool is_default_ascii_behavior = true; + + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& /*inputTypes*/, + const core::QueryConfig& config, + const arg_type* /*json*/, + const arg_type* jsonPath) { + if (jsonPath != nullptr) { + if (checkJsonPath(*jsonPath)) { + jsonPath_ = removeSingleQuotes(*jsonPath); + } + } + } + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& json, + const arg_type& jsonPath) { + // Spark requires the first char in jsonPath is '$'. + if (!checkJsonPath(jsonPath)) { + return false; + } + // jsonPath is "$". + if (jsonPath.size() == 1) { + result.append(json); + return true; + } + simdjson::ondemand::document jsonDoc; + simdjson::padded_string paddedJson(json.data(), json.size()); + if (simdjsonParse(paddedJson).get(jsonDoc)) { + return false; + } + + auto rawResult = jsonPath_.has_value() + ? jsonDoc.at_path(jsonPath_.value().data()) + : jsonDoc.at_path(removeSingleQuotes(jsonPath)); + if (rawResult.error()) { + return false; + } + + if (!extractStringResult(rawResult, result)) { + return false; + } + + const char* currentPos; + jsonDoc.current_location().get(currentPos); + return isValidEndingCharacter(currentPos); + } + + private: + FOLLY_ALWAYS_INLINE bool checkJsonPath(StringView jsonPath) { + // Spark requires the first char in jsonPath is '$'. + if (jsonPath.size() < 1 || jsonPath.data()[0] != '$') { + return false; + } + return true; + } + + // Spark's json path requires field name surrounded by single quotes if it is + // specified in "[]". But simdjson lib requires not. This method just removes + // such single quotes, e.g., converts "['a']['b']" to "[a][b]". + std::string removeSingleQuotes(StringView jsonPath) { + // Skip the initial "$". + std::string result(jsonPath.data() + 1, jsonPath.size() - 1); + size_t pairEnd = 0; + while (true) { + auto pairBegin = result.find("['", pairEnd); + if (pairBegin == std::string::npos) { + break; + } + pairEnd = result.find("]", pairBegin); + if (pairEnd == std::string::npos || result[pairEnd - 1] != '\'') { + return "-1"; + } + result.erase(pairEnd - 1, 1); + result.erase(pairBegin + 1, 1); + pairEnd -= 2; + } + return result; + } + + // Returns true if no error. + bool extractStringResult( + simdjson::simdjson_result rawResult, + out_type& result) { + std::stringstream ss; + switch (rawResult.type()) { + // For number and bool types, we need to explicitly get the value + // for specific types instead of using `ss << rawResult`. Thus, we + // can make simdjson's internal parsing position moved and then we + // can check the validity of ending character. + case simdjson::ondemand::json_type::number: { + switch (rawResult.get_number_type()) { + case simdjson::ondemand::number_type::unsigned_integer: { + uint64_t numberResult; + if (!rawResult.get_uint64().get(numberResult)) { + ss << numberResult; + result.append(ss.str()); + return true; + } + return false; + } + case simdjson::ondemand::number_type::signed_integer: { + int64_t numberResult; + if (!rawResult.get_int64().get(numberResult)) { + ss << numberResult; + result.append(ss.str()); + return true; + } + return false; + } + case simdjson::ondemand::number_type::floating_point_number: { + double numberResult; + if (!rawResult.get_double().get(numberResult)) { + ss << rawResult; + result.append(ss.str()); + return true; + } + return false; + } + default: + VELOX_UNREACHABLE(); + } + } + case simdjson::ondemand::json_type::boolean: { + bool boolResult; + if (!rawResult.get_bool().get(boolResult)) { + result.append(boolResult ? "true" : "false"); + return true; + } + return false; + } + case simdjson::ondemand::json_type::string: { + std::string_view stringResult; + if (!rawResult.get_string().get(stringResult)) { + result.append(stringResult); + return true; + } + return false; + } + case simdjson::ondemand::json_type::object: { + // For nested case, e.g., for "{"my": {"hello": 10}}", "$.my" will + // return an object type. + ss << rawResult; + result.append(ss.str()); + return true; + } + case simdjson::ondemand::json_type::array: { + ss << rawResult; + result.append(ss.str()); + return true; + } + default: { + return false; + } + } + } + + // This is a simple validation by checking whether the obtained result is + // followed by valid char. Because ondemand parsing we are using ignores json + // format validation for characters following the current parsing position. + // As json doc is padded with NULL characters, it's safe to do recursively + // check. + bool isValidEndingCharacter(const char* currentPos) { + char endingChar = *currentPos; + if (endingChar == ',' || endingChar == '}' || endingChar == ']') { + return true; + } + // These chars can be prior to a valid ending char. + if (endingChar == ' ' || endingChar == '\r' || endingChar == '\n' || + endingChar == '\t') { + return isValidEndingCharacter(currentPos++); + } + return false; + } + + std::optional jsonPath_; +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 1978891e6923..1485d466a6e2 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -35,6 +35,7 @@ #include "velox/functions/sparksql/DateTimeFunctions.h" #include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/In.h" +#include "velox/functions/sparksql/JsonFunctions.h" #include "velox/functions/sparksql/LeastGreatest.h" #include "velox/functions/sparksql/MightContain.h" #include "velox/functions/sparksql/MonotonicallyIncreasingId.h" @@ -174,6 +175,9 @@ void registerFunctions(const std::string& prefix) { registerRegexpReplace(prefix); + registerFunction( + {prefix + "get_json_object"}); + // Register string functions. registerFunction({prefix + "chr"}); registerFunction({prefix + "ascii"}); diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index af4b4c3f66da..e3e4e823fd78 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable( ElementAtTest.cpp HashTest.cpp InTest.cpp + JsonFunctionsTest.cpp LeastGreatestTest.cpp MakeDecimalTest.cpp MakeTimestampTest.cpp diff --git a/velox/functions/sparksql/tests/JsonFunctionsTest.cpp b/velox/functions/sparksql/tests/JsonFunctionsTest.cpp new file mode 100644 index 000000000000..c0c8ecc90999 --- /dev/null +++ b/velox/functions/sparksql/tests/JsonFunctionsTest.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +#include + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class GetJsonObjectTest : public SparkFunctionBaseTest { + protected: + std::optional getJsonObject( + const std::string& json, + const std::string& jsonPath) { + return evaluateOnce( + "get_json_object(c0, c1)", + std::optional(json), + std::optional(jsonPath)); + } +}; + +TEST_F(GetJsonObjectTest, basic) { + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"hello": 3.5})", "$.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"hello": 292222730})", "$.hello"), "292222730"); + EXPECT_EQ(getJsonObject(R"({"hello": -292222730})", "$.hello"), "-292222730"); + EXPECT_EQ(getJsonObject(R"({"my": {"hello": 3.5}})", "$.my.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"my": {"hello": true}})", "$.my.hello"), "true"); + EXPECT_EQ(getJsonObject(R"({"hello": ""})", "$.hello"), ""); + EXPECT_EQ( + "0.0215434648799772", + getJsonObject(R"({"score":0.0215434648799772})", "$.score")); + // Returns input json if json path is "$". + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$"), + R"({"name": "Alice", "age": 5, "id": "001"})"); + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.age"), + "5"); + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.id"), + "001"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])", + "$[0]['my']['param']['age']"), + "5"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])", + "$[0].my.param.age"), + "5"); + + // Json object as result. + EXPECT_EQ( + getJsonObject( + R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})", + "$.my.param"), + R"({"name": "Alice", "age": "5", "id": "001"})"); + EXPECT_EQ( + getJsonObject( + R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})", + "$['my']['param']"), + R"({"name": "Alice", "age": "5", "id": "001"})"); + + // Array as result. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other"), + R"(["v1", "v2"])"); + // Array element as result. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other[0]"), + "v1"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other[1]"), + "v2"); +} + +TEST_F(GetJsonObjectTest, nullResult) { + // Field not found. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hi"), std::nullopt); + + // Illegal json. + EXPECT_EQ(getJsonObject(R"({"hello"-3.5})", "$.hello"), std::nullopt); + + // Illegal json path. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$hello"), std::nullopt); + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$."), std::nullopt); + // The first char is not '$'. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", ".hello"), std::nullopt); + // Constains '$' not in the first position. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.$hello"), std::nullopt); + + // Invalid ending character. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"quoted""}}}, {"other": ["v1", "v2"]}])", + "$[0].my.param.name"), + std::nullopt); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test