Skip to content

Commit

Permalink
change to serialize ExpressionLiteral
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed May 20, 2024
1 parent 0f35c70 commit 608f568
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 120 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/dataset/file_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ Result<std::shared_ptr<FragmentScanOptions>> CsvFragmentScanOptions::from(
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else if (key == "strings_can_be_null") {
options->convert_options.strings_can_be_null = parseBool(value);
} else {
return Status::Invalid("Config " + it.first + "is not supported.");
}
Expand Down
1 change: 0 additions & 1 deletion cpp/src/arrow/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ arrow_install_all_headers("arrow/engine")
set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extended_expression_internal.cc
substrait/extension_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
substrait/options.cc
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1537,5 +1537,17 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
return std::move(out);
}

Status FromProto(const substrait::Expression::Literal& literal,
std::unordered_map<std::string, std::string>& out) {
ARROW_RETURN_IF(!literal.has_map(), Status::Invalid("Literal does not have a map."));
auto literalMap = literal.map();
auto size = literalMap.key_values_size();
for (auto i = 0; i < size; i++) {
substrait::Expression_Literal_Map_KeyValue keyValue = literalMap.key_values(i);
out.emplace(keyValue.key().string(), keyValue.value().string());
}
return Status::OK();
}

} // namespace engine
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/engine/substrait/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,9 @@ ARROW_ENGINE_EXPORT
Result<SubstraitCall> FromProto(const substrait::AggregateFunction&, bool is_hash,
const ExtensionSet&, const ConversionOptions&);

ARROW_ENGINE_EXPORT
Status FromProto(const substrait::Expression::Literal& literal,
std::unordered_map<std::string, std::string>& out);

} // namespace engine
} // namespace arrow
48 changes: 0 additions & 48 deletions cpp/src/arrow/engine/substrait/extension_internal.cc

This file was deleted.

44 changes: 0 additions & 44 deletions cpp/src/arrow/engine/substrait/extension_internal.h

This file was deleted.

12 changes: 7 additions & 5 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "arrow/dataset/file_base.h"
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/engine/substrait/extended_expression_internal.h"
#include "arrow/engine/substrait/extension_internal.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/plan_internal.h"
#include "arrow/engine/substrait/relation.h"
Expand Down Expand Up @@ -249,10 +248,13 @@ Result<BoundExpressions> DeserializeExpressions(
}

Status DeserializeMap(const Buffer& buf,
std::unordered_map<std::string, std::string> out) {
ARROW_ASSIGN_OR_RAISE(auto advanced_extension,
ParseFromBuffer<substrait::extensions::AdvancedExtension>(buf));
return FromProto(advanced_extension, out);
std::unordered_map<std::string, std::string>& out) {
// ARROW_ASSIGN_OR_RAISE(auto advanced_extension,
// ParseFromBuffer<substrait::extensions::AdvancedExtension>(buf));
// return FromProto(advanced_extension, out);
ARROW_ASSIGN_OR_RAISE(auto literal,
ParseFromBuffer<substrait::Expression::Literal>(buf));
return FromProto(literal, out);
}

namespace {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/engine/substrait/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ ARROW_ENGINE_EXPORT Result<BoundExpressions> DeserializeExpressions(
ExtensionSet* ext_set_out = NULLPTR);

ARROW_ENGINE_EXPORT Status
DeserializeMap(const Buffer& buf, std::unordered_map<std::string, std::string> out);
DeserializeMap(const Buffer& buf, std::unordered_map<std::string, std::string>& out);

/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
///
Expand Down
18 changes: 18 additions & 0 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,15 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory(
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
#ifdef ARROW_CSV
if (file_format_id == 3) {
std::shared_ptr<arrow::dataset::CsvFileFormat> csv_file_format =
std::dynamic_pointer_cast<arrow::dataset::CsvFileFormat>(file_format);
csv_file_format->parse_options =
std::dynamic_pointer_cast<arrow::dataset::CsvFragmentScanOptions>(scan_options)
->parse_options;
}
#endif
}
arrow::dataset::FileSystemFactoryOptions options;
std::shared_ptr<arrow::dataset::DatasetFactory> d =
Expand Down Expand Up @@ -733,6 +742,15 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactoryWithFi
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
#ifdef ARROW_CSV
if (file_format_id == 3) {
std::shared_ptr<arrow::dataset::CsvFileFormat> csv_file_format =
std::dynamic_pointer_cast<arrow::dataset::CsvFileFormat>(file_format);
csv_file_format->parse_options =
std::dynamic_pointer_cast<arrow::dataset::CsvFragmentScanOptions>(scan_options)
->parse_options;
}
#endif
}
arrow::dataset::FileSystemFactoryOptions options;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.arrow.dataset.substrait.util.ConvertUtil;

import io.substrait.proto.AdvancedExtension;
import io.substrait.proto.Expression;

public interface FragmentScanOptions {
String typeName();
Expand All @@ -42,9 +43,11 @@ default ByteBuffer serializeMap(Map<String, String> config) {
return null;
}

AdvancedExtension extension = ConvertUtil.expressionToExtension(ConvertUtil.mapToExpression(config));
ByteBuffer buf = ByteBuffer.allocateDirect(extension.getSerializedSize());
buf.put(extension.toByteArray());
Expression.Literal literal = ConvertUtil.mapToExpressionLiteral(config);

// AdvancedExtension extension = ConvertUtil.expressionToExtension(ConvertUtil.mapToExpression(config));
ByteBuffer buf = ByteBuffer.allocateDirect(literal.getSerializedSize());
buf.put(literal.toByteArray());
return buf;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

import java.util.Map;

import com.google.protobuf.Any;

import io.substrait.proto.AdvancedExtension;
import io.substrait.proto.Expression;

public class ConvertUtil {
Expand All @@ -31,7 +28,7 @@ public class ConvertUtil {
*
* @return Substrait Expression
*/
public static Expression mapToExpression(Map<String, String> values) {
public static Expression.Literal mapToExpressionLiteral(Map<String, String> values) {
Expression.Literal.Builder literalBuilder = Expression.Literal.newBuilder();
Expression.Literal.Map.KeyValue.Builder keyValueBuilder =
Expression.Literal.Map.KeyValue.newBuilder();
Expand All @@ -44,19 +41,6 @@ public static Expression mapToExpression(Map<String, String> values) {
mapBuilder.addKeyValues(keyValueBuilder.build());
}
literalBuilder.setMap(mapBuilder.build());
return Expression.newBuilder().setLiteral(literalBuilder.build()).build();
}

/**
* Add substrait expression to AdvancedExtension.
*
* @param expr Substrait Expression.
* @return Substrait AdvancedExtension
*/
public static AdvancedExtension expressionToExtension(Expression expr) {
AdvancedExtension.Builder extensionBuilder = AdvancedExtension.newBuilder();
Any.Builder builder = extensionBuilder.getEnhancementBuilder();
builder.setValue(expr.toByteString());
return extensionBuilder.build();
return literalBuilder.build();
}
}

0 comments on commit 608f568

Please sign in to comment.