Skip to content

Commit

Permalink
support options
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed May 16, 2024
1 parent 17cf7e7 commit 9d70700
Show file tree
Hide file tree
Showing 19 changed files with 283 additions and 88 deletions.
30 changes: 30 additions & 0 deletions cpp/src/arrow/dataset/file_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <unordered_set>
#include <utility>

#include "arrow/c/bridge.h"
#include "arrow/csv/options.h"
#include "arrow/csv/parser.h"
#include "arrow/csv/reader.h"
Expand Down Expand Up @@ -52,6 +53,9 @@ using internal::Executor;
using internal::SerialExecutor;

namespace dataset {
namespace {
inline bool parseBool(const std::string& value) { return value == "true" ? true : false; }
} // namespace

struct CsvInspectedFragment : public InspectedFragment {
CsvInspectedFragment(std::vector<std::string> column_names,
Expand Down Expand Up @@ -503,5 +507,31 @@ Future<> CsvFileWriter::FinishInternal() {
return Status::OK();
}

Result<std::shared_ptr<FragmentScanOptions>> CsvFragmentScanOptions::from(
const std::unordered_map<std::string, std::string>& configs) {
std::shared_ptr<CsvFragmentScanOptions> options =
std::make_shared<CsvFragmentScanOptions>();
for (auto it : configs) {
auto& key = it.first;
auto& value = it.second;
if (key == "delimiter") {
options->parse_options.delimiter = value.data()[0];
} else if (key == "quoting") {
options->parse_options.quoting = parseBool(value);
} else if (key == "ArrowSchemaAddress") {
int64_t schema_address = std::stol(value);
ArrowSchema* cSchema = reinterpret_cast<ArrowSchema*>(schema_address);
ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(cSchema));
auto column_types = options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else {
return Status::Invalid("Not support this config " + it.first);
}
}
return options;
}

} // namespace dataset
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/dataset/file_csv.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat {
struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions {
std::string type_name() const override { return kCsvTypeName; }

static Result<std::shared_ptr<FragmentScanOptions>> from(
const std::unordered_map<std::string, std::string>& configs);

using StreamWrapFunc = std::function<Result<std::shared_ptr<io::InputStream>>(
std::shared_ptr<io::InputStream>)>;

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ arrow_install_all_headers("arrow/engine")
set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extended_expression_internal.cc
substrait/extension_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
substrait/options.cc
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "arrow/dataset/file_base.h"
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/engine/substrait/extended_expression_internal.h"
#include "arrow/engine/substrait/extension_internal.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/plan_internal.h"
#include "arrow/engine/substrait/relation.h"
Expand Down Expand Up @@ -247,6 +248,13 @@ Result<BoundExpressions> DeserializeExpressions(
return FromProto(extended_expression, ext_set_out, conversion_options, registry);
}

Status DeserializeMap(const Buffer& buf,
std::unordered_map<std::string, std::string> out) {
ARROW_ASSIGN_OR_RAISE(auto advanced_extension,
ParseFromBuffer<substrait::extensions::AdvancedExtension>(buf));
return FromProto(advanced_extension, out);
}

namespace {

Result<std::shared_ptr<acero::ExecPlan>> MakeSingleDeclarationPlan(
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/engine/substrait/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ ARROW_ENGINE_EXPORT Result<BoundExpressions> DeserializeExpressions(
const ConversionOptions& conversion_options = {},
ExtensionSet* ext_set_out = NULLPTR);

ARROW_ENGINE_EXPORT Status
DeserializeMap(const Buffer& buf, std::unordered_map<std::string, std::string> out);

/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type
Expand Down
2 changes: 1 addition & 1 deletion cpp/thirdparty/versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=f989a862f694e7dbb695925ddb7c4ce06aa6c51aca
ARROW_S2N_TLS_BUILD_VERSION=v1.3.35
ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d
ARROW_THRIFT_BUILD_VERSION=0.16.0
ARROW_THRIFT_BUILD_SHA256_CHECKSUM=f460b5c1ca30d8918ff95ea3eb6291b3951cf518553566088f3f2be8981f6209
ARROW_THRIFT_BUILD_SHA256_CHECKSUM=df2931de646a366c2e5962af679018bca2395d586e00ba82d09c0379f14f8e7b
ARROW_UCX_BUILD_VERSION=1.12.1
ARROW_UCX_BUILD_SHA256_CHECKSUM=9bef31aed0e28bf1973d28d74d9ac4f8926c43ca3b7010bd22a084e164e31b71
ARROW_UTF8PROC_BUILD_VERSION=v2.7.0
Expand Down
14 changes: 14 additions & 0 deletions java/dataset/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
<arrow.cpp.build.dir>../../../cpp/release-build/</arrow.cpp.build.dir>
<parquet.version>1.13.1</parquet.version>
<avro.version>1.11.3</avro.version>
<substrait.version>0.31.0</substrait.version>
<protobuf.version>3.25.3</protobuf.version>
</properties>

<dependencies>
Expand All @@ -48,6 +50,18 @@
<groupId>org.immutables</groupId>
<artifactId>value</artifactId>
</dependency>
<dependency>
<groupId>io.substrait</groupId>
<artifactId>core</artifactId>
<version>${substrait.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
Expand Down
64 changes: 33 additions & 31 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,12 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(
}

arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
GetFragmentScanOptions(jint file_format_id, jlong schema_address) {
GetFragmentScanOptions(jint file_format_id,
const std::unordered_map<std::string, string>& configs) {
switch (file_format_id) {
#ifdef ARROW_CSV
case 3: {
std::shared_ptr<arrow::dataset::CsvFragmentScanOptions> csv_options =
std::make_shared<arrow::dataset::CsvFragmentScanOptions>();
ArrowSchema* cSchema = reinterpret_cast<ArrowSchema*>(schema_address);
auto schema = JniGetOrThrow(arrow::ImportSchema(cSchema));
auto column_types = csv_options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
return csv_options;
}
case 3:
return arrow::dataset::CsvFragmentScanOptions::from(configs);
#endif
default:
std::string error_message =
Expand Down Expand Up @@ -526,12 +518,13 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: createScanner
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJJJ)J
* Signature:
* (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ;Ljava/nio/ByteBuffer;J)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns,
jobject substrait_projection, jobject substrait_filter, jlong file_format_id,
jlong schema_address, jlong batch_size, jlong memory_pool_id) {
jobject substrait_projection, jobject substrait_filter, jlong batch_size,
jlong file_format_id, jobject options, jlong memory_pool_id) {
JNI_METHOD_START
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
if (pool == nullptr) {
Expand Down Expand Up @@ -580,10 +573,13 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
}
JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr));
}
if (file_format_id != -1 && schema_address != -1) {
std::shared_ptr<arrow::dataset::FragmentScanOptions> options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, schema_address));
JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(options));
if (file_format_id != -1 && options != nullptr) {
std::unordered_map<std::string, std::string> optionMap;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, optionMap));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetragmentScanOptions(file_format_id, optionsMap));
JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(scan_options));
}
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));

Expand Down Expand Up @@ -696,18 +692,21 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: (Ljava/lang/String;IIJ)J
* Signature: (Ljava/lang/String;IILjava/lang/String;Ljava/nio/ByteBuffer)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I(
JNIEnv* env, jobject, jstring uri, jint file_format_id, jlong schema_address) {
JNIEnv* env, jobject, jstring uri, jint file_format_id, jobject options) {
JNI_METHOD_START
std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (schema_address != -1) {
std::shared_ptr<arrow::dataset::FragmentScanOptions> options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, schema_address));
file_format->default_fragment_scan_options = options;
if (options != nullptr) {
std::unordered_map<std::string, std::string> optionMap;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, optionMap));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetragmentScanOptions(file_format_id, optionsMap));
file_format->default_fragment_scan_options = scan_options;
}
arrow::dataset::FileSystemFactoryOptions options;
std::shared_ptr<arrow::dataset::DatasetFactory> d =
Expand All @@ -720,19 +719,22 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: ([Ljava/lang/String;IIJ)J
* Signature: ([Ljava/lang/String;IIJ;Ljava/nio/ByteBuffer)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jlong schema_address) {
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jobject options) {
JNI_METHOD_START

std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (schema_address != -1) {
std::shared_ptr<arrow::dataset::FragmentScanOptions> options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, schema_address));
file_format->default_fragment_scan_options = options;
if (options != nullptr) {
std::unordered_map<std::string, std::string> optionMap;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, optionMap));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetragmentScanOptions(file_format_id, optionsMap));
file_format->default_fragment_scan_options = scan_options;
}
arrow::dataset::FileSystemFactoryOptions options;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,14 @@ public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memo
super(allocator, memoryPool, createNative(format, uris, Optional.empty()));
}

private static long getArrowSchemaAddress(Optional<FragmentScanOptions> fragmentScanOptions) {
if (fragmentScanOptions.isPresent()) {
FragmentScanOptions options = fragmentScanOptions.get();
if (options instanceof CsvFragmentScanOptions) {
return ((CsvFragmentScanOptions) options)
.getConvertOptions().getArrowSchemaAddress();
}
}

return -1;
}

private static long createNative(FileFormat format, String uri, Optional<FragmentScanOptions> fragmentScanOptions) {
long cSchemaAddress = getArrowSchemaAddress(fragmentScanOptions);
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id(), cSchemaAddress);
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id(),
fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null));
}

private static long createNative(FileFormat format, String[] uris, Optional<FragmentScanOptions> fragmentScanOptions) {
long cSchemaAddress = getArrowSchemaAddress(fragmentScanOptions);
return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id(), cSchemaAddress);
return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id(),
fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.arrow.dataset.jni.JniLoader;

import java.nio.ByteBuffer;

/**
* JniWrapper for filesystem based {@link org.apache.arrow.dataset.source.Dataset} implementations.
*/
Expand All @@ -43,7 +45,8 @@ private JniWrapper() {
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String uri, int fileFormat, long schemaAddress);
public native long makeFileSystemDatasetFactory(String uri, int fileFormat,
ByteBuffer serializedFragmentScanOptions);

/**
* Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a
Expand All @@ -54,7 +57,8 @@ private JniWrapper() {
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat, long cSchemaAddress);
public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat,
ByteBuffer serializedFragmentScanOptions);

/**
* Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private JniWrapper() {
*/
public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection,
ByteBuffer substraitFilter, long batchSize, long fileFormat,
long schemaAddress, long memoryPool);
ByteBuffer serializedFragmentScanOptions, long memoryPool);

/**
* Get a serialized schema from native instance of a Scanner.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.arrow.dataset.scanner.csv.CsvFragmentScanOptions;
import org.apache.arrow.dataset.source.Dataset;

import java.nio.ByteBuffer;

/**
* Native implementation of {@link Dataset}.
*/
Expand All @@ -42,20 +44,18 @@ public synchronized NativeScanner newScan(ScanOptions options) {
if (closed) {
throw new NativeInstanceReleasedException();
}
long cSchemaAddress = -1;
int fileFormat = -1;
ByteBuffer serialized = null;
if (options.getFragmentScanOptions().isPresent()) {
FragmentScanOptions fragmentScanOptions = options.getFragmentScanOptions().get();
if (fragmentScanOptions instanceof CsvFragmentScanOptions) {
cSchemaAddress = ((CsvFragmentScanOptions) fragmentScanOptions)
.getConvertOptions().getArrowSchemaAddress();
}
fileFormat = fragmentScanOptions.fileFormatId();
serialized = fragmentScanOptions.serialize();
}
long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null),
options.getSubstraitProjection().orElse(null),
options.getSubstraitFilter().orElse(null),
options.getBatchSize(), fileFormat, cSchemaAddress, context.getMemoryPool().getNativeInstanceId());
options.getBatchSize(), fileFormat, serialized,
context.getMemoryPool().getNativeInstanceId());

return new NativeScanner(context, scannerId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,29 @@

package org.apache.arrow.dataset.scanner;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import io.substrait.proto.AdvancedExtension;
import org.apache.arrow.dataset.substrait.StringMapNode;

import java.nio.ByteBuffer;
import java.util.Map;

public interface FragmentScanOptions {
String typeName();

int fileFormatId();

ByteBuffer serialize();

default ByteBuffer serializeMap(Map<String, String> config) {
if (config.isEmpty()) {
return null;
}
StringMapNode stringMapNode = new StringMapNode(config);
AdvancedExtension.Builder extensionBuilder = AdvancedExtension.newBuilder();
Any.Builder builder = extensionBuilder.getEnhancementBuilder();
builder.setValue(stringMapNode.toProtobuf().toByteString());
return ByteBuffer.wrap(extensionBuilder.build().toByteArray());
}
}
Loading

0 comments on commit 9d70700

Please sign in to comment.