diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index f324f87d6c301..63b8dd73f4720 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -368,29 +368,104 @@ std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; } +inline bool ParseChar(const std::string& key, const std::string& value) { + if (value.size() != 1) { + JniThrow("Option " + key + " should be a char, but is " + value); + } + return value.at(0); +} + /// \brief Construct FragmentScanOptions from config map #ifdef ARROW_CSV -arrow::Result> -ToCsvFragmentScanOptions(const std::unordered_map& configs) { + +bool SetCsvConvertOptions(arrow::csv::ConvertOptions& options, const std::string& key, + const std::string& value) { + if (key == "column_types") { + int64_t schema_address = std::stol(value); + ArrowSchema* c_schema = reinterpret_cast(schema_address); + auto schema = JniGetOrThrow(arrow::ImportSchema(c_schema)); + auto& column_types = options.column_types; + for (auto field : schema->fields()) { + column_types[field->name()] = field->type(); + } + } else if (key == "strings_can_be_null") { + options.strings_can_be_null = ParseBool(value); + } else if (key == "check_utf8") { + options.check_utf8 = ParseBool(value); + } else if (key == "null_values") { + options.null_values = {value}; + } else if (key == "true_values") { + options.true_values = {value}; + } else if (key == "false_values") { + options.false_values = {value}; + } else if (key == "quoted_strings_can_be_null") { + options.quoted_strings_can_be_null = ParseBool(value); + } else if (key == "auto_dict_encode") { + options.auto_dict_encode = ParseBool(value); + } else if (key == "auto_dict_max_cardinality") { + options.auto_dict_max_cardinality = std::stoi(value); + } else if (key == "decimal_point") { + options.decimal_point = ParseChar(key, value); + } else if (key == "include_missing_columns") { + options.include_missing_columns = ParseBool(value); + } else { + return false; + } + return true; +} + +bool SetCsvParseOptions(arrow::csv::ParseOptions& options, const std::string& key, + const std::string& value) { + if (key == "delimiter") { + options.delimiter = ParseChar(key, value); + } else if (key == "quoting") { + options.quoting = ParseBool(value); + } else if (key == "quote_char") { + options.quote_char = ParseChar(key, value); + } else if (key == "double_quote") { + options.double_quote = ParseBool(value); + } else if (key == "escaping") { + options.escaping = ParseBool(value); + } else if (key == "escape_char") { + options.escape_char = ParseChar(key, value); + } else if (key == "newlines_in_values") { + options.newlines_in_values = ParseBool(value); + } else if (key == "ignore_empty_lines") { + options.ignore_empty_lines = ParseBool(value); + } else { + return false; + } + return true; +} + +bool SetCsvReadOptions(arrow::csv::ReadOptions& options, const std::string& key, + const std::string& value) { + if (key == "use_threads") { + options.use_threads = ParseBool(value); + } else if (key == "block_size") { + options.block_size = std::stoi(value); + } else if (key == "skip_rows") { + options.skip_rows = std::stoi(value); + } else if (key == "skip_rows_after_names") { + options.skip_rows_after_names = std::stoi(value); + } else if (key == "autogenerate_column_names") { + options.autogenerate_column_names = ParseBool(value); + } else { + return false; + } + return true; +} + +std::shared_ptr ToCsvFragmentScanOptions( + const std::unordered_map& configs) { std::shared_ptr options = std::make_shared(); - for (auto const& [key, value] : configs) { - if (key == "delimiter") { - options->parse_options.delimiter = value.data()[0]; - } else if (key == "quoting") { - options->parse_options.quoting = ParseBool(value); - } else if (key == "column_types") { - int64_t schema_address = std::stol(value); - ArrowSchema* c_schema = reinterpret_cast(schema_address); - ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema)); - auto& column_types = options->convert_options.column_types; - for (auto field : schema->fields()) { - column_types[field->name()] = field->type(); - } - } else if (key == "strings_can_be_null") { - options->convert_options.strings_can_be_null = ParseBool(value); - } else { - return arrow::Status::Invalid("Config " + key + " is not supported."); + for (const auto& [key, value] : configs) { + bool setValid = SetCsvParseOptions(options->parse_options, key, value) || + SetCsvConvertOptions(options->convert_options, key, value) || + SetCsvReadOptions(options->read_options, key, value); + if (!setValid) { + JniThrow("Config " + key + " is not supported."); } } return options; diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java index 39271b5f063fb..dddc36d38714e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java @@ -32,6 +32,10 @@ public class CsvFragmentScanOptions implements FragmentScanOptions { * CSV scan options, map to CPP struct CsvFragmentScanOptions. The key in config map is the field * name of mapping cpp struct * + *

Currently, multi-valued options (which are std::vector values in C++) only support having a + * single value set. For example, for the null_values option, only one string can be set as the + * null value. + * * @param convertOptions similar to CsvFragmentScanOptions#convert_options in CPP, the ArrowSchema * represents column_types, convert data option such as null value recognition. * @param readOptions similar to CsvFragmentScanOptions#read_options in CPP, specify how to read diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java index 9787e8308e73e..d598190528811 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java @@ -18,10 +18,13 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; +import java.util.Map; import java.util.Optional; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.CDataDictionaryProvider; @@ -42,6 +45,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; import org.hamcrest.collection.IsIterableContainingInOrder; import org.junit.jupiter.api.Test; @@ -165,4 +169,156 @@ public void testCsvConvertOptionsNoOption() throws Exception { assertEquals(3, rowCount); } } + + @Test + public void testCsvReadParseAndReadOptions() throws Exception { + final Schema schema = + new Schema( + Collections.singletonList(Field.nullable("Id;Name;Language", new ArrowType.Utf8())), + null); + String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + CsvFragmentScanOptions fragmentScanOptions = + new CsvFragmentScanOptions( + new CsvConvertOptions(ImmutableMap.of()), + ImmutableMap.of("skip_rows_after_names", "1"), + ImmutableMap.of("delimiter", ";")); + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptions) + .build(); + try (DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.CSV, + path, + Optional.of(fragmentScanOptions)); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowCount = 0; + while (reader.loadNextBatch()) { + final ValueIterableVector idVector = + (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id;Name;Language"); + assertThat( + idVector.getValueIterable(), + IsIterableContainingInOrder.contains( + new Text("2;Peter;Python"), new Text("3;Celin;C++"))); + rowCount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(2, rowCount); + } + } + + @Test + public void testCsvReadOtherOptions() throws Exception { + String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + Map convertOption = + ImmutableMap.of( + "check_utf8", + "true", + "null_values", + "NULL", + "true_values", + "True", + "false_values", + "False", + "quoted_strings_can_be_null", + "true", + "auto_dict_encode", + "false", + "auto_dict_max_cardinality", + "3456", + "decimal_point", + ".", + "include_missing_columns", + "false"); + Map readOption = + ImmutableMap.of( + "use_threads", + "true", + "block_size", + "1024", + "skip_rows", + "12", + "skip_rows_after_names", + "12", + "autogenerate_column_names", + "false"); + Map parseOption = + ImmutableMap.of( + "delimiter", + ".", + "quoting", + "true", + "quote_char", + "'", + "double_quote", + "False", + "escaping", + "true", + "escape_char", + "v", + "newlines_in_values", + "false", + "ignore_empty_lines", + "true"); + CsvFragmentScanOptions fragmentScanOptions = + new CsvFragmentScanOptions(new CsvConvertOptions(convertOption), readOption, parseOption); + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptions) + .build(); + try (DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options)) { + assertNotNull(scanner); + } + } + + @Test + public void testCsvInvalidOption() throws Exception { + String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + Map convertOption = ImmutableMap.of("not_exists_key_check_utf8", "true"); + CsvFragmentScanOptions fragmentScanOptions = + new CsvFragmentScanOptions( + new CsvConvertOptions(convertOption), ImmutableMap.of(), ImmutableMap.of()); + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptions) + .build(); + try (DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + Dataset dataset = datasetFactory.finish()) { + assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + } + + CsvFragmentScanOptions fragmentScanOptionsFaultValue = + new CsvFragmentScanOptions( + new CsvConvertOptions(ImmutableMap.of()), + ImmutableMap.of("", ""), + ImmutableMap.of("escape_char", "vbvb")); + ScanOptions optionsFault = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptionsFaultValue) + .build(); + try (DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + Dataset dataset = datasetFactory.finish()) { + assertThrows(RuntimeException.class, () -> dataset.newScan(optionsFault)); + } + } }