From 5d6c6f3a04bdcdf94157d140dbee3c1db512de34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Sun, 4 Aug 2024 18:03:21 +0800 Subject: [PATCH] [GLUTEN-6588][CH] Cast columns if necessary before finally writing to ORC/Parquet files during native inserting (#6691) * [GLUTEN-6588][CH] Cast columns if necessary before finally writing to ORC/Parquet files during native inserting * fix style * fix style * fix conflicts and remove examples * fix style * fix style --- .../datasources/CHDatasourceJniWrapper.java | 2 +- .../v1/CHFormatWriterInjects.scala | 14 +++++- ...lutenClickHouseNativeWriteTableSuite.scala | 10 ++--- cpp-ch/local-engine/Common/CHUtil.cpp | 5 ++- cpp-ch/local-engine/Common/CHUtil.h | 3 +- .../Storages/Output/FileWriterWrappers.cpp | 28 +++++++++--- .../Storages/Output/FileWriterWrappers.h | 12 +++--- .../Storages/Output/ORCOutputFormatFile.cpp | 6 +-- .../Storages/Output/ORCOutputFormatFile.h | 4 +- .../Storages/Output/OutputFormatFile.cpp | 43 +++++++++---------- .../Storages/Output/OutputFormatFile.h | 10 +++-- .../Output/ParquetOutputFormatFile.cpp | 8 ++-- .../Storages/Output/ParquetOutputFormatFile.h | 3 +- .../Storages/SubstraitSource/FormatFile.h | 9 +--- .../SubstraitSource/ORCFormatFile.cpp | 1 + cpp-ch/local-engine/local_engine_jni.cpp | 32 +++++++++----- 16 files changed, 107 insertions(+), 83 deletions(-) diff --git a/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java b/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java index 2bb3d44e0ff1..f19c5d39df1d 100644 --- a/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java +++ b/backends-clickhouse/src/main/java/org/apache/spark/sql/execution/datasources/CHDatasourceJniWrapper.java @@ -19,7 +19,7 @@ public class CHDatasourceJniWrapper { public native long nativeInitFileWriterWrapper( - String filePath, String[] preferredColumnNames, String formatHint); + String filePath, byte[] preferredSchema, String formatHint); public native long nativeInitMergeTreeWriterWrapper( byte[] plan, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala index 06d5b152716d..547904d7e037 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v1 import org.apache.gluten.execution.datasource.GlutenRowSplitter +import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized.CHColumnVector @@ -26,6 +27,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.types.StructType +import io.substrait.proto.{NamedStruct, Type} import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -39,10 +41,20 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase { val originPath = path val datasourceJniWrapper = new CHDatasourceJniWrapper(); CHThreadGroup.registerNewThreadGroup() + + val namedStructBuilder = NamedStruct.newBuilder + val structBuilder = Type.Struct.newBuilder + for (field <- dataSchema.fields) { + namedStructBuilder.addNames(field.name) + structBuilder.addTypes(ConverterUtils.getTypeNode(field.dataType, field.nullable).toProtobuf) + } + namedStructBuilder.setStruct(structBuilder.build) + var namedStruct = namedStructBuilder.build + val instance = datasourceJniWrapper.nativeInitFileWriterWrapper( path, - dataSchema.fieldNames, + namedStruct.toByteArray, getFormatName()); new OutputWriter { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 2fec68a49216..1f99947e5b96 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -913,12 +913,10 @@ class GlutenClickHouseNativeWriteTableSuite (table_name, create_sql, insert_sql) }, (table_name, _) => - if (isSparkVersionGE("3.5")) { - compareResultsAgainstVanillaSpark( - s"select * from $table_name", - compareResult = true, - _ => {}) - } + compareResultsAgainstVanillaSpark( + s"select * from $table_name", + compareResult = true, + _ => {}) ) } } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 12bf7ed59939..d32eed92340a 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -16,11 +16,12 @@ */ #include "CHUtil.h" + #include -#include #include #include #include + #include #include #include @@ -1009,7 +1010,7 @@ void BackendInitializerUtil::init(const std::string_view plan) }); } -void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, const std::string_view plan) +void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, std::string_view plan) { std::map backend_conf_map = getBackendConfMap(plan); diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index f52812803335..785d5d6c0056 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -161,8 +161,7 @@ class BackendInitializerUtil /// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime /// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver static void init(const std::string_view plan); - static void updateConfig(const DB::ContextMutablePtr &, const std::string_view); - + static void updateConfig(const DB::ContextMutablePtr &, std::string_view); // use excel text parser inline static const std::string USE_EXCEL_PARSER = "use_excel_serialization"; diff --git a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp index fc4b3a72f75b..46edb7f30d5b 100644 --- a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp +++ b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp @@ -15,7 +15,6 @@ * limitations under the License. */ #include "FileWriterWrappers.h" -#include namespace local_engine { @@ -28,7 +27,6 @@ NormalFileWriter::NormalFileWriter(const OutputFormatFilePtr & file_, const DB:: { } - void NormalFileWriter::consume(DB::Block & block) { if (!writer) [[unlikely]] @@ -39,6 +37,22 @@ void NormalFileWriter::consume(DB::Block & block) writer = std::make_unique(*pipeline); } + /// In case input block didn't have the same types as the preferred schema, we cast the input block to the preferred schema. + /// Notice that preferred_schema is the actual file schema, which is also the data schema of current inserted table. + /// Refer to issue: https://github.com/apache/incubator-gluten/issues/6588 + size_t index = 0; + const auto & preferred_schema = file->getPreferredSchema(); + for (auto & column : block) + { + if (column.name.starts_with("__bucket_value__")) + continue; + + const auto & preferred_column = preferred_schema.getByPosition(index++); + column.column = DB::castColumn(column, preferred_column.type); + column.name = preferred_column.name; + column.type = preferred_column.type; + } + /// Although gluten will append MaterializingTransform to the end of the pipeline before native insert in most cases, there are some cases in which MaterializingTransform won't be appended. /// e.g. https://github.com/oap-project/gluten/issues/2900 /// So we need to do materialize here again to make sure all blocks passed to native writer are all materialized. @@ -54,8 +68,8 @@ void NormalFileWriter::close() writer->finish(); } -OutputFormatFilePtr create_output_format_file( - const DB::ContextPtr & context, const std::string & file_uri, const DB::Names & preferred_column_names, const std::string & format_hint) +OutputFormatFilePtr createOutputFormatFile( + const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint) { // the passed in file_uri is exactly what is expected to see in the output folder // e.g /xxx/中文/timestamp_field=2023-07-13 03%3A00%3A17.622/abc.parquet @@ -64,13 +78,13 @@ OutputFormatFilePtr create_output_format_file( Poco::URI::encode(file_uri, "", encoded); // encode the space and % seen in the file_uri Poco::URI poco_uri(encoded); auto write_buffer_builder = WriteBufferBuilderFactory::instance().createBuilder(poco_uri.getScheme(), context); - return OutputFormatFileUtil::createFile(context, write_buffer_builder, encoded, preferred_column_names, format_hint); + return OutputFormatFileUtil::createFile(context, write_buffer_builder, encoded, preferred_schema, format_hint); } std::unique_ptr createFileWriterWrapper( - const DB::ContextPtr & context, const std::string & file_uri, const DB::Names & preferred_column_names, const std::string & format_hint) + const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint) { - return std::make_unique(create_output_format_file(context, file_uri, preferred_column_names, format_hint), context); + return std::make_unique(createOutputFormatFile(context, file_uri, preferred_schema, format_hint), context); } } diff --git a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h index 57cb47e41a55..736f5a95f6bd 100644 --- a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h +++ b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h @@ -41,6 +41,7 @@ class FileWriterWrapper public: explicit FileWriterWrapper(const OutputFormatFilePtr & file_) : file(file_) { } virtual ~FileWriterWrapper() = default; + virtual void consume(DB::Block & block) = 0; virtual void close() = 0; @@ -53,10 +54,9 @@ using FileWriterWrapperPtr = std::shared_ptr; class NormalFileWriter : public FileWriterWrapper { public: - //TODO: EmptyFileReader and ConstColumnsFileReader ? - //TODO: to support complex types NormalFileWriter(const OutputFormatFilePtr & file_, const DB::ContextPtr & context_); ~NormalFileWriter() override = default; + void consume(DB::Block & block) override; void close() override; @@ -71,13 +71,13 @@ class NormalFileWriter : public FileWriterWrapper std::unique_ptr createFileWriterWrapper( const DB::ContextPtr & context, const std::string & file_uri, - const DB::Names & preferred_column_names, + const DB::Block & preferred_schema, const std::string & format_hint); -OutputFormatFilePtr create_output_format_file( +OutputFormatFilePtr createOutputFormatFile( const DB::ContextPtr & context, const std::string & file_uri, - const DB::Names & preferred_column_names, + const DB::Block & preferred_schema, const std::string & format_hint); class WriteStats : public DB::ISimpleTransform @@ -191,7 +191,7 @@ class SubstraitFileSink final : public SinkToStorage : SinkToStorage(header) , partition_id_(partition_id.empty() ? NO_PARTITION_ID : partition_id) , relative_path_(relative) - , output_format_(create_output_format_file(context, makeFilename(base_path, partition_id, relative), header.getNames(), format_hint) + , output_format_(createOutputFormatFile(context, makeFilename(base_path, partition_id, relative), header, format_hint) ->createOutputFormat(header)) { } diff --git a/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.cpp b/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.cpp index 007325a515ae..c54f2e7b33bf 100644 --- a/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.cpp @@ -27,8 +27,8 @@ ORCOutputFormatFile::ORCOutputFormatFile( DB::ContextPtr context_, const std::string & file_uri_, WriteBufferBuilderPtr write_buffer_builder_, - const std::vector & preferred_column_names_) - : OutputFormatFile(context_, file_uri_, write_buffer_builder_, preferred_column_names_) + const DB::Block & preferred_schema_) + : OutputFormatFile(context_, file_uri_, write_buffer_builder_, preferred_schema_) { } @@ -37,7 +37,7 @@ OutputFormatFile::OutputFormatPtr ORCOutputFormatFile::createOutputFormat(const auto res = std::make_shared(); res->write_buffer = write_buffer_builder->build(file_uri); - auto new_header = creatHeaderWithPreferredColumnNames(header); + auto new_header = creatHeaderWithPreferredSchema(header); // TODO: align all spark orc config with ch orc config auto format_settings = DB::getFormatSettings(context); auto output_format = std::make_shared(*(res->write_buffer), new_header, format_settings); diff --git a/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.h b/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.h index 0654f4ebcfdb..2ea197cddaa0 100644 --- a/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.h +++ b/cpp-ch/local-engine/Storages/Output/ORCOutputFormatFile.h @@ -20,8 +20,6 @@ #include "config.h" #if USE_ORC - -# include # include # include @@ -34,7 +32,7 @@ class ORCOutputFormatFile : public OutputFormatFile DB::ContextPtr context_, const std::string & file_uri_, WriteBufferBuilderPtr write_buffer_builder_, - const std::vector & preferred_column_names_); + const DB::Block & preferred_schema_); ~ORCOutputFormatFile() override = default; OutputFormatFile::OutputFormatPtr createOutputFormat(const DB::Block & header) override; diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp index e785f78aae67..1e8364c6dac2 100644 --- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp @@ -38,49 +38,48 @@ OutputFormatFile::OutputFormatFile( DB::ContextPtr context_, const std::string & file_uri_, WriteBufferBuilderPtr write_buffer_builder_, - const std::vector & preferred_column_names_) - : context(context_), file_uri(file_uri_), write_buffer_builder(write_buffer_builder_), preferred_column_names(preferred_column_names_) + const DB::Block & preferred_schema_) + : context(context_), file_uri(file_uri_), write_buffer_builder(write_buffer_builder_), preferred_schema(preferred_schema_) { } -Block OutputFormatFile::creatHeaderWithPreferredColumnNames(const Block & header) +Block OutputFormatFile::creatHeaderWithPreferredSchema(const Block & header) { - if (!preferred_column_names.empty()) + if (!preferred_schema) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "preferred_schema is empty"); + + /// Create a new header with the preferred column name and type + DB::ColumnsWithTypeAndName columns; + columns.reserve(preferred_schema.columns()); + size_t index = 0; + for (const auto & name_type : header.getNamesAndTypesList()) { - /// Create a new header with the preferred column name - DB::NamesAndTypesList names_types_list = header.getNamesAndTypesList(); - DB::ColumnsWithTypeAndName cols; - size_t index = 0; - for (const auto & name_type : header.getNamesAndTypesList()) - { - if (name_type.name.starts_with("__bucket_value__")) - continue; + if (name_type.name.starts_with("__bucket_value__")) + continue; - DB::ColumnWithTypeAndName col(name_type.type->createColumn(), name_type.type, preferred_column_names.at(index++)); - cols.emplace_back(std::move(col)); - } - assert(preferred_column_names.size() == index); - return {std::move(cols)}; + const auto & preferred_column = preferred_schema.getByPosition(index++); + ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name); + columns.emplace_back(std::move(column)); } - else - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "preferred_column_names is empty"); + assert(preferred_column_names.size() == index); + return {std::move(columns)}; } OutputFormatFilePtr OutputFormatFileUtil::createFile( DB::ContextPtr context, local_engine::WriteBufferBuilderPtr write_buffer_builder, const std::string & file_uri, - const std::vector & preferred_column_names, + const DB::Block & preferred_schema, const std::string & format_hint) { #if USE_PARQUET if (boost::to_lower_copy(file_uri).ends_with(".parquet") || "parquet" == boost::to_lower_copy(format_hint)) - return std::make_shared(context, file_uri, write_buffer_builder, preferred_column_names); + return std::make_shared(context, file_uri, write_buffer_builder, preferred_schema); #endif #if USE_ORC if (boost::to_lower_copy(file_uri).ends_with(".orc") || "orc" == boost::to_lower_copy(format_hint)) - return std::make_shared(context, file_uri, write_buffer_builder, preferred_column_names); + return std::make_shared(context, file_uri, write_buffer_builder, preferred_schema); #endif diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h index 54c6ba4cdc04..93c26d7d188b 100644 --- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h +++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h @@ -43,19 +43,21 @@ class OutputFormatFile DB::ContextPtr context_, const std::string & file_uri_, WriteBufferBuilderPtr write_buffer_builder_, - const std::vector & preferred_column_names_); + const DB::Block & prefered_schema_); virtual ~OutputFormatFile() = default; virtual OutputFormatPtr createOutputFormat(const DB::Block & header_) = 0; + virtual const DB::Block getPreferredSchema() const { return preferred_schema; } + protected: - DB::Block creatHeaderWithPreferredColumnNames(const DB::Block & header); + DB::Block creatHeaderWithPreferredSchema(const DB::Block & header); DB::ContextPtr context; std::string file_uri; WriteBufferBuilderPtr write_buffer_builder; - std::vector preferred_column_names; + DB::Block preferred_schema; }; using OutputFormatFilePtr = std::shared_ptr; @@ -66,7 +68,7 @@ class OutputFormatFileUtil DB::ContextPtr context, WriteBufferBuilderPtr write_buffer_builder_, const std::string & file_uri_, - const std::vector & preferred_column_names, + const DB::Block & prefered_schema_, const std::string & format_hint = ""); }; } diff --git a/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.cpp b/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.cpp index 6ef8b4524675..ea173b03cba5 100644 --- a/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.cpp @@ -17,10 +17,8 @@ #include "ParquetOutputFormatFile.h" #if USE_PARQUET - # include # include -# include # include # include @@ -35,8 +33,8 @@ ParquetOutputFormatFile::ParquetOutputFormatFile( DB::ContextPtr context_, const std::string & file_uri_, const WriteBufferBuilderPtr & write_buffer_builder_, - const std::vector & preferred_column_names_) - : OutputFormatFile(context_, file_uri_, write_buffer_builder_, preferred_column_names_) + const DB::Block & preferred_schema_) + : OutputFormatFile(context_, file_uri_, write_buffer_builder_, preferred_schema_) { } @@ -45,7 +43,7 @@ OutputFormatFile::OutputFormatPtr ParquetOutputFormatFile::createOutputFormat(co auto res = std::make_shared(); res->write_buffer = write_buffer_builder->build(file_uri); - auto new_header = creatHeaderWithPreferredColumnNames(header); + auto new_header = creatHeaderWithPreferredSchema(header); // TODO: align all spark parquet config with ch parquet config auto format_settings = DB::getFormatSettings(context); auto output_format = std::make_shared(*(res->write_buffer), new_header, format_settings); diff --git a/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h b/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h index 13b731600938..cc87da7da854 100644 --- a/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h +++ b/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h @@ -20,7 +20,6 @@ #if USE_PARQUET -#include #include #include @@ -33,7 +32,7 @@ class ParquetOutputFormatFile : public OutputFormatFile DB::ContextPtr context_, const std::string & file_uri_, const WriteBufferBuilderPtr & write_buffer_builder_, - const std::vector & preferred_column_names_); + const DB::Block & preferred_schema_); ~ParquetOutputFormatFile() override = default; OutputFormatFile::OutputFormatPtr createOutputFormat(const DB::Block & header) override; diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h index e67259ce59cc..8ab82f312f28 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h @@ -19,21 +19,14 @@ #include #include #include - +#include #include #include - #include - #include - #include - #include -#include - - namespace DB { diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp index d213342f6d76..1c57010751c0 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp @@ -67,6 +67,7 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he std::back_inserter(skip_stripe_indices)); format_settings.orc.skip_stripes = std::unordered_set(skip_stripe_indices.begin(), skip_stripe_indices.end()); + auto input_format = std::make_shared(*file_format->read_buffer, header, format_settings); file_format->input = input_format; return file_format; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 8807a0f63991..c4e8ec67b106 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include #include #include @@ -55,6 +54,7 @@ #include #include #include +#include #include @@ -826,23 +826,33 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_CHBlockWriterJniWrapper_nativeC } JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_nativeInitFileWriterWrapper( - JNIEnv * env, jobject, jstring file_uri_, jobjectArray names_, jstring format_hint_) + JNIEnv * env, jobject, jstring file_uri_, jbyteArray preferred_schema_, jstring format_hint_) { LOCAL_ENGINE_JNI_METHOD_START - const int num_columns = env->GetArrayLength(names_); - DB::Names names; - names.reserve(num_columns); - for (int i = 0; i < num_columns; i++) + + const auto preferred_schema_ref = local_engine::getByteArrayElementsSafe(env, preferred_schema_); + auto parse_named_struct = [&]() -> std::optional { - auto * name = static_cast(env->GetObjectArrayElement(names_, i)); - names.emplace_back(jstring2string(env, name)); - env->DeleteLocalRef(name); - } + std::string_view view{ + reinterpret_cast(preferred_schema_ref.elems()), static_cast(preferred_schema_ref.length())}; + + substrait::NamedStruct res; + bool ok = res.ParseFromString(view); + if (!ok) + return {}; + return std::move(res); + }; + + auto named_struct = parse_named_struct(); + if (!named_struct.has_value()) + throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse schema from substrait protobuf failed"); + + DB::Block preferred_schema = local_engine::TypeParser::buildBlockFromNamedStructWithoutDFS(*named_struct); const auto file_uri = jstring2string(env, file_uri_); // for HiveFileFormat, the file url may not end with .parquet, so we pass in the format as a hint const auto format_hint = jstring2string(env, format_hint_); const auto context = local_engine::QueryContextManager::instance().currentQueryContext(); - auto * writer = local_engine::createFileWriterWrapper(context, file_uri, names, format_hint).release(); + auto * writer = local_engine::createFileWriterWrapper(context, file_uri, preferred_schema, format_hint).release(); return reinterpret_cast(writer); LOCAL_ENGINE_JNI_METHOD_END(env, 0) }