Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-28866: [Java] Java Dataset API ScanOptions expansion #41646

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 107 additions & 10 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include "arrow/c/helpers.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/file_base.h"
#ifdef ARROW_CSV
#include "arrow/dataset/file_csv.h"
#endif
#include "arrow/filesystem/api.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/engine/substrait/util.h"
Expand Down Expand Up @@ -363,6 +366,63 @@ std::shared_ptr<arrow::Buffer> LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec
return buffer;
}

inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; }

/// \brief Construct FragmentScanOptions from config map
#ifdef ARROW_CSV
arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
ToCsvFragmentScanOptions(const std::unordered_map<std::string, std::string>& configs) {
std::shared_ptr<arrow::dataset::CsvFragmentScanOptions> options =
std::make_shared<arrow::dataset::CsvFragmentScanOptions>();
for (auto const& [key, value] : configs) {
if (key == "delimiter") {
options->parse_options.delimiter = value.data()[0];
} else if (key == "quoting") {
options->parse_options.quoting = ParseBool(value);
} else if (key == "column_types") {
int64_t schema_address = std::stol(value);
ArrowSchema* c_schema = reinterpret_cast<ArrowSchema*>(schema_address);
ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema));
auto& column_types = options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else if (key == "strings_can_be_null") {
options->convert_options.strings_can_be_null = ParseBool(value);
} else {
return arrow::Status::Invalid("Config " + key + " is not supported.");
}
}
return options;
}
#endif

arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
GetFragmentScanOptions(jint file_format_id,
const std::unordered_map<std::string, std::string>& configs) {
switch (file_format_id) {
#ifdef ARROW_CSV
case 3:
return ToCsvFragmentScanOptions(configs);
#endif
default:
return arrow::Status::Invalid("Illegal file format id: ", file_format_id);
}
}

std::unordered_map<std::string, std::string> ToStringMap(JNIEnv* env,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can just take the Java Map and create the C++ Map using JNI, rather than using String[]?
cc @lidavidm

jobjectArray& str_array) {
int length = env->GetArrayLength(str_array);
std::unordered_map<std::string, std::string> map;
map.reserve(length / 2);
for (int i = 0; i < length; i += 2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we do i =+2, should we ensure the length is even? And maybe reserve the map since we know the length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this logic is guaranteed by internal code logic, so I don't think we need to do the check.
Reserve is added.

auto key = reinterpret_cast<jstring>(env->GetObjectArrayElement(str_array, i));
auto value = reinterpret_cast<jstring>(env->GetObjectArrayElement(str_array, i + 1));
map[JStringToCString(env, key)] = JStringToCString(env, value);
}
return map;
}

/*
* Class: org_apache_arrow_dataset_jni_NativeMemoryPool
* Method: getDefaultMemoryPool
Expand Down Expand Up @@ -501,12 +561,13 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: createScanner
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J
* Signature:
* (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JI;[Ljava/lang/String;J)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns,
jobject substrait_projection, jobject substrait_filter,
jlong batch_size, jlong memory_pool_id) {
jobject substrait_projection, jobject substrait_filter, jlong batch_size,
jint file_format_id, jobjectArray options, jlong memory_pool_id) {
JNI_METHOD_START
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
if (pool == nullptr) {
Expand Down Expand Up @@ -555,6 +616,12 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
}
JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr));
}
if (file_format_id != -1 && options != nullptr) {
std::unordered_map<std::string, std::string> option_map = ToStringMap(env, options);
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(scan_options));
}
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));

auto scanner = JniGetOrThrow(scanner_builder->Finish());
Expand Down Expand Up @@ -668,14 +735,29 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: (Ljava/lang/String;II)J
* Signature: (Ljava/lang/String;II;Ljava/lang/String;Ljava/lang/String)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I(
JNIEnv* env, jobject, jstring uri, jint file_format_id) {
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory(
JNIEnv* env, jobject, jstring uri, jint file_format_id, jobjectArray options) {
JNI_METHOD_START
std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (options != nullptr) {
std::unordered_map<std::string, std::string> option_map = ToStringMap(env, options);
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
#ifdef ARROW_CSV
if (file_format_id == 3) {
std::shared_ptr<arrow::dataset::CsvFileFormat> csv_file_format =
std::dynamic_pointer_cast<arrow::dataset::CsvFileFormat>(file_format);
csv_file_format->parse_options =
std::dynamic_pointer_cast<arrow::dataset::CsvFragmentScanOptions>(scan_options)
->parse_options;
}
#endif
}
arrow::dataset::FileSystemFactoryOptions options;
std::shared_ptr<arrow::dataset::DatasetFactory> d =
JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make(
Expand All @@ -686,16 +768,31 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav

/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: ([Ljava/lang/String;II)J
* Method: makeFileSystemDatasetFactoryWithFiles
* Signature: ([Ljava/lang/String;II;[Ljava/lang/String)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id) {
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactoryWithFiles(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jobjectArray options) {
JNI_METHOD_START

std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (options != nullptr) {
std::unordered_map<std::string, std::string> option_map = ToStringMap(env, options);
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
#ifdef ARROW_CSV
if (file_format_id == 3) {
std::shared_ptr<arrow::dataset::CsvFileFormat> csv_file_format =
std::dynamic_pointer_cast<arrow::dataset::CsvFileFormat>(file_format);
csv_file_format->parse_options =
std::dynamic_pointer_cast<arrow::dataset::CsvFragmentScanOptions>(scan_options)
->parse_options;
}
#endif
}
arrow::dataset::FileSystemFactoryOptions options;

std::vector<std::string> uri_vec = ToStringVector(env, uris);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,56 @@
*/
package org.apache.arrow.dataset.file;

import java.util.Optional;
import org.apache.arrow.dataset.jni.NativeDatasetFactory;
import org.apache.arrow.dataset.jni.NativeMemoryPool;
import org.apache.arrow.dataset.scanner.FragmentScanOptions;
import org.apache.arrow.memory.BufferAllocator;

/** Java binding of the C++ FileSystemDatasetFactory. */
public class FileSystemDatasetFactory extends NativeDatasetFactory {

public FileSystemDatasetFactory(
BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, String uri) {
super(allocator, memoryPool, createNative(format, uri));
super(allocator, memoryPool, createNative(format, uri, Optional.empty()));
}

public FileSystemDatasetFactory(
BufferAllocator allocator,
NativeMemoryPool memoryPool,
FileFormat format,
String uri,
Optional<FragmentScanOptions> fragmentScanOptions) {
super(allocator, memoryPool, createNative(format, uri, fragmentScanOptions));
}

public FileSystemDatasetFactory(
BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, String[] uris) {
super(allocator, memoryPool, createNative(format, uris));
super(allocator, memoryPool, createNative(format, uris, Optional.empty()));
}

public FileSystemDatasetFactory(
BufferAllocator allocator,
NativeMemoryPool memoryPool,
FileFormat format,
String[] uris,
Optional<FragmentScanOptions> fragmentScanOptions) {
super(allocator, memoryPool, createNative(format, uris, fragmentScanOptions));
}

private static long createNative(FileFormat format, String uri) {
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id());
private static long createNative(
FileFormat format, String uri, Optional<FragmentScanOptions> fragmentScanOptions) {
return JniWrapper.get()
.makeFileSystemDatasetFactory(
uri, format.id(), fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null));
}

private static long createNative(FileFormat format, String[] uris) {
return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id());
private static long createNative(
FileFormat format, String[] uris, Optional<FragmentScanOptions> fragmentScanOptions) {
return JniWrapper.get()
.makeFileSystemDatasetFactoryWithFiles(
uris,
format.id(),
fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,26 @@ private JniWrapper() {}
* intermediate shared_ptr of the factory instance.
*
* @param uri file uri to read, either a file or a directory
* @param fileFormat file format ID
* @param fileFormat file format ID.
* @param serializedFragmentScanOptions serialized FragmentScanOptions.
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String uri, int fileFormat);
public native long makeFileSystemDatasetFactory(
String uri, int fileFormat, String[] serializedFragmentScanOptions);

/**
* Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a
* intermediate shared_ptr of the factory instance.
*
* @param uris List of file uris to read, each path pointing to an individual file
* @param fileFormat file format ID
* @param fileFormat file format ID.
* @param serializedFragmentScanOptions serialized FragmentScanOptions.
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat);
public native long makeFileSystemDatasetFactoryWithFiles(
String[] uris, int fileFormat, String[] serializedFragmentScanOptions);

/**
* Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ private JniWrapper() {}
* @param substraitProjection substrait extended expression to evaluate for project new columns
* @param substraitFilter substrait extended expression to evaluate for apply filter
* @param batchSize batch size of scanned record batches.
* @param fileFormat file format ID.
* @param serializedFragmentScanOptions serialized FragmentScanOptions.
* @param memoryPool identifier of memory pool used in the native scanner.
* @return the native pointer of the arrow::dataset::Scanner instance.
*/
Expand All @@ -80,6 +82,8 @@ public native long createScanner(
ByteBuffer substraitProjection,
ByteBuffer substraitFilter,
long batchSize,
int fileFormat,
String[] serializedFragmentScanOptions,
long memoryPool);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.arrow.dataset.jni;

import org.apache.arrow.dataset.scanner.FragmentScanOptions;
import org.apache.arrow.dataset.scanner.ScanOptions;
import org.apache.arrow.dataset.source.Dataset;

Expand All @@ -37,7 +38,13 @@ public synchronized NativeScanner newScan(ScanOptions options) {
if (closed) {
throw new NativeInstanceReleasedException();
}

int fileFormatId = -1;
String[] serialized = null;
if (options.getFragmentScanOptions().isPresent()) {
FragmentScanOptions fragmentScanOptions = options.getFragmentScanOptions().get();
fileFormatId = fragmentScanOptions.fileFormat().id();
serialized = fragmentScanOptions.serialize();
}
long scannerId =
JniWrapper.get()
.createScanner(
Expand All @@ -46,6 +53,8 @@ public synchronized NativeScanner newScan(ScanOptions options) {
options.getSubstraitProjection().orElse(null),
options.getSubstraitFilter().orElse(null),
options.getBatchSize(),
fileFormatId,
serialized,
context.getMemoryPool().getNativeInstanceId());

return new NativeScanner(context, scannerId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.dataset.scanner;

import org.apache.arrow.dataset.file.FileFormat;

/** The file fragment scan options interface. It is used to transfer to JNI call. */
public interface FragmentScanOptions {
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
FileFormat fileFormat();

String[] serialize();
}
Loading
Loading