diff --git a/common/BUILD b/common/BUILD index 1a5342a5d..04eb740f6 100644 --- a/common/BUILD +++ b/common/BUILD @@ -196,6 +196,7 @@ cc_test( "//internal:types", "//internal:value_internal", "//testutil:util", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_googleapis//google/type:money_cc_proto", @@ -218,6 +219,7 @@ cc_library( "//internal:list_impl", "//internal:map_impl", "//internal:types", + "@com_google_absl//absl/meta:type_traits", ], ) diff --git a/common/converters.h b/common/converters.h index 2a1bc59a7..c8836ed9b 100644 --- a/common/converters.h +++ b/common/converters.h @@ -5,6 +5,7 @@ #include +#include "absl/meta/type_traits.h" #include "common/parent_ref.h" #include "common/value.h" #include "internal/list_impl.h" @@ -53,7 +54,7 @@ template Value ValueFromList(T&& value) { static_assert(!std::is_pointer::value, "use ValueForList"); return Value::MakeList< - internal::ListWrapper>>( + internal::ListWrapper>>( std::forward(value)); } @@ -92,7 +93,7 @@ template Value ValueFromMap(T&& value) { static_assert(!std::is_pointer::value, "use ValueForList"); return Value::MakeMap< - internal::MapWrapper>>( + internal::MapWrapper>>( std::forward(value)); } diff --git a/common/escaping.cc b/common/escaping.cc index 26b1367b8..98e7e8d28 100644 --- a/common/escaping.cc +++ b/common/escaping.cc @@ -241,7 +241,7 @@ inline std::tuple unescape_char( } if (value < 0x80 || !encode) { - char tmp[2] = {(char)value, '\0'}; + char tmp[2] = {static_cast(value), '\0'}; return std::make_tuple(std::string(tmp), s, ""); } else { char tmp[5]; @@ -294,7 +294,7 @@ absl::optional unescape(const std::string& s, bool is_bytes) { } value = value.substr(1, n - 2); // If there is nothing to escape, then return. - if (is_raw_literal || (value.find('\\') == std::string::npos)) { + if (is_raw_literal || (!absl::StrContains(value, '\\'))) { return value; } diff --git a/common/value.cc b/common/value.cc index 429d2b412..47eaab59c 100644 --- a/common/value.cc +++ b/common/value.cc @@ -20,8 +20,6 @@ namespace api { namespace expr { namespace common { -using ::google::api::expr::internal::NotFoundError; - namespace { static constexpr const Value::Kind kIndexToKind[] = { diff --git a/common/value_test.cc b/common/value_test.cc index 3a6a60c3a..95b8646e3 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -10,6 +10,7 @@ #include "google/type/money.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" #include "common/custom_object.h" #include "internal/status_util.h" @@ -709,7 +710,7 @@ class ValueVisitTest : public ::testing::Test { template void TestGetPtrVisitor() { - using T = remove_reference_t())>; + using T = absl::remove_reference_t())>; ExpectSameType>(); // Return by const ref. ExpectSameType>(); @@ -717,7 +718,7 @@ class ValueVisitTest : public ::testing::Test { template void TestGetVisitor() { - using T = remove_reference_t())>; + using T = absl::remove_reference_t())>; // Return by optional. ExpectSameType, GetVisitorType, T>>(); @@ -727,7 +728,7 @@ class ValueVisitTest : public ::testing::Test { template void TestValueAdapter() { - using T = remove_reference_t())>; + using T = absl::remove_reference_t())>; ExpectSameType>(); ExpectSameType>(); } diff --git a/conformance/BUILD b/conformance/BUILD index b357cef48..8592cec0c 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -18,6 +18,7 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/logic.textproto", "@com_google_cel_spec//tests/simple:testdata/macros.textproto", "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", + "@com_google_cel_spec//tests/simple:testdata/parse.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", @@ -26,27 +27,6 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", ] -DASHBOARD_TESTS = [ - "@com_google_cel_spec//tests/simple:testdata/basic.textproto", - "@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", - "@com_google_cel_spec//tests/simple:testdata/conversions.textproto", - "@com_google_cel_spec//tests/simple:testdata/dynamic.textproto", - "@com_google_cel_spec//tests/simple:testdata/enums.textproto", - "@com_google_cel_spec//tests/simple:testdata/fields.textproto", - "@com_google_cel_spec//tests/simple:testdata/fp_math.textproto", - "@com_google_cel_spec//tests/simple:testdata/integer_math.textproto", - "@com_google_cel_spec//tests/simple:testdata/lists.textproto", - "@com_google_cel_spec//tests/simple:testdata/logic.textproto", - "@com_google_cel_spec//tests/simple:testdata/macros.textproto", - "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", - "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", - "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", - "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", - "@com_google_cel_spec//tests/simple:testdata/string.textproto", - "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", - "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", -] - cc_binary( name = "server", testonly = 1, @@ -81,71 +61,54 @@ cc_binary( "--server=\"$(location :server) " + arg + "\"", "--skip_check", "--pipe", - # TODO(issues/93): Inconsistent Duration.getMilliseconds() behavior. + + # Tests which require spec changes. + # TODO(issues/93): Deprecate Duration.getMilliseconds. "--skip_test=timestamps/duration_converters/get_milliseconds", - # TODO(issues/94): Missing timestamp conversion functions (type / string) - "--skip_test=timestamps/timestamp_conversions/toType_timestamp,toString_timestamp", - "--skip_test=timestamps/duration_conversions/toType_duration,toString_duration", # TODO(issues/81): Conversion functions for int(), uint() which can be # uncommented when the spec changes to truncation rather than rounding. "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", "--skip_test=conversions/uint/double_nearest,double_nearest_int,double_half_away", - # TODO(issues/96): Well-known type conversion support. - "--skip_test=proto2/literal_wellknown", - "--skip_test=proto3/literal_wellknown", - "--skip_test=proto2/empty_field/wkt", - "--skip_test=proto3/empty_field/wkt", - # Requires container support - "--skip_test=namespace/namespace/self_eval_container_lookup_unchecked", + # TODO(issues/110): Tune parse limits to mirror those for proto deserialization and C++ safety limits. + "--skip_test=parse/nest/list_index,message_literal,funcall,list_literal,map_literal;repeat/conditional,add_sub,mul_div,select,index,map_literal,message_literal", + + # Broken test cases which should be supported. + # TODO(issues/111): Byte literal decoding of invalid UTF-8 results in incorrect bytes output. "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", - # Requires heteregenous equality spec clarification - "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", - "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", - "--skip_test=fields/in/singleton", - # Requires qualified bindings error message relaxation - "--skip_test=fields/qualified_identifier_resolution/qualified_identifier_resolution_unchecked", - "--skip_test=string/size/one_unicode,unicode", "--skip_test=string/bytes_concat/left_unit", - # TODO(issues/85): The exists one macro should not short-circuit false. - "--skip_test=macros/exists_one/list_no_shortcircuit", - # TODO(issues/86): Map macro may produce incorrect results on error. - "--skip_test=macros/map/list_error", - # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails - "--skip_test=namespace/qualified/self_eval_qualified_lookup", - "--skip_test=namespace/namespace/self_eval_container_lookup", - "--skip_test=fields/qualified_identifier_resolution/qualified_ident", - "--skip_test=fields/qualified_identifier_resolution/map_field_select", - "--skip_test=fields/qualified_identifier_resolution/ident_with_longest_prefix_check", - # New conformance tests awaiting synchronization. + # TODO(issues/112): Unbound functions result in empty eval response. "--skip_test=basic/functions/unbound", "--skip_test=basic/functions/unbound_is_runtime_error", + # TODO(issues/113): Aggregate values must logically AND element equality results. "--skip_test=comparisons/eq_literal/not_eq_list_false_vs_types", "--skip_test=comparisons/eq_literal/not_eq_map_false_vs_types", - "--skip_test=dynamic/int32", - "--skip_test=dynamic/int64", - "--skip_test=dynamic/uint32", - "--skip_test=dynamic/uint64", - "--skip_test=dynamic/float", - "--skip_test=dynamic/double", - "--skip_test=dynamic/string", - "--skip_test=dynamic/bytes", - "--skip_test=dynamic/bool", - "--skip_test=dynamic/list", - "--skip_test=dynamic/struct", - "--skip_test=dynamic/value_null", - "--skip_test=dynamic/value_number", - "--skip_test=dynamic/value_string", - "--skip_test=dynamic/value_bool", - "--skip_test=dynamic/value_struct", - "--skip_test=dynamic/value_list", - "--skip_test=dynamic/any", - "--skip_test=dynamic/complex", - "--skip_test=enums/legacy_proto2", - "--skip_test=enums/legacy_proto3", + # TODO(issues/114): Ensure the 'in' operator is a logical OR of element equality results. + "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", + "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", + # TODO(issues/115): The 'in' operator fails with maps containing boolean keys. + "--skip_test=fields/in/singleton", + # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails + "--skip_test=fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", + "--skip_test=namespace/qualified/self_eval_qualified_lookup", + "--skip_test=namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", + # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. + "--skip_test=dynamic/list/var", + # TODO(issues/109): Ensure that unset wrapper fields return 'null' rather than the default value of the wrapper. + "--skip_test=dynamic/int32/field_read_proto2_unset,field_read_proto3_unset;uint32/field_read_proto2_unset;uint64/field_read_proto2_unset;float/field_read_proto2_unset,field_read_proto3_unset;double/field_read_proto2_unset,field_read_proto3_unset", + "--skip_test=proto2/empty_field/wkt", + "--skip_test=proto3/empty_field/wkt", + # TODO(issues/117): Integer overflow on enum assignments should error. + "--skip_test=enums/legacy_proto2/select_big,select_neg,assign_standalone_int_too_big,assign_standalone_int_too_neg", + "--skip_test=enums/legacy_proto3/assign_standalone_int_too_big,assign_standalone_int_too_neg", + # TODO(issues/118): Duration and timestamp range errors should result in evaluation errors. + "--skip_test=timestamps/timestamp_range", + + # Future features for CEL 1.0 + # TODO(issues/119): Strong typing support for enums, specified but not implemented. "--skip_test=enums/strong_proto2", "--skip_test=enums/strong_proto3", - "--skip_test=timestamps/timestamp_range", - "--skip_test=timestamps/duration_range", + # Bad tests, temporarily disable. + "--skip_test=dynamic/float/field_assign_proto2_range,field_assign_proto3_range", ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", @@ -165,12 +128,17 @@ sh_test( "$(location @com_google_cel_spec//tests/simple:simple_test)", "--server=$(location :server)", "--skip_check", + # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. + "--skip_test=dynamic/list/var", + # TODO(issues/119): Strong typing support for enums, specified but not implemented. + "--skip_test=enums/strong_proto2", + "--skip_test=enums/strong_proto3", "--pipe", - ] + ["$(location " + test + ")" for test in DASHBOARD_TESTS], + ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", "@com_google_cel_spec//tests/simple:simple_test", - ] + DASHBOARD_TESTS, + ] + ALL_TESTS, visibility = [ "//:__subpackages__", "//third_party/cel:__pkg__", diff --git a/conformance/server.cc b/conformance/server.cc index 8b9ddac35..dcceae3a3 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -1,8 +1,6 @@ -#include - +#include "google/api/expr/v1alpha1/conformance_service.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/conformance_service.pb.h" #include "google/api/expr/v1alpha1/eval.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/value.pb.h" @@ -151,6 +149,7 @@ int RunServer(bool optimize) { google::protobuf::Arena arena; InterpreterOptions options; options.enable_qualified_type_identifiers = true; + options.enable_string_size_as_unicode_codepoints = true; if (optimize) { std::cerr << "Enabling optimizations" << std::endl; @@ -169,7 +168,8 @@ int RunServer(bool optimize) { NestedEnum_descriptor()); type_registry->Register(google::api::expr::test::v1::proto3::TestAllTypes:: NestedEnum_descriptor()); - auto register_status = RegisterBuiltinFunctions(builder->GetRegistry()); + auto register_status = + RegisterBuiltinFunctions(builder->GetRegistry(), options); if (!register_status.ok()) { std::cerr << "Failed to initialize: " << register_status.ToString() << std::endl; @@ -181,7 +181,7 @@ int RunServer(bool optimize) { // Implementation of a simple pipe protocol: // INPUT LINE 1: parse/check/eval // INPUT LINE 2: JSON of the corresponding request protobuf - // OUTPUT LINE 1: JSON of the coressponding response protobuf + // OUTPUT LINE 1: JSON of the corresponding response protobuf while (true) { std::string cmd, input, output; std::getline(std::cin, cmd); diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 48d2ea33b..a2cb71025 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -173,13 +173,10 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) diff --git a/eval/eval/attribute_utility_test.cc b/eval/eval/attribute_utility_test.cc index b7a09d4a8..48bf9a901 100644 --- a/eval/eval/attribute_utility_test.cc +++ b/eval/eval/attribute_utility_test.cc @@ -15,7 +15,6 @@ namespace runtime { using google::api::expr::v1alpha1::Expr; using testing::Eq; -using testing::IsNull; using testing::NotNull; using testing::SizeIs; using testing::UnorderedPointwise; diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index a42cf822a..5f4f471a6 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -209,7 +209,7 @@ absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { class ListKeysStep : public ExpressionStepBase { public: - ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} + explicit ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 619a065d5..eea6cca70 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -22,7 +22,7 @@ constexpr int NUM_CONTAINER_ACCESS_ARGUMENTS = 2; // message. class ContainerAccessStep : public ExpressionStepBase { public: - ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + explicit ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; @@ -54,7 +54,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, break; } } - return CreateNoSuchKeyError(arena, absl::StrCat("Key not found in map")); + return CreateNoSuchKeyError(arena, "Key not found in map"); } inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index 21be321a9..d3373c397 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -1,5 +1,6 @@ #include "eval/eval/create_list_step.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/public/containers/container_backed_list_impl.h" @@ -35,6 +36,14 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { auto args = frame->value_stack().GetSpan(list_size_); CelValue result; + for (const auto& arg : args) { + if (arg.IsError()) { + result = arg; + frame->value_stack().Pop(list_size_); + frame->value_stack().Push(result); + return absl::OkStatus(); + } + } const UnknownSet* unknown_set = nullptr; if (frame->enable_unknowns()) { @@ -44,18 +53,17 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { /*use_partial=*/true); if (unknown_set != nullptr) { result = CelValue::CreateUnknownSet(unknown_set); + frame->value_stack().Pop(list_size_); + frame->value_stack().Push(result); + return absl::OkStatus(); } } - if (unknown_set == nullptr) { - CelList* cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); - result = CelValue::CreateList(cel_list); - } - + CelList* cel_list = google::protobuf::Arena::Create( + frame->arena(), std::vector(args.begin(), args.end())); + result = CelValue::CreateList(cel_list); frame->value_stack().Pop(list_size_); frame->value_stack().Push(result); - return absl::OkStatus(); } diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 7cdd569b9..a6ba1a159 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -100,7 +100,7 @@ class CreateListStepTest : public testing::TestWithParam {}; // Tests error when not enough list elements are on the stack during list // creation. -TEST(CreateListStepTest, TestCreateListStackUndeflow) { +TEST(CreateListStepTest, TestCreateListStackUnderflow) { ExecutionPath path; Expr dummy_expr; @@ -144,6 +144,42 @@ TEST_P(CreateListStepTest, CreateListOne) { EXPECT_THAT((*result_value.ListOrDie())[0].Int64OrDie(), Eq(100)); } +TEST_P(CreateListStepTest, CreateListWithError) { + google::protobuf::Arena arena; + std::vector values; + CelError error = absl::InvalidArgumentError("bad arg"); + values.push_back(CelValue::CreateError(&error)); + auto eval_result = RunExpressionWithCelValues(values, &arena, GetParam()); + + ASSERT_OK(eval_result); + const CelValue result_value = eval_result.value(); + ASSERT_TRUE(result_value.IsError()); + EXPECT_THAT(*result_value.ErrorOrDie(), + Eq(absl::InvalidArgumentError("bad arg"))); +} + +TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { + google::protobuf::Arena arena; + // list composition is: {unknown, error} + std::vector values; + Expr expr0; + expr0.mutable_ident_expr()->set_name("name0"); + CelAttribute attr0(expr0, {}); + UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); + values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); + CelError error = absl::InvalidArgumentError("bad arg"); + values.push_back(CelValue::CreateError(&error)); + + auto eval_result = RunExpressionWithCelValues(values, &arena, GetParam()); + + // The bad arg should win. + ASSERT_OK(eval_result); + const CelValue result_value = eval_result.value(); + ASSERT_TRUE(result_value.IsError()); + EXPECT_THAT(*result_value.ErrorOrDie(), + Eq(absl::InvalidArgumentError("bad arg"))); +} + TEST_P(CreateListStepTest, CreateListHundred) { google::protobuf::Arena arena; std::vector values; diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 68f1c727b..fd9c79dec 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -147,12 +147,13 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } Message* entry_msg = msg->GetReflection()->AddMessage(msg, entry.field); - status = SetValueToSingleField(key, key_field_descriptor, entry_msg); + status = SetValueToSingleField(key, key_field_descriptor, entry_msg, + frame->arena()); if (!status.ok()) { break; } status = SetValueToSingleField(value.value(), value_field_descriptor, - entry_msg); + entry_msg, frame->arena()); if (!status.ok()) { break; } @@ -170,11 +171,12 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } for (int i = 0; i < cel_list->size(); i++) { - status = AddValueToRepeatedField((*cel_list)[i], entry.field, msg); + status = AddValueToRepeatedField((*cel_list)[i], entry.field, msg, + frame->arena()); if (!status.ok()) break; } } else { - status = SetValueToSingleField(arg, entry.field, msg); + status = SetValueToSingleField(arg, entry.field, msg, frame->arena()); } if (!status.ok()) { @@ -274,8 +276,6 @@ absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, const Descriptor* message_desc, int64_t expr_id) { if (message_desc != nullptr) { - // TODO(issues/92): Support resolving a type name within a container. - // Make message-creating step. std::vector entries; for (const auto& entry : create_struct_expr->entries()) { diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index a27f33a8a..f3f210ab4 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -122,7 +122,8 @@ class AddFunction : public CelFunction { class SinkFunction : public CelFunction { public: - SinkFunction(CelValue::Type type) : CelFunction(CreateDescriptor(type)) {} + explicit SinkFunction(CelValue::Type type) + : CelFunction(CreateDescriptor(type)) {} static CelFunctionDescriptor CreateDescriptor(CelValue::Type type) { return CelFunctionDescriptor{"Sink", false, {type}}; diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index 420cf10e5..f63006711 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -17,7 +17,7 @@ namespace { class TernaryStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. - TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + explicit TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; diff --git a/eval/public/BUILD b/eval/public/BUILD index 2effb7e8e..d3a25becd 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -108,7 +108,7 @@ cc_library( "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", ], ) @@ -122,6 +122,8 @@ cc_library( ], deps = [ ":cel_value", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -182,10 +184,11 @@ cc_library( ":cel_options", "//base:unilib", "//eval/public/containers:container_backed_list_impl", + "//internal:proto_util", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_protobuf//:protobuf", + "@com_google_absl//absl/time", "@com_googlesource_code_re2//:re2", ], ) @@ -402,8 +405,8 @@ cc_test( "//base:status_macros", "//eval/testutil:test_message_cc_proto", "//testutil:util", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", ], ) @@ -490,8 +493,10 @@ cc_test( ":cel_builtins", ":cel_expr_builder_factory", ":cel_function_registry", + ":cel_value", "//base:status_macros", "//eval/public/structs:cel_proto_wrapper", + "//internal:proto_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/eval/public/activation_bind_helper.cc b/eval/public/activation_bind_helper.cc index a5bede00d..1e8004003 100644 --- a/eval/public/activation_bind_helper.cc +++ b/eval/public/activation_bind_helper.cc @@ -1,5 +1,6 @@ #include "eval/public/activation_bind_helper.h" +#include "absl/status/status.h" #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" @@ -37,6 +38,13 @@ absl::Status CreateValueFromField(const google::protobuf::Message* msg, absl::Status BindProtoToActivation(const Message* message, Arena* arena, Activation* activation, ProtoUnsetFieldOptions options) { + // If we need to bind any types that are backed by an arena allocation, we + // will cause a memory leak. + if (arena == nullptr) { + return absl::InvalidArgumentError( + "arena must not be null for BindProtoToActivation."); + } + // TODO(issues/24): Improve the utilities to bind dynamic values as well. const Descriptor* desc = message->GetDescriptor(); const google::protobuf::Reflection* reflection = message->GetReflection(); diff --git a/eval/public/activation_bind_helper.h b/eval/public/activation_bind_helper.h index 2154b91ec..92b32b917 100644 --- a/eval/public/activation_bind_helper.h +++ b/eval/public/activation_bind_helper.h @@ -17,7 +17,8 @@ enum class ProtoUnsetFieldOptions { }; // Utility method, that takes a protobuf Message and interprets it as a -// namespace, binding its fields to Activation. +// namespace, binding its fields to Activation. |arena| must be non-null. +// // Field names and values become respective names and values of parameters // bound to the Activation object. // Example: diff --git a/eval/public/activation_bind_helper_test.cc b/eval/public/activation_bind_helper_test.cc index d28653ac2..b423679ab 100644 --- a/eval/public/activation_bind_helper_test.cc +++ b/eval/public/activation_bind_helper_test.cc @@ -2,6 +2,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "eval/public/activation.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" @@ -127,6 +128,17 @@ TEST(ActivationBindHelperTest, TestBindDefaultFields) { EqualsProto(*result.value().MessageOrDie())); } +TEST(ActivationBindHelperTest, RejectsNullArena) { + TestMessage message; + message.set_bool_value(true); + + Activation activation; + + ASSERT_EQ(BindProtoToActivation(&message, /*arena=*/nullptr, &activation), + absl::InvalidArgumentError( + "arena must not be null for BindProtoToActivation.")); +} + } // namespace } // namespace runtime diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 86fa2adbe..fd0e574b8 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -1,20 +1,22 @@ #include "eval/public/builtin_func_registrar.h" +#include #include #include -#include "google/protobuf/util/time_util.h" #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/time/time.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/containers/container_backed_list_impl.h" +#include "internal/proto_util.h" #include "re2/re2.h" #include "base/unilib.h" @@ -28,9 +30,20 @@ using google::protobuf::Arena; namespace { const int64_t kIntMax = std::numeric_limits::max(); -const int64_t kIntMin = std::numeric_limits::min(); +const int64_t kIntMin = std::numeric_limits::lowest(); const uint64_t kUintMax = std::numeric_limits::max(); +// Returns the number of UTF8 codepoints within a string. +// The input string must first be checked to see if it is valid UTF8. +static int UTF8CodepointCount(absl::string_view str) { + int n = 0; + // Increment the codepoint count on non-trail-byte characters. + for (const auto p : str) { + n += (*reinterpret_cast(&p) >= -0x40); + } + return n; +} + // Comparison template functions template CelValue Inequal(Arena*, Type t1, Type t2) { @@ -1087,8 +1100,10 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, status = FunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena* arena, double v) { - if ((v > static_cast(kIntMax)) || - (v < static_cast(kIntMin))) { + // NaN and -+infinite numbers cannot be represented as int values, + // nor can double values which exceed the integer 64-bit range. + if (!std::isfinite(v) || v > static_cast(kIntMax) || + v < static_cast(kIntMin)) { return CreateErrorValue(arena, "double out of int range", absl::StatusCode::kInvalidArgument); } @@ -1199,6 +1214,37 @@ absl::Status RegisterStringConversionFunctions( registry); if (!status.ok()) return status; + // duration -> string + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, absl::Duration value) -> CelValue { + auto encode = + google::api::expr::internal::EncodeDurationToString(value); + if (!encode.ok()) { + const auto& status = encode.status(); + return CreateErrorValue(arena, status.message(), status.code()); + } + return CelValue::CreateString(CelValue::StringHolder( + Arena::Create(arena, encode.value()))); + }, + registry); + if (!status.ok()) return status; + + // timestamp -> string + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, absl::Time value) -> CelValue { + auto encode = google::api::expr::internal::EncodeTimeToString(value); + if (!encode.ok()) { + const auto& status = encode.status(); + return CreateErrorValue(arena, status.message(), status.code()); + } + return CelValue::CreateString(CelValue::StringHolder( + Arena::Create(arena, encode.value()))); + }, + registry); + if (!status.ok()) return status; + return absl::OkStatus(); } @@ -1208,7 +1254,16 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, auto status = FunctionAdapter::CreateAndRegister( builtin::kUint, false, [](Arena* arena, double v) { - if ((v > static_cast(kUintMax)) || (v < 0)) { + // NaN and -+infinite numbers cannot be represented as uint values, + // nor doubles that exceed the uint64_t range. In some limited cases, + // like 1.84467e+19, the value appears to fit within the uint64_t range + // but type conversion results in rounding that overflows. + // + // Note, the double is checked to make sure it is not greater than 2^64 + // before it is converted to a uint128 value, as the type conversion + // may check-fail for some double inputs that exceed this value. + if (!std::isfinite(v) || v < 0 || v > std::ldexp(1.0, 64) || + absl::uint128(v) > kUintMax) { return CreateErrorValue(arena, "double out of uint range", absl::StatusCode::kInvalidArgument); } @@ -1351,16 +1406,24 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // String size - auto string_size_func = [](Arena*, CelValue::StringHolder value) -> int64_t { - return value.value().size(); + auto size_func = [=](Arena* arena, CelValue::StringHolder value) -> CelValue { + absl::string_view str = value.value(); + if (options.enable_string_size_as_unicode_codepoints) { + if (!UniLib::IsStructurallyValid(str)) { + return CreateErrorValue(arena, "invalid utf-8 string", + absl::StatusCode::kInvalidArgument); + } + return CelValue::CreateInt64(UTF8CodepointCount(str)); + } + return CelValue::CreateInt64(str.size()); }; // receiver style = true/false // Support global and receiver style size() operations on strings. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, string_size_func, registry); + status = FunctionAdapter::CreateAndRegister( + builtin::kSize, true, size_func, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, string_size_func, registry); + status = FunctionAdapter::CreateAndRegister( + builtin::kSize, false, size_func, registry); if (!status.ok()) return status; // Bytes size diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 008f3bf08..5cff08eda 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -9,7 +9,9 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/proto_util.h" #include "base/status_macros.h" namespace google { @@ -28,6 +30,9 @@ using google::api::expr::v1alpha1::SourceInfo; using google::protobuf::Arena; using google::protobuf::util::TimeUtil; +using ::google::api::expr::internal::MakeGoogleApiDurationMax; +using ::google::api::expr::internal::MakeGoogleApiDurationMin; +using ::google::api::expr::internal::MakeGoogleApiTimeMin; using testing::Eq; class BuiltinsTest : public ::testing::Test { @@ -132,6 +137,18 @@ class BuiltinsTest : public ::testing::Test { << operation << " for " << CelValue::TypeName(ref.type()); } + // Helper method. Looks up in registry and tests Type conversions. + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + CelValue::StringHolder result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsString(), true); + ASSERT_EQ(result_value.StringOrDie().value(), result.value()) + << operation << " for " << CelValue::TypeName(ref.type()); + } + void TestTypeConverts(absl::string_view operation, const CelValue& ref, double result) { CelValue result_value; @@ -165,6 +182,18 @@ class BuiltinsTest : public ::testing::Test { << operation << " for " << CelValue::TypeName(ref.type()); } + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + Duration& result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsDuration(), true); + ASSERT_EQ(result_value.DurationOrDie(), + CelProtoWrapper::CreateDuration(&result).DurationOrDie()) + << operation << " for " << CelValue::TypeName(ref.type()); + } + // Helper method. Attempts to perform a type conversion and expects an error // as the result. void TestTypeConversionError(absl::string_view operation, @@ -548,6 +577,20 @@ TEST_F(BuiltinsTest, TestTimestampDurationArithmeticalOperation) { ASSERT_EQ(result_value.IsDuration(), true); ASSERT_EQ(absl::ToInt64Nanoseconds(result_value.DurationOrDie()), TimeUtil::DurationToNanoseconds(d0)); + + const auto min = CelValue::CreateDuration(MakeGoogleApiDurationMin()); + ASSERT_TRUE(min.IsDuration()); + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kSubtract, {}, + {min, CelValue::CreateDuration(absl::Nanoseconds(1))}, &result_value)); + ASSERT_TRUE(result_value.IsError()); + + const auto max = CelValue::CreateDuration(MakeGoogleApiDurationMax()); + ASSERT_TRUE(max.IsDuration()); + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kAdd, {}, {max, CelValue::CreateDuration(absl::Nanoseconds(1))}, + &result_value)); + ASSERT_TRUE(result_value.IsError()); } // Test functions for Duration @@ -565,6 +608,11 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), 11L); + std::string result = "93541.011s"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), + CelValue::StringHolder(&result)); + TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); + ref.set_seconds(-93541L); ref.set_nanos(-11000000L); @@ -575,6 +623,24 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { -93541L); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), -11L); + + result = "-93541.011s"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), + CelValue::StringHolder(&result)); + TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); + + absl::Duration d = MakeGoogleApiDurationMin() + absl::Seconds(-1); + result = absl::FormatDuration(d); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); + + d = MakeGoogleApiDurationMax() + absl::Seconds(1); + result = absl::FormatDuration(d); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); + + std::string inf = "inf"; + std::string ninf = "-inf"; + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&inf)); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&ninf)); } // Test functions for Timestamp @@ -598,11 +664,19 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), 11L); + std::string result = "1970-01-01T00:00:01.011Z"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), + CelValue::StringHolder(&result)); + ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), 0L); + result = "1970-01-04T00:00:00Z"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), + CelValue::StringHolder(&result)); + // Test timestamp functions w/ IANA timezone ref.set_seconds(1L); ref.set_nanos(11000000L); @@ -702,6 +776,10 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), 3L); + + TestTypeConversionError( + builtin::kString, + CelValue::CreateTimestamp(MakeGoogleApiTimeMin() + absl::Seconds(-1))); } TEST_F(BuiltinsTest, TestBytesConversions_bytes) { @@ -1327,8 +1405,28 @@ TEST_F(BuiltinsTest, StringSize) { builtin::kSize, {}, {CelValue::CreateString(&test)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 9); +} - ASSERT_EQ(result_value.Int64OrDie(), test.size()); +TEST_F(BuiltinsTest, StringUnicodeSize) { + std::string test = "πέντε"; + CelValue result_value; + InterpreterOptions options; + options.enable_string_size_as_unicode_codepoints = true; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kSize, {}, + {CelValue::CreateString(&test)}, + &result_value, options)); + ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 5); + + // Disable the option to measure string size by codepoints, and the return + // value should be equal to the number of bytes in the string (10). + options.enable_string_size_as_unicode_codepoints = false; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kSize, {}, + {CelValue::CreateString(&test)}, + &result_value, options)); + ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 10); } TEST_F(BuiltinsTest, BytesSize) { diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index 7e1dc0275..d789a02d9 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -1,8 +1,12 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ +#include +#include #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "eval/public/cel_value.h" @@ -15,9 +19,11 @@ namespace runtime { // This complex structure is needed for overloads support. class CelFunctionDescriptor { public: - CelFunctionDescriptor(const std::string& name, bool receiver_style, - const std::vector types) - : name_(name), receiver_style_(receiver_style), types_(types) {} + CelFunctionDescriptor(absl::string_view name, bool receiver_style, + std::vector types) + : name_(name), + receiver_style_(receiver_style), + types_(std::move(types)) {} // Function name. const std::string& name() const { return name_; } @@ -56,8 +62,8 @@ class CelFunctionDescriptor { class CelFunction { public: // Build CelFunction from descriptor - explicit CelFunction(const CelFunctionDescriptor& descriptor) - : descriptor_(descriptor) {} + explicit CelFunction(CelFunctionDescriptor descriptor) + : descriptor_(std::move(descriptor)) {} // Non-copyable CelFunction(const CelFunction& other) = delete; diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index a26391cb0..7bca2ca96 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -1,6 +1,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ + #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -63,6 +65,9 @@ bool AddType(std::vector* arg_types) { // It accepts method implementations as std::function, allowing // them to be lambdas/regular C++ functions. CEL method descriptors are // deduced based on C++ function signatures. +// CelFunction::Evaluate will set result to the value returned by the handler. +// To handle errors, choose CelValue as the return type, and use the +// CreateError/Create* helpers in cel_value.h. // // Usage example: // @@ -81,13 +86,14 @@ class FunctionAdapter : public CelFunction { public: using FuncType = std::function; - FunctionAdapter(const CelFunctionDescriptor& descriptor, FuncType handler) - : CelFunction(descriptor), handler_(std::move(handler)) {} + FunctionAdapter(CelFunctionDescriptor descriptor, FuncType handler) + : CelFunction(std::move(descriptor)), handler_(std::move(handler)) {} static absl::StatusOr> Create( absl::string_view name, bool receiver_type, std::function handler) { std::vector arg_types; + arg_types.reserve(sizeof...(Arguments)); if (!internal::AddType<0, Arguments...>(&arg_types)) { return absl::Status( @@ -96,10 +102,9 @@ class FunctionAdapter : public CelFunction { ": failed to determine input parameter type")); } - std::unique_ptr cel_func = absl::make_unique( - CelFunctionDescriptor(std::string(name), receiver_type, arg_types), + return absl::make_unique( + CelFunctionDescriptor(name, receiver_type, std::move(arg_types)), std::move(handler)); - return std::move(cel_func); } // Creates function handler and attempts to register it with @@ -113,7 +118,7 @@ class FunctionAdapter : public CelFunction { return status.status(); } - return registry->Register(std::move(status.value())); + return registry->Register(std::move(status).value()); } #if defined(__clang_major_version__) && __clang_major_version__ >= 8 && !defined(__APPLE__) diff --git a/eval/public/cel_function_provider_test.cc b/eval/public/cel_function_provider_test.cc index 8a6ff3e42..897cffe12 100644 --- a/eval/public/cel_function_provider_test.cc +++ b/eval/public/cel_function_provider_test.cc @@ -10,7 +10,6 @@ namespace expr { namespace runtime { namespace { -using testing::_; using testing::Eq; using testing::HasSubstr; using testing::Ne; diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index dcd774c6c..4dbe48d03 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -27,6 +27,11 @@ struct InterpreterOptions { bool enable_missing_attribute_errors = false; + // Enable functions which return the string.size() as the number of unicode + // codepoints. + // Starting on 4/7/2021 this will default to 'true' + bool enable_string_size_as_unicode_codepoints = false; + // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 565ae3e27..7f19a1ff9 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -1,6 +1,8 @@ #include "eval/public/cel_type_registry.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/descriptor.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/status/status.h" #include "absl/types/optional.h" @@ -30,9 +32,19 @@ const absl::node_hash_set& GetCoreTypes() { return *kCoreTypes; } +const absl::flat_hash_set GetCoreEnums() { + static const auto* const kCoreEnums = + new absl::flat_hash_set{ + // Register the NULL_VALUE enum. + google::protobuf::NullValue_descriptor(), + }; + return *kCoreEnums; +} + } // namespace -CelTypeRegistry::CelTypeRegistry() : types_(GetCoreTypes()), enums_() {} +CelTypeRegistry::CelTypeRegistry() + : types_(GetCoreTypes()), enums_(GetCoreEnums()) {} void CelTypeRegistry::Register(std::string fully_qualified_type_name) { // Registers the fully qualified type name as a CEL type. diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 3117722da..c194bec9f 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -24,6 +24,7 @@ TEST(CelTypeRegistryTest, TestRegisterEnumDescriptor) { enum_set.insert(enum_desc->full_name()); } absl::flat_hash_set expected_set; + expected_set.insert({"google.protobuf.NullValue"}); expected_set.insert({"google.api.expr.runtime.TestMessage.TestEnum"}); EXPECT_THAT(enum_set, Eq(expected_set)); } diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 2d47bdbb0..c81f7427d 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -42,8 +42,25 @@ constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; +// Exclusive bounds for valid duration values. +constexpr absl::Duration kDurationHigh = absl::Seconds(315576000001); +constexpr absl::Duration kDurationLow = absl::Seconds(-315576000001); + +const absl::Status* DurationOverflowError() { + static const auto* const kDurationOverflow = new absl::Status( + absl::StatusCode::kInvalidArgument, "Duration is out of range"); + return kDurationOverflow; +} + } // namespace +CelValue CelValue::CreateDuration(absl::Duration value) { + if (value >= kDurationHigh || value <= kDurationLow) { + return CelValue(DurationOverflowError()); + } + return CelValue(value); +} + std::string CelValue::TypeName(Type value_type) { switch (value_type) { case Type::kBool: diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index a7c34bdce..6b0fa400b 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -183,9 +183,7 @@ class CelValue { return CelValue(BytesHolder(str)); } - static CelValue CreateDuration(absl::Duration value) { - return CelValue(value); - } + static CelValue CreateDuration(absl::Duration value); static CelValue CreateTimestamp(absl::Time value) { return CelValue(value); } diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 04e76b4e2..cbe409d2e 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -17,7 +17,6 @@ cc_library( deps = [ "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", - "//internal:proto_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -121,9 +120,22 @@ cc_test( ], deps = [ ":field_backed_map_impl", - "//eval/eval:evaluator_core", "//eval/testutil:test_message_cc_proto", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "field_access_test", + srcs = ["field_access_test.cc"], + deps = [ + ":field_access", + "//internal:proto_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index 23442b186..fe208c3e1 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -3,12 +3,15 @@ #include #include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/arena.h" #include "google/protobuf/map_field.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/proto_util.h" namespace google { namespace api { @@ -24,11 +27,8 @@ using ::google::protobuf::Message; using ::google::protobuf::Reflection; // Well-known type protobuf type names which require special get / set behavior. -constexpr const char kProtobufDuration[] = "google.protobuf.Duration"; -constexpr const char kProtobufTimestamp[] = "google.protobuf.Timestamp"; -constexpr const char kProtobufAny[] = "google.protobuf.Any"; - -const char kTypeGoogleApisComPrefix[] = "type.googleapis.com/"; +constexpr absl::string_view kProtobufAny = "google.protobuf.Any"; +constexpr absl::string_view kTypeGoogleApisComPrefix = "type.googleapis.com/"; // Singular message fields and repeated message fields have similar access model // To provide common approach, we implement accessor classes, based on CRTP. @@ -197,6 +197,8 @@ class ScalarFieldAccessor : public FieldAccessor { } const Message* GetMessage() const { + // TODO(issues/109): When the field descriptor is a wrapper type, check if + // the field is set. If set, return the unwrapped value, else return 'null'. return &GetReflection()->GetMessage(*msg_, field_desc_); } @@ -456,7 +458,6 @@ class FieldSetter { MessageRetrieverOp()); if (!value.has_value()) { - GOOGLE_LOG(ERROR) << "Has No Value"; return false; } @@ -464,30 +465,6 @@ class FieldSetter { return true; } - bool AssignDuration(const CelValue& cel_value) const { - absl::Duration d; - if (!cel_value.GetValue(&d)) { - GOOGLE_LOG(ERROR) << "Unable to retrieve duration"; - return false; - } - google::protobuf::Duration duration; - google::api::expr::internal::EncodeDuration(d, &duration); - static_cast(this)->SetMessage(&duration); - return true; - } - - bool AssignTimestamp(const CelValue& cel_value) const { - absl::Time t; - if (!cel_value.GetValue(&t)) { - GOOGLE_LOG(ERROR) << "Unable to retrieve timestamp"; - return false; - } - google::protobuf::Timestamp timestamp; - google::api::expr::internal::EncodeTime(t, ×tamp); - static_cast(this)->SetMessage(×tamp); - return true; - } - // This method provides message field content, wrapped in CelValue. // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. @@ -528,16 +505,14 @@ class FieldSetter { break; } case FieldDescriptor::CPPTYPE_MESSAGE: { - const std::string& type_name = field_desc_->message_type()->full_name(); + const absl::string_view type_name = + field_desc_->message_type()->full_name(); // When the field is a message, it might be a well-known type with a // non-proto representation that requires special handling before it // can be set on the field. - if (type_name == kProtobufTimestamp) { - return AssignTimestamp(value); - } else if (type_name == kProtobufDuration) { - return AssignDuration(value); - } - return AssignMessage(value); + auto wrapped_value = + CelProtoWrapper::MaybeWrapValue(type_name, value, arena_); + return AssignMessage(wrapped_value.value_or(value)); } case FieldDescriptor::CPPTYPE_ENUM: { return AssignEnum(value); @@ -550,18 +525,20 @@ class FieldSetter { } protected: - FieldSetter(Message* msg, const FieldDescriptor* field_desc) - : msg_(msg), field_desc_(field_desc) {} + FieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) + : msg_(msg), field_desc_(field_desc), arena_(arena) {} Message* msg_; const FieldDescriptor* field_desc_; + Arena* arena_; }; // Accessor class, to work with singular fields class ScalarFieldSetter : public FieldSetter { public: - ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc) - : FieldSetter(msg, field_desc) {} + ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} bool SetBool(bool value) const { GetReflection()->SetBool(msg_, field_desc_, value); @@ -645,8 +622,9 @@ class ScalarFieldSetter : public FieldSetter { // Appender class, to work with repeated fields class RepeatedFieldSetter : public FieldSetter { public: - RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc) - : FieldSetter(msg, field_desc) {} + RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} bool SetBool(bool value) const { GetReflection()->AddBool(msg_, field_desc_, value); @@ -718,8 +696,9 @@ class RepeatedFieldSetter : public FieldSetter { // arena Arena to use for allocations if needed. // result pointer to object to store value in. absl::Status SetValueToSingleField(const CelValue& value, - const FieldDescriptor* desc, Message* msg) { - ScalarFieldSetter setter(msg, desc); + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + ScalarFieldSetter setter(msg, desc, arena); return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( @@ -730,9 +709,9 @@ absl::Status SetValueToSingleField(const CelValue& value, } absl::Status AddValueToRepeatedField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { - RepeatedFieldSetter setter(msg, desc); + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + RepeatedFieldSetter setter(msg, desc, arena); return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( @@ -742,32 +721,6 @@ absl::Status AddValueToRepeatedField(const CelValue& value, value.DebugString())); } -absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const FieldDescriptor* desc, Message* msg) { - auto entry_msg = msg->GetReflection()->AddMessage(msg, desc); - auto key_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(1); - auto value_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(2); - - ScalarFieldSetter key_setter(entry_msg, key_field_desc); - ScalarFieldSetter value_setter(entry_msg, value_field_desc); - - if (!key_setter.SetFieldFromCelValue(key)) { - return absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument \"$2\" to message " - "\"$0\" field \"$1\" map key.", - msg->GetDescriptor()->name(), desc->name(), key.DebugString())); - } - - if (!value_setter.SetFieldFromCelValue(value)) { - return absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument \"$2\" to message \"$0\" " - "field \"$1\" map value.", - msg->GetDescriptor()->name(), desc->name(), value.DebugString())); - } - - return absl::OkStatus(); -} - } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/containers/field_access.h b/eval/public/containers/field_access.h index 711cf1744..4f52b4d0d 100644 --- a/eval/public/containers/field_access.h +++ b/eval/public/containers/field_access.h @@ -46,25 +46,20 @@ absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when setting the field. absl::Status SetValueToSingleField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); + google::protobuf::Message* msg, google::protobuf::Arena* arena); + // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when adding the value. absl::Status AddValueToRepeatedField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); - -// Adds content of CelValue to repeated message field. -// Returns status of the operation. -// msg Message containing the field. -// desc Descriptor of the field to access. - -absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); + google::protobuf::Message* msg, + google::protobuf::Arena* arena); } // namespace runtime } // namespace expr diff --git a/eval/public/containers/field_access_test.cc b/eval/public/containers/field_access_test.cc new file mode 100644 index 000000000..5ad4fbf5c --- /dev/null +++ b/eval/public/containers/field_access_test.cc @@ -0,0 +1,97 @@ +#include "eval/public/containers/field_access.h" + +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "internal/proto_util.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::api::expr::internal::MakeGoogleApiDurationMax; +using google::api::expr::internal::MakeGoogleApiTimeMax; +using google::protobuf::Arena; +using google::protobuf::FieldDescriptor; +using test::v1::proto3::TestAllTypes; + +TEST(FieldAccessTest, SetDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MakeGoogleApiDurationMax()), field, &msg, + &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetDurationBadDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MakeGoogleApiDurationMax() + absl::Seconds(1)), + field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetDurationBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestamp) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MakeGoogleApiTimeMax()), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetTimestampBadTime) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MakeGoogleApiTimeMax() + absl::Seconds(1)), + field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestampBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/containers/field_backed_map_impl_test.cc b/eval/public/containers/field_backed_map_impl_test.cc index 41b214aa1..59f75bf8e 100644 --- a/eval/public/containers/field_backed_map_impl_test.cc +++ b/eval/public/containers/field_backed_map_impl_test.cc @@ -103,10 +103,6 @@ TEST(FieldBackedMapImplTest, EmptySizeTest) { google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); - - std::string test0 = "test0"; - std::string test1 = "test1"; - EXPECT_EQ(cel_map->size(), 0); } diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index aad518b7e..479dcef58 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -13,8 +13,6 @@ namespace expr { namespace runtime { namespace { -using google::protobuf::Duration; -using google::protobuf::Timestamp; using google::protobuf::Arena; static const int kNanosPerSecond = 1000000000; diff --git a/eval/public/set_util.cc b/eval/public/set_util.cc index fd85903f1..885d9031f 100644 --- a/eval/public/set_util.cc +++ b/eval/public/set_util.cc @@ -89,7 +89,7 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { struct ComparisonVisitor { CelValue rhs; - ComparisonVisitor(CelValue rhs) : rhs(rhs) {} + explicit ComparisonVisitor(CelValue rhs) : rhs(rhs) {} template int operator()(T lhs_value) { T rhs_value; diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 3660d1e31..3afa8d9c2 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -12,11 +12,11 @@ cc_library( ], deps = [ "//eval/public:cel_value", + "//eval/testutil:test_message_cc_proto", "//internal:proto_util", - "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) @@ -29,10 +29,15 @@ cc_test( ], deps = [ ":cel_proto_wrapper", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", "//internal:proto_util", "//testutil:util", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", "@com_google_protobuf//:protobuf", ], diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 77d318c27..b121ccf59 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -1,11 +1,24 @@ #include "eval/public/structs/cel_proto_wrapper.h" +#include + +#include +#include + #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" -#include "absl/container/node_hash_map.h" +#include "google/protobuf/message.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" +#include "eval/public/cel_value.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/proto_util.h" namespace google { namespace api { @@ -35,8 +48,26 @@ using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; using google::protobuf::Value; +// kMaxIntJSON is defined as the Number.MAX_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMaxIntJSON = (1l << 53) - 1; + +// kMinIntJSON is defined as the Number.MIN_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMinIntJSON = -kMaxIntJSON; + // Forward declaration for google.protobuf.Value CelValue ValueFromMessage(const Value* value, Arena* arena); +absl::optional MessageFromValue(const CelValue& value, + Value* json); + +// IsJSONSafe indicates whether the int is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(int64_t i) { return i >= kMinIntJSON && i <= kMaxIntJSON; } + +// IsJSONSafe indicates whether the uint is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(uint64_t i) { + return i <= static_cast(kMaxIntJSON); +} // Map implementation wrapping google.protobuf.ListValue class DynamicList : public CelList { @@ -145,8 +176,7 @@ CelValue ValueFromMessage(const Struct* struct_value, Arena* arena) { CelValue ValueFromMessage(const Any* any_value, Arena* arena) { auto type_url = any_value->type_url(); - - auto pos = type_url.find_last_of("/"); + auto pos = type_url.find_last_of('/'); if (pos == absl::string_view::npos) { // TODO(issues/25) What error code? // Malformed type_url @@ -235,7 +265,7 @@ CelValue ValueFromMessage(const Value* value, Arena* arena) { case Value::KindCase::kListValue: return CelProtoWrapper::CreateMessage(&value->list_value(), arena); default: - return CreateErrorValue(arena, "No known fields set in Value message"); + return CelValue::CreateNull(); } } @@ -320,11 +350,492 @@ class ValueFromMessageMaker { factories_.emplace(desc, std::move(factory)); } - absl::node_hash_map> factories_; }; +absl::optional MessageFromValue(const CelValue& value, + Duration* duration) { + absl::Duration val; + if (!value.GetValue(&val)) { + return {}; + } + auto status = google::api::expr::internal::EncodeDuration(val, duration); + if (!status.ok()) { + return {}; + } + return duration; +} + +absl::optional MessageFromValue(const CelValue& value, + BoolValue* wrapper) { + bool val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + BytesValue* wrapper) { + CelValue::BytesHolder view_val; + if (!value.GetValue(&view_val)) { + return {}; + } + wrapper->set_value(view_val.value().data()); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + DoubleValue* wrapper) { + double val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + FloatValue* wrapper) { + double val; + if (!value.GetValue(&val)) { + return {}; + } + // Abort the conversion if the value is outside the float range. + if (val > std::numeric_limits::max()) { + wrapper->set_value(std::numeric_limits::infinity()); + return wrapper; + } + if (val < std::numeric_limits::lowest()) { + wrapper->set_value(-std::numeric_limits::infinity()); + return wrapper; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + Int32Value* wrapper) { + int64_t val; + if (!value.GetValue(&val)) { + return {}; + } + // Abort the conversion if the value is outside the int32_t range. + if (val > std::numeric_limits::max() || + val < std::numeric_limits::lowest()) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + Int64Value* wrapper) { + int64_t val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + StringValue* wrapper) { + CelValue::StringHolder view_val; + if (!value.GetValue(&view_val)) { + return {}; + } + wrapper->set_value(view_val.value().data()); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + Timestamp* timestamp) { + absl::Time val; + if (!value.GetValue(&val)) { + return {}; + } + auto status = google::api::expr::internal::EncodeTime(val, timestamp); + if (!status.ok()) { + return {}; + } + return timestamp; +} + +absl::optional MessageFromValue(const CelValue& value, + UInt32Value* wrapper) { + uint64_t val; + if (!value.GetValue(&val)) { + return {}; + } + // Abort the conversion if the value is outside the uint32_t range. + if (val > std::numeric_limits::max()) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + UInt64Value* wrapper) { + uint64_t val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + ListValue* json_list) { + if (!value.IsList()) { + return {}; + } + const CelList& list = *value.ListOrDie(); + for (int i = 0; i < list.size(); i++) { + auto e = list[i]; + Value* elem = json_list->add_values(); + auto result = MessageFromValue(e, elem); + if (!result.has_value()) { + return {}; + } + } + return json_list; +} + +absl::optional MessageFromValue(const CelValue& value, + Struct* json_struct) { + if (!value.IsMap()) { + return {}; + } + const CelMap& map = *value.MapOrDie(); + const auto& keys = *map.ListKeys(); + auto fields = json_struct->mutable_fields(); + for (int i = 0; i < keys.size(); i++) { + auto k = keys[i]; + // If the key is not a string type, abort the conversion. + if (!k.IsString()) { + return {}; + } + absl::string_view key = k.StringOrDie().value(); + + auto v = map[k]; + if (!v.has_value()) { + return {}; + } + Value field_value; + auto result = MessageFromValue(v.value(), &field_value); + // If the value is not a valid JSON type, abort the conversion. + if (!result.has_value()) { + return {}; + } + (*fields)[key] = field_value; + } + return json_struct; +} + +absl::optional MessageFromValue(const CelValue& value, + Value* json) { + switch (value.type()) { + case CelValue::Type::kBool: { + bool val; + if (value.GetValue(&val)) { + json->set_bool_value(val); + return json; + } + } break; + case CelValue::Type::kBytes: { + // Base64 encode byte strings to ensure they can safely be transpored + // in a JSON string. + CelValue::BytesHolder val; + if (value.GetValue(&val)) { + json->set_string_value(absl::Base64Escape(val.value())); + return json; + } + } break; + case CelValue::Type::kDouble: { + double val; + if (value.GetValue(&val)) { + json->set_number_value(val); + return json; + } + } break; + case CelValue::Type::kDuration: { + // Convert duration values to a protobuf JSON format. + absl::Duration val; + if (value.GetValue(&val)) { + auto encode = google::api::expr::internal::EncodeDurationToString(val); + if (!encode.ok()) { + return {}; + } + json->set_string_value(*encode); + return json; + } + } break; + case CelValue::Type::kInt64: { + int64_t val; + // Convert int64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return json; + } + } break; + case CelValue::Type::kString: { + CelValue::StringHolder val; + if (value.GetValue(&val)) { + json->set_string_value(val.value().data()); + return json; + } + } break; + case CelValue::Type::kTimestamp: { + // Convert timestamp values to a protobuf JSON format. + absl::Time val; + if (value.GetValue(&val)) { + auto encode = google::api::expr::internal::EncodeTimeToString(val); + if (!encode.ok()) { + return {}; + } + json->set_string_value(*encode); + return json; + } + } break; + case CelValue::Type::kUint64: { + uint64_t val; + // Convert uint64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return json; + } + } break; + case CelValue::Type::kList: { + auto lv = MessageFromValue(value, json->mutable_list_value()); + if (lv.has_value()) { + return json; + } + } break; + case CelValue::Type::kMap: { + auto sv = MessageFromValue(value, json->mutable_struct_value()); + if (sv.has_value()) { + return json; + } + } break; + default: + if (value.IsNull()) { + json->set_null_value(protobuf::NULL_VALUE); + return json; + } + return {}; + } + return {}; +} + +absl::optional MessageFromValue(const CelValue& value, + Any* any) { + // In open source, any->PackFrom() returns void rather than boolean. + switch (value.type()) { + case CelValue::Type::kBool: { + BoolValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kBytes: { + BytesValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kDouble: { + DoubleValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kDuration: { + Duration v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kInt64: { + Int64Value v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kString: { + StringValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kTimestamp: { + Timestamp v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kUint64: { + UInt64Value v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kList: { + ListValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kMap: { + Struct v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } break; + case CelValue::Type::kMessage: { + if (value.IsNull()) { + Value v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value()) { + any->PackFrom(**msg); + return any; + } + } else { + any->PackFrom(*(value.MessageOrDie())); + return any; + } + } break; + default: + break; + } + return {}; +} + +// Factory class, responsible for populating a Message type instance with the +// value of a simple CelValue. +class MessageFromValueFactory { + public: + virtual ~MessageFromValueFactory() {} + virtual const google::protobuf::Descriptor* GetDescriptor() const = 0; + virtual absl::optional WrapMessage( + const CelValue& value, Arena* arena) const = 0; +}; + +// This template class has a good performance, but performes downcast +// operations on google::protobuf::Message pointers. +template +class CastingMessageFromValueFactory : public MessageFromValueFactory { + public: + const google::protobuf::Descriptor* GetDescriptor() const override { + return MessageType::descriptor(); + } + + absl::optional WrapMessage( + const CelValue& value, Arena* arena) const override { + // Convert nulls separately from other messages as a null value is still + // technically a message value, but not one that can be converted in the + // standard way. + if (value.IsNull()) { + return MessageFromValue(value, Arena::CreateMessage(arena)); + } + // If the value is a message type, see if it is already of the proper type + // name, and return it directly. + if (value.IsMessage()) { + const auto* msg = value.MessageOrDie(); + if (MessageType::descriptor() == msg->GetDescriptor()) { + return {}; + } + } + // Otherwise, allocate an empty message type, and attempt to populate it + // using the proper MessageFromValue overload. + auto* msg_buffer = Arena::CreateMessage(arena); + return MessageFromValue(value, msg_buffer); + } +}; + +// MessageFromValueMaker makes a specific protobuf Message instance based on +// the desired protobuf type name and an input CelValue. +// +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, an the factory +// returns an absent value. +class MessageFromValueMaker { + public: + explicit MessageFromValueMaker() { + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + } + // Non-copyable, non-assignable + MessageFromValueMaker(const MessageFromValueMaker&) = delete; + MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; + + absl::optional MaybeWrapMessage( + absl::string_view type_name, const CelValue& value, Arena* arena) const { + auto it = factories_.find(type_name); + if (it == factories_.end()) { + // Descriptor not found for type name. + return {}; + } + return (it->second)->WrapMessage(value, arena); + } + + private: + void Add(std::unique_ptr factory) { + const Descriptor* desc = factory->GetDescriptor(); + factories_.emplace(desc->full_name(), std::move(factory)); + } + + absl::flat_hash_map> + factories_; +}; + } // namespace // CreateMessage creates CelValue from google::protobuf::Message. @@ -340,10 +851,20 @@ CelValue CelProtoWrapper::CreateMessage(const google::protobuf::Message* value, } auto special_value = maker->CreateValue(value, arena); - return special_value.has_value() ? special_value.value() : CelValue(value); } +absl::optional CelProtoWrapper::MaybeWrapValue( + absl::string_view type_name, const CelValue& value, Arena* arena) { + static const MessageFromValueMaker* maker = new MessageFromValueMaker(); + + auto msg = maker->MaybeWrapMessage(type_name, value, arena); + if (!msg.has_value()) { + return {}; + } + return CelValue(msg.value()); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/structs/cel_proto_wrapper.h b/eval/public/structs/cel_proto_wrapper.h index 830bfd67f..379f20b4e 100644 --- a/eval/public/structs/cel_proto_wrapper.h +++ b/eval/public/structs/cel_proto_wrapper.h @@ -27,6 +27,19 @@ class CelProtoWrapper { static CelValue CreateTimestamp(const google::protobuf::Timestamp *value) { return CelValue(expr::internal::DecodeTime(*value)); } + + // MaybeWrapValue attempts to wrap the input value in a proto message with + // the given type_name. If the value can be wrapped, it is returned as a + // CelValue pointing to the protobuf message. Otherwise, the result will be + // empty. + // + // This method is the complement to CreateMessage which may unwrap a protobuf + // message to native CelValue representation during a protobuf field read. + // Just as CreateMessage should only be used when reading protobuf values, + // MaybeWrapValue should only be used when assigning protobuf fields. + static absl::optional MaybeWrapValue(absl::string_view type_name, + const CelValue &value, + google::protobuf::Arena *arena); }; } // namespace runtime diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index 4e797589a..a93f28b61 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -1,6 +1,10 @@ #include "eval/public/structs/cel_proto_wrapper.h" +#include +#include + #include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" @@ -9,6 +13,11 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" #include "internal/proto_util.h" #include "testutil/util.h" @@ -18,6 +27,8 @@ namespace api { namespace expr { namespace runtime { +namespace { + using testing::Eq; using testing::UnorderedPointwise; @@ -38,9 +49,84 @@ using google::protobuf::StringValue; using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; -TEST(CelProtoWrapperTest, TestType) { - ::google::protobuf::Arena arena; +using google::protobuf::Arena; + +class CelProtoWrapperTest : public ::testing::Test { + protected: + CelProtoWrapperTest() {} + + void ExpectWrappedMessage(const CelValue& value, + const google::protobuf::Message& message) { + // Test the input value wraps to the destination message type. + std::string type_name = message.GetTypeName(); + auto result = CelProtoWrapper::MaybeWrapValue(type_name, value, arena()); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE((*result).IsMessage()); + EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); + + // Ensure that double wrapping results in the object being wrapped once. + auto identity = + CelProtoWrapper::MaybeWrapValue(type_name, *result, arena()); + EXPECT_FALSE(identity.has_value()); + + // Check to make sure that even dynamic messages can be used as input to + // the wrapping call. + result = CelProtoWrapper::MaybeWrapValue( + ReflectedCopy(message)->GetTypeName(), value, arena()); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE((*result).IsMessage()); + EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); + } + + void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { + // Test the input value does not wrap by asserting value == result. + auto result = + CelProtoWrapper::MaybeWrapValue(message.GetTypeName(), value, arena()); + EXPECT_FALSE(result.has_value()); + } + + template + void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { + CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); + T value; + EXPECT_TRUE(cel_value.GetValue(&value)); + EXPECT_THAT(value, Eq(result)); + + T dyn_value; + CelValue cel_dyn_value = + CelProtoWrapper::CreateMessage(ReflectedCopy(message).get(), arena()); + EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); + EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); + EXPECT_THAT(value, Eq(dyn_value)); + } + + void ExpectUnwrappedMessage(const google::protobuf::Message& message, + google::protobuf::Message* result) { + CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); + if (result == nullptr) { + EXPECT_TRUE(cel_value.IsNull()); + return; + } + EXPECT_TRUE(cel_value.IsMessage()); + EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); + } + std::unique_ptr ReflectedCopy( + const google::protobuf::Message& message) { + std::unique_ptr dynamic_value( + factory_.GetPrototype(message.GetDescriptor())->New()); + dynamic_value->CopyFrom(message); + return dynamic_value; + } + + Arena* arena() { return &arena_; } + + private: + Arena arena_; + google::protobuf::DynamicMessageFactory factory_; +}; + +TEST_F(CelProtoWrapperTest, TestType) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); @@ -48,7 +134,7 @@ TEST(CelProtoWrapperTest, TestType) { EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = - CelProtoWrapper::CreateMessage(&msg_duration, &arena); + CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); Timestamp msg_timestamp; @@ -58,14 +144,12 @@ TEST(CelProtoWrapperTest, TestType) { EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = - CelProtoWrapper::CreateMessage(&msg_timestamp, &arena); + CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); } // This test verifies CelValue support of Duration type. -TEST(CelProtoWrapperTest, TestDuration) { - google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, TestDuration) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); @@ -73,21 +157,19 @@ TEST(CelProtoWrapperTest, TestDuration) { EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = - CelProtoWrapper::CreateMessage(&msg_duration, &arena); + CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); CelValue value = CelProtoWrapper::CreateDuration(&msg_duration); - // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsDuration()); Duration out; - expr::internal::EncodeDuration(value.DurationOrDie(), &out); + auto status = expr::internal::EncodeDuration(value.DurationOrDie(), &out); + EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } // This test verifies CelValue support of Timestamp type. -TEST(CelProtoWrapperTest, TestTimestamp) { - google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, TestTimestamp) { Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); @@ -95,70 +177,63 @@ TEST(CelProtoWrapperTest, TestTimestamp) { EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = - CelProtoWrapper::CreateMessage(&msg_timestamp, &arena); + CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); CelValue value = CelProtoWrapper::CreateTimestamp(&msg_timestamp); // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsTimestamp()); Timestamp out; - expr::internal::EncodeTime(value.TimestampOrDie(), &out); + auto status = expr::internal::EncodeTime(value.TimestampOrDie(), &out); + EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } // Dynamic Values test // - -TEST(CelProtoWrapperTest, TestValueFieldNull) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_null_value(google::protobuf::NullValue::NULL_VALUE); - - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsNull()); +TEST_F(CelProtoWrapperTest, UnwrapValueNull) { + Value json; + json.set_null_value(google::protobuf::NullValue::NULL_VALUE); + ExpectUnwrappedMessage(json, nullptr); } -TEST(CelProtoWrapperTest, TestValueFieldBool) { - ::google::protobuf::Arena arena; +// Test support for unwrapping a google::protobuf::Value to a CEL value. +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { + Value value_msg; + value_msg.set_null_value(protobuf::NULL_VALUE); - Value value1; - value1.set_bool_value(true); - - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsBool()); - EXPECT_EQ(value.BoolOrDie(), true); + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsNull()); } -TEST(CelProtoWrapperTest, TestValueFieldNumeric) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_number_value(1.0); +TEST_F(CelProtoWrapperTest, UnwrapValueBool) { + bool value = true; - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsDouble()); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), 1.0); + Value json; + json.set_bool_value(true); + ExpectUnwrappedPrimitive(json, value); } -TEST(CelProtoWrapperTest, TestValueFieldString) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapValueNumber) { + double value = 1.0; - const std::string kTest = "test"; + Value json; + json.set_number_value(value); + ExpectUnwrappedPrimitive(json, value); +} - Value value1; - value1.set_string_value(kTest); +TEST_F(CelProtoWrapperTest, UnwrapValueString) { + const std::string test = "test"; + auto value = CelValue::StringHolder(&test); - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsString()); - EXPECT_EQ(value.StringOrDie().value(), kTest); + Value json; + json.set_string_value(test); + ExpectUnwrappedPrimitive(json, value); } -TEST(CelProtoWrapperTest, TestValueFieldStruct) { - ::google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, UnwrapValueStruct) { const std::vector kFields = {"field1", "field2", "field3"}; - Struct value_struct; auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; @@ -170,7 +245,7 @@ TEST(CelProtoWrapperTest, TestValueFieldStruct) { auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; value3.set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&value_struct, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&value_struct, arena()); ASSERT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); @@ -203,9 +278,55 @@ TEST(CelProtoWrapperTest, TestValueFieldStruct) { EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); } -TEST(CelProtoWrapperTest, TestListFieldStruct) { - ::google::protobuf::Arena arena; +// Test support for google::protobuf::Struct when it is created as dynamic +// message +TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { + Struct struct_msg; + const std::string kFieldInt = "field_int"; + const std::string kFieldBool = "field_bool"; + (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); + (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(struct_msg).get(), arena()); + EXPECT_TRUE(value.IsMap()); + const CelMap* cel_map = value.MapOrDie(); + ASSERT_TRUE(cel_map != nullptr); + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsDouble()); + EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); + } + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsBool()); + EXPECT_EQ(v.BoolOrDie(), true); + } +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { + const std::string kField1 = "field1"; + const std::string kField2 = "field2"; + Value value_msg; + (*value_msg.mutable_struct_value()->mutable_fields())[kField1] + .set_number_value(1); + (*value_msg.mutable_struct_value()->mutable_fields())[kField2] + .set_number_value(2); + + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsMap()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); +} + +TEST_F(CelProtoWrapperTest, UnwrapValueList) { const std::vector kFields = {"field1", "field2", "field3"}; ListValue list_value; @@ -214,7 +335,7 @@ TEST(CelProtoWrapperTest, TestListFieldStruct) { list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&list_value, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); ASSERT_TRUE(value.IsList()); const CelList* cel_list = value.ListOrDie(); @@ -234,437 +355,471 @@ TEST(CelProtoWrapperTest, TestListFieldStruct) { EXPECT_EQ(value3.StringOrDie().value(), "test"); } -// Test support of google.protobuf.Any in CelValue. -TEST(CelProtoWrapperTest, TestAnyValue) { - ::google::protobuf::Arena arena; - Any any; +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { + Value value_msg; + value_msg.mutable_list_value()->add_values()->set_number_value(1.); + value_msg.mutable_list_value()->add_values()->set_number_value(2.); + + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsList()); + EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); + EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); +} +// Test support of google.protobuf.Any in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { TestMessage test_message; test_message.set_string_value("test"); + Any any; any.PackFrom(test_message); - - CelValue value = CelProtoWrapper::CreateMessage(&any, &arena); - ASSERT_TRUE(value.IsMessage()); - - const google::protobuf::Message* unpacked_message = value.MessageOrDie(); - EXPECT_THAT(test_message, testutil::EqualsProto(*unpacked_message)); + ExpectUnwrappedMessage(any, &test_message); } -TEST(CelProtoWrapperTest, TestHandlingInvalidAnyValue) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { Any any; - - CelValue value = CelProtoWrapper::CreateMessage(&any, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&any, arena()); ASSERT_TRUE(value.IsError()); any.set_type_url("/"); - ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, &arena).IsError()); + ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); any.set_type_url("/invalid.proto.name"); - ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, &arena).IsError()); + ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); } // Test support of google.protobuf.Value wrappers in CelValue. -TEST(CelProtoWrapperTest, TestBoolWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { + bool value = true; BoolValue wrapper; - wrapper.set_value(true); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsBool()); - - EXPECT_EQ(value.BoolOrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { + int64_t value = 12; Int32Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsInt64()); - - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestUInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { + uint64_t value = 12; UInt32Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsUint64()); - - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { + int64_t value = 12; Int64Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsInt64()); - - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestUInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { + uint64_t value = 12; UInt64Value wrapper; - wrapper.set_value(12); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsUint64()); +TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { + double value = 42.5; - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + FloatValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestFloatWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { + double value = 42.5; - FloatValue wrapper; - wrapper.set_value(42); + DoubleValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsDouble()); +TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { + std::string text = "42"; + auto value = CelValue::StringHolder(&text); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + StringValue wrapper; + wrapper.set_value(text); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestDoubleWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { + std::string text = "42"; + auto value = CelValue::BytesHolder(&text); - DoubleValue wrapper; - wrapper.set_value(42); + BytesValue wrapper; + wrapper.set_value("42"); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, WrapNull) { + auto cel_value = CelValue::CreateNull(); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsDouble()); + Value json; + json.set_null_value(protobuf::NULL_VALUE); + ExpectWrappedMessage(cel_value, json); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + Any any; + any.PackFrom(json); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, TestStringWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapBool) { + auto cel_value = CelValue::CreateBool(true); - StringValue wrapper; - wrapper.set_value("42"); + Value json; + json.set_bool_value(true); + ExpectWrappedMessage(cel_value, json); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsString()); + BoolValue wrapper; + wrapper.set_value(true); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_EQ(value.StringOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, TestBytesWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapBytes) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); BytesValue wrapper; - wrapper.set_value("42"); + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsBytes()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBytesToValue) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); - EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); + Value json; + json.set_string_value("aGVsbG8gd29ybGQ="); + ExpectWrappedMessage(cel_value, json); } -// Test support for google::protobuf::Struct when it is created as dynamic -// message -TEST(CelProtoWrapperTest, DynamicStructSupport) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDuration) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); - google::protobuf::DynamicMessageFactory factory; - { - Struct struct_msg; - - const std::string kFieldInt = "field_int"; - const std::string kFieldBool = "field_bool"; - - (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); - (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); - std::unique_ptr dynamic_struct( - factory.GetPrototype(Struct::descriptor())->New()); - dynamic_struct->CopyFrom(struct_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_struct.get(), &arena); - EXPECT_TRUE(value.IsMap()); - const CelMap* cel_map = value.MapOrDie(); - ASSERT_TRUE(cel_map != nullptr); - - { - auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; - ASSERT_TRUE(lookup.has_value()); - auto v = lookup.value(); - ASSERT_TRUE(v.IsDouble()); - EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); - } - { - auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; - ASSERT_TRUE(lookup.has_value()); - auto v = lookup.value(); - ASSERT_TRUE(v.IsBool()); - EXPECT_EQ(v.BoolOrDie(), true); - } - } + Duration d; + d.set_seconds(300); + ExpectWrappedMessage(cel_value, d); + + Any any; + any.PackFrom(d); + ExpectWrappedMessage(cel_value, any); } -// Test support for google::protobuf::Value when it is created as dynamic -// message -TEST(CelProtoWrapperTest, DynamicValueSupport) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDurationToValue) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); - google::protobuf::DynamicMessageFactory factory; - // Null - { - Value value_msg; - value_msg.set_null_value(protobuf::NULL_VALUE); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsNull()); - } - // Boolean - { - Value value_msg; - value_msg.set_bool_value(true); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsBool()); - EXPECT_TRUE(value.BoolOrDie()); - } - // Numeric - { - Value value_msg; - value_msg.set_number_value(1.0); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsDouble()); - EXPECT_THAT(value.DoubleOrDie(), testing::DoubleEq(1.)); - } - // String - { - Value value_msg; - value_msg.set_string_value("test"); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsString()); - EXPECT_THAT(value.StringOrDie().value(), Eq("test")); - } - // List - { - Value value_msg; - value_msg.mutable_list_value()->add_values()->set_number_value(1.); - value_msg.mutable_list_value()->add_values()->set_number_value(2.); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsList()); - EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); - EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); - } - // Struct - { - const std::string kField1 = "field1"; - const std::string kField2 = "field2"; - - Value value_msg; - (*value_msg.mutable_struct_value()->mutable_fields())[kField1] - .set_number_value(1); - (*value_msg.mutable_struct_value()->mutable_fields())[kField2] - .set_number_value(2); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsMap()); - EXPECT_TRUE( - (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); - EXPECT_TRUE( - (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); - } + Value json; + json.set_string_value("300s"); + ExpectWrappedMessage(cel_value, json); } -// Test support of google.protobuf.Value wrappers in CelValue. -TEST(CelProtoWrapperTest, DynamicBoolWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDouble) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); - BoolValue wrapper; - wrapper.set_value(true); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(BoolValue::descriptor())->New()); - dynamic_value->CopyFrom(wrapper); + Value json; + json.set_number_value(num); + ExpectWrappedMessage(cel_value, json); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - ASSERT_TRUE(value.IsBool()); + DoubleValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_EQ(value.BoolOrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); - Int32Value wrapper; - wrapper.set_value(12); + FloatValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); + // Imprecise double -> float representation results in truncation. + double small_num = -9.9e-100; + wrapper.set_value(small_num); + cel_value = CelValue::CreateDouble(small_num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { + double lowest_double = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateDouble(lowest_double); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + // Double exceeds float precision, overflow to -infinity. + FloatValue wrapper; + wrapper.set_value(-std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); - ASSERT_TRUE(value.IsInt64()); + double max_double = std::numeric_limits::max(); + cel_value = CelValue::CreateDouble(max_double); - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicUInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - UInt32Value wrapper; - wrapper.set_value(12); + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Int64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - ASSERT_TRUE(value.IsUint64()); - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamocInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - Int64Value wrapper; - wrapper.set_value(12); + Int32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { + int64_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + Int32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicUInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { + int64_t max = std::numeric_limits::max(); + auto cel_value = CelValue::CreateInt64(max); - UInt64Value wrapper; - wrapper.set_value(12); + Value json; + json.set_string_value(absl::StrCat(max)); + ExpectWrappedMessage(cel_value, json); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - ASSERT_TRUE(value.IsUint64()); + int64_t min = std::numeric_limits::min(); + cel_value = CelValue::CreateInt64(min); - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + json.set_string_value(absl::StrCat(min)); + ExpectWrappedMessage(cel_value, json); } -TEST(CelProtoWrapperTest, DynamicFloatWrapper) { - ::google::protobuf::Arena arena; - - FloatValue wrapper; - wrapper.set_value(42); +TEST_F(CelProtoWrapperTest, WrapUint64) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); - ASSERT_TRUE(value.IsDouble()); + UInt64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicDoubleWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - DoubleValue wrapper; - wrapper.set_value(42); + UInt32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - ASSERT_TRUE(value.IsDouble()); + Value json; + json.set_string_value(absl::StrCat(num)); + ExpectWrappedMessage(cel_value, json); +} - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); +TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + UInt32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicStringWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapString) { + std::string str = "test"; + auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); + + Value json; + json.set_string_value(str); + ExpectWrappedMessage(cel_value, json); StringValue wrapper; - wrapper.set_value("42"); + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapTimestamp) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); - ASSERT_TRUE(value.IsString()); + Timestamp t; + t.set_seconds(1615852799); + ExpectWrappedMessage(cel_value, t); - EXPECT_EQ(value.StringOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(t); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicBytesWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); - BytesValue wrapper; - wrapper.set_value("42"); + Value json; + json.set_string_value("2021-03-15T23:59:59Z"); + ExpectWrappedMessage(cel_value, json); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapList) { + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelValue::CreateInt64(-2L), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); - ASSERT_TRUE(value.IsBytes()); + Value json; + json.mutable_list_value()->add_values()->set_number_value(1.5); + json.mutable_list_value()->add_values()->set_number_value(-2.); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.list_value()); - EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(json.list_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { + TestMessage message; + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelProtoWrapper::CreateMessage(&message, arena()), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapStruct) { + const std::string kField1 = "field1"; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelValue::CreateBool(true)}}; + auto cel_map = CreateContainerBackedMap( + absl::Span>(args.data(), args.size())); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( + true); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.struct_value()); + + Any any; + any.PackFrom(json.struct_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { + std::vector> args = { + {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; + auto cel_map = CreateContainerBackedMap( + absl::Span>(args.data(), args.size())); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { + const std::string kField1 = "field1"; + TestMessage bad_value; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelProtoWrapper::CreateMessage(&bad_value, arena())}}; + auto cel_map = CreateContainerBackedMap( + absl::Span>(args.data(), args.size())); + auto cel_value = CelValue::CreateMap(cel_map.get()); + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { + auto cel_value = CelValue::CreateNull(); + std::vector wrong_types = { + &BoolValue::default_instance(), &BytesValue::default_instance(), + &DoubleValue::default_instance(), &Duration::default_instance(), + &FloatValue::default_instance(), &Int32Value::default_instance(), + &Int64Value::default_instance(), &ListValue::default_instance(), + &StringValue::default_instance(), &Struct::default_instance(), + &Timestamp::default_instance(), &UInt32Value::default_instance(), + &UInt64Value::default_instance(), + }; + for (const auto* wrong_type : wrong_types) { + ExpectNotWrapped(cel_value, *wrong_type); + } +} + +TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { + auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); + ExpectNotWrapped(cel_value, Any::default_instance()); } -TEST(CelProtoWrapperTest, DebugString) { +TEST_F(CelProtoWrapperTest, DebugString) { google::protobuf::Empty e; - ::google::protobuf::Arena arena; - EXPECT_EQ(CelProtoWrapper::CreateMessage(&e, &arena).DebugString(), + EXPECT_EQ(CelProtoWrapper::CreateMessage(&e, arena()).DebugString(), "Message: "); ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&list_value, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); EXPECT_EQ(value.DebugString(), "List, size: 3"); Struct value_struct; @@ -675,10 +830,12 @@ TEST(CelProtoWrapperTest, DebugString) { auto& value3 = (*value_struct.mutable_fields())["c"]; value3.set_string_value("test"); - value = CelProtoWrapper::CreateMessage(&value_struct, &arena); + value = CelProtoWrapper::CreateMessage(&value_struct, arena()); EXPECT_EQ(value.DebugString(), "Map, size: 3"); } +} // namespace + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index 2315a0127..930df8d77 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -42,7 +42,7 @@ class CelValueEqualImpl : public MatcherInterface { template class CelValueMatcherImpl : public testing::MatcherInterface { public: - CelValueMatcherImpl(testing::Matcher m) + explicit CelValueMatcherImpl(testing::Matcher m) : underlying_type_matcher_(m) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index 4b4ebb8ad..7081170c6 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -44,18 +44,26 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { break; case CelValue::Type::kDuration: { google::protobuf::Duration duration; - expr::internal::EncodeDuration(value.DurationOrDie(), &duration); + auto status = + expr::internal::EncodeDuration(value.DurationOrDie(), &duration); + if (!status.ok()) { + return status; + } result->mutable_object_value()->PackFrom(duration); break; } case CelValue::Type::kTimestamp: { google::protobuf::Timestamp timestamp; - expr::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + auto status = + expr::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + if (!status.ok()) { + return status; + } result->mutable_object_value()->PackFrom(timestamp); break; } case CelValue::Type::kMessage: - if (value.MessageOrDie() == nullptr) { + if (value.IsNull()) { result->set_null_value(google::protobuf::NullValue::NULL_VALUE); } else { result->mutable_object_value()->PackFrom(*value.MessageOrDie()); diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index 89ef53022..d1c9892a4 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -12,12 +12,8 @@ namespace expr { namespace runtime { using google::protobuf::Duration; -using google::protobuf::ListValue; -using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::Value; -using google::protobuf::FieldDescriptor; -using google::protobuf::Message; using google::protobuf::util::TimeUtil; absl::Status KeyAsString(const CelValue& value, std::string* key) { @@ -77,13 +73,21 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { } case CelValue::Type::kDuration: { Duration duration; - expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + auto status = + expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + if (!status.ok()) { + return status; + } out_value->set_string_value(TimeUtil::ToString(duration)); break; } case CelValue::Type::kTimestamp: { Timestamp timestamp; - expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + auto status = + expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + if (!status.ok()) { + return status; + } out_value->set_string_value(TimeUtil::ToString(timestamp)); break; } diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 10f0fb3c6..53f0206cb 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -51,6 +51,7 @@ cc_test( "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", + "//testutil:util", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googletest//:gtest_main", "@com_google_protobuf//:protobuf", diff --git a/eval/tests/end_to_end_test.cc b/eval/tests/end_to_end_test.cc index f28907da9..0591b6619 100644 --- a/eval/tests/end_to_end_test.cc +++ b/eval/tests/end_to_end_test.cc @@ -1,4 +1,5 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/text_format.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -9,6 +10,7 @@ #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" +#include "testutil/util.h" #include "base/status_macros.h" namespace google { @@ -167,6 +169,56 @@ TEST(EndToEndTest, EmptyStringCompare) { EXPECT_TRUE(result.BoolOrDie()); } +TEST(EndToEndTest, NullLiteral) { + // AST CEL equivalent of "Value{null_value: NullValue.NULL_VALUE}" + constexpr char kExpr0[] = R"( + struct_expr: < + message_name: "Value" + entries: < + field_key: "null_value" + value: < + select_expr: < + operand: < + ident_expr: < + name: "NullValue" + > + > + field: "NULL_VALUE" + > + > + > + > + )"; + + Expr expr; + SourceInfo source_info; + TextFormat::ParseFromString(kExpr0, &expr); + + // Obtain CEL Expression builder. + std::unique_ptr builder = CreateCelExpressionBuilder(); + builder->set_container("google.protobuf"); + + // Builtin registration. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + auto cel_expression_status = builder->CreateExpression(&expr, &source_info); + ASSERT_OK(cel_expression_status); + + auto cel_expression = std::move(cel_expression_status.value()); + Activation activation; + Arena arena; + // Run evaluation. + auto eval_status = cel_expression->Evaluate(activation, &arena); + + ASSERT_OK(eval_status); + + google::protobuf::Value null_value; + null_value.set_null_value(protobuf::NULL_VALUE); + CelValue result = eval_status.value(); + ASSERT_TRUE(result.IsNull()); +} + } // namespace } // namespace runtime diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index 672846534..e709669b2 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -277,7 +277,7 @@ TEST_F(UnknownsTest, UnknownFunctions) { ASSERT_OK(maybe_response); CelValue response = maybe_response.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), @@ -302,7 +302,7 @@ TEST_F(UnknownsTest, UnknownsMerge) { ASSERT_OK(maybe_response); CelValue response = maybe_response.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), @@ -454,7 +454,7 @@ TEST_F(UnknownsCompTest, UnknownsMerge) { ASSERT_OK(eval_status); CelValue response = eval_status.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), @@ -589,7 +589,7 @@ TEST_F(UnknownsCompCondTest, UnknownConditionReturned) { ASSERT_OK(eval_status); CelValue response = eval_status.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); // The comprehension ends on the first non-bool condition, so we only get one // call captured in the UnknownSet. EXPECT_THAT(response.UnknownSetOrDie() diff --git a/internal/BUILD b/internal/BUILD index 7819be972..072a78592 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -16,6 +16,7 @@ cc_library( deps = [ ":holder", ":specialize", + "@com_google_absl//absl/meta:type_traits", ], ) @@ -57,10 +58,10 @@ cc_library( "holder.h", ], deps = [ - ":port", ":specialize", ":types", ":visitor_util", + "@com_google_absl//absl/meta:type_traits", ], ) @@ -85,8 +86,8 @@ cc_library( "hash_util.h", ], deps = [ - ":port", ":specialize", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googleapis//google/rpc:status_cc_proto", @@ -129,7 +130,6 @@ cc_test( ":visitor_util", "//testutil:util", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) @@ -174,12 +174,13 @@ cc_library( srcs = ["proto_util.cc"], hdrs = ["proto_util.h"], deps = [ - ":status_util", + "//base:status_macros", "//common:macros", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -236,11 +237,6 @@ cc_test( ], ) -cc_library( - name = "port", - hdrs = ["port.h"], -) - cc_library( name = "specialize", hdrs = ["specialize.h"], @@ -250,10 +246,10 @@ cc_library( name = "cast", hdrs = ["cast.h"], deps = [ - ":port", ":specialize", ":types", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", ], ) @@ -273,9 +269,9 @@ cc_library( "types.h", ], deps = [ - ":port", ":specialize", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", ], ) diff --git a/internal/cast.h b/internal/cast.h index f5c3349b4..6001ab512 100644 --- a/internal/cast.h +++ b/internal/cast.h @@ -5,8 +5,8 @@ #include #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/types/optional.h" -#include "internal/port.h" #include "internal/specialize.h" #include "internal/types.h" @@ -63,7 +63,7 @@ struct StaticDownCastHelper> { template struct RepresentableAsHelper { static constexpr bool check(const U&) { - return std::is_same, U>::value; + return std::is_same, U>::value; } }; diff --git a/internal/hash_util.h b/internal/hash_util.h index 655c75c1c..dda25586f 100644 --- a/internal/hash_util.h +++ b/internal/hash_util.h @@ -6,9 +6,9 @@ #include "google/protobuf/any.pb.h" #include "google/rpc/status.pb.h" +#include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "internal/port.h" #include "internal/specialize.h" namespace google { diff --git a/internal/holder.h b/internal/holder.h index 76833aef1..8e6370c9c 100644 --- a/internal/holder.h +++ b/internal/holder.h @@ -4,7 +4,7 @@ #include #include -#include "internal/port.h" +#include "absl/meta/type_traits.h" #include "internal/specialize.h" #include "internal/types.h" #include "internal/visitor_util.h" @@ -69,7 +69,7 @@ struct Copy : BaseHolderPolicy { constexpr static const bool kOwnsValue = true; template - using ValueType = remove_const_t; + using ValueType = absl::remove_const_t; template static T& get(T& value) { diff --git a/internal/holder_test.cc b/internal/holder_test.cc index ac06d043c..4cbf51331 100644 --- a/internal/holder_test.cc +++ b/internal/holder_test.cc @@ -149,7 +149,7 @@ TEST(Holder, UnownedPtr) { EXPECT_TRUE(std::is_move_assignable::value); // Null cannot be accessed. - HolderType holder(static_cast(0)); + HolderType holder(static_cast(nullptr)); #ifndef NDEBUG // Assert only throws when debugging. EXPECT_DEATH(holder.value(), "null"); holder = HolderType(nullptr); @@ -185,7 +185,7 @@ TEST(Holder, UnownedPtr_const) { EXPECT_TRUE(std::is_move_assignable::value); // Null cannot be accessed. - HolderType holder(static_cast(0)); + HolderType holder(static_cast(nullptr)); #ifndef NDEBUG // Assert only throws when debugging. EXPECT_DEATH(holder.value(), "null"); holder = HolderType(nullptr); diff --git a/internal/port.h b/internal/port.h deleted file mode 100644 index 07473fed7..000000000 --- a/internal/port.h +++ /dev/null @@ -1,66 +0,0 @@ -// This files is a forwarding header for other headers containing various -// portability macros and functions. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ - -#include - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// Back port some helpers. -// Defined in std as of c++14. -template -using decay_t = typename std::decay::type; -template -using enable_if_t = typename std::enable_if::type; -template -using conditional_t = typename std::conditional::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_cv_t = typename std::remove_cv::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_volatile_t = typename std::remove_volatile::type; - -// Defined in std as of c++17 -template -struct conjunction : std::true_type {}; -template -struct conjunction : T {}; -template -struct conjunction - : std::conditional, T>::type {}; -template -struct disjunction : std::false_type {}; -template -struct disjunction : B1 {}; -template -struct disjunction - : conditional_t> {}; -template -using bool_constant = std::integral_constant; -template -struct negation : bool_constant(B::value)> {}; - -// Defined in std as of c++20 -template -struct remove_cvref { - typedef remove_cv_t> type; -}; -template -using remove_cvref_t = typename remove_cvref::type; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ diff --git a/internal/proto_util.cc b/internal/proto_util.cc index b2dd0a22b..06eb9c76b 100644 --- a/internal/proto_util.cc +++ b/internal/proto_util.cc @@ -1,10 +1,12 @@ #include "internal/proto_util.h" + #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" +#include "google/protobuf/util/time_util.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/macros.h" -#include "internal/status_util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -13,30 +15,30 @@ namespace internal { namespace { -google::rpc::Status Validate(absl::Duration duration) { - if (duration < MakeGoogleApiDurationMin()) { - return InvalidArgumentError(absl::StrCat("duration below min")); +absl::Status Validate(absl::Time time) { + if (time < MakeGoogleApiTimeMin()) { + return absl::InvalidArgumentError("time below min"); } - if (duration > MakeGoogleApiDurationMax()) { - return InvalidArgumentError(absl::StrCat("duration above max")); + if (time > MakeGoogleApiTimeMax()) { + return absl::InvalidArgumentError("time above max"); } - return OkStatus(); + return absl::OkStatus(); } -google::rpc::Status Validate(absl::Time time) { - if (time < MakeGoogleApiTimeMin()) { - return InvalidArgumentError(absl::StrCat("time below min")); +} // namespace + +absl::Status ValidateDuration(absl::Duration duration) { + if (duration < MakeGoogleApiDurationMin()) { + return absl::InvalidArgumentError("duration below min"); } - if (time > MakeGoogleApiTimeMax()) { - return InvalidArgumentError(absl::StrCat("time above max")); + if (duration > MakeGoogleApiDurationMax()) { + return absl::InvalidArgumentError("duration above max"); } - return OkStatus(); + return absl::OkStatus(); } -} // namespace - absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); } @@ -46,24 +48,41 @@ absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { absl::Nanoseconds(proto.nanos()); } -google::rpc::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto) { - RETURN_IF_STATUS_ERROR(Validate(duration)); +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto) { + RETURN_IF_ERROR(ValidateDuration(duration)); // s and n may both be negative, per the Duration proto spec. const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); proto->set_seconds(s); proto->set_nanos(n); - return OkStatus(); + return absl::OkStatus(); } -google::rpc::Status EncodeTime(absl::Time time, - google::protobuf::Timestamp* proto) { - RETURN_IF_STATUS_ERROR(Validate(time)); +absl::StatusOr EncodeDurationToString(absl::Duration duration) { + google::protobuf::Duration d; + auto status = EncodeDuration(duration, &d); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(d); +} + +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { + RETURN_IF_ERROR(Validate(time)); const int64_t s = absl::ToUnixSeconds(time); proto->set_seconds(s); proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); - return OkStatus(); + return absl::OkStatus(); +} + +absl::StatusOr EncodeTimeToString(absl::Time time) { + google::protobuf::Timestamp t; + auto status = EncodeTime(time, &t); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(t); } } // namespace internal diff --git a/internal/proto_util.h b/internal/proto_util.h index 12534ec50..02e2b92fa 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -3,9 +3,10 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" #include "google/protobuf/util/message_differencer.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/time/time.h" namespace google { @@ -20,13 +21,21 @@ struct DefaultProtoEqual { } }; +/** Validate that the duration is in the valid protobuf duration range. */ +absl::Status ValidateDuration(absl::Duration duration); + /** Helper function to encode a duration in a google::protobuf::Duration. */ -google::rpc::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto); +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto); + +/** Helper function to encode an absl::Duration to a JSON-formatted string. */ +absl::StatusOr EncodeDurationToString(absl::Duration duration); /** Helper function to encode a time in a google::protobuf::Timestamp. */ -google::rpc::Status EncodeTime(absl::Time time, - google::protobuf::Timestamp* proto); +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); + +/** Helper function to encode an absl::Time to a JSON-formatted string. */ +absl::StatusOr EncodeTimeToString(absl::Time time); /** Helper function to decode a duration from a google::protobuf::Duration. */ absl::Duration DecodeDuration(const google::protobuf::Duration& proto); diff --git a/internal/ref_countable.h b/internal/ref_countable.h index 514cbce88..e455d2c16 100644 --- a/internal/ref_countable.h +++ b/internal/ref_countable.h @@ -6,6 +6,7 @@ #include #include +#include "absl/meta/type_traits.h" #include "internal/holder.h" #include "internal/specialize.h" @@ -240,8 +241,9 @@ using RefCopyHolder = Holder>; // If the size of T is smaller than MAX, it is stored inline, otherwise // it is stored in a heap allocated, reference counted holder. template -using SizeLimitHolder = conditional_t, - Holder>>; +using SizeLimitHolder = + absl::conditional_t, + Holder>>; template void ReffedPtr::reset() { diff --git a/internal/types.h b/internal/types.h index d2bd2a270..55e9f3a38 100644 --- a/internal/types.h +++ b/internal/types.h @@ -7,8 +7,8 @@ #include #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" -#include "internal/port.h" #include "internal/specialize.h" namespace google { @@ -18,11 +18,11 @@ namespace internal { // Short names for AND, OR and NOT. template -using and_t = conjunction; +using and_t = absl::conjunction; template -using or_t = disjunction; +using or_t = absl::disjunction; template -using not_t = negation; +using not_t = absl::negation; // Holder for a static list of types. template @@ -33,7 +33,7 @@ template using types_size = std::tuple_size; template -using types_cat = decltype(std::tuple_cat(inst_of>()...)); +using types_cat = decltype(std::tuple_cat(inst_of>()...)); // Helper that resolves to the Ith type in T. template @@ -77,7 +77,7 @@ template using type_is = std::is_same; template -using type_not = negation>; +using type_not = absl::negation>; /** * Tests if a type is a raw or smart pointer. @@ -85,39 +85,39 @@ using type_not = negation>; * Any type that defines overloads for * and -> is considered a smart pointer. */ template -struct is_ptr : std::is_pointer> {}; +struct is_ptr : std::is_pointer> {}; template -struct is_ptr::operator*), - decltype(&decay_t::operator->)>> +struct is_ptr::operator*), + decltype(&absl::decay_t::operator->)>> : std::true_type {}; /** * Tests if a type is convertible to absl::string_view. */ template -using is_string = and_t, std::nullptr_t>, +using is_string = and_t, std::nullptr_t>, std::is_convertible>; /** * Tests if a type is a signed integer type. */ template -using is_int = conjunction>, - std::is_signed>>; +using is_int = absl::conjunction>, + std::is_signed>>; /** * Tests if a type is an unsigned integer type. */ template -using is_uint = - conjunction, bool>, std::is_integral>, - std::is_unsigned>>; +using is_uint = absl::conjunction, bool>, + std::is_integral>, + std::is_unsigned>>; /** * Tests if a type is a floating point type. */ template -using is_float = std::is_floating_point>; +using is_float = std::is_floating_point>; template using is_numeric = or_t, is_uint, is_float>; @@ -128,20 +128,21 @@ using is_numeric = or_t, is_uint, is_float>; template struct is_container : public std::false_type {}; template -struct is_container::value_type, - typename remove_cvref_t::iterator>> +struct is_container::value_type, + typename absl::remove_cvref_t::iterator>> : public std::true_type {}; // Maps are containers that also define a "mapped_type". template struct is_map : public std::false_type {}; template -struct is_map::mapped_type>> +struct is_map::mapped_type>> : public is_container {}; // Lists are containers that are not maps. template -using is_list = bool_constant::value && !is_map::value>; +using is_list = std::bool_constant::value && !is_map::value>; // Used to create a compiler error when a specialized function/class is // instantiated with an unsupported type. diff --git a/internal/value_internal.h b/internal/value_internal.h index f9cee1334..c9dae2a83 100644 --- a/internal/value_internal.h +++ b/internal/value_internal.h @@ -54,12 +54,12 @@ class BaseValue { protected: // The return type of `Value::get_if` call for T. template - using GetIfType = - conditional_t::value, const T*, absl::optional>; + using GetIfType = absl::conditional_t::value, const T*, + absl::optional>; // The return type of `Value::get` call for T. template - using GetType = conditional_t::value, const T&, T>; + using GetType = absl::conditional_t::value, const T&, T>; // If the string is 8 bytes, it is assumed to be copy-on-write and is stored // inline. Otherwise it is held in a ref-counted container. @@ -178,18 +178,18 @@ class BaseValue { friend class ValueAdapterTest; template - using NumericValueType = - conditional_t::value, int64_t, - conditional_t::value, uint64_t, double>>; + using NumericValueType = absl::conditional_t< + is_int::value, int64_t, + absl::conditional_t::value, uint64_t, double>>; template - using CustomValueType = conditional_t< + using CustomValueType = absl::conditional_t< std::is_convertible::value, common::Map, - conditional_t::value, - common::List, common::Object>>; + absl::conditional_t::value, + common::List, common::Object>>; template - using HolderType = conditional_t::value, - CopyHolder, RefCopyHolder>; + using HolderType = absl::conditional_t::value, + CopyHolder, RefCopyHolder>; // A ValueData visitor that, for any type in 'Alts', returns the associated // value as T. diff --git a/parser/BUILD b/parser/BUILD index 3f2be39f1..23e382c8b 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -29,6 +29,7 @@ cc_library( ":visitor", "@antlr4_runtimes//:cpp", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/parser/parser.cc b/parser/parser.cc index 40dce202e..9fb31e118 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -2,6 +2,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" #include "absl/types/optional.h" #include "parser/cel_grammar.inc/cel_grammar/CelLexer.h" #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" @@ -31,14 +32,26 @@ using ::cel_grammar::CelParser; namespace { +// Replacements for absl::StrReplaceAll for escaping standard whitespace +// characters. +static constexpr auto kStandardReplacements = + std::array, 3>{ + std::make_pair("\n", "\\n"), + std::make_pair("\r", "\\r"), + std::make_pair("\t", "\\t"), + }; + +static constexpr absl::string_view kSingleQuote = "'"; + // ExprRecursionListener extends the standard ANTLR CelParser to ensure that // recursive entries into the 'expr' rule are limited to a configurable depth so // as to prevent stack overflows. class ExprRecursionListener : public ParseTreeListener { public: - ExprRecursionListener( + explicit ExprRecursionListener( const int max_recursion_depth = kDefaultMaxRecursionDepth) : max_recursion_depth_(max_recursion_depth), recursion_depth_(0) {} + ~ExprRecursionListener() override {} void visitTerminal(TerminalNode* node) override{}; void visitErrorNode(ErrorNode* error) override{}; @@ -70,21 +83,62 @@ void ExprRecursionListener::exitEveryRule(ParserRuleContext* ctx) { } } +class RecoveryLimitErrorStrategy : public antlr4::DefaultErrorStrategy { + public: + explicit RecoveryLimitErrorStrategy( + int recovery_limit = kDefaultErrorRecoveryLimit) + : recovery_limit_(recovery_limit), recovery_attempts_(0) {} + + void recover(antlr4::Parser* recognizer, std::exception_ptr e) override { + checkRecoveryLimit(recognizer); + antlr4::DefaultErrorStrategy::recover(recognizer, e); + } + + antlr4::Token* recoverInline(antlr4::Parser* recognizer) override { + checkRecoveryLimit(recognizer); + return antlr4::DefaultErrorStrategy::recoverInline(recognizer); + } + + protected: + std::string escapeWSAndQuote(const std::string& s) const override { + std::string result; + result.reserve(s.size() + 2); + absl::StrAppend(&result, kSingleQuote, s, kSingleQuote); + absl::StrReplaceAll(kStandardReplacements, &result); + return result; + } + + private: + void checkRecoveryLimit(antlr4::Parser* recognizer) { + if (recovery_attempts_++ >= recovery_limit_) { + std::string too_many_errors = + absl::StrFormat("More than %d parse errors.", recovery_limit_); + recognizer->notifyErrorListeners(too_many_errors); + throw ParseCancellationException(too_many_errors); + } + } + + int recovery_limit_; + int recovery_attempts_; +}; + } // namespace absl::StatusOr Parse(const std::string& expression, const std::string& description, - const int max_recursion_depth) { + int max_recursion_depth, + int error_recovery_limit) { return ParseWithMacros(expression, Macro::AllMacros(), description, - max_recursion_depth); + max_recursion_depth, error_recovery_limit); } absl::StatusOr ParseWithMacros(const std::string& expression, const std::vector& macros, const std::string& description, - const int max_recursion_depth) { - auto result = - EnrichedParse(expression, macros, description, max_recursion_depth); + int max_recursion_depth, + int error_recovery_limit) { + auto result = EnrichedParse(expression, macros, description, + max_recursion_depth, error_recovery_limit); if (result.ok()) { return result->parsed_expr(); } @@ -93,7 +147,8 @@ absl::StatusOr ParseWithMacros(const std::string& expression, absl::StatusOr EnrichedParse( const std::string& expression, const std::vector& macros, - const std::string& description, const int max_recursion_depth) { + const std::string& description, int max_recursion_depth, + int error_recovery_limit) { ANTLRInputStream input(expression); CelLexer lexer(&input); CommonTokenStream tokens(&lexer); @@ -107,21 +162,25 @@ absl::StatusOr EnrichedParse( parser.addErrorListener(&visitor); parser.addParseListener(&listener); - // if we were to ignore errors completely: - // std::shared_ptr error_strategy(new BailErrorStrategy()); - // parser.setErrorHandler(error_strategy); + // Limit the number of error recovery attempts to prevent bad expressions + // from consuming lots of cpu / memory. + std::shared_ptr error_strategy( + new RecoveryLimitErrorStrategy(error_recovery_limit)); + parser.setErrorHandler(error_strategy); CelParser::StartContext* root; try { root = parser.start(); } catch (const ParseCancellationException& e) { + if (visitor.hasErrored()) { + return absl::InvalidArgumentError(visitor.errorMessage()); + } return absl::CancelledError(e.what()); } catch (const std::exception& e) { return absl::AbortedError(e.what()); } Expr expr = visitor.visit(root).as(); - if (visitor.hasErrored()) { return absl::InvalidArgumentError(visitor.errorMessage()); } diff --git a/parser/parser.h b/parser/parser.h index 7227c8fed..70edd5848 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -12,6 +12,7 @@ namespace api { namespace expr { namespace parser { +constexpr int kDefaultErrorRecoveryLimit = 30; constexpr int kDefaultMaxRecursionDepth = 250; class VerboseParsedExpr { @@ -36,16 +37,19 @@ class VerboseParsedExpr { absl::StatusOr EnrichedParse( const std::string& expression, const std::vector& macros, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); + int max_recursion_depth = kDefaultMaxRecursionDepth, + int error_recovery_limit = kDefaultErrorRecoveryLimit); absl::StatusOr Parse( const std::string& expression, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); + int max_recursion_depth = kDefaultMaxRecursionDepth, + int error_recovery_limit = kDefaultErrorRecoveryLimit); absl::StatusOr ParseWithMacros( const std::string& expression, const std::vector& macros, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); + int max_recursion_depth = kDefaultMaxRecursionDepth, + int error_recovery_limit = kDefaultErrorRecoveryLimit); } // namespace parser } // namespace expr diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 81f8e0324..1b9c3ca2c 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -423,24 +423,21 @@ std::vector test_cases = { " // Init\n" " 0^#5:int64#,\n" " // LoopCondition\n" - " _<=_(\n" - " __result__^#7:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#8:Expr.Call#,\n" + " true^#7:bool#,\n" " // LoopStep\n" " _?_:_(\n" " f^#4:Expr.Ident#,\n" " _+_(\n" - " __result__^#9:Expr.Ident#,\n" + " __result__^#8:Expr.Ident#,\n" " 1^#6:int64#\n" - " )^#10:Expr.Call#,\n" - " __result__^#11:Expr.Ident#\n" - " )^#12:Expr.Call#,\n" + " )^#9:Expr.Call#,\n" + " __result__^#10:Expr.Ident#\n" + " )^#11:Expr.Call#,\n" " // Result\n" " _==_(\n" - " __result__^#13:Expr.Ident#,\n" + " __result__^#12:Expr.Ident#,\n" " 1^#6:int64#\n" - " )^#14:Expr.Call#)^#15:Expr.Comprehension#"}, + " )^#13:Expr.Call#)^#14:Expr.Comprehension#"}, {"m.map(v, f)", "__comprehension__(\n" " // Variable\n" @@ -824,6 +821,14 @@ std::vector test_cases = { "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[['just fine'],[1],[2],[3],[4],[5]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]", "" // parse output not validated as it is too large. + }, + { + "[\n\t\r[\n\t\r[\n\t\r]\n\t\r]\n\t\r", + "", // parse output not validated as it is too large. + "ERROR: :6:3: Syntax error: mismatched input '' expecting " + "{']', ','}\n" + " | \r\n" + " | ..^", }}; class KindAndIdAdorner : public testutil::ExpressionAdorner { @@ -867,7 +872,7 @@ class KindAndIdAdorner : public testutil::ExpressionAdorner { class LocationAdorner : public testutil::ExpressionAdorner { public: - LocationAdorner(const google::api::expr::v1alpha1::SourceInfo& source_info) + explicit LocationAdorner(const google::api::expr::v1alpha1::SourceInfo& source_info) : source_info_(source_info) {} absl::optional> getLocation(int64_t id) const { @@ -972,6 +977,38 @@ TEST_P(ExpressionTest, Parse) { } } +TEST(ExpressionTest, TsanOom) { + Parse( + "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[???[" + "a([[????") + .IgnoreError(); +} + +TEST(ExpressionTest, ErrorRecoveryLimits) { + auto result = Parse("......", "", kDefaultMaxRecursionDepth, 1); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.status().message(), + "ERROR: :1:2: Syntax error: missing IDENTIFIER at '.'\n" + " | ......\n" + " | .^\n" + "ERROR: :1:3: Syntax error: More than 1 parse errors.\n" + " | ......\n" + " | ..^"); +} + INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases)); diff --git a/parser/source_factory.cc b/parser/source_factory.cc index 3052b2407..e7879f99e 100644 --- a/parser/source_factory.cc +++ b/parser/source_factory.cc @@ -250,8 +250,7 @@ Expr SourceFactory::newQuantifierExprForMacro( Expr zero_expr = newLiteralIntForMacro(macro_id, 0); Expr one_expr = newLiteralIntForMacro(macro_id, 1); init = zero_expr; - condition = newGlobalCallForMacro(macro_id, CelOperator::LESS_EQUALS, - {accu_ident(), one_expr}); + condition = newLiteralBoolForMacro(macro_id, true); step = newGlobalCallForMacro( macro_id, CelOperator::CONDITIONAL, {args[1], diff --git a/protoutil/type_registry.cc b/protoutil/type_registry.cc index ba5c7854a..0bfc15e24 100644 --- a/protoutil/type_registry.cc +++ b/protoutil/type_registry.cc @@ -150,7 +150,8 @@ template class UnrecognizedMessageObject final : public common::Object { public: template - UnrecognizedMessageObject(T&& value) : holder_(std::forward(value)) {} + explicit UnrecognizedMessageObject(T&& value) + : holder_(std::forward(value)) {} common::Value GetMember(absl::string_view name) const override { return common::Value::FromError( diff --git a/testutil/expr_printer.cc b/testutil/expr_printer.cc index a8b5bde9a..8b618ede3 100644 --- a/testutil/expr_printer.cc +++ b/testutil/expr_printer.cc @@ -15,11 +15,11 @@ using ::google::api::expr::v1alpha1::Expr; class EmptyAdorner : public ExpressionAdorner { public: - ~EmptyAdorner() {} + ~EmptyAdorner() override {} - std::string adorn(const Expr& e) const { return ""; } + std::string adorn(const Expr& e) const override { return ""; } - std::string adorn(const Expr::CreateStruct::Entry& e) const { + std::string adorn(const Expr::CreateStruct::Entry& e) const override { return ""; } }; @@ -28,7 +28,7 @@ const EmptyAdorner the_empty_adorner; class Writer { public: - Writer(const ExpressionAdorner& adorner) + explicit Writer(const ExpressionAdorner& adorner) : adorner_(adorner), line_start_(true), indent_(0) {} void appendExpr(const Expr& e) { diff --git a/testutil/test_data_gen.cc b/testutil/test_data_gen.cc index 1701442e2..13c6f9d58 100644 --- a/testutil/test_data_gen.cc +++ b/testutil/test_data_gen.cc @@ -86,7 +86,7 @@ TestData UniqueValues() { expr::internal::MakeGoogleApiDurationMin() + absl::Nanoseconds(1), "min+1"); add_val = NewValue(absl::Nanoseconds(-1), "-1"); - add_val = NewValue(absl::Nanoseconds(0), "0"); + add_val = NewValue(absl::ZeroDuration(), "0"); add_val = NewValue(absl::Nanoseconds(1), "1"); add_val = NewValue( expr::internal::MakeGoogleApiDurationMax() - absl::Nanoseconds(1), diff --git a/testutil/test_data_util.cc b/testutil/test_data_util.cc index 3a1a10928..b18e2d39d 100644 --- a/testutil/test_data_util.cc +++ b/testutil/test_data_util.cc @@ -345,7 +345,7 @@ TestValue NewBytesValue(absl::string_view value, absl::string_view name) { // wrapped values. google::protobuf::BytesValue wrapped_bytes; - wrapped_bytes.set_value(std::string(value)); + wrapped_bytes.set_value(value.data()); *AddProto(&result, "wrapped_bytes")->mutable_wrapped_bytes() = wrapped_bytes; AddProto(&result, "single_any") ->mutable_single_any() @@ -373,7 +373,7 @@ TestValue NewValue(const google::protobuf::Message& value, absl::string_view nam TestValue NewValue(absl::Duration value, absl::string_view name) { google::protobuf::Duration duration; auto status = expr::internal::EncodeDuration(value, &duration); - assert(status.code() == google::rpc::Code::OK); + assert(status.ok()); auto result = NewValue(duration, name); *AddProto(&result, "single_duration")->mutable_single_duration() = duration; return result; @@ -382,7 +382,7 @@ TestValue NewValue(absl::Duration value, absl::string_view name) { TestValue NewValue(absl::Time value, absl::string_view name) { google::protobuf::Timestamp timestamp; auto status = expr::internal::EncodeTime(value, ×tamp); - assert(status.code() == google::rpc::Code::OK); + assert(status.ok()); auto result = NewValue(timestamp, name); *AddProto(&result, "single_timestamp")->mutable_single_timestamp() = timestamp; diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 55f3e4852..ffeb06044 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -34,7 +34,7 @@ class FlatBuffersListImpl : public CelList { class StringListImpl : public CelList { public: - StringListImpl( + explicit StringListImpl( const flatbuffers::Vector>* list) : list_(list) {} int size() const override { return list_ ? list_->size() : 0; } diff --git a/v1beta1/converters.cc b/v1beta1/converters.cc index 97e6b3b8f..22a6e4ff7 100644 --- a/v1beta1/converters.cc +++ b/v1beta1/converters.cc @@ -81,6 +81,16 @@ struct ToExprValue { return true; } + google::rpc::Status StatusToRpcStatus(const absl::Status& value) { + if (value.ok()) { + return OkStatus(); + } + google::rpc::Status error; + error.set_code(static_cast(value.code())); + error.set_message(value.message().data(), value.message().size()); + return error; + } + void EncodeMessage(const google::protobuf::Message& value) { result->mutable_value()->mutable_object_value()->PackFrom(value); } @@ -99,7 +109,7 @@ struct ToExprValue { google::rpc::Status operator()(absl::Duration value) { google::protobuf::Duration duration; - auto status = EncodeDuration(value, &duration); + auto status = StatusToRpcStatus(EncodeDuration(value, &duration)); if (CheckAndEncodeIfError(status)) { EncodeMessage(duration); } @@ -108,7 +118,7 @@ struct ToExprValue { google::rpc::Status operator()(absl::Time value) { google::protobuf::Timestamp time; - auto status = EncodeTime(value, &time); + auto status = StatusToRpcStatus(EncodeTime(value, &time)); if (CheckAndEncodeIfError(status)) { EncodeMessage(time); }