From abf9ea65308819ffd642d373fc56411e9f19e955 Mon Sep 17 00:00:00 2001 From: tswadell Date: Fri, 29 Jan 2021 16:16:44 -0800 Subject: [PATCH 01/23] Remove extant comment PiperOrigin-RevId: 354632109 --- eval/eval/create_struct_step.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 68f1c727b..5e4247071 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -274,8 +274,6 @@ absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, const Descriptor* message_desc, int64_t expr_id) { if (message_desc != nullptr) { - // TODO(issues/92): Support resolving a type name within a container. - // Make message-creating step. std::vector entries; for (const auto& entry : create_struct_expr->entries()) { From 3fd14128ad87700ede8827730f57f9ffcf2ff708 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 3 Feb 2021 06:31:46 -0800 Subject: [PATCH 02/23] Fix handling of not-a-number cases within C++ evaluator. PiperOrigin-RevId: 355383746 --- eval/public/builtin_func_registrar.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 86fa2adbe..b2cb4b337 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -1,5 +1,6 @@ #include "eval/public/builtin_func_registrar.h" +#include #include #include @@ -1088,7 +1089,7 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, builtin::kInt, false, [](Arena* arena, double v) { if ((v > static_cast(kIntMax)) || - (v < static_cast(kIntMin))) { + (v < static_cast(kIntMin)) || std::isnan(v)) { return CreateErrorValue(arena, "double out of int range", absl::StatusCode::kInvalidArgument); } @@ -1208,7 +1209,7 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, auto status = FunctionAdapter::CreateAndRegister( builtin::kUint, false, [](Arena* arena, double v) { - if ((v > static_cast(kUintMax)) || (v < 0)) { + if ((v > static_cast(kUintMax)) || (v < 0) || std::isnan(v)) { return CreateErrorValue(arena, "double out of uint range", absl::StatusCode::kInvalidArgument); } From c0fcd1c2294511bea76813ffdb64b92c3e70db3a Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 10 Feb 2021 14:56:41 -0800 Subject: [PATCH 03/23] Fix list creation issue where error and unknown values considered value list elements. PiperOrigin-RevId: 356833286 --- conformance/BUILD | 2 -- eval/eval/BUILD | 5 +--- eval/eval/create_list_step.cc | 22 +++++++++++------ eval/eval/create_list_step_test.cc | 38 +++++++++++++++++++++++++++++- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index b357cef48..f6234eb4e 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -108,8 +108,6 @@ cc_binary( "--skip_test=string/bytes_concat/left_unit", # TODO(issues/85): The exists one macro should not short-circuit false. "--skip_test=macros/exists_one/list_no_shortcircuit", - # TODO(issues/86): Map macro may produce incorrect results on error. - "--skip_test=macros/map/list_error", # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails "--skip_test=namespace/qualified/self_eval_qualified_lookup", "--skip_test=namespace/namespace/self_eval_container_lookup", diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 48d2ea33b..a2cb71025 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -173,13 +173,10 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index 21be321a9..d3373c397 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -1,5 +1,6 @@ #include "eval/eval/create_list_step.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/public/containers/container_backed_list_impl.h" @@ -35,6 +36,14 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { auto args = frame->value_stack().GetSpan(list_size_); CelValue result; + for (const auto& arg : args) { + if (arg.IsError()) { + result = arg; + frame->value_stack().Pop(list_size_); + frame->value_stack().Push(result); + return absl::OkStatus(); + } + } const UnknownSet* unknown_set = nullptr; if (frame->enable_unknowns()) { @@ -44,18 +53,17 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { /*use_partial=*/true); if (unknown_set != nullptr) { result = CelValue::CreateUnknownSet(unknown_set); + frame->value_stack().Pop(list_size_); + frame->value_stack().Push(result); + return absl::OkStatus(); } } - if (unknown_set == nullptr) { - CelList* cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); - result = CelValue::CreateList(cel_list); - } - + CelList* cel_list = google::protobuf::Arena::Create( + frame->arena(), std::vector(args.begin(), args.end())); + result = CelValue::CreateList(cel_list); frame->value_stack().Pop(list_size_); frame->value_stack().Push(result); - return absl::OkStatus(); } diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 7cdd569b9..a6ba1a159 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -100,7 +100,7 @@ class CreateListStepTest : public testing::TestWithParam {}; // Tests error when not enough list elements are on the stack during list // creation. -TEST(CreateListStepTest, TestCreateListStackUndeflow) { +TEST(CreateListStepTest, TestCreateListStackUnderflow) { ExecutionPath path; Expr dummy_expr; @@ -144,6 +144,42 @@ TEST_P(CreateListStepTest, CreateListOne) { EXPECT_THAT((*result_value.ListOrDie())[0].Int64OrDie(), Eq(100)); } +TEST_P(CreateListStepTest, CreateListWithError) { + google::protobuf::Arena arena; + std::vector values; + CelError error = absl::InvalidArgumentError("bad arg"); + values.push_back(CelValue::CreateError(&error)); + auto eval_result = RunExpressionWithCelValues(values, &arena, GetParam()); + + ASSERT_OK(eval_result); + const CelValue result_value = eval_result.value(); + ASSERT_TRUE(result_value.IsError()); + EXPECT_THAT(*result_value.ErrorOrDie(), + Eq(absl::InvalidArgumentError("bad arg"))); +} + +TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { + google::protobuf::Arena arena; + // list composition is: {unknown, error} + std::vector values; + Expr expr0; + expr0.mutable_ident_expr()->set_name("name0"); + CelAttribute attr0(expr0, {}); + UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); + values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); + CelError error = absl::InvalidArgumentError("bad arg"); + values.push_back(CelValue::CreateError(&error)); + + auto eval_result = RunExpressionWithCelValues(values, &arena, GetParam()); + + // The bad arg should win. + ASSERT_OK(eval_result); + const CelValue result_value = eval_result.value(); + ASSERT_TRUE(result_value.IsError()); + EXPECT_THAT(*result_value.ErrorOrDie(), + Eq(absl::InvalidArgumentError("bad arg"))); +} + TEST_P(CreateListStepTest, CreateListHundred) { google::protobuf::Arena arena; std::vector values; From cb2e67139301e0ab7097856723d7f708ae5bda2c Mon Sep 17 00:00:00 2001 From: tswadell Date: Fri, 12 Feb 2021 10:25:21 -0800 Subject: [PATCH 04/23] Test to ensure that the parser does not OOM on inputs with heavy nesting PiperOrigin-RevId: 357218300 --- parser/parser_test.cc | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 81f8e0324..2c7ffb6c1 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -972,6 +972,26 @@ TEST_P(ExpressionTest, Parse) { } } +TEST(ExpressionTest, TsanOom) { + Parse( + "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[???[" + "a([[????") + .IgnoreError(); +} + INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases)); From 99edfb1859da4444ad0262efba4659475c3ee2c5 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 23 Feb 2021 07:21:44 -0800 Subject: [PATCH 05/23] Return a status if nullptr arena is passed. PiperOrigin-RevId: 359041829 --- eval/public/BUILD | 4 ++-- eval/public/activation_bind_helper.cc | 8 ++++++++ eval/public/activation_bind_helper.h | 3 ++- eval/public/activation_bind_helper_test.cc | 12 ++++++++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 2effb7e8e..6d0a2815d 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -108,7 +108,7 @@ cc_library( "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", ], ) @@ -402,8 +402,8 @@ cc_test( "//base:status_macros", "//eval/testutil:test_message_cc_proto", "//testutil:util", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/activation_bind_helper.cc b/eval/public/activation_bind_helper.cc index a5bede00d..1e8004003 100644 --- a/eval/public/activation_bind_helper.cc +++ b/eval/public/activation_bind_helper.cc @@ -1,5 +1,6 @@ #include "eval/public/activation_bind_helper.h" +#include "absl/status/status.h" #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" @@ -37,6 +38,13 @@ absl::Status CreateValueFromField(const google::protobuf::Message* msg, absl::Status BindProtoToActivation(const Message* message, Arena* arena, Activation* activation, ProtoUnsetFieldOptions options) { + // If we need to bind any types that are backed by an arena allocation, we + // will cause a memory leak. + if (arena == nullptr) { + return absl::InvalidArgumentError( + "arena must not be null for BindProtoToActivation."); + } + // TODO(issues/24): Improve the utilities to bind dynamic values as well. const Descriptor* desc = message->GetDescriptor(); const google::protobuf::Reflection* reflection = message->GetReflection(); diff --git a/eval/public/activation_bind_helper.h b/eval/public/activation_bind_helper.h index 2154b91ec..92b32b917 100644 --- a/eval/public/activation_bind_helper.h +++ b/eval/public/activation_bind_helper.h @@ -17,7 +17,8 @@ enum class ProtoUnsetFieldOptions { }; // Utility method, that takes a protobuf Message and interprets it as a -// namespace, binding its fields to Activation. +// namespace, binding its fields to Activation. |arena| must be non-null. +// // Field names and values become respective names and values of parameters // bound to the Activation object. // Example: diff --git a/eval/public/activation_bind_helper_test.cc b/eval/public/activation_bind_helper_test.cc index d28653ac2..b423679ab 100644 --- a/eval/public/activation_bind_helper_test.cc +++ b/eval/public/activation_bind_helper_test.cc @@ -2,6 +2,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "eval/public/activation.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" @@ -127,6 +128,17 @@ TEST(ActivationBindHelperTest, TestBindDefaultFields) { EqualsProto(*result.value().MessageOrDie())); } +TEST(ActivationBindHelperTest, RejectsNullArena) { + TestMessage message; + message.set_bool_value(true); + + Activation activation; + + ASSERT_EQ(BindProtoToActivation(&message, /*arena=*/nullptr, &activation), + absl::InvalidArgumentError( + "arena must not be null for BindProtoToActivation.")); +} + } // namespace } // namespace runtime From 90959470c3cc440b149d81f7b93ae43e632d10b7 Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 23 Feb 2021 12:51:06 -0800 Subject: [PATCH 06/23] Disable short-circuiting in `exists_one` macro. PiperOrigin-RevId: 359110912 --- conformance/BUILD | 8 +++----- parser/parser_test.cc | 17 +++++++---------- parser/source_factory.cc | 3 +-- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index f6234eb4e..f7d9a0d0f 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -81,11 +81,11 @@ cc_binary( "--server=\"$(location :server) " + arg + "\"", "--skip_check", "--pipe", - # TODO(issues/93): Inconsistent Duration.getMilliseconds() behavior. + # TODO(issues/93): Deprecate Duration.getMilliseconds. "--skip_test=timestamps/duration_converters/get_milliseconds", # TODO(issues/94): Missing timestamp conversion functions (type / string) - "--skip_test=timestamps/timestamp_conversions/toType_timestamp,toString_timestamp", - "--skip_test=timestamps/duration_conversions/toType_duration,toString_duration", + "--skip_test=timestamps/timestamp_conversions/toString_timestamp", + "--skip_test=timestamps/duration_conversions/toString_duration", # TODO(issues/81): Conversion functions for int(), uint() which can be # uncommented when the spec changes to truncation rather than rounding. "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", @@ -106,8 +106,6 @@ cc_binary( "--skip_test=fields/qualified_identifier_resolution/qualified_identifier_resolution_unchecked", "--skip_test=string/size/one_unicode,unicode", "--skip_test=string/bytes_concat/left_unit", - # TODO(issues/85): The exists one macro should not short-circuit false. - "--skip_test=macros/exists_one/list_no_shortcircuit", # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails "--skip_test=namespace/qualified/self_eval_qualified_lookup", "--skip_test=namespace/namespace/self_eval_container_lookup", diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 2c7ffb6c1..373e29bf4 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -423,24 +423,21 @@ std::vector test_cases = { " // Init\n" " 0^#5:int64#,\n" " // LoopCondition\n" - " _<=_(\n" - " __result__^#7:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#8:Expr.Call#,\n" + " true^#7:bool#,\n" " // LoopStep\n" " _?_:_(\n" " f^#4:Expr.Ident#,\n" " _+_(\n" - " __result__^#9:Expr.Ident#,\n" + " __result__^#8:Expr.Ident#,\n" " 1^#6:int64#\n" - " )^#10:Expr.Call#,\n" - " __result__^#11:Expr.Ident#\n" - " )^#12:Expr.Call#,\n" + " )^#9:Expr.Call#,\n" + " __result__^#10:Expr.Ident#\n" + " )^#11:Expr.Call#,\n" " // Result\n" " _==_(\n" - " __result__^#13:Expr.Ident#,\n" + " __result__^#12:Expr.Ident#,\n" " 1^#6:int64#\n" - " )^#14:Expr.Call#)^#15:Expr.Comprehension#"}, + " )^#13:Expr.Call#)^#14:Expr.Comprehension#"}, {"m.map(v, f)", "__comprehension__(\n" " // Variable\n" diff --git a/parser/source_factory.cc b/parser/source_factory.cc index 3052b2407..e7879f99e 100644 --- a/parser/source_factory.cc +++ b/parser/source_factory.cc @@ -250,8 +250,7 @@ Expr SourceFactory::newQuantifierExprForMacro( Expr zero_expr = newLiteralIntForMacro(macro_id, 0); Expr one_expr = newLiteralIntForMacro(macro_id, 1); init = zero_expr; - condition = newGlobalCallForMacro(macro_id, CelOperator::LESS_EQUALS, - {accu_ident(), one_expr}); + condition = newLiteralBoolForMacro(macro_id, true); step = newGlobalCallForMacro( macro_id, CelOperator::CONDITIONAL, {args[1], From 9eea89d7634b8b8a7b4059184caf36f220deb39b Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 24 Feb 2021 09:25:19 -0800 Subject: [PATCH 07/23] String conversion functions for timestamp and duration. The functions `string(timestamp)` and `string(duration)` will produce formatted strings which match those found in string-encoded protobuf values for `google.protobuf.Timestamp` and `google.protobuf.Duration` respectively. PiperOrigin-RevId: 359302330 --- conformance/BUILD | 4 +- conformance/server.cc | 2 +- eval/public/BUILD | 4 + eval/public/builtin_func_registrar.cc | 34 ++++++++ eval/public/builtin_func_test.cc | 43 ++++++++++ eval/public/containers/BUILD | 15 +++- eval/public/containers/field_access.cc | 10 ++- eval/public/containers/field_access_test.cc | 86 +++++++++++++++++++ .../containers/field_backed_map_impl_test.cc | 4 - eval/public/structs/cel_proto_wrapper_test.cc | 6 +- eval/public/transform_utility.cc | 12 ++- eval/public/value_export_util.cc | 12 ++- internal/BUILD | 4 +- internal/proto_util.cc | 36 ++++---- internal/proto_util.h | 9 +- testutil/test_data_util.cc | 4 +- v1beta1/converters.cc | 14 ++- 17 files changed, 253 insertions(+), 46 deletions(-) create mode 100644 eval/public/containers/field_access_test.cc diff --git a/conformance/BUILD b/conformance/BUILD index f7d9a0d0f..ed6809c4f 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -83,9 +83,6 @@ cc_binary( "--pipe", # TODO(issues/93): Deprecate Duration.getMilliseconds. "--skip_test=timestamps/duration_converters/get_milliseconds", - # TODO(issues/94): Missing timestamp conversion functions (type / string) - "--skip_test=timestamps/timestamp_conversions/toString_timestamp", - "--skip_test=timestamps/duration_conversions/toString_duration", # TODO(issues/81): Conversion functions for int(), uint() which can be # uncommented when the spec changes to truncation rather than rounding. "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", @@ -161,6 +158,7 @@ sh_test( "$(location @com_google_cel_spec//tests/simple:simple_test)", "--server=$(location :server)", "--skip_check", + "--skip_test=dynamic/list/var", "--pipe", ] + ["$(location " + test + ")" for test in DASHBOARD_TESTS], data = [ diff --git a/conformance/server.cc b/conformance/server.cc index 8b9ddac35..de3af202b 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -181,7 +181,7 @@ int RunServer(bool optimize) { // Implementation of a simple pipe protocol: // INPUT LINE 1: parse/check/eval // INPUT LINE 2: JSON of the corresponding request protobuf - // OUTPUT LINE 1: JSON of the coressponding response protobuf + // OUTPUT LINE 1: JSON of the corresponding response protobuf while (true) { std::string cmd, input, output; std::getline(std::cin, cmd); diff --git a/eval/public/BUILD b/eval/public/BUILD index 6d0a2815d..23ad142f4 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -182,9 +182,11 @@ cc_library( ":cel_options", "//base:unilib", "//eval/public/containers:container_backed_list_impl", + "//internal:proto_util", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], @@ -490,8 +492,10 @@ cc_test( ":cel_builtins", ":cel_expr_builder_factory", ":cel_function_registry", + ":cel_value", "//base:status_macros", "//eval/public/structs:cel_proto_wrapper", + "//internal:proto_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index b2cb4b337..94f3381f3 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -11,11 +11,13 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/time/time.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/containers/container_backed_list_impl.h" +#include "internal/proto_util.h" #include "re2/re2.h" #include "base/unilib.h" @@ -1200,6 +1202,38 @@ absl::Status RegisterStringConversionFunctions( registry); if (!status.ok()) return status; + // duration -> string + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, absl::Duration value) -> CelValue { + google::protobuf::Duration d; + auto status = google::api::expr::internal::EncodeDuration(value, &d); + if (!status.ok()) { + return CreateErrorValue(arena, status.message(), status.code()); + } + return CelValue::CreateString( + CelValue::StringHolder(Arena::Create( + arena, google::protobuf::util::TimeUtil::ToString(d)))); + }, + registry); + if (!status.ok()) return status; + + // timestamp -> string + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, absl::Time value) -> CelValue { + google::protobuf::Timestamp ts; + auto status = google::api::expr::internal::EncodeTime(value, &ts); + if (!status.ok()) { + return CreateErrorValue(arena, status.message(), status.code()); + } + return CelValue::CreateString( + CelValue::StringHolder(Arena::Create( + arena, google::protobuf::util::TimeUtil::ToString(ts)))); + }, + registry); + if (!status.ok()) return status; + return absl::OkStatus(); } diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 008f3bf08..c3f4d6621 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -9,7 +9,9 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/proto_util.h" #include "base/status_macros.h" namespace google { @@ -28,6 +30,9 @@ using google::api::expr::v1alpha1::SourceInfo; using google::protobuf::Arena; using google::protobuf::util::TimeUtil; +using ::google::api::expr::internal::MakeGoogleApiDurationMax; +using ::google::api::expr::internal::MakeGoogleApiDurationMin; +using ::google::api::expr::internal::MakeGoogleApiTimeMin; using testing::Eq; class BuiltinsTest : public ::testing::Test { @@ -132,6 +137,18 @@ class BuiltinsTest : public ::testing::Test { << operation << " for " << CelValue::TypeName(ref.type()); } + // Helper method. Looks up in registry and tests Type conversions. + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + CelValue::StringHolder result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsString(), true); + ASSERT_EQ(result_value.StringOrDie().value(), result.value()) + << operation << " for " << CelValue::TypeName(ref.type()); + } + void TestTypeConverts(absl::string_view operation, const CelValue& ref, double result) { CelValue result_value; @@ -565,6 +582,10 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), 11L); + std::string result = "93541.011s"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), + CelValue::StringHolder(&result)); + ref.set_seconds(-93541L); ref.set_nanos(-11000000L); @@ -575,6 +596,16 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { -93541L); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), -11L); + + result = "-93541.011s"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), + CelValue::StringHolder(&result)); + + absl::Duration d = MakeGoogleApiDurationMin() + absl::Seconds(-1); + TestTypeConversionError(builtin::kString, CelValue::CreateDuration(d)); + + d = MakeGoogleApiDurationMax() + absl::Seconds(1); + TestTypeConversionError(builtin::kString, CelValue::CreateDuration(d)); } // Test functions for Timestamp @@ -598,11 +629,19 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), 11L); + std::string result = "1970-01-01T00:00:01.011Z"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), + CelValue::StringHolder(&result)); + ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), 0L); + result = "1970-01-04T00:00:00Z"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), + CelValue::StringHolder(&result)); + // Test timestamp functions w/ IANA timezone ref.set_seconds(1L); ref.set_nanos(11000000L); @@ -702,6 +741,10 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), 3L); + + TestTypeConversionError( + builtin::kString, + CelValue::CreateTimestamp(MakeGoogleApiTimeMin() + absl::Seconds(-1))); } TEST_F(BuiltinsTest, TestBytesConversions_bytes) { diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 04e76b4e2..d7dcdc689 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -121,9 +121,22 @@ cc_test( ], deps = [ ":field_backed_map_impl", - "//eval/eval:evaluator_core", "//eval/testutil:test_message_cc_proto", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "field_access_test", + srcs = ["field_access_test.cc"], + deps = [ + ":field_access", + "//internal:proto_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index 23442b186..88cd2436d 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -471,7 +471,10 @@ class FieldSetter { return false; } google::protobuf::Duration duration; - google::api::expr::internal::EncodeDuration(d, &duration); + auto status = google::api::expr::internal::EncodeDuration(d, &duration); + if (!status.ok()) { + return false; + } static_cast(this)->SetMessage(&duration); return true; } @@ -483,7 +486,10 @@ class FieldSetter { return false; } google::protobuf::Timestamp timestamp; - google::api::expr::internal::EncodeTime(t, ×tamp); + auto status = google::api::expr::internal::EncodeTime(t, ×tamp); + if (!status.ok()) { + return false; + } static_cast(this)->SetMessage(×tamp); return true; } diff --git a/eval/public/containers/field_access_test.cc b/eval/public/containers/field_access_test.cc new file mode 100644 index 000000000..c6b380b3a --- /dev/null +++ b/eval/public/containers/field_access_test.cc @@ -0,0 +1,86 @@ +#include "eval/public/containers/field_access.h" + +#include "google/protobuf/message.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "internal/proto_util.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::api::expr::internal::MakeGoogleApiDurationMax; +using google::api::expr::internal::MakeGoogleApiTimeMax; +using google::protobuf::FieldDescriptor; +using test::v1::proto3::TestAllTypes; + +TEST(FieldAccessTest, SetDuration) { + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MakeGoogleApiDurationMax()), field, &msg); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetDurationBadDuration) { + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MakeGoogleApiDurationMax() + absl::Seconds(1)), + field, &msg); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetDurationBadInputType) { + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestamp) { + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MakeGoogleApiTimeMax()), field, &msg); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetTimestampBadTime) { + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MakeGoogleApiTimeMax() + absl::Seconds(1)), + field, &msg); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestampBadInputType) { + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/containers/field_backed_map_impl_test.cc b/eval/public/containers/field_backed_map_impl_test.cc index 41b214aa1..59f75bf8e 100644 --- a/eval/public/containers/field_backed_map_impl_test.cc +++ b/eval/public/containers/field_backed_map_impl_test.cc @@ -103,10 +103,6 @@ TEST(FieldBackedMapImplTest, EmptySizeTest) { google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); - - std::string test0 = "test0"; - std::string test1 = "test1"; - EXPECT_EQ(cel_map->size(), 0); } diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index 4e797589a..7c48abd4e 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -80,7 +80,8 @@ TEST(CelProtoWrapperTest, TestDuration) { // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsDuration()); Duration out; - expr::internal::EncodeDuration(value.DurationOrDie(), &out); + auto status = expr::internal::EncodeDuration(value.DurationOrDie(), &out); + EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } @@ -102,7 +103,8 @@ TEST(CelProtoWrapperTest, TestTimestamp) { // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsTimestamp()); Timestamp out; - expr::internal::EncodeTime(value.TimestampOrDie(), &out); + auto status = expr::internal::EncodeTime(value.TimestampOrDie(), &out); + EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index 4b4ebb8ad..bcfae9d94 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -44,13 +44,21 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { break; case CelValue::Type::kDuration: { google::protobuf::Duration duration; - expr::internal::EncodeDuration(value.DurationOrDie(), &duration); + auto status = + expr::internal::EncodeDuration(value.DurationOrDie(), &duration); + if (!status.ok()) { + return status; + } result->mutable_object_value()->PackFrom(duration); break; } case CelValue::Type::kTimestamp: { google::protobuf::Timestamp timestamp; - expr::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + auto status = + expr::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + if (!status.ok()) { + return status; + } result->mutable_object_value()->PackFrom(timestamp); break; } diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index 89ef53022..56274ebf3 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -77,13 +77,21 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { } case CelValue::Type::kDuration: { Duration duration; - expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + auto status = + expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + if (!status.ok()) { + return status; + } out_value->set_string_value(TimeUtil::ToString(duration)); break; } case CelValue::Type::kTimestamp: { Timestamp timestamp; - expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + auto status = + expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + if (!status.ok()) { + return status; + } out_value->set_string_value(TimeUtil::ToString(timestamp)); break; } diff --git a/internal/BUILD b/internal/BUILD index 7819be972..a6fc4d3a2 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -174,12 +174,12 @@ cc_library( srcs = ["proto_util.cc"], hdrs = ["proto_util.h"], deps = [ - ":status_util", + "//base:status_macros", "//common:macros", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/internal/proto_util.cc b/internal/proto_util.cc index b2dd0a22b..32789b61d 100644 --- a/internal/proto_util.cc +++ b/internal/proto_util.cc @@ -1,10 +1,11 @@ #include "internal/proto_util.h" + #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/macros.h" -#include "internal/status_util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -13,26 +14,26 @@ namespace internal { namespace { -google::rpc::Status Validate(absl::Duration duration) { +absl::Status Validate(absl::Duration duration) { if (duration < MakeGoogleApiDurationMin()) { - return InvalidArgumentError(absl::StrCat("duration below min")); + return absl::InvalidArgumentError("duration below min"); } if (duration > MakeGoogleApiDurationMax()) { - return InvalidArgumentError(absl::StrCat("duration above max")); + return absl::InvalidArgumentError("duration above max"); } - return OkStatus(); + return absl::OkStatus(); } -google::rpc::Status Validate(absl::Time time) { +absl::Status Validate(absl::Time time) { if (time < MakeGoogleApiTimeMin()) { - return InvalidArgumentError(absl::StrCat("time below min")); + return absl::InvalidArgumentError("time below min"); } if (time > MakeGoogleApiTimeMax()) { - return InvalidArgumentError(absl::StrCat("time above max")); + return absl::InvalidArgumentError("time above max"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -46,24 +47,23 @@ absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { absl::Nanoseconds(proto.nanos()); } -google::rpc::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto) { - RETURN_IF_STATUS_ERROR(Validate(duration)); +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto) { + RETURN_IF_ERROR(Validate(duration)); // s and n may both be negative, per the Duration proto spec. const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); proto->set_seconds(s); proto->set_nanos(n); - return OkStatus(); + return absl::OkStatus(); } -google::rpc::Status EncodeTime(absl::Time time, - google::protobuf::Timestamp* proto) { - RETURN_IF_STATUS_ERROR(Validate(time)); +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { + RETURN_IF_ERROR(Validate(time)); const int64_t s = absl::ToUnixSeconds(time); proto->set_seconds(s); proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/internal/proto_util.h b/internal/proto_util.h index 12534ec50..3134fa43f 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -3,9 +3,9 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" #include "google/protobuf/util/message_differencer.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/time/time.h" namespace google { @@ -21,12 +21,11 @@ struct DefaultProtoEqual { }; /** Helper function to encode a duration in a google::protobuf::Duration. */ -google::rpc::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto); +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto); /** Helper function to encode a time in a google::protobuf::Timestamp. */ -google::rpc::Status EncodeTime(absl::Time time, - google::protobuf::Timestamp* proto); +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); /** Helper function to decode a duration from a google::protobuf::Duration. */ absl::Duration DecodeDuration(const google::protobuf::Duration& proto); diff --git a/testutil/test_data_util.cc b/testutil/test_data_util.cc index 3a1a10928..fc8f894ab 100644 --- a/testutil/test_data_util.cc +++ b/testutil/test_data_util.cc @@ -373,7 +373,7 @@ TestValue NewValue(const google::protobuf::Message& value, absl::string_view nam TestValue NewValue(absl::Duration value, absl::string_view name) { google::protobuf::Duration duration; auto status = expr::internal::EncodeDuration(value, &duration); - assert(status.code() == google::rpc::Code::OK); + assert(status.ok()); auto result = NewValue(duration, name); *AddProto(&result, "single_duration")->mutable_single_duration() = duration; return result; @@ -382,7 +382,7 @@ TestValue NewValue(absl::Duration value, absl::string_view name) { TestValue NewValue(absl::Time value, absl::string_view name) { google::protobuf::Timestamp timestamp; auto status = expr::internal::EncodeTime(value, ×tamp); - assert(status.code() == google::rpc::Code::OK); + assert(status.ok()); auto result = NewValue(timestamp, name); *AddProto(&result, "single_timestamp")->mutable_single_timestamp() = timestamp; diff --git a/v1beta1/converters.cc b/v1beta1/converters.cc index 97e6b3b8f..22a6e4ff7 100644 --- a/v1beta1/converters.cc +++ b/v1beta1/converters.cc @@ -81,6 +81,16 @@ struct ToExprValue { return true; } + google::rpc::Status StatusToRpcStatus(const absl::Status& value) { + if (value.ok()) { + return OkStatus(); + } + google::rpc::Status error; + error.set_code(static_cast(value.code())); + error.set_message(value.message().data(), value.message().size()); + return error; + } + void EncodeMessage(const google::protobuf::Message& value) { result->mutable_value()->mutable_object_value()->PackFrom(value); } @@ -99,7 +109,7 @@ struct ToExprValue { google::rpc::Status operator()(absl::Duration value) { google::protobuf::Duration duration; - auto status = EncodeDuration(value, &duration); + auto status = StatusToRpcStatus(EncodeDuration(value, &duration)); if (CheckAndEncodeIfError(status)) { EncodeMessage(duration); } @@ -108,7 +118,7 @@ struct ToExprValue { google::rpc::Status operator()(absl::Time value) { google::protobuf::Timestamp time; - auto status = EncodeTime(value, &time); + auto status = StatusToRpcStatus(EncodeTime(value, &time)); if (CheckAndEncodeIfError(status)) { EncodeMessage(time); } From d133b25b2af28d9441fb16925680ab6de187c498 Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 25 Feb 2021 01:25:09 -0800 Subject: [PATCH 08/23] Internal build change. PiperOrigin-RevId: 359468897 --- conformance/BUILD | 110 +++++++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 59 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index ed6809c4f..8a19ac09f 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -18,6 +18,7 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/logic.textproto", "@com_google_cel_spec//tests/simple:testdata/macros.textproto", "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", + "@com_google_cel_spec//tests/simple:testdata/parse.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", @@ -26,27 +27,6 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", ] -DASHBOARD_TESTS = [ - "@com_google_cel_spec//tests/simple:testdata/basic.textproto", - "@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", - "@com_google_cel_spec//tests/simple:testdata/conversions.textproto", - "@com_google_cel_spec//tests/simple:testdata/dynamic.textproto", - "@com_google_cel_spec//tests/simple:testdata/enums.textproto", - "@com_google_cel_spec//tests/simple:testdata/fields.textproto", - "@com_google_cel_spec//tests/simple:testdata/fp_math.textproto", - "@com_google_cel_spec//tests/simple:testdata/integer_math.textproto", - "@com_google_cel_spec//tests/simple:testdata/lists.textproto", - "@com_google_cel_spec//tests/simple:testdata/logic.textproto", - "@com_google_cel_spec//tests/simple:testdata/macros.textproto", - "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", - "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", - "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", - "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", - "@com_google_cel_spec//tests/simple:testdata/string.textproto", - "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", - "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", -] - cc_binary( name = "server", testonly = 1, @@ -81,48 +61,45 @@ cc_binary( "--server=\"$(location :server) " + arg + "\"", "--skip_check", "--pipe", + + # Tests which require spec changes. # TODO(issues/93): Deprecate Duration.getMilliseconds. "--skip_test=timestamps/duration_converters/get_milliseconds", # TODO(issues/81): Conversion functions for int(), uint() which can be # uncommented when the spec changes to truncation rather than rounding. "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", "--skip_test=conversions/uint/double_nearest,double_nearest_int,double_half_away", - # TODO(issues/96): Well-known type conversion support. - "--skip_test=proto2/literal_wellknown", - "--skip_test=proto3/literal_wellknown", - "--skip_test=proto2/empty_field/wkt", - "--skip_test=proto3/empty_field/wkt", - # Requires container support - "--skip_test=namespace/namespace/self_eval_container_lookup_unchecked", + # TODO(issues/110): Tune parse limits to mirror those for proto deserialization and C++ safety limits. + "--skip_test=parse/nest/list_index,message_literal,funcall,list_literal,map_literal;repeat/conditional,add_sub,mul_div,select,index,map_literal,message_literal", + + # Broken test cases which should be supported. + # TODO(issues/111): Byte literal decoding of invalid UTF-8 results in incorrect bytes output. "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", - # Requires heteregenous equality spec clarification + # TODO(issues/112): Unbound functions result in empty eval response. + "--skip_test=basic/functions/unbound", + "--skip_test=basic/functions/unbound_is_runtime_error", + # TODO(issues/113): Aggregate values must logically AND element equality results. + "--skip_test=comparisons/eq_literal/not_eq_list_false_vs_types", + "--skip_test=comparisons/eq_literal/not_eq_map_false_vs_types", + # TODO(issues/114): Ensure the 'in' operator is a logical OR of element equality results. "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", + # TODO(issues/115): The 'in' operator fails with maps containing boolean keys. "--skip_test=fields/in/singleton", - # Requires qualified bindings error message relaxation - "--skip_test=fields/qualified_identifier_resolution/qualified_identifier_resolution_unchecked", - "--skip_test=string/size/one_unicode,unicode", - "--skip_test=string/bytes_concat/left_unit", # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails + "--skip_test=fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "--skip_test=namespace/qualified/self_eval_qualified_lookup", - "--skip_test=namespace/namespace/self_eval_container_lookup", - "--skip_test=fields/qualified_identifier_resolution/qualified_ident", - "--skip_test=fields/qualified_identifier_resolution/map_field_select", - "--skip_test=fields/qualified_identifier_resolution/ident_with_longest_prefix_check", - # New conformance tests awaiting synchronization. - "--skip_test=basic/functions/unbound", - "--skip_test=basic/functions/unbound_is_runtime_error", - "--skip_test=comparisons/eq_literal/not_eq_list_false_vs_types", - "--skip_test=comparisons/eq_literal/not_eq_map_false_vs_types", - "--skip_test=dynamic/int32", - "--skip_test=dynamic/int64", - "--skip_test=dynamic/uint32", - "--skip_test=dynamic/uint64", - "--skip_test=dynamic/float", - "--skip_test=dynamic/double", - "--skip_test=dynamic/string", - "--skip_test=dynamic/bytes", - "--skip_test=dynamic/bool", + "--skip_test=namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", + # TODO(issues/96): Well-known type conversion support. + "--skip_test=dynamic/int32/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", + "--skip_test=dynamic/int64/field_assign_proto2,field_assign_proto2_zero,field_assign_proto3,field_assign_proto3_zero", + "--skip_test=dynamic/uint32/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", + "--skip_test=dynamic/uint64/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero", + "--skip_test=dynamic/float/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", + "--skip_test=dynamic/double/field_assign_proto2,field_assign_proto2_zero,field_assign_proto2_range,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_assign_proto3_range,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", + "--skip_test=dynamic/string/field_assign_proto2,field_assign_proto2_empty,field_assign_proto3,field_assign_proto3_empty", + "--skip_test=dynamic/bytes/field_assign_proto2,field_assign_proto2_empty,field_assign_proto3,field_assign_proto3_empty", + "--skip_test=dynamic/bool/field_assign_proto2,field_assign_proto2_false,field_assign_proto3,field_assign_proto3_false", "--skip_test=dynamic/list", "--skip_test=dynamic/struct", "--skip_test=dynamic/value_null", @@ -131,14 +108,25 @@ cc_binary( "--skip_test=dynamic/value_bool", "--skip_test=dynamic/value_struct", "--skip_test=dynamic/value_list", - "--skip_test=dynamic/any", - "--skip_test=dynamic/complex", - "--skip_test=enums/legacy_proto2", - "--skip_test=enums/legacy_proto3", + "--skip_test=dynamic/complex/any_list_map", + "--skip_test=proto2/literal_wellknown", + "--skip_test=proto3/literal_wellknown", + "--skip_test=proto2/empty_field/wkt", + "--skip_test=proto3/empty_field/wkt", + # TODO(issues/120): Ensure that string size refers to the number of code points. + "--skip_test=string/size/one_unicode,unicode", + "--skip_test=string/bytes_concat/left_unit", + # TODO(issues/117): Integer overflow on enum assignments should error. + "--skip_test=enums/legacy_proto2/select_big,select_neg,assign_standalone_int_too_big,assign_standalone_int_too_neg", + "--skip_test=enums/legacy_proto3/assign_standalone_int_too_big,assign_standalone_int_too_neg", + # TODO(issues/118): Duration and timestamp range errors should result in evaluation errors. + "--skip_test=timestamps/duration_range", + "--skip_test=timestamps/timestamp_range", + + # Future features for CEL 1.0 + # TODO(issues/119): Strong typing support for enums, specified but not implemented. "--skip_test=enums/strong_proto2", "--skip_test=enums/strong_proto3", - "--skip_test=timestamps/timestamp_range", - "--skip_test=timestamps/duration_range", ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", @@ -158,13 +146,17 @@ sh_test( "$(location @com_google_cel_spec//tests/simple:simple_test)", "--server=$(location :server)", "--skip_check", + # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. "--skip_test=dynamic/list/var", + # TODO(issues/119): Strong typing support for enums, specified but not implemented. + "--skip_test=enums/strong_proto2", + "--skip_test=enums/strong_proto3", "--pipe", - ] + ["$(location " + test + ")" for test in DASHBOARD_TESTS], + ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", "@com_google_cel_spec//tests/simple:simple_test", - ] + DASHBOARD_TESTS, + ] + ALL_TESTS, visibility = [ "//:__subpackages__", "//third_party/cel:__pkg__", From 01ad69e8d493b32f0ebb3308970fd3e05239b222 Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 11 Mar 2021 02:00:44 -0800 Subject: [PATCH 09/23] Introduce option for enabling / disabling string.size() returning the number unicode codepoints rather than the number of bytes. PiperOrigin-RevId: 362242965 --- eval/public/cel_options.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index dcd774c6c..4dbe48d03 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -27,6 +27,11 @@ struct InterpreterOptions { bool enable_missing_attribute_errors = false; + // Enable functions which return the string.size() as the number of unicode + // codepoints. + // Starting on 4/7/2021 this will default to 'true' + bool enable_string_size_as_unicode_codepoints = false; + // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. From 663c2d83a0baa5a6128d012ab6bb46fd014c1b58 Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 11 Mar 2021 21:08:31 -0800 Subject: [PATCH 10/23] Use unicode codepoint count as string size. The CEL Spec indicates that the size of a string should be the number of unicode codepoints contained within it. This option is flag enabled to ensure that users may transition their logic to the new, more intuitive, convention. PiperOrigin-RevId: 362442988 --- conformance/BUILD | 4 +--- conformance/server.cc | 4 +++- eval/public/BUILD | 1 + eval/public/builtin_func_registrar.cc | 21 +++++++++++++++------ eval/public/builtin_func_test.cc | 22 +++++++++++++++++++++- 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 8a19ac09f..596bf2254 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -75,6 +75,7 @@ cc_binary( # Broken test cases which should be supported. # TODO(issues/111): Byte literal decoding of invalid UTF-8 results in incorrect bytes output. "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", + "--skip_test=string/bytes_concat/left_unit", # TODO(issues/112): Unbound functions result in empty eval response. "--skip_test=basic/functions/unbound", "--skip_test=basic/functions/unbound_is_runtime_error", @@ -113,9 +114,6 @@ cc_binary( "--skip_test=proto3/literal_wellknown", "--skip_test=proto2/empty_field/wkt", "--skip_test=proto3/empty_field/wkt", - # TODO(issues/120): Ensure that string size refers to the number of code points. - "--skip_test=string/size/one_unicode,unicode", - "--skip_test=string/bytes_concat/left_unit", # TODO(issues/117): Integer overflow on enum assignments should error. "--skip_test=enums/legacy_proto2/select_big,select_neg,assign_standalone_int_too_big,assign_standalone_int_too_neg", "--skip_test=enums/legacy_proto3/assign_standalone_int_too_big,assign_standalone_int_too_neg", diff --git a/conformance/server.cc b/conformance/server.cc index de3af202b..089b7936c 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -151,6 +151,7 @@ int RunServer(bool optimize) { google::protobuf::Arena arena; InterpreterOptions options; options.enable_qualified_type_identifiers = true; + options.enable_string_size_as_unicode_codepoints = true; if (optimize) { std::cerr << "Enabling optimizations" << std::endl; @@ -169,7 +170,8 @@ int RunServer(bool optimize) { NestedEnum_descriptor()); type_registry->Register(google::api::expr::test::v1::proto3::TestAllTypes:: NestedEnum_descriptor()); - auto register_status = RegisterBuiltinFunctions(builder->GetRegistry()); + auto register_status = + RegisterBuiltinFunctions(builder->GetRegistry(), options); if (!register_status.ok()) { std::cerr << "Failed to initialize: " << register_status.ToString() << std::endl; diff --git a/eval/public/BUILD b/eval/public/BUILD index 23ad142f4..08fadabb7 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -183,6 +183,7 @@ cc_library( "//base:unilib", "//eval/public/containers:container_backed_list_impl", "//internal:proto_util", + "//util/utf8/public:unicodetext", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 94f3381f3..6770516b1 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -19,6 +19,7 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "internal/proto_util.h" #include "re2/re2.h" +#include "util/utf8/public/unicodetext.h" #include "base/unilib.h" namespace google { @@ -1386,16 +1387,24 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // String size - auto string_size_func = [](Arena*, CelValue::StringHolder value) -> int64_t { - return value.value().size(); + auto size_func = [=](Arena* arena, CelValue::StringHolder value) -> CelValue { + if (options.enable_string_size_as_unicode_codepoints) { + auto status = UnownedUnicodeTextFromUTF8(value.value()); + if (!status.ok()) { + return CreateErrorValue(arena, "invalid utf-8 string", + absl::StatusCode::kInvalidArgument); + } + return CelValue::CreateInt64(status.value().size()); + } + return CelValue::CreateInt64(value.value().size()); }; // receiver style = true/false // Support global and receiver style size() operations on strings. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, string_size_func, registry); + status = FunctionAdapter::CreateAndRegister( + builtin::kSize, true, size_func, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, string_size_func, registry); + status = FunctionAdapter::CreateAndRegister( + builtin::kSize, false, size_func, registry); if (!status.ok()) return status; // Bytes size diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index c3f4d6621..1b391c23e 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1370,8 +1370,28 @@ TEST_F(BuiltinsTest, StringSize) { builtin::kSize, {}, {CelValue::CreateString(&test)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 9); +} - ASSERT_EQ(result_value.Int64OrDie(), test.size()); +TEST_F(BuiltinsTest, StringUnicodeSize) { + std::string test = "πέντε"; + CelValue result_value; + InterpreterOptions options; + options.enable_string_size_as_unicode_codepoints = true; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kSize, {}, + {CelValue::CreateString(&test)}, + &result_value, options)); + ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 5); + + // Disable the option to measure string size by codepoints, and the return + // value should be equal to the number of bytes in the string (10). + options.enable_string_size_as_unicode_codepoints = false; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kSize, {}, + {CelValue::CreateString(&test)}, + &result_value, options)); + ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 10); } TEST_F(BuiltinsTest, BytesSize) { From e12f02c05d78a3fce293c4d442622f66ef0ae0bf Mon Sep 17 00:00:00 2001 From: tswadell Date: Fri, 12 Mar 2021 15:21:10 -0800 Subject: [PATCH 11/23] Change the UTF8 codepoint count method to be more OSS friendly. PiperOrigin-RevId: 362611318 --- eval/public/BUILD | 1 - eval/public/builtin_func_registrar.cc | 20 +++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 08fadabb7..23ad142f4 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -183,7 +183,6 @@ cc_library( "//base:unilib", "//eval/public/containers:container_backed_list_impl", "//internal:proto_util", - "//util/utf8/public:unicodetext", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 6770516b1..626beefea 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -19,7 +19,6 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "internal/proto_util.h" #include "re2/re2.h" -#include "util/utf8/public/unicodetext.h" #include "base/unilib.h" namespace google { @@ -35,6 +34,17 @@ const int64_t kIntMax = std::numeric_limits::max(); const int64_t kIntMin = std::numeric_limits::min(); const uint64_t kUintMax = std::numeric_limits::max(); +// Returns the number of UTF8 codepoints within a string. +// The input string must first be checked to see if it is valid UTF8. +static int UTF8CodepointCount(absl::string_view str) { + int n = 0; + // Increment the codepoint count on non-trail-byte characters. + for (const auto p : str) { + n += (*reinterpret_cast(&p) >= -0x40); + } + return n; +} + // Comparison template functions template CelValue Inequal(Arena*, Type t1, Type t2) { @@ -1388,15 +1398,15 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, // String size auto size_func = [=](Arena* arena, CelValue::StringHolder value) -> CelValue { + absl::string_view str = value.value(); if (options.enable_string_size_as_unicode_codepoints) { - auto status = UnownedUnicodeTextFromUTF8(value.value()); - if (!status.ok()) { + if (!UniLib::IsStructurallyValid(str)) { return CreateErrorValue(arena, "invalid utf-8 string", absl::StatusCode::kInvalidArgument); } - return CelValue::CreateInt64(status.value().size()); + return CelValue::CreateInt64(UTF8CodepointCount(str)); } - return CelValue::CreateInt64(value.value().size()); + return CelValue::CreateInt64(str.size()); }; // receiver style = true/false // Support global and receiver style size() operations on strings. From 8ad130527ca68fd1c1c4108e794bf488dee5b86d Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 17 Mar 2021 17:50:13 -0700 Subject: [PATCH 12/23] Introduce an error recovery limit for the CEL Parser PiperOrigin-RevId: 363548932 --- parser/parser.cc | 60 +++++++++++++++++++++++++++++++++++-------- parser/parser.h | 10 +++++--- parser/parser_test.cc | 12 +++++++++ 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/parser/parser.cc b/parser/parser.cc index 40dce202e..cedce38bd 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -36,9 +36,10 @@ namespace { // as to prevent stack overflows. class ExprRecursionListener : public ParseTreeListener { public: - ExprRecursionListener( + explicit ExprRecursionListener( const int max_recursion_depth = kDefaultMaxRecursionDepth) : max_recursion_depth_(max_recursion_depth), recursion_depth_(0) {} + ~ExprRecursionListener() override {} void visitTerminal(TerminalNode* node) override{}; void visitErrorNode(ErrorNode* error) override{}; @@ -70,21 +71,53 @@ void ExprRecursionListener::exitEveryRule(ParserRuleContext* ctx) { } } +class RecoveryLimitErrorStrategy : public antlr4::DefaultErrorStrategy { + public: + explicit RecoveryLimitErrorStrategy( + int recovery_limit = kDefaultErrorRecoveryLimit) + : recovery_limit_(recovery_limit), recovery_attempts_(0) {} + + void recover(antlr4::Parser* recognizer, std::exception_ptr e) override { + checkRecoveryLimit(recognizer); + antlr4::DefaultErrorStrategy::recover(recognizer, e); + } + + antlr4::Token* recoverInline(antlr4::Parser* recognizer) override { + checkRecoveryLimit(recognizer); + return antlr4::DefaultErrorStrategy::recoverInline(recognizer); + } + + private: + void checkRecoveryLimit(antlr4::Parser* recognizer) { + if (recovery_attempts_++ >= recovery_limit_) { + std::string too_many_errors = + absl::StrFormat("More than %d parse errors.", recovery_limit_); + recognizer->notifyErrorListeners(too_many_errors); + throw ParseCancellationException(too_many_errors); + } + } + + int recovery_limit_; + int recovery_attempts_; +}; + } // namespace absl::StatusOr Parse(const std::string& expression, const std::string& description, - const int max_recursion_depth) { + int max_recursion_depth, + int error_recovery_limit) { return ParseWithMacros(expression, Macro::AllMacros(), description, - max_recursion_depth); + max_recursion_depth, error_recovery_limit); } absl::StatusOr ParseWithMacros(const std::string& expression, const std::vector& macros, const std::string& description, - const int max_recursion_depth) { - auto result = - EnrichedParse(expression, macros, description, max_recursion_depth); + int max_recursion_depth, + int error_recovery_limit) { + auto result = EnrichedParse(expression, macros, description, + max_recursion_depth, error_recovery_limit); if (result.ok()) { return result->parsed_expr(); } @@ -93,7 +126,8 @@ absl::StatusOr ParseWithMacros(const std::string& expression, absl::StatusOr EnrichedParse( const std::string& expression, const std::vector& macros, - const std::string& description, const int max_recursion_depth) { + const std::string& description, int max_recursion_depth, + int error_recovery_limit) { ANTLRInputStream input(expression); CelLexer lexer(&input); CommonTokenStream tokens(&lexer); @@ -107,21 +141,25 @@ absl::StatusOr EnrichedParse( parser.addErrorListener(&visitor); parser.addParseListener(&listener); - // if we were to ignore errors completely: - // std::shared_ptr error_strategy(new BailErrorStrategy()); - // parser.setErrorHandler(error_strategy); + // Limit the number of error recovery attempts to prevent bad expressions + // from consuming lots of cpu / memory. + std::shared_ptr error_strategy( + new RecoveryLimitErrorStrategy(error_recovery_limit)); + parser.setErrorHandler(error_strategy); CelParser::StartContext* root; try { root = parser.start(); } catch (const ParseCancellationException& e) { + if (visitor.hasErrored()) { + return absl::InvalidArgumentError(visitor.errorMessage()); + } return absl::CancelledError(e.what()); } catch (const std::exception& e) { return absl::AbortedError(e.what()); } Expr expr = visitor.visit(root).as(); - if (visitor.hasErrored()) { return absl::InvalidArgumentError(visitor.errorMessage()); } diff --git a/parser/parser.h b/parser/parser.h index 7227c8fed..70edd5848 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -12,6 +12,7 @@ namespace api { namespace expr { namespace parser { +constexpr int kDefaultErrorRecoveryLimit = 30; constexpr int kDefaultMaxRecursionDepth = 250; class VerboseParsedExpr { @@ -36,16 +37,19 @@ class VerboseParsedExpr { absl::StatusOr EnrichedParse( const std::string& expression, const std::vector& macros, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); + int max_recursion_depth = kDefaultMaxRecursionDepth, + int error_recovery_limit = kDefaultErrorRecoveryLimit); absl::StatusOr Parse( const std::string& expression, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); + int max_recursion_depth = kDefaultMaxRecursionDepth, + int error_recovery_limit = kDefaultErrorRecoveryLimit); absl::StatusOr ParseWithMacros( const std::string& expression, const std::vector& macros, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); + int max_recursion_depth = kDefaultMaxRecursionDepth, + int error_recovery_limit = kDefaultErrorRecoveryLimit); } // namespace parser } // namespace expr diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 373e29bf4..860e7f813 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -989,6 +989,18 @@ TEST(ExpressionTest, TsanOom) { .IgnoreError(); } +TEST(ExpressionTest, ErrorRecoveryLimits) { + auto result = Parse("......", "", kDefaultMaxRecursionDepth, 1); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.status().message(), + "ERROR: :1:2: Syntax error: missing IDENTIFIER at '.'\n" + " | ......\n" + " | .^\n" + "ERROR: :1:3: Syntax error: More than 1 parse errors.\n" + " | ......\n" + " | ..^"); +} + INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases)); From a105a0484730920c40c8634518efc38a44e7e67f Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 18 Mar 2021 09:16:35 -0700 Subject: [PATCH 13/23] Fix for float-cast-overflow in `uint()` expressions. PiperOrigin-RevId: 363672573 --- eval/public/builtin_func_registrar.cc | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 626beefea..ceb99538b 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -1101,8 +1101,10 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, status = FunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena* arena, double v) { - if ((v > static_cast(kIntMax)) || - (v < static_cast(kIntMin)) || std::isnan(v)) { + // NaN and -+infinite numbers cannot be represented as int values, + // nor can double values which exceed the integer 64-bit range. + if (!std::isfinite(v) || v > static_cast(kIntMax) || + v < static_cast(kIntMin)) { return CreateErrorValue(arena, "double out of int range", absl::StatusCode::kInvalidArgument); } @@ -1254,7 +1256,16 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, auto status = FunctionAdapter::CreateAndRegister( builtin::kUint, false, [](Arena* arena, double v) { - if ((v > static_cast(kUintMax)) || (v < 0) || std::isnan(v)) { + // NaN and -+infinite numbers cannot be represented as uint values, + // nor doubles that exceed the uint64_t range. In some limited cases, + // like 1.84467e+19, the value appears to fit within the uint64_t range + // but type conversion results in rounding that overflows. + // + // Note, the double is checked to make sure it is not greater than 2^64 + // before it is converted to a uint128 value, as the type conversion + // may check-fail for some double inputs that exceed this value. + if (!std::isfinite(v) || v < 0 || v > std::ldexp(1.0, 64) || + absl::uint128(v) > kUintMax) { return CreateErrorValue(arena, "double out of uint range", absl::StatusCode::kInvalidArgument); } From c332be67ae5ee17ce4a5b11b747f692cc7f08cc1 Mon Sep 17 00:00:00 2001 From: kuat Date: Thu, 18 Mar 2021 11:20:05 -0700 Subject: [PATCH 14/23] Ensure that string->duration conversion validates that the duration is in the valid range. Fixes https://github.com/google/cel-cpp/issues/104. PiperOrigin-RevId: 363701763 --- conformance/BUILD | 2 +- eval/public/builtin_func_registrar.cc | 6 ++++++ eval/public/builtin_func_test.cc | 23 +++++++++++++++++++++++ internal/proto_util.cc | 24 ++++++++++++------------ internal/proto_util.h | 3 +++ 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 596bf2254..114526512 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -118,7 +118,7 @@ cc_binary( "--skip_test=enums/legacy_proto2/select_big,select_neg,assign_standalone_int_too_big,assign_standalone_int_too_neg", "--skip_test=enums/legacy_proto3/assign_standalone_int_too_big,assign_standalone_int_too_neg", # TODO(issues/118): Duration and timestamp range errors should result in evaluation errors. - "--skip_test=timestamps/duration_range", + "--skip_test=timestamps/duration_range/add_over,add_under", "--skip_test=timestamps/timestamp_range", # Future features for CEL 1.0 diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index ceb99538b..42d88da2d 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -745,6 +745,12 @@ CelValue CreateDurationFromString(Arena* arena, absl::StatusCode::kInvalidArgument); } + absl::Status status = google::api::expr::internal::ValidateDuration(d); + if (!status.ok()) { + return CreateErrorValue(arena, "Duration is out of range", + absl::StatusCode::kInvalidArgument); + } + return CelValue::CreateDuration(d); } diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 1b391c23e..fb913ee27 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -182,6 +182,18 @@ class BuiltinsTest : public ::testing::Test { << operation << " for " << CelValue::TypeName(ref.type()); } + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + Duration& result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsDuration(), true); + ASSERT_EQ(result_value.DurationOrDie(), + CelProtoWrapper::CreateDuration(&result).DurationOrDie()) + << operation << " for " << CelValue::TypeName(ref.type()); + } + // Helper method. Attempts to perform a type conversion and expects an error // as the result. void TestTypeConversionError(absl::string_view operation, @@ -585,6 +597,7 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { std::string result = "93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), CelValue::StringHolder(&result)); + TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); ref.set_seconds(-93541L); ref.set_nanos(-11000000L); @@ -600,12 +613,22 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { result = "-93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), CelValue::StringHolder(&result)); + TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); absl::Duration d = MakeGoogleApiDurationMin() + absl::Seconds(-1); TestTypeConversionError(builtin::kString, CelValue::CreateDuration(d)); + result = absl::FormatDuration(d); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); d = MakeGoogleApiDurationMax() + absl::Seconds(1); TestTypeConversionError(builtin::kString, CelValue::CreateDuration(d)); + result = absl::FormatDuration(d); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); + + std::string inf = "inf"; + std::string ninf = "-inf"; + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&inf)); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&ninf)); } // Test functions for Timestamp diff --git a/internal/proto_util.cc b/internal/proto_util.cc index 32789b61d..e33d267aa 100644 --- a/internal/proto_util.cc +++ b/internal/proto_util.cc @@ -14,17 +14,6 @@ namespace internal { namespace { -absl::Status Validate(absl::Duration duration) { - if (duration < MakeGoogleApiDurationMin()) { - return absl::InvalidArgumentError("duration below min"); - } - - if (duration > MakeGoogleApiDurationMax()) { - return absl::InvalidArgumentError("duration above max"); - } - return absl::OkStatus(); -} - absl::Status Validate(absl::Time time) { if (time < MakeGoogleApiTimeMin()) { return absl::InvalidArgumentError("time below min"); @@ -38,6 +27,17 @@ absl::Status Validate(absl::Time time) { } // namespace +absl::Status ValidateDuration(absl::Duration duration) { + if (duration < MakeGoogleApiDurationMin()) { + return absl::InvalidArgumentError("duration below min"); + } + + if (duration > MakeGoogleApiDurationMax()) { + return absl::InvalidArgumentError("duration above max"); + } + return absl::OkStatus(); +} + absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); } @@ -49,7 +49,7 @@ absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { absl::Status EncodeDuration(absl::Duration duration, google::protobuf::Duration* proto) { - RETURN_IF_ERROR(Validate(duration)); + RETURN_IF_ERROR(ValidateDuration(duration)); // s and n may both be negative, per the Duration proto spec. const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); diff --git a/internal/proto_util.h b/internal/proto_util.h index 3134fa43f..3bb771ce3 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -20,6 +20,9 @@ struct DefaultProtoEqual { } }; +/** Validate that the duration is in the valid protobuf duration range. */ +absl::Status ValidateDuration(absl::Duration duration); + /** Helper function to encode a duration in a google::protobuf::Duration. */ absl::Status EncodeDuration(absl::Duration duration, google::protobuf::Duration* proto); From 79bef78361cb768d63fdc8c7a98dec3d8aac1d1a Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 23 Mar 2021 08:43:17 -0700 Subject: [PATCH 15/23] Internal change PiperOrigin-RevId: 364565575 --- conformance/server.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/conformance/server.cc b/conformance/server.cc index 089b7936c..dcceae3a3 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -1,8 +1,6 @@ -#include - +#include "google/api/expr/v1alpha1/conformance_service.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/conformance_service.pb.h" #include "google/api/expr/v1alpha1/eval.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/value.pb.h" From 3c4fb48eb26d8fd97a9df9dfc1b8ac9939891370 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 23 Mar 2021 10:14:47 -0700 Subject: [PATCH 16/23] Document how to handle errors when creating functions with FunctionAdadapter. No change to behaviour. PiperOrigin-RevId: 364585728 --- eval/public/cel_function_adapter.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index a26391cb0..4f7d6e081 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -63,6 +63,9 @@ bool AddType(std::vector* arg_types) { // It accepts method implementations as std::function, allowing // them to be lambdas/regular C++ functions. CEL method descriptors are // deduced based on C++ function signatures. +// CelFunction::Evaluate will set result to the value returned by the handler. +// To handle errors, choose CelValue as the return type, and use the +// CreateError/Create* helpers in cel_value.h. // // Usage example: // From ae60aee73b82c0d3a4e7ba497b559ed9ba767596 Mon Sep 17 00:00:00 2001 From: kuat Date: Tue, 23 Mar 2021 21:01:19 -0700 Subject: [PATCH 17/23] Enforce that all durations are within the valid range. PiperOrigin-RevId: 364714161 --- conformance/BUILD | 1 - eval/public/builtin_func_registrar.cc | 6 ------ eval/public/builtin_func_test.cc | 16 ++++++++++++++-- eval/public/cel_value.cc | 17 +++++++++++++++++ eval/public/cel_value.h | 4 +--- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 114526512..e1b1bbf66 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -118,7 +118,6 @@ cc_binary( "--skip_test=enums/legacy_proto2/select_big,select_neg,assign_standalone_int_too_big,assign_standalone_int_too_neg", "--skip_test=enums/legacy_proto3/assign_standalone_int_too_big,assign_standalone_int_too_neg", # TODO(issues/118): Duration and timestamp range errors should result in evaluation errors. - "--skip_test=timestamps/duration_range/add_over,add_under", "--skip_test=timestamps/timestamp_range", # Future features for CEL 1.0 diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 42d88da2d..ceb99538b 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -745,12 +745,6 @@ CelValue CreateDurationFromString(Arena* arena, absl::StatusCode::kInvalidArgument); } - absl::Status status = google::api::expr::internal::ValidateDuration(d); - if (!status.ok()) { - return CreateErrorValue(arena, "Duration is out of range", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateDuration(d); } diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index fb913ee27..5cff08eda 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -577,6 +577,20 @@ TEST_F(BuiltinsTest, TestTimestampDurationArithmeticalOperation) { ASSERT_EQ(result_value.IsDuration(), true); ASSERT_EQ(absl::ToInt64Nanoseconds(result_value.DurationOrDie()), TimeUtil::DurationToNanoseconds(d0)); + + const auto min = CelValue::CreateDuration(MakeGoogleApiDurationMin()); + ASSERT_TRUE(min.IsDuration()); + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kSubtract, {}, + {min, CelValue::CreateDuration(absl::Nanoseconds(1))}, &result_value)); + ASSERT_TRUE(result_value.IsError()); + + const auto max = CelValue::CreateDuration(MakeGoogleApiDurationMax()); + ASSERT_TRUE(max.IsDuration()); + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kAdd, {}, {max, CelValue::CreateDuration(absl::Nanoseconds(1))}, + &result_value)); + ASSERT_TRUE(result_value.IsError()); } // Test functions for Duration @@ -616,12 +630,10 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); absl::Duration d = MakeGoogleApiDurationMin() + absl::Seconds(-1); - TestTypeConversionError(builtin::kString, CelValue::CreateDuration(d)); result = absl::FormatDuration(d); TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); d = MakeGoogleApiDurationMax() + absl::Seconds(1); - TestTypeConversionError(builtin::kString, CelValue::CreateDuration(d)); result = absl::FormatDuration(d); TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 2d47bdbb0..c81f7427d 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -42,8 +42,25 @@ constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; +// Exclusive bounds for valid duration values. +constexpr absl::Duration kDurationHigh = absl::Seconds(315576000001); +constexpr absl::Duration kDurationLow = absl::Seconds(-315576000001); + +const absl::Status* DurationOverflowError() { + static const auto* const kDurationOverflow = new absl::Status( + absl::StatusCode::kInvalidArgument, "Duration is out of range"); + return kDurationOverflow; +} + } // namespace +CelValue CelValue::CreateDuration(absl::Duration value) { + if (value >= kDurationHigh || value <= kDurationLow) { + return CelValue(DurationOverflowError()); + } + return CelValue(value); +} + std::string CelValue::TypeName(Type value_type) { switch (value_type) { case Type::kBool: diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index a7c34bdce..6b0fa400b 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -183,9 +183,7 @@ class CelValue { return CelValue(BytesHolder(str)); } - static CelValue CreateDuration(absl::Duration value) { - return CelValue(value); - } + static CelValue CreateDuration(absl::Duration value); static CelValue CreateTimestamp(absl::Time value) { return CelValue(value); } From 86db81c2e25659bb8832d11f135eb056bc24559d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 24 Mar 2021 08:58:06 -0700 Subject: [PATCH 18/23] Apply clang-tidy fixes PiperOrigin-RevId: 364814147 --- common/escaping.cc | 4 ++-- common/value.cc | 2 -- eval/eval/attribute_utility_test.cc | 1 - eval/eval/comprehension_step.cc | 2 +- eval/eval/container_access_step.cc | 4 ++-- eval/eval/function_step_test.cc | 3 ++- eval/eval/ternary_step.cc | 2 +- eval/public/cel_function_provider_test.cc | 1 - eval/public/extension_func_test.cc | 2 -- eval/public/set_util.cc | 2 +- eval/public/structs/BUILD | 2 +- eval/public/structs/cel_proto_wrapper.cc | 6 +++--- eval/public/testing/matchers.cc | 2 +- eval/public/value_export_util.cc | 4 ---- eval/tests/unknowns_end_to_end_test.cc | 8 ++++---- internal/holder_test.cc | 4 ++-- parser/parser_test.cc | 2 +- protoutil/type_registry.cc | 3 ++- testutil/expr_printer.cc | 8 ++++---- testutil/test_data_gen.cc | 2 +- testutil/test_data_util.cc | 2 +- tools/flatbuffers_backed_impl.cc | 2 +- 22 files changed, 30 insertions(+), 38 deletions(-) diff --git a/common/escaping.cc b/common/escaping.cc index 26b1367b8..98e7e8d28 100644 --- a/common/escaping.cc +++ b/common/escaping.cc @@ -241,7 +241,7 @@ inline std::tuple unescape_char( } if (value < 0x80 || !encode) { - char tmp[2] = {(char)value, '\0'}; + char tmp[2] = {static_cast(value), '\0'}; return std::make_tuple(std::string(tmp), s, ""); } else { char tmp[5]; @@ -294,7 +294,7 @@ absl::optional unescape(const std::string& s, bool is_bytes) { } value = value.substr(1, n - 2); // If there is nothing to escape, then return. - if (is_raw_literal || (value.find('\\') == std::string::npos)) { + if (is_raw_literal || (!absl::StrContains(value, '\\'))) { return value; } diff --git a/common/value.cc b/common/value.cc index 429d2b412..47eaab59c 100644 --- a/common/value.cc +++ b/common/value.cc @@ -20,8 +20,6 @@ namespace api { namespace expr { namespace common { -using ::google::api::expr::internal::NotFoundError; - namespace { static constexpr const Value::Kind kIndexToKind[] = { diff --git a/eval/eval/attribute_utility_test.cc b/eval/eval/attribute_utility_test.cc index b7a09d4a8..48bf9a901 100644 --- a/eval/eval/attribute_utility_test.cc +++ b/eval/eval/attribute_utility_test.cc @@ -15,7 +15,6 @@ namespace runtime { using google::api::expr::v1alpha1::Expr; using testing::Eq; -using testing::IsNull; using testing::NotNull; using testing::SizeIs; using testing::UnorderedPointwise; diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index a42cf822a..5f4f471a6 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -209,7 +209,7 @@ absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { class ListKeysStep : public ExpressionStepBase { public: - ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} + explicit ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 619a065d5..eea6cca70 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -22,7 +22,7 @@ constexpr int NUM_CONTAINER_ACCESS_ARGUMENTS = 2; // message. class ContainerAccessStep : public ExpressionStepBase { public: - ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + explicit ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; @@ -54,7 +54,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, break; } } - return CreateNoSuchKeyError(arena, absl::StrCat("Key not found in map")); + return CreateNoSuchKeyError(arena, "Key not found in map"); } inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index a27f33a8a..f3f210ab4 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -122,7 +122,8 @@ class AddFunction : public CelFunction { class SinkFunction : public CelFunction { public: - SinkFunction(CelValue::Type type) : CelFunction(CreateDescriptor(type)) {} + explicit SinkFunction(CelValue::Type type) + : CelFunction(CreateDescriptor(type)) {} static CelFunctionDescriptor CreateDescriptor(CelValue::Type type) { return CelFunctionDescriptor{"Sink", false, {type}}; diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index 420cf10e5..f63006711 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -17,7 +17,7 @@ namespace { class TernaryStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. - TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + explicit TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; diff --git a/eval/public/cel_function_provider_test.cc b/eval/public/cel_function_provider_test.cc index 8a6ff3e42..897cffe12 100644 --- a/eval/public/cel_function_provider_test.cc +++ b/eval/public/cel_function_provider_test.cc @@ -10,7 +10,6 @@ namespace expr { namespace runtime { namespace { -using testing::_; using testing::Eq; using testing::HasSubstr; using testing::Ne; diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index aad518b7e..479dcef58 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -13,8 +13,6 @@ namespace expr { namespace runtime { namespace { -using google::protobuf::Duration; -using google::protobuf::Timestamp; using google::protobuf::Arena; static const int kNanosPerSecond = 1000000000; diff --git a/eval/public/set_util.cc b/eval/public/set_util.cc index fd85903f1..885d9031f 100644 --- a/eval/public/set_util.cc +++ b/eval/public/set_util.cc @@ -89,7 +89,7 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { struct ComparisonVisitor { CelValue rhs; - ComparisonVisitor(CelValue rhs) : rhs(rhs) {} + explicit ComparisonVisitor(CelValue rhs) : rhs(rhs) {} template int operator()(T lhs_value) { T rhs_value; diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 3660d1e31..c5f701193 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -13,7 +13,7 @@ cc_library( deps = [ "//eval/public:cel_value", "//internal:proto_util", - "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 77d318c27..bcd0867d0 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -3,7 +3,7 @@ #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" -#include "absl/container/node_hash_map.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" @@ -146,7 +146,7 @@ CelValue ValueFromMessage(const Struct* struct_value, Arena* arena) { CelValue ValueFromMessage(const Any* any_value, Arena* arena) { auto type_url = any_value->type_url(); - auto pos = type_url.find_last_of("/"); + auto pos = type_url.find_last_of('/'); if (pos == absl::string_view::npos) { // TODO(issues/25) What error code? // Malformed type_url @@ -320,7 +320,7 @@ class ValueFromMessageMaker { factories_.emplace(desc, std::move(factory)); } - absl::node_hash_map> factories_; }; diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index 2315a0127..930df8d77 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -42,7 +42,7 @@ class CelValueEqualImpl : public MatcherInterface { template class CelValueMatcherImpl : public testing::MatcherInterface { public: - CelValueMatcherImpl(testing::Matcher m) + explicit CelValueMatcherImpl(testing::Matcher m) : underlying_type_matcher_(m) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index 56274ebf3..d1c9892a4 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -12,12 +12,8 @@ namespace expr { namespace runtime { using google::protobuf::Duration; -using google::protobuf::ListValue; -using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::Value; -using google::protobuf::FieldDescriptor; -using google::protobuf::Message; using google::protobuf::util::TimeUtil; absl::Status KeyAsString(const CelValue& value, std::string* key) { diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index 672846534..e709669b2 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -277,7 +277,7 @@ TEST_F(UnknownsTest, UnknownFunctions) { ASSERT_OK(maybe_response); CelValue response = maybe_response.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), @@ -302,7 +302,7 @@ TEST_F(UnknownsTest, UnknownsMerge) { ASSERT_OK(maybe_response); CelValue response = maybe_response.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), @@ -454,7 +454,7 @@ TEST_F(UnknownsCompTest, UnknownsMerge) { ASSERT_OK(eval_status); CelValue response = eval_status.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), @@ -589,7 +589,7 @@ TEST_F(UnknownsCompCondTest, UnknownConditionReturned) { ASSERT_OK(eval_status); CelValue response = eval_status.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); // The comprehension ends on the first non-bool condition, so we only get one // call captured in the UnknownSet. EXPECT_THAT(response.UnknownSetOrDie() diff --git a/internal/holder_test.cc b/internal/holder_test.cc index ac06d043c..4cbf51331 100644 --- a/internal/holder_test.cc +++ b/internal/holder_test.cc @@ -149,7 +149,7 @@ TEST(Holder, UnownedPtr) { EXPECT_TRUE(std::is_move_assignable::value); // Null cannot be accessed. - HolderType holder(static_cast(0)); + HolderType holder(static_cast(nullptr)); #ifndef NDEBUG // Assert only throws when debugging. EXPECT_DEATH(holder.value(), "null"); holder = HolderType(nullptr); @@ -185,7 +185,7 @@ TEST(Holder, UnownedPtr_const) { EXPECT_TRUE(std::is_move_assignable::value); // Null cannot be accessed. - HolderType holder(static_cast(0)); + HolderType holder(static_cast(nullptr)); #ifndef NDEBUG // Assert only throws when debugging. EXPECT_DEATH(holder.value(), "null"); holder = HolderType(nullptr); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 860e7f813..f93efa681 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -864,7 +864,7 @@ class KindAndIdAdorner : public testutil::ExpressionAdorner { class LocationAdorner : public testutil::ExpressionAdorner { public: - LocationAdorner(const google::api::expr::v1alpha1::SourceInfo& source_info) + explicit LocationAdorner(const google::api::expr::v1alpha1::SourceInfo& source_info) : source_info_(source_info) {} absl::optional> getLocation(int64_t id) const { diff --git a/protoutil/type_registry.cc b/protoutil/type_registry.cc index ba5c7854a..0bfc15e24 100644 --- a/protoutil/type_registry.cc +++ b/protoutil/type_registry.cc @@ -150,7 +150,8 @@ template class UnrecognizedMessageObject final : public common::Object { public: template - UnrecognizedMessageObject(T&& value) : holder_(std::forward(value)) {} + explicit UnrecognizedMessageObject(T&& value) + : holder_(std::forward(value)) {} common::Value GetMember(absl::string_view name) const override { return common::Value::FromError( diff --git a/testutil/expr_printer.cc b/testutil/expr_printer.cc index a8b5bde9a..8b618ede3 100644 --- a/testutil/expr_printer.cc +++ b/testutil/expr_printer.cc @@ -15,11 +15,11 @@ using ::google::api::expr::v1alpha1::Expr; class EmptyAdorner : public ExpressionAdorner { public: - ~EmptyAdorner() {} + ~EmptyAdorner() override {} - std::string adorn(const Expr& e) const { return ""; } + std::string adorn(const Expr& e) const override { return ""; } - std::string adorn(const Expr::CreateStruct::Entry& e) const { + std::string adorn(const Expr::CreateStruct::Entry& e) const override { return ""; } }; @@ -28,7 +28,7 @@ const EmptyAdorner the_empty_adorner; class Writer { public: - Writer(const ExpressionAdorner& adorner) + explicit Writer(const ExpressionAdorner& adorner) : adorner_(adorner), line_start_(true), indent_(0) {} void appendExpr(const Expr& e) { diff --git a/testutil/test_data_gen.cc b/testutil/test_data_gen.cc index 1701442e2..13c6f9d58 100644 --- a/testutil/test_data_gen.cc +++ b/testutil/test_data_gen.cc @@ -86,7 +86,7 @@ TestData UniqueValues() { expr::internal::MakeGoogleApiDurationMin() + absl::Nanoseconds(1), "min+1"); add_val = NewValue(absl::Nanoseconds(-1), "-1"); - add_val = NewValue(absl::Nanoseconds(0), "0"); + add_val = NewValue(absl::ZeroDuration(), "0"); add_val = NewValue(absl::Nanoseconds(1), "1"); add_val = NewValue( expr::internal::MakeGoogleApiDurationMax() - absl::Nanoseconds(1), diff --git a/testutil/test_data_util.cc b/testutil/test_data_util.cc index fc8f894ab..b18e2d39d 100644 --- a/testutil/test_data_util.cc +++ b/testutil/test_data_util.cc @@ -345,7 +345,7 @@ TestValue NewBytesValue(absl::string_view value, absl::string_view name) { // wrapped values. google::protobuf::BytesValue wrapped_bytes; - wrapped_bytes.set_value(std::string(value)); + wrapped_bytes.set_value(value.data()); *AddProto(&result, "wrapped_bytes")->mutable_wrapped_bytes() = wrapped_bytes; AddProto(&result, "single_any") ->mutable_single_any() diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 55f3e4852..ffeb06044 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -34,7 +34,7 @@ class FlatBuffersListImpl : public CelList { class StringListImpl : public CelList { public: - StringListImpl( + explicit StringListImpl( const flatbuffers::Vector>* list) : list_(list) {} int size() const override { return list_ ? list_->size() : 0; } From f3114eedefce399b27ae5befa8b696762247d4b2 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 25 Mar 2021 08:37:49 -0700 Subject: [PATCH 19/23] Remove some unnecessary copying by using or fixing move semantics PiperOrigin-RevId: 365044313 --- eval/public/BUILD | 2 ++ eval/public/cel_function.h | 16 +++++++++++----- eval/public/cel_function_adapter.h | 14 ++++++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 23ad142f4..a4214ce0b 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -122,6 +122,8 @@ cc_library( ], deps = [ ":cel_value", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index 7e1dc0275..d789a02d9 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -1,8 +1,12 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ +#include +#include #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "eval/public/cel_value.h" @@ -15,9 +19,11 @@ namespace runtime { // This complex structure is needed for overloads support. class CelFunctionDescriptor { public: - CelFunctionDescriptor(const std::string& name, bool receiver_style, - const std::vector types) - : name_(name), receiver_style_(receiver_style), types_(types) {} + CelFunctionDescriptor(absl::string_view name, bool receiver_style, + std::vector types) + : name_(name), + receiver_style_(receiver_style), + types_(std::move(types)) {} // Function name. const std::string& name() const { return name_; } @@ -56,8 +62,8 @@ class CelFunctionDescriptor { class CelFunction { public: // Build CelFunction from descriptor - explicit CelFunction(const CelFunctionDescriptor& descriptor) - : descriptor_(descriptor) {} + explicit CelFunction(CelFunctionDescriptor descriptor) + : descriptor_(std::move(descriptor)) {} // Non-copyable CelFunction(const CelFunction& other) = delete; diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index 4f7d6e081..7bca2ca96 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -1,6 +1,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ + #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -84,13 +86,14 @@ class FunctionAdapter : public CelFunction { public: using FuncType = std::function; - FunctionAdapter(const CelFunctionDescriptor& descriptor, FuncType handler) - : CelFunction(descriptor), handler_(std::move(handler)) {} + FunctionAdapter(CelFunctionDescriptor descriptor, FuncType handler) + : CelFunction(std::move(descriptor)), handler_(std::move(handler)) {} static absl::StatusOr> Create( absl::string_view name, bool receiver_type, std::function handler) { std::vector arg_types; + arg_types.reserve(sizeof...(Arguments)); if (!internal::AddType<0, Arguments...>(&arg_types)) { return absl::Status( @@ -99,10 +102,9 @@ class FunctionAdapter : public CelFunction { ": failed to determine input parameter type")); } - std::unique_ptr cel_func = absl::make_unique( - CelFunctionDescriptor(std::string(name), receiver_type, arg_types), + return absl::make_unique( + CelFunctionDescriptor(name, receiver_type, std::move(arg_types)), std::move(handler)); - return std::move(cel_func); } // Creates function handler and attempts to register it with @@ -116,7 +118,7 @@ class FunctionAdapter : public CelFunction { return status.status(); } - return registry->Register(std::move(status.value())); + return registry->Register(std::move(status).value()); } #if defined(__clang_major_version__) && __clang_major_version__ >= 8 && !defined(__APPLE__) From 958554e3f716be3ae43128f7872cc43ba7872e74 Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 25 Mar 2021 15:17:46 -0700 Subject: [PATCH 20/23] Fix a timeout which occurs due to an inefficient string replacement algo used internal to Antlr. Replaced default implementation with `absl::StrReplaceAll` PiperOrigin-RevId: 365135683 --- parser/BUILD | 1 + parser/parser.cc | 21 +++++++++++++++++++++ parser/parser_test.cc | 8 ++++++++ 3 files changed, 30 insertions(+) diff --git a/parser/BUILD b/parser/BUILD index 3f2be39f1..23e382c8b 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -29,6 +29,7 @@ cc_library( ":visitor", "@antlr4_runtimes//:cpp", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/parser/parser.cc b/parser/parser.cc index cedce38bd..9fb31e118 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -2,6 +2,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" #include "absl/types/optional.h" #include "parser/cel_grammar.inc/cel_grammar/CelLexer.h" #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" @@ -31,6 +32,17 @@ using ::cel_grammar::CelParser; namespace { +// Replacements for absl::StrReplaceAll for escaping standard whitespace +// characters. +static constexpr auto kStandardReplacements = + std::array, 3>{ + std::make_pair("\n", "\\n"), + std::make_pair("\r", "\\r"), + std::make_pair("\t", "\\t"), + }; + +static constexpr absl::string_view kSingleQuote = "'"; + // ExprRecursionListener extends the standard ANTLR CelParser to ensure that // recursive entries into the 'expr' rule are limited to a configurable depth so // as to prevent stack overflows. @@ -87,6 +99,15 @@ class RecoveryLimitErrorStrategy : public antlr4::DefaultErrorStrategy { return antlr4::DefaultErrorStrategy::recoverInline(recognizer); } + protected: + std::string escapeWSAndQuote(const std::string& s) const override { + std::string result; + result.reserve(s.size() + 2); + absl::StrAppend(&result, kSingleQuote, s, kSingleQuote); + absl::StrReplaceAll(kStandardReplacements, &result); + return result; + } + private: void checkRecoveryLimit(antlr4::Parser* recognizer) { if (recovery_attempts_++ >= recovery_limit_) { diff --git a/parser/parser_test.cc b/parser/parser_test.cc index f93efa681..1b9c3ca2c 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -821,6 +821,14 @@ std::vector test_cases = { "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[['just fine'],[1],[2],[3],[4],[5]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]", "" // parse output not validated as it is too large. + }, + { + "[\n\t\r[\n\t\r[\n\t\r]\n\t\r]\n\t\r", + "", // parse output not validated as it is too large. + "ERROR: :6:3: Syntax error: mismatched input '' expecting " + "{']', ','}\n" + " | \r\n" + " | ..^", }}; class KindAndIdAdorner : public testutil::ExpressionAdorner { From e88e8fee77173e52e2da5fd185548f4cd434e63d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 30 Mar 2021 08:37:17 -0700 Subject: [PATCH 21/23] Replace internal/port.h with absl/meta/type_traits.h PiperOrigin-RevId: 365817691 --- common/BUILD | 2 ++ common/converters.h | 5 +-- common/value_test.cc | 7 +++-- internal/BUILD | 15 +++------ internal/cast.h | 4 +-- internal/hash_util.h | 2 +- internal/holder.h | 4 +-- internal/port.h | 66 --------------------------------------- internal/ref_countable.h | 6 ++-- internal/types.h | 41 ++++++++++++------------ internal/value_internal.h | 22 ++++++------- 11 files changed, 55 insertions(+), 119 deletions(-) delete mode 100644 internal/port.h diff --git a/common/BUILD b/common/BUILD index 1a5342a5d..04eb740f6 100644 --- a/common/BUILD +++ b/common/BUILD @@ -196,6 +196,7 @@ cc_test( "//internal:types", "//internal:value_internal", "//testutil:util", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_googleapis//google/type:money_cc_proto", @@ -218,6 +219,7 @@ cc_library( "//internal:list_impl", "//internal:map_impl", "//internal:types", + "@com_google_absl//absl/meta:type_traits", ], ) diff --git a/common/converters.h b/common/converters.h index 2a1bc59a7..c8836ed9b 100644 --- a/common/converters.h +++ b/common/converters.h @@ -5,6 +5,7 @@ #include +#include "absl/meta/type_traits.h" #include "common/parent_ref.h" #include "common/value.h" #include "internal/list_impl.h" @@ -53,7 +54,7 @@ template Value ValueFromList(T&& value) { static_assert(!std::is_pointer::value, "use ValueForList"); return Value::MakeList< - internal::ListWrapper>>( + internal::ListWrapper>>( std::forward(value)); } @@ -92,7 +93,7 @@ template Value ValueFromMap(T&& value) { static_assert(!std::is_pointer::value, "use ValueForList"); return Value::MakeMap< - internal::MapWrapper>>( + internal::MapWrapper>>( std::forward(value)); } diff --git a/common/value_test.cc b/common/value_test.cc index 3a6a60c3a..95b8646e3 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -10,6 +10,7 @@ #include "google/type/money.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" #include "common/custom_object.h" #include "internal/status_util.h" @@ -709,7 +710,7 @@ class ValueVisitTest : public ::testing::Test { template void TestGetPtrVisitor() { - using T = remove_reference_t())>; + using T = absl::remove_reference_t())>; ExpectSameType>(); // Return by const ref. ExpectSameType>(); @@ -717,7 +718,7 @@ class ValueVisitTest : public ::testing::Test { template void TestGetVisitor() { - using T = remove_reference_t())>; + using T = absl::remove_reference_t())>; // Return by optional. ExpectSameType, GetVisitorType, T>>(); @@ -727,7 +728,7 @@ class ValueVisitTest : public ::testing::Test { template void TestValueAdapter() { - using T = remove_reference_t())>; + using T = absl::remove_reference_t())>; ExpectSameType>(); ExpectSameType>(); } diff --git a/internal/BUILD b/internal/BUILD index a6fc4d3a2..1c2e70b72 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -16,6 +16,7 @@ cc_library( deps = [ ":holder", ":specialize", + "@com_google_absl//absl/meta:type_traits", ], ) @@ -57,10 +58,10 @@ cc_library( "holder.h", ], deps = [ - ":port", ":specialize", ":types", ":visitor_util", + "@com_google_absl//absl/meta:type_traits", ], ) @@ -85,8 +86,8 @@ cc_library( "hash_util.h", ], deps = [ - ":port", ":specialize", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googleapis//google/rpc:status_cc_proto", @@ -129,7 +130,6 @@ cc_test( ":visitor_util", "//testutil:util", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) @@ -236,11 +236,6 @@ cc_test( ], ) -cc_library( - name = "port", - hdrs = ["port.h"], -) - cc_library( name = "specialize", hdrs = ["specialize.h"], @@ -250,10 +245,10 @@ cc_library( name = "cast", hdrs = ["cast.h"], deps = [ - ":port", ":specialize", ":types", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", ], ) @@ -273,9 +268,9 @@ cc_library( "types.h", ], deps = [ - ":port", ":specialize", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", ], ) diff --git a/internal/cast.h b/internal/cast.h index f5c3349b4..6001ab512 100644 --- a/internal/cast.h +++ b/internal/cast.h @@ -5,8 +5,8 @@ #include #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/types/optional.h" -#include "internal/port.h" #include "internal/specialize.h" #include "internal/types.h" @@ -63,7 +63,7 @@ struct StaticDownCastHelper> { template struct RepresentableAsHelper { static constexpr bool check(const U&) { - return std::is_same, U>::value; + return std::is_same, U>::value; } }; diff --git a/internal/hash_util.h b/internal/hash_util.h index 655c75c1c..dda25586f 100644 --- a/internal/hash_util.h +++ b/internal/hash_util.h @@ -6,9 +6,9 @@ #include "google/protobuf/any.pb.h" #include "google/rpc/status.pb.h" +#include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "internal/port.h" #include "internal/specialize.h" namespace google { diff --git a/internal/holder.h b/internal/holder.h index 76833aef1..8e6370c9c 100644 --- a/internal/holder.h +++ b/internal/holder.h @@ -4,7 +4,7 @@ #include #include -#include "internal/port.h" +#include "absl/meta/type_traits.h" #include "internal/specialize.h" #include "internal/types.h" #include "internal/visitor_util.h" @@ -69,7 +69,7 @@ struct Copy : BaseHolderPolicy { constexpr static const bool kOwnsValue = true; template - using ValueType = remove_const_t; + using ValueType = absl::remove_const_t; template static T& get(T& value) { diff --git a/internal/port.h b/internal/port.h deleted file mode 100644 index 07473fed7..000000000 --- a/internal/port.h +++ /dev/null @@ -1,66 +0,0 @@ -// This files is a forwarding header for other headers containing various -// portability macros and functions. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ - -#include - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// Back port some helpers. -// Defined in std as of c++14. -template -using decay_t = typename std::decay::type; -template -using enable_if_t = typename std::enable_if::type; -template -using conditional_t = typename std::conditional::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_cv_t = typename std::remove_cv::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_volatile_t = typename std::remove_volatile::type; - -// Defined in std as of c++17 -template -struct conjunction : std::true_type {}; -template -struct conjunction : T {}; -template -struct conjunction - : std::conditional, T>::type {}; -template -struct disjunction : std::false_type {}; -template -struct disjunction : B1 {}; -template -struct disjunction - : conditional_t> {}; -template -using bool_constant = std::integral_constant; -template -struct negation : bool_constant(B::value)> {}; - -// Defined in std as of c++20 -template -struct remove_cvref { - typedef remove_cv_t> type; -}; -template -using remove_cvref_t = typename remove_cvref::type; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ diff --git a/internal/ref_countable.h b/internal/ref_countable.h index 514cbce88..e455d2c16 100644 --- a/internal/ref_countable.h +++ b/internal/ref_countable.h @@ -6,6 +6,7 @@ #include #include +#include "absl/meta/type_traits.h" #include "internal/holder.h" #include "internal/specialize.h" @@ -240,8 +241,9 @@ using RefCopyHolder = Holder>; // If the size of T is smaller than MAX, it is stored inline, otherwise // it is stored in a heap allocated, reference counted holder. template -using SizeLimitHolder = conditional_t, - Holder>>; +using SizeLimitHolder = + absl::conditional_t, + Holder>>; template void ReffedPtr::reset() { diff --git a/internal/types.h b/internal/types.h index d2bd2a270..55e9f3a38 100644 --- a/internal/types.h +++ b/internal/types.h @@ -7,8 +7,8 @@ #include #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" -#include "internal/port.h" #include "internal/specialize.h" namespace google { @@ -18,11 +18,11 @@ namespace internal { // Short names for AND, OR and NOT. template -using and_t = conjunction; +using and_t = absl::conjunction; template -using or_t = disjunction; +using or_t = absl::disjunction; template -using not_t = negation; +using not_t = absl::negation; // Holder for a static list of types. template @@ -33,7 +33,7 @@ template using types_size = std::tuple_size; template -using types_cat = decltype(std::tuple_cat(inst_of>()...)); +using types_cat = decltype(std::tuple_cat(inst_of>()...)); // Helper that resolves to the Ith type in T. template @@ -77,7 +77,7 @@ template using type_is = std::is_same; template -using type_not = negation>; +using type_not = absl::negation>; /** * Tests if a type is a raw or smart pointer. @@ -85,39 +85,39 @@ using type_not = negation>; * Any type that defines overloads for * and -> is considered a smart pointer. */ template -struct is_ptr : std::is_pointer> {}; +struct is_ptr : std::is_pointer> {}; template -struct is_ptr::operator*), - decltype(&decay_t::operator->)>> +struct is_ptr::operator*), + decltype(&absl::decay_t::operator->)>> : std::true_type {}; /** * Tests if a type is convertible to absl::string_view. */ template -using is_string = and_t, std::nullptr_t>, +using is_string = and_t, std::nullptr_t>, std::is_convertible>; /** * Tests if a type is a signed integer type. */ template -using is_int = conjunction>, - std::is_signed>>; +using is_int = absl::conjunction>, + std::is_signed>>; /** * Tests if a type is an unsigned integer type. */ template -using is_uint = - conjunction, bool>, std::is_integral>, - std::is_unsigned>>; +using is_uint = absl::conjunction, bool>, + std::is_integral>, + std::is_unsigned>>; /** * Tests if a type is a floating point type. */ template -using is_float = std::is_floating_point>; +using is_float = std::is_floating_point>; template using is_numeric = or_t, is_uint, is_float>; @@ -128,20 +128,21 @@ using is_numeric = or_t, is_uint, is_float>; template struct is_container : public std::false_type {}; template -struct is_container::value_type, - typename remove_cvref_t::iterator>> +struct is_container::value_type, + typename absl::remove_cvref_t::iterator>> : public std::true_type {}; // Maps are containers that also define a "mapped_type". template struct is_map : public std::false_type {}; template -struct is_map::mapped_type>> +struct is_map::mapped_type>> : public is_container {}; // Lists are containers that are not maps. template -using is_list = bool_constant::value && !is_map::value>; +using is_list = std::bool_constant::value && !is_map::value>; // Used to create a compiler error when a specialized function/class is // instantiated with an unsupported type. diff --git a/internal/value_internal.h b/internal/value_internal.h index f9cee1334..c9dae2a83 100644 --- a/internal/value_internal.h +++ b/internal/value_internal.h @@ -54,12 +54,12 @@ class BaseValue { protected: // The return type of `Value::get_if` call for T. template - using GetIfType = - conditional_t::value, const T*, absl::optional>; + using GetIfType = absl::conditional_t::value, const T*, + absl::optional>; // The return type of `Value::get` call for T. template - using GetType = conditional_t::value, const T&, T>; + using GetType = absl::conditional_t::value, const T&, T>; // If the string is 8 bytes, it is assumed to be copy-on-write and is stored // inline. Otherwise it is held in a ref-counted container. @@ -178,18 +178,18 @@ class BaseValue { friend class ValueAdapterTest; template - using NumericValueType = - conditional_t::value, int64_t, - conditional_t::value, uint64_t, double>>; + using NumericValueType = absl::conditional_t< + is_int::value, int64_t, + absl::conditional_t::value, uint64_t, double>>; template - using CustomValueType = conditional_t< + using CustomValueType = absl::conditional_t< std::is_convertible::value, common::Map, - conditional_t::value, - common::List, common::Object>>; + absl::conditional_t::value, + common::List, common::Object>>; template - using HolderType = conditional_t::value, - CopyHolder, RefCopyHolder>; + using HolderType = absl::conditional_t::value, + CopyHolder, RefCopyHolder>; // A ValueData visitor that, for any type in 'Alts', returns the associated // value as T. From f433eb33ff0981a5846a5fc4db0e55e172b1f81e Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 30 Mar 2021 09:39:56 -0700 Subject: [PATCH 22/23] Well-known type wrapping support for field assignments This change increases cel-cpp's conformance coverage and pass rate to 94% PiperOrigin-RevId: 365830271 --- conformance/BUILD | 27 +- eval/eval/create_struct_step.cc | 10 +- eval/public/BUILD | 1 - eval/public/builtin_func_registrar.cc | 26 +- eval/public/cel_type_registry.cc | 14 +- eval/public/cel_type_registry_test.cc | 1 + eval/public/containers/BUILD | 1 - eval/public/containers/field_access.cc | 109 +-- eval/public/containers/field_access.h | 17 +- eval/public/containers/field_access_test.cc | 23 +- eval/public/structs/BUILD | 7 +- eval/public/structs/cel_proto_wrapper.cc | 514 +++++++++- eval/public/structs/cel_proto_wrapper.h | 13 + eval/public/structs/cel_proto_wrapper_test.cc | 881 ++++++++++-------- eval/public/transform_utility.cc | 2 +- eval/tests/BUILD | 1 + eval/tests/end_to_end_test.cc | 52 ++ internal/BUILD | 1 + internal/proto_util.cc | 19 + internal/proto_util.h | 7 + 20 files changed, 1218 insertions(+), 508 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index e1b1bbf66..8592cec0c 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -91,27 +91,10 @@ cc_binary( "--skip_test=fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "--skip_test=namespace/qualified/self_eval_qualified_lookup", "--skip_test=namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", - # TODO(issues/96): Well-known type conversion support. - "--skip_test=dynamic/int32/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", - "--skip_test=dynamic/int64/field_assign_proto2,field_assign_proto2_zero,field_assign_proto3,field_assign_proto3_zero", - "--skip_test=dynamic/uint32/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", - "--skip_test=dynamic/uint64/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero", - "--skip_test=dynamic/float/field_assign_proto2,field_assign_proto2_zero,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", - "--skip_test=dynamic/double/field_assign_proto2,field_assign_proto2_zero,field_assign_proto2_range,field_read_proto2,field_read_proto2_zero,field_read_proto2_unset,field_assign_proto3,field_assign_proto3_zero,field_assign_proto3_range,field_read_proto3,field_read_proto3_zero,field_read_proto3_unset", - "--skip_test=dynamic/string/field_assign_proto2,field_assign_proto2_empty,field_assign_proto3,field_assign_proto3_empty", - "--skip_test=dynamic/bytes/field_assign_proto2,field_assign_proto2_empty,field_assign_proto3,field_assign_proto3_empty", - "--skip_test=dynamic/bool/field_assign_proto2,field_assign_proto2_false,field_assign_proto3,field_assign_proto3_false", - "--skip_test=dynamic/list", - "--skip_test=dynamic/struct", - "--skip_test=dynamic/value_null", - "--skip_test=dynamic/value_number", - "--skip_test=dynamic/value_string", - "--skip_test=dynamic/value_bool", - "--skip_test=dynamic/value_struct", - "--skip_test=dynamic/value_list", - "--skip_test=dynamic/complex/any_list_map", - "--skip_test=proto2/literal_wellknown", - "--skip_test=proto3/literal_wellknown", + # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. + "--skip_test=dynamic/list/var", + # TODO(issues/109): Ensure that unset wrapper fields return 'null' rather than the default value of the wrapper. + "--skip_test=dynamic/int32/field_read_proto2_unset,field_read_proto3_unset;uint32/field_read_proto2_unset;uint64/field_read_proto2_unset;float/field_read_proto2_unset,field_read_proto3_unset;double/field_read_proto2_unset,field_read_proto3_unset", "--skip_test=proto2/empty_field/wkt", "--skip_test=proto3/empty_field/wkt", # TODO(issues/117): Integer overflow on enum assignments should error. @@ -124,6 +107,8 @@ cc_binary( # TODO(issues/119): Strong typing support for enums, specified but not implemented. "--skip_test=enums/strong_proto2", "--skip_test=enums/strong_proto3", + # Bad tests, temporarily disable. + "--skip_test=dynamic/float/field_assign_proto2_range,field_assign_proto3_range", ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 5e4247071..fd9c79dec 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -147,12 +147,13 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } Message* entry_msg = msg->GetReflection()->AddMessage(msg, entry.field); - status = SetValueToSingleField(key, key_field_descriptor, entry_msg); + status = SetValueToSingleField(key, key_field_descriptor, entry_msg, + frame->arena()); if (!status.ok()) { break; } status = SetValueToSingleField(value.value(), value_field_descriptor, - entry_msg); + entry_msg, frame->arena()); if (!status.ok()) { break; } @@ -170,11 +171,12 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } for (int i = 0; i < cel_list->size(); i++) { - status = AddValueToRepeatedField((*cel_list)[i], entry.field, msg); + status = AddValueToRepeatedField((*cel_list)[i], entry.field, msg, + frame->arena()); if (!status.ok()) break; } } else { - status = SetValueToSingleField(arg, entry.field, msg); + status = SetValueToSingleField(arg, entry.field, msg, frame->arena()); } if (!status.ok()) { diff --git a/eval/public/BUILD b/eval/public/BUILD index a4214ce0b..d3a25becd 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -189,7 +189,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], ) diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index ceb99538b..fd0e574b8 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -4,7 +4,6 @@ #include #include -#include "google/protobuf/util/time_util.h" #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/strings/match.h" @@ -31,7 +30,7 @@ using google::protobuf::Arena; namespace { const int64_t kIntMax = std::numeric_limits::max(); -const int64_t kIntMin = std::numeric_limits::min(); +const int64_t kIntMin = std::numeric_limits::lowest(); const uint64_t kUintMax = std::numeric_limits::max(); // Returns the number of UTF8 codepoints within a string. @@ -1219,14 +1218,14 @@ absl::Status RegisterStringConversionFunctions( status = FunctionAdapter::CreateAndRegister( builtin::kString, false, [](Arena* arena, absl::Duration value) -> CelValue { - google::protobuf::Duration d; - auto status = google::api::expr::internal::EncodeDuration(value, &d); - if (!status.ok()) { + auto encode = + google::api::expr::internal::EncodeDurationToString(value); + if (!encode.ok()) { + const auto& status = encode.status(); return CreateErrorValue(arena, status.message(), status.code()); } - return CelValue::CreateString( - CelValue::StringHolder(Arena::Create( - arena, google::protobuf::util::TimeUtil::ToString(d)))); + return CelValue::CreateString(CelValue::StringHolder( + Arena::Create(arena, encode.value()))); }, registry); if (!status.ok()) return status; @@ -1235,14 +1234,13 @@ absl::Status RegisterStringConversionFunctions( status = FunctionAdapter::CreateAndRegister( builtin::kString, false, [](Arena* arena, absl::Time value) -> CelValue { - google::protobuf::Timestamp ts; - auto status = google::api::expr::internal::EncodeTime(value, &ts); - if (!status.ok()) { + auto encode = google::api::expr::internal::EncodeTimeToString(value); + if (!encode.ok()) { + const auto& status = encode.status(); return CreateErrorValue(arena, status.message(), status.code()); } - return CelValue::CreateString( - CelValue::StringHolder(Arena::Create( - arena, google::protobuf::util::TimeUtil::ToString(ts)))); + return CelValue::CreateString(CelValue::StringHolder( + Arena::Create(arena, encode.value()))); }, registry); if (!status.ok()) return status; diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 565ae3e27..7f19a1ff9 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -1,6 +1,8 @@ #include "eval/public/cel_type_registry.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/descriptor.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/status/status.h" #include "absl/types/optional.h" @@ -30,9 +32,19 @@ const absl::node_hash_set& GetCoreTypes() { return *kCoreTypes; } +const absl::flat_hash_set GetCoreEnums() { + static const auto* const kCoreEnums = + new absl::flat_hash_set{ + // Register the NULL_VALUE enum. + google::protobuf::NullValue_descriptor(), + }; + return *kCoreEnums; +} + } // namespace -CelTypeRegistry::CelTypeRegistry() : types_(GetCoreTypes()), enums_() {} +CelTypeRegistry::CelTypeRegistry() + : types_(GetCoreTypes()), enums_(GetCoreEnums()) {} void CelTypeRegistry::Register(std::string fully_qualified_type_name) { // Registers the fully qualified type name as a CEL type. diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 3117722da..c194bec9f 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -24,6 +24,7 @@ TEST(CelTypeRegistryTest, TestRegisterEnumDescriptor) { enum_set.insert(enum_desc->full_name()); } absl::flat_hash_set expected_set; + expected_set.insert({"google.protobuf.NullValue"}); expected_set.insert({"google.api.expr.runtime.TestMessage.TestEnum"}); EXPECT_THAT(enum_set, Eq(expected_set)); } diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index d7dcdc689..cbe409d2e 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -17,7 +17,6 @@ cc_library( deps = [ "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", - "//internal:proto_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index 88cd2436d..fe208c3e1 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -3,12 +3,15 @@ #include #include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/arena.h" #include "google/protobuf/map_field.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/proto_util.h" namespace google { namespace api { @@ -24,11 +27,8 @@ using ::google::protobuf::Message; using ::google::protobuf::Reflection; // Well-known type protobuf type names which require special get / set behavior. -constexpr const char kProtobufDuration[] = "google.protobuf.Duration"; -constexpr const char kProtobufTimestamp[] = "google.protobuf.Timestamp"; -constexpr const char kProtobufAny[] = "google.protobuf.Any"; - -const char kTypeGoogleApisComPrefix[] = "type.googleapis.com/"; +constexpr absl::string_view kProtobufAny = "google.protobuf.Any"; +constexpr absl::string_view kTypeGoogleApisComPrefix = "type.googleapis.com/"; // Singular message fields and repeated message fields have similar access model // To provide common approach, we implement accessor classes, based on CRTP. @@ -197,6 +197,8 @@ class ScalarFieldAccessor : public FieldAccessor { } const Message* GetMessage() const { + // TODO(issues/109): When the field descriptor is a wrapper type, check if + // the field is set. If set, return the unwrapped value, else return 'null'. return &GetReflection()->GetMessage(*msg_, field_desc_); } @@ -456,7 +458,6 @@ class FieldSetter { MessageRetrieverOp()); if (!value.has_value()) { - GOOGLE_LOG(ERROR) << "Has No Value"; return false; } @@ -464,36 +465,6 @@ class FieldSetter { return true; } - bool AssignDuration(const CelValue& cel_value) const { - absl::Duration d; - if (!cel_value.GetValue(&d)) { - GOOGLE_LOG(ERROR) << "Unable to retrieve duration"; - return false; - } - google::protobuf::Duration duration; - auto status = google::api::expr::internal::EncodeDuration(d, &duration); - if (!status.ok()) { - return false; - } - static_cast(this)->SetMessage(&duration); - return true; - } - - bool AssignTimestamp(const CelValue& cel_value) const { - absl::Time t; - if (!cel_value.GetValue(&t)) { - GOOGLE_LOG(ERROR) << "Unable to retrieve timestamp"; - return false; - } - google::protobuf::Timestamp timestamp; - auto status = google::api::expr::internal::EncodeTime(t, ×tamp); - if (!status.ok()) { - return false; - } - static_cast(this)->SetMessage(×tamp); - return true; - } - // This method provides message field content, wrapped in CelValue. // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. @@ -534,16 +505,14 @@ class FieldSetter { break; } case FieldDescriptor::CPPTYPE_MESSAGE: { - const std::string& type_name = field_desc_->message_type()->full_name(); + const absl::string_view type_name = + field_desc_->message_type()->full_name(); // When the field is a message, it might be a well-known type with a // non-proto representation that requires special handling before it // can be set on the field. - if (type_name == kProtobufTimestamp) { - return AssignTimestamp(value); - } else if (type_name == kProtobufDuration) { - return AssignDuration(value); - } - return AssignMessage(value); + auto wrapped_value = + CelProtoWrapper::MaybeWrapValue(type_name, value, arena_); + return AssignMessage(wrapped_value.value_or(value)); } case FieldDescriptor::CPPTYPE_ENUM: { return AssignEnum(value); @@ -556,18 +525,20 @@ class FieldSetter { } protected: - FieldSetter(Message* msg, const FieldDescriptor* field_desc) - : msg_(msg), field_desc_(field_desc) {} + FieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) + : msg_(msg), field_desc_(field_desc), arena_(arena) {} Message* msg_; const FieldDescriptor* field_desc_; + Arena* arena_; }; // Accessor class, to work with singular fields class ScalarFieldSetter : public FieldSetter { public: - ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc) - : FieldSetter(msg, field_desc) {} + ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} bool SetBool(bool value) const { GetReflection()->SetBool(msg_, field_desc_, value); @@ -651,8 +622,9 @@ class ScalarFieldSetter : public FieldSetter { // Appender class, to work with repeated fields class RepeatedFieldSetter : public FieldSetter { public: - RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc) - : FieldSetter(msg, field_desc) {} + RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} bool SetBool(bool value) const { GetReflection()->AddBool(msg_, field_desc_, value); @@ -724,8 +696,9 @@ class RepeatedFieldSetter : public FieldSetter { // arena Arena to use for allocations if needed. // result pointer to object to store value in. absl::Status SetValueToSingleField(const CelValue& value, - const FieldDescriptor* desc, Message* msg) { - ScalarFieldSetter setter(msg, desc); + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + ScalarFieldSetter setter(msg, desc, arena); return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( @@ -736,9 +709,9 @@ absl::Status SetValueToSingleField(const CelValue& value, } absl::Status AddValueToRepeatedField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { - RepeatedFieldSetter setter(msg, desc); + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + RepeatedFieldSetter setter(msg, desc, arena); return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( @@ -748,32 +721,6 @@ absl::Status AddValueToRepeatedField(const CelValue& value, value.DebugString())); } -absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const FieldDescriptor* desc, Message* msg) { - auto entry_msg = msg->GetReflection()->AddMessage(msg, desc); - auto key_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(1); - auto value_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(2); - - ScalarFieldSetter key_setter(entry_msg, key_field_desc); - ScalarFieldSetter value_setter(entry_msg, value_field_desc); - - if (!key_setter.SetFieldFromCelValue(key)) { - return absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument \"$2\" to message " - "\"$0\" field \"$1\" map key.", - msg->GetDescriptor()->name(), desc->name(), key.DebugString())); - } - - if (!value_setter.SetFieldFromCelValue(value)) { - return absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument \"$2\" to message \"$0\" " - "field \"$1\" map value.", - msg->GetDescriptor()->name(), desc->name(), value.DebugString())); - } - - return absl::OkStatus(); -} - } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/containers/field_access.h b/eval/public/containers/field_access.h index 711cf1744..4f52b4d0d 100644 --- a/eval/public/containers/field_access.h +++ b/eval/public/containers/field_access.h @@ -46,25 +46,20 @@ absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when setting the field. absl::Status SetValueToSingleField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); + google::protobuf::Message* msg, google::protobuf::Arena* arena); + // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when adding the value. absl::Status AddValueToRepeatedField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); - -// Adds content of CelValue to repeated message field. -// Returns status of the operation. -// msg Message containing the field. -// desc Descriptor of the field to access. - -absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); + google::protobuf::Message* msg, + google::protobuf::Arena* arena); } // namespace runtime } // namespace expr diff --git a/eval/public/containers/field_access_test.cc b/eval/public/containers/field_access_test.cc index c6b380b3a..5ad4fbf5c 100644 --- a/eval/public/containers/field_access_test.cc +++ b/eval/public/containers/field_access_test.cc @@ -1,5 +1,6 @@ #include "eval/public/containers/field_access.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -17,63 +18,73 @@ namespace { using google::api::expr::internal::MakeGoogleApiDurationMax; using google::api::expr::internal::MakeGoogleApiTimeMax; +using google::protobuf::Arena; using google::protobuf::FieldDescriptor; using test::v1::proto3::TestAllTypes; TEST(FieldAccessTest, SetDuration) { + Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField( - CelValue::CreateDuration(MakeGoogleApiDurationMax()), field, &msg); + CelValue::CreateDuration(MakeGoogleApiDurationMax()), field, &msg, + &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetDurationBadDuration) { + Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField( CelValue::CreateDuration(MakeGoogleApiDurationMax() + absl::Seconds(1)), - field, &msg); + field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetDurationBadInputType) { + Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); - auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetTimestamp) { + Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField( - CelValue::CreateTimestamp(MakeGoogleApiTimeMax()), field, &msg); + CelValue::CreateTimestamp(MakeGoogleApiTimeMax()), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetTimestampBadTime) { + Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField( CelValue::CreateTimestamp(MakeGoogleApiTimeMax() + absl::Seconds(1)), - field, &msg); + field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetTimestampBadInputType) { + Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); - auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index c5f701193..3afa8d9c2 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -12,11 +12,11 @@ cc_library( ], deps = [ "//eval/public:cel_value", + "//eval/testutil:test_message_cc_proto", "//internal:proto_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) @@ -29,10 +29,15 @@ cc_test( ], deps = [ ":cel_proto_wrapper", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", "//internal:proto_util", "//testutil:util", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", "@com_google_protobuf//:protobuf", ], diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index bcd0867d0..eedb75e7a 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -1,11 +1,24 @@ #include "eval/public/structs/cel_proto_wrapper.h" +#include + +#include +#include + #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/message.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" +#include "eval/public/cel_value.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/proto_util.h" namespace google { namespace api { @@ -35,8 +48,26 @@ using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; using google::protobuf::Value; +// kMaxIntJSON is defined as the Number.MAX_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMaxIntJSON = (1l << 53) - 1; + +// kMinIntJSON is defined as the Number.MIN_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMinIntJSON = -kMaxIntJSON; + // Forward declaration for google.protobuf.Value CelValue ValueFromMessage(const Value* value, Arena* arena); +absl::optional MessageFromValue(const CelValue& value, + Value* json); + +// IsJSONSafe indicates whether the int is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(int64_t i) { return i >= kMinIntJSON && i <= kMaxIntJSON; } + +// IsJSONSafe indicates whether the uint is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(uint64_t i) { + return i <= static_cast(kMaxIntJSON); +} // Map implementation wrapping google.protobuf.ListValue class DynamicList : public CelList { @@ -145,7 +176,6 @@ CelValue ValueFromMessage(const Struct* struct_value, Arena* arena) { CelValue ValueFromMessage(const Any* any_value, Arena* arena) { auto type_url = any_value->type_url(); - auto pos = type_url.find_last_of('/'); if (pos == absl::string_view::npos) { // TODO(issues/25) What error code? @@ -235,7 +265,7 @@ CelValue ValueFromMessage(const Value* value, Arena* arena) { case Value::KindCase::kListValue: return CelProtoWrapper::CreateMessage(&value->list_value(), arena); default: - return CreateErrorValue(arena, "No known fields set in Value message"); + return CelValue::CreateNull(); } } @@ -325,6 +355,474 @@ class ValueFromMessageMaker { factories_; }; +absl::optional MessageFromValue(const CelValue& value, + Duration* duration) { + absl::Duration val; + if (!value.GetValue(&val)) { + return {}; + } + auto status = google::api::expr::internal::EncodeDuration(val, duration); + if (!status.ok()) { + return {}; + } + return duration; +} + +absl::optional MessageFromValue(const CelValue& value, + BoolValue* wrapper) { + bool val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + BytesValue* wrapper) { + CelValue::BytesHolder val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val.value()); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + DoubleValue* wrapper) { + double val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + FloatValue* wrapper) { + double val; + if (!value.GetValue(&val)) { + return {}; + } + // Abort the conversion if the value is outside the float range. + if (val > std::numeric_limits::max()) { + wrapper->set_value(std::numeric_limits::infinity()); + return wrapper; + } + if (val < std::numeric_limits::lowest()) { + wrapper->set_value(-std::numeric_limits::infinity()); + return wrapper; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + Int32Value* wrapper) { + int64_t val; + if (!value.GetValue(&val)) { + return {}; + } + // Abort the conversion if the value is outside the int32_t range. + if (val > std::numeric_limits::max() || + val < std::numeric_limits::lowest()) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + Int64Value* wrapper) { + int64_t val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + StringValue* wrapper) { + CelValue::StringHolder val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val.value()); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + Timestamp* timestamp) { + absl::Time val; + if (!value.GetValue(&val)) { + return {}; + } + auto status = google::api::expr::internal::EncodeTime(val, timestamp); + if (!status.ok()) { + return {}; + } + return timestamp; +} + +absl::optional MessageFromValue(const CelValue& value, + UInt32Value* wrapper) { + uint64_t val; + if (!value.GetValue(&val)) { + return {}; + } + // Abort the conversion if the value is outside the uint32_t range. + if (val > std::numeric_limits::max()) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + UInt64Value* wrapper) { + uint64_t val; + if (!value.GetValue(&val)) { + return {}; + } + wrapper->set_value(val); + return wrapper; +} + +absl::optional MessageFromValue(const CelValue& value, + ListValue* json_list) { + if (!value.IsList()) { + return {}; + } + const CelList& list = *value.ListOrDie(); + for (int i = 0; i < list.size(); i++) { + auto e = list[i]; + Value* elem = json_list->add_values(); + auto result = MessageFromValue(e, elem); + if (!result.has_value()) { + return {}; + } + } + return json_list; +} + +absl::optional MessageFromValue(const CelValue& value, + Struct* json_struct) { + if (!value.IsMap()) { + return {}; + } + const CelMap& map = *value.MapOrDie(); + const auto& keys = *map.ListKeys(); + auto fields = json_struct->mutable_fields(); + for (int i = 0; i < keys.size(); i++) { + auto k = keys[i]; + // If the key is not a string type, abort the conversion. + if (!k.IsString()) { + return {}; + } + absl::string_view key = k.StringOrDie().value(); + + auto v = map[k]; + if (!v.has_value()) { + return {}; + } + Value field_value; + auto result = MessageFromValue(v.value(), &field_value); + // If the value is not a valid JSON type, abort the conversion. + if (!result.has_value()) { + return {}; + } + (*fields)[key] = field_value; + } + return json_struct; +} + +absl::optional MessageFromValue(const CelValue& value, + Value* json) { + switch (value.type()) { + case CelValue::Type::kBool: { + bool val; + if (value.GetValue(&val)) { + json->set_bool_value(val); + return json; + } + } break; + case CelValue::Type::kBytes: { + // Base64 encode byte strings to ensure they can safely be transpored + // in a JSON string. + CelValue::BytesHolder val; + if (value.GetValue(&val)) { + json->set_string_value(absl::Base64Escape(val.value())); + return json; + } + } break; + case CelValue::Type::kDouble: { + double val; + if (value.GetValue(&val)) { + json->set_number_value(val); + return json; + } + } break; + case CelValue::Type::kDuration: { + // Convert duration values to a protobuf JSON format. + absl::Duration val; + if (value.GetValue(&val)) { + auto encode = google::api::expr::internal::EncodeDurationToString(val); + if (!encode.ok()) { + return {}; + } + json->set_string_value(*encode); + return json; + } + } break; + case CelValue::Type::kInt64: { + int64_t val; + // Convert int64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return json; + } + } break; + case CelValue::Type::kString: { + CelValue::StringHolder val; + if (value.GetValue(&val)) { + json->set_string_value(val.value().data()); + return json; + } + } break; + case CelValue::Type::kTimestamp: { + // Convert timestamp values to a protobuf JSON format. + absl::Time val; + if (value.GetValue(&val)) { + auto encode = google::api::expr::internal::EncodeTimeToString(val); + if (!encode.ok()) { + return {}; + } + json->set_string_value(*encode); + return json; + } + } break; + case CelValue::Type::kUint64: { + uint64_t val; + // Convert uint64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return json; + } + } break; + case CelValue::Type::kList: { + auto lv = MessageFromValue(value, json->mutable_list_value()); + if (lv.has_value()) { + return json; + } + } break; + case CelValue::Type::kMap: { + auto sv = MessageFromValue(value, json->mutable_struct_value()); + if (sv.has_value()) { + return json; + } + } break; + default: + if (value.IsNull()) { + json->set_null_value(protobuf::NULL_VALUE); + return json; + } + return {}; + } + return {}; +} + +absl::optional MessageFromValue(const CelValue& value, + Any* any) { + switch (value.type()) { + case CelValue::Type::kBool: { + BoolValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kBytes: { + BytesValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kDouble: { + DoubleValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kDuration: { + Duration v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kInt64: { + Int64Value v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kString: { + StringValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kTimestamp: { + Timestamp v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kUint64: { + UInt64Value v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kList: { + ListValue v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kMap: { + Struct v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } break; + case CelValue::Type::kMessage: { + if (value.IsNull()) { + Value v; + auto msg = MessageFromValue(value, &v); + if (msg.has_value() && any->PackFrom(**msg)) { + return any; + } + } else if (any->PackFrom(*(value.MessageOrDie()))) { + return any; + } + } break; + default: + break; + } + return {}; +} + +// Factory class, responsible for populating a Message type instance with the +// value of a simple CelValue. +class MessageFromValueFactory { + public: + virtual ~MessageFromValueFactory() {} + virtual const google::protobuf::Descriptor* GetDescriptor() const = 0; + virtual absl::optional WrapMessage( + const CelValue& value, Arena* arena) const = 0; +}; + +// This template class has a good performance, but performes downcast +// operations on google::protobuf::Message pointers. +template +class CastingMessageFromValueFactory : public MessageFromValueFactory { + public: + const google::protobuf::Descriptor* GetDescriptor() const override { + return MessageType::descriptor(); + } + + absl::optional WrapMessage( + const CelValue& value, Arena* arena) const override { + // Convert nulls separately from other messages as a null value is still + // technically a message value, but not one that can be converted in the + // standard way. + if (value.IsNull()) { + return MessageFromValue(value, Arena::CreateMessage(arena)); + } + // If the value is a message type, see if it is already of the proper type + // name, and return it directly. + if (value.IsMessage()) { + const auto* msg = value.MessageOrDie(); + if (MessageType::descriptor() == msg->GetDescriptor()) { + return {}; + } + } + // Otherwise, allocate an empty message type, and attempt to populate it + // using the proper MessageFromValue overload. + auto* msg_buffer = Arena::CreateMessage(arena); + return MessageFromValue(value, msg_buffer); + } +}; + +// MessageFromValueMaker makes a specific protobuf Message instance based on +// the desired protobuf type name and an input CelValue. +// +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, an the factory +// returns an absent value. +class MessageFromValueMaker { + public: + explicit MessageFromValueMaker() { + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + Add(absl::make_unique>()); + } + // Non-copyable, non-assignable + MessageFromValueMaker(const MessageFromValueMaker&) = delete; + MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; + + absl::optional MaybeWrapMessage( + absl::string_view type_name, const CelValue& value, Arena* arena) const { + auto it = factories_.find(type_name); + if (it == factories_.end()) { + // Descriptor not found for type name. + return {}; + } + return (it->second)->WrapMessage(value, arena); + } + + private: + void Add(std::unique_ptr factory) { + const Descriptor* desc = factory->GetDescriptor(); + factories_.emplace(desc->full_name(), std::move(factory)); + } + + absl::flat_hash_map> + factories_; +}; + } // namespace // CreateMessage creates CelValue from google::protobuf::Message. @@ -340,10 +838,20 @@ CelValue CelProtoWrapper::CreateMessage(const google::protobuf::Message* value, } auto special_value = maker->CreateValue(value, arena); - return special_value.has_value() ? special_value.value() : CelValue(value); } +absl::optional CelProtoWrapper::MaybeWrapValue( + absl::string_view type_name, const CelValue& value, Arena* arena) { + static const MessageFromValueMaker* maker = new MessageFromValueMaker(); + + auto msg = maker->MaybeWrapMessage(type_name, value, arena); + if (!msg.has_value()) { + return {}; + } + return CelValue(msg.value()); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/structs/cel_proto_wrapper.h b/eval/public/structs/cel_proto_wrapper.h index 830bfd67f..379f20b4e 100644 --- a/eval/public/structs/cel_proto_wrapper.h +++ b/eval/public/structs/cel_proto_wrapper.h @@ -27,6 +27,19 @@ class CelProtoWrapper { static CelValue CreateTimestamp(const google::protobuf::Timestamp *value) { return CelValue(expr::internal::DecodeTime(*value)); } + + // MaybeWrapValue attempts to wrap the input value in a proto message with + // the given type_name. If the value can be wrapped, it is returned as a + // CelValue pointing to the protobuf message. Otherwise, the result will be + // empty. + // + // This method is the complement to CreateMessage which may unwrap a protobuf + // message to native CelValue representation during a protobuf field read. + // Just as CreateMessage should only be used when reading protobuf values, + // MaybeWrapValue should only be used when assigning protobuf fields. + static absl::optional MaybeWrapValue(absl::string_view type_name, + const CelValue &value, + google::protobuf::Arena *arena); }; } // namespace runtime diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index 7c48abd4e..a93f28b61 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -1,6 +1,10 @@ #include "eval/public/structs/cel_proto_wrapper.h" +#include +#include + #include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" @@ -9,6 +13,11 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" #include "internal/proto_util.h" #include "testutil/util.h" @@ -18,6 +27,8 @@ namespace api { namespace expr { namespace runtime { +namespace { + using testing::Eq; using testing::UnorderedPointwise; @@ -38,9 +49,84 @@ using google::protobuf::StringValue; using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; -TEST(CelProtoWrapperTest, TestType) { - ::google::protobuf::Arena arena; +using google::protobuf::Arena; + +class CelProtoWrapperTest : public ::testing::Test { + protected: + CelProtoWrapperTest() {} + + void ExpectWrappedMessage(const CelValue& value, + const google::protobuf::Message& message) { + // Test the input value wraps to the destination message type. + std::string type_name = message.GetTypeName(); + auto result = CelProtoWrapper::MaybeWrapValue(type_name, value, arena()); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE((*result).IsMessage()); + EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); + + // Ensure that double wrapping results in the object being wrapped once. + auto identity = + CelProtoWrapper::MaybeWrapValue(type_name, *result, arena()); + EXPECT_FALSE(identity.has_value()); + + // Check to make sure that even dynamic messages can be used as input to + // the wrapping call. + result = CelProtoWrapper::MaybeWrapValue( + ReflectedCopy(message)->GetTypeName(), value, arena()); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE((*result).IsMessage()); + EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); + } + + void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { + // Test the input value does not wrap by asserting value == result. + auto result = + CelProtoWrapper::MaybeWrapValue(message.GetTypeName(), value, arena()); + EXPECT_FALSE(result.has_value()); + } + + template + void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { + CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); + T value; + EXPECT_TRUE(cel_value.GetValue(&value)); + EXPECT_THAT(value, Eq(result)); + + T dyn_value; + CelValue cel_dyn_value = + CelProtoWrapper::CreateMessage(ReflectedCopy(message).get(), arena()); + EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); + EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); + EXPECT_THAT(value, Eq(dyn_value)); + } + + void ExpectUnwrappedMessage(const google::protobuf::Message& message, + google::protobuf::Message* result) { + CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); + if (result == nullptr) { + EXPECT_TRUE(cel_value.IsNull()); + return; + } + EXPECT_TRUE(cel_value.IsMessage()); + EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); + } + std::unique_ptr ReflectedCopy( + const google::protobuf::Message& message) { + std::unique_ptr dynamic_value( + factory_.GetPrototype(message.GetDescriptor())->New()); + dynamic_value->CopyFrom(message); + return dynamic_value; + } + + Arena* arena() { return &arena_; } + + private: + Arena arena_; + google::protobuf::DynamicMessageFactory factory_; +}; + +TEST_F(CelProtoWrapperTest, TestType) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); @@ -48,7 +134,7 @@ TEST(CelProtoWrapperTest, TestType) { EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = - CelProtoWrapper::CreateMessage(&msg_duration, &arena); + CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); Timestamp msg_timestamp; @@ -58,14 +144,12 @@ TEST(CelProtoWrapperTest, TestType) { EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = - CelProtoWrapper::CreateMessage(&msg_timestamp, &arena); + CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); } // This test verifies CelValue support of Duration type. -TEST(CelProtoWrapperTest, TestDuration) { - google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, TestDuration) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); @@ -73,11 +157,10 @@ TEST(CelProtoWrapperTest, TestDuration) { EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = - CelProtoWrapper::CreateMessage(&msg_duration, &arena); + CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); CelValue value = CelProtoWrapper::CreateDuration(&msg_duration); - // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsDuration()); Duration out; auto status = expr::internal::EncodeDuration(value.DurationOrDie(), &out); @@ -86,9 +169,7 @@ TEST(CelProtoWrapperTest, TestDuration) { } // This test verifies CelValue support of Timestamp type. -TEST(CelProtoWrapperTest, TestTimestamp) { - google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, TestTimestamp) { Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); @@ -96,7 +177,7 @@ TEST(CelProtoWrapperTest, TestTimestamp) { EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = - CelProtoWrapper::CreateMessage(&msg_timestamp, &arena); + CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); CelValue value = CelProtoWrapper::CreateTimestamp(&msg_timestamp); @@ -110,57 +191,49 @@ TEST(CelProtoWrapperTest, TestTimestamp) { // Dynamic Values test // - -TEST(CelProtoWrapperTest, TestValueFieldNull) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_null_value(google::protobuf::NullValue::NULL_VALUE); - - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsNull()); +TEST_F(CelProtoWrapperTest, UnwrapValueNull) { + Value json; + json.set_null_value(google::protobuf::NullValue::NULL_VALUE); + ExpectUnwrappedMessage(json, nullptr); } -TEST(CelProtoWrapperTest, TestValueFieldBool) { - ::google::protobuf::Arena arena; +// Test support for unwrapping a google::protobuf::Value to a CEL value. +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { + Value value_msg; + value_msg.set_null_value(protobuf::NULL_VALUE); - Value value1; - value1.set_bool_value(true); - - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsBool()); - EXPECT_EQ(value.BoolOrDie(), true); + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsNull()); } -TEST(CelProtoWrapperTest, TestValueFieldNumeric) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_number_value(1.0); +TEST_F(CelProtoWrapperTest, UnwrapValueBool) { + bool value = true; - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsDouble()); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), 1.0); + Value json; + json.set_bool_value(true); + ExpectUnwrappedPrimitive(json, value); } -TEST(CelProtoWrapperTest, TestValueFieldString) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapValueNumber) { + double value = 1.0; - const std::string kTest = "test"; + Value json; + json.set_number_value(value); + ExpectUnwrappedPrimitive(json, value); +} - Value value1; - value1.set_string_value(kTest); +TEST_F(CelProtoWrapperTest, UnwrapValueString) { + const std::string test = "test"; + auto value = CelValue::StringHolder(&test); - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsString()); - EXPECT_EQ(value.StringOrDie().value(), kTest); + Value json; + json.set_string_value(test); + ExpectUnwrappedPrimitive(json, value); } -TEST(CelProtoWrapperTest, TestValueFieldStruct) { - ::google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, UnwrapValueStruct) { const std::vector kFields = {"field1", "field2", "field3"}; - Struct value_struct; auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; @@ -172,7 +245,7 @@ TEST(CelProtoWrapperTest, TestValueFieldStruct) { auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; value3.set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&value_struct, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&value_struct, arena()); ASSERT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); @@ -205,9 +278,55 @@ TEST(CelProtoWrapperTest, TestValueFieldStruct) { EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); } -TEST(CelProtoWrapperTest, TestListFieldStruct) { - ::google::protobuf::Arena arena; +// Test support for google::protobuf::Struct when it is created as dynamic +// message +TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { + Struct struct_msg; + const std::string kFieldInt = "field_int"; + const std::string kFieldBool = "field_bool"; + (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); + (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(struct_msg).get(), arena()); + EXPECT_TRUE(value.IsMap()); + const CelMap* cel_map = value.MapOrDie(); + ASSERT_TRUE(cel_map != nullptr); + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsDouble()); + EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); + } + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsBool()); + EXPECT_EQ(v.BoolOrDie(), true); + } +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { + const std::string kField1 = "field1"; + const std::string kField2 = "field2"; + Value value_msg; + (*value_msg.mutable_struct_value()->mutable_fields())[kField1] + .set_number_value(1); + (*value_msg.mutable_struct_value()->mutable_fields())[kField2] + .set_number_value(2); + + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsMap()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); +} + +TEST_F(CelProtoWrapperTest, UnwrapValueList) { const std::vector kFields = {"field1", "field2", "field3"}; ListValue list_value; @@ -216,7 +335,7 @@ TEST(CelProtoWrapperTest, TestListFieldStruct) { list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&list_value, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); ASSERT_TRUE(value.IsList()); const CelList* cel_list = value.ListOrDie(); @@ -236,437 +355,471 @@ TEST(CelProtoWrapperTest, TestListFieldStruct) { EXPECT_EQ(value3.StringOrDie().value(), "test"); } -// Test support of google.protobuf.Any in CelValue. -TEST(CelProtoWrapperTest, TestAnyValue) { - ::google::protobuf::Arena arena; - Any any; +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { + Value value_msg; + value_msg.mutable_list_value()->add_values()->set_number_value(1.); + value_msg.mutable_list_value()->add_values()->set_number_value(2.); + + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsList()); + EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); + EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); +} +// Test support of google.protobuf.Any in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { TestMessage test_message; test_message.set_string_value("test"); + Any any; any.PackFrom(test_message); - - CelValue value = CelProtoWrapper::CreateMessage(&any, &arena); - ASSERT_TRUE(value.IsMessage()); - - const google::protobuf::Message* unpacked_message = value.MessageOrDie(); - EXPECT_THAT(test_message, testutil::EqualsProto(*unpacked_message)); + ExpectUnwrappedMessage(any, &test_message); } -TEST(CelProtoWrapperTest, TestHandlingInvalidAnyValue) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { Any any; - - CelValue value = CelProtoWrapper::CreateMessage(&any, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&any, arena()); ASSERT_TRUE(value.IsError()); any.set_type_url("/"); - ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, &arena).IsError()); + ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); any.set_type_url("/invalid.proto.name"); - ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, &arena).IsError()); + ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); } // Test support of google.protobuf.Value wrappers in CelValue. -TEST(CelProtoWrapperTest, TestBoolWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { + bool value = true; BoolValue wrapper; - wrapper.set_value(true); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsBool()); - - EXPECT_EQ(value.BoolOrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { + int64_t value = 12; Int32Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsInt64()); - - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestUInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { + uint64_t value = 12; UInt32Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsUint64()); - - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { + int64_t value = 12; Int64Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsInt64()); - - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestUInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { + uint64_t value = 12; UInt64Value wrapper; - wrapper.set_value(12); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsUint64()); +TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { + double value = 42.5; - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + FloatValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestFloatWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { + double value = 42.5; - FloatValue wrapper; - wrapper.set_value(42); + DoubleValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsDouble()); +TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { + std::string text = "42"; + auto value = CelValue::StringHolder(&text); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + StringValue wrapper; + wrapper.set_value(text); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestDoubleWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { + std::string text = "42"; + auto value = CelValue::BytesHolder(&text); - DoubleValue wrapper; - wrapper.set_value(42); + BytesValue wrapper; + wrapper.set_value("42"); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, WrapNull) { + auto cel_value = CelValue::CreateNull(); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsDouble()); + Value json; + json.set_null_value(protobuf::NULL_VALUE); + ExpectWrappedMessage(cel_value, json); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + Any any; + any.PackFrom(json); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, TestStringWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapBool) { + auto cel_value = CelValue::CreateBool(true); - StringValue wrapper; - wrapper.set_value("42"); + Value json; + json.set_bool_value(true); + ExpectWrappedMessage(cel_value, json); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsString()); + BoolValue wrapper; + wrapper.set_value(true); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_EQ(value.StringOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, TestBytesWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapBytes) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); BytesValue wrapper; - wrapper.set_value("42"); + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsBytes()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBytesToValue) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); - EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); + Value json; + json.set_string_value("aGVsbG8gd29ybGQ="); + ExpectWrappedMessage(cel_value, json); } -// Test support for google::protobuf::Struct when it is created as dynamic -// message -TEST(CelProtoWrapperTest, DynamicStructSupport) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDuration) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); - google::protobuf::DynamicMessageFactory factory; - { - Struct struct_msg; - - const std::string kFieldInt = "field_int"; - const std::string kFieldBool = "field_bool"; - - (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); - (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); - std::unique_ptr dynamic_struct( - factory.GetPrototype(Struct::descriptor())->New()); - dynamic_struct->CopyFrom(struct_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_struct.get(), &arena); - EXPECT_TRUE(value.IsMap()); - const CelMap* cel_map = value.MapOrDie(); - ASSERT_TRUE(cel_map != nullptr); - - { - auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; - ASSERT_TRUE(lookup.has_value()); - auto v = lookup.value(); - ASSERT_TRUE(v.IsDouble()); - EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); - } - { - auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; - ASSERT_TRUE(lookup.has_value()); - auto v = lookup.value(); - ASSERT_TRUE(v.IsBool()); - EXPECT_EQ(v.BoolOrDie(), true); - } - } + Duration d; + d.set_seconds(300); + ExpectWrappedMessage(cel_value, d); + + Any any; + any.PackFrom(d); + ExpectWrappedMessage(cel_value, any); } -// Test support for google::protobuf::Value when it is created as dynamic -// message -TEST(CelProtoWrapperTest, DynamicValueSupport) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDurationToValue) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); - google::protobuf::DynamicMessageFactory factory; - // Null - { - Value value_msg; - value_msg.set_null_value(protobuf::NULL_VALUE); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsNull()); - } - // Boolean - { - Value value_msg; - value_msg.set_bool_value(true); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsBool()); - EXPECT_TRUE(value.BoolOrDie()); - } - // Numeric - { - Value value_msg; - value_msg.set_number_value(1.0); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsDouble()); - EXPECT_THAT(value.DoubleOrDie(), testing::DoubleEq(1.)); - } - // String - { - Value value_msg; - value_msg.set_string_value("test"); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsString()); - EXPECT_THAT(value.StringOrDie().value(), Eq("test")); - } - // List - { - Value value_msg; - value_msg.mutable_list_value()->add_values()->set_number_value(1.); - value_msg.mutable_list_value()->add_values()->set_number_value(2.); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsList()); - EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); - EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); - } - // Struct - { - const std::string kField1 = "field1"; - const std::string kField2 = "field2"; - - Value value_msg; - (*value_msg.mutable_struct_value()->mutable_fields())[kField1] - .set_number_value(1); - (*value_msg.mutable_struct_value()->mutable_fields())[kField2] - .set_number_value(2); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsMap()); - EXPECT_TRUE( - (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); - EXPECT_TRUE( - (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); - } + Value json; + json.set_string_value("300s"); + ExpectWrappedMessage(cel_value, json); } -// Test support of google.protobuf.Value wrappers in CelValue. -TEST(CelProtoWrapperTest, DynamicBoolWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDouble) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); - BoolValue wrapper; - wrapper.set_value(true); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(BoolValue::descriptor())->New()); - dynamic_value->CopyFrom(wrapper); + Value json; + json.set_number_value(num); + ExpectWrappedMessage(cel_value, json); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - ASSERT_TRUE(value.IsBool()); + DoubleValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_EQ(value.BoolOrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); - Int32Value wrapper; - wrapper.set_value(12); + FloatValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); + // Imprecise double -> float representation results in truncation. + double small_num = -9.9e-100; + wrapper.set_value(small_num); + cel_value = CelValue::CreateDouble(small_num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { + double lowest_double = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateDouble(lowest_double); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + // Double exceeds float precision, overflow to -infinity. + FloatValue wrapper; + wrapper.set_value(-std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); - ASSERT_TRUE(value.IsInt64()); + double max_double = std::numeric_limits::max(); + cel_value = CelValue::CreateDouble(max_double); - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicUInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - UInt32Value wrapper; - wrapper.set_value(12); + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Int64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - ASSERT_TRUE(value.IsUint64()); - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamocInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - Int64Value wrapper; - wrapper.set_value(12); + Int32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { + int64_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + Int32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicUInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { + int64_t max = std::numeric_limits::max(); + auto cel_value = CelValue::CreateInt64(max); - UInt64Value wrapper; - wrapper.set_value(12); + Value json; + json.set_string_value(absl::StrCat(max)); + ExpectWrappedMessage(cel_value, json); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - ASSERT_TRUE(value.IsUint64()); + int64_t min = std::numeric_limits::min(); + cel_value = CelValue::CreateInt64(min); - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + json.set_string_value(absl::StrCat(min)); + ExpectWrappedMessage(cel_value, json); } -TEST(CelProtoWrapperTest, DynamicFloatWrapper) { - ::google::protobuf::Arena arena; - - FloatValue wrapper; - wrapper.set_value(42); +TEST_F(CelProtoWrapperTest, WrapUint64) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); - ASSERT_TRUE(value.IsDouble()); + UInt64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicDoubleWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - DoubleValue wrapper; - wrapper.set_value(42); + UInt32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - ASSERT_TRUE(value.IsDouble()); + Value json; + json.set_string_value(absl::StrCat(num)); + ExpectWrappedMessage(cel_value, json); +} - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); +TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + UInt32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicStringWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapString) { + std::string str = "test"; + auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); + + Value json; + json.set_string_value(str); + ExpectWrappedMessage(cel_value, json); StringValue wrapper; - wrapper.set_value("42"); + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapTimestamp) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); - ASSERT_TRUE(value.IsString()); + Timestamp t; + t.set_seconds(1615852799); + ExpectWrappedMessage(cel_value, t); - EXPECT_EQ(value.StringOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(t); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicBytesWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); - BytesValue wrapper; - wrapper.set_value("42"); + Value json; + json.set_string_value("2021-03-15T23:59:59Z"); + ExpectWrappedMessage(cel_value, json); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapList) { + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelValue::CreateInt64(-2L), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); - ASSERT_TRUE(value.IsBytes()); + Value json; + json.mutable_list_value()->add_values()->set_number_value(1.5); + json.mutable_list_value()->add_values()->set_number_value(-2.); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.list_value()); - EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(json.list_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { + TestMessage message; + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelProtoWrapper::CreateMessage(&message, arena()), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapStruct) { + const std::string kField1 = "field1"; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelValue::CreateBool(true)}}; + auto cel_map = CreateContainerBackedMap( + absl::Span>(args.data(), args.size())); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( + true); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.struct_value()); + + Any any; + any.PackFrom(json.struct_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { + std::vector> args = { + {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; + auto cel_map = CreateContainerBackedMap( + absl::Span>(args.data(), args.size())); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { + const std::string kField1 = "field1"; + TestMessage bad_value; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelProtoWrapper::CreateMessage(&bad_value, arena())}}; + auto cel_map = CreateContainerBackedMap( + absl::Span>(args.data(), args.size())); + auto cel_value = CelValue::CreateMap(cel_map.get()); + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { + auto cel_value = CelValue::CreateNull(); + std::vector wrong_types = { + &BoolValue::default_instance(), &BytesValue::default_instance(), + &DoubleValue::default_instance(), &Duration::default_instance(), + &FloatValue::default_instance(), &Int32Value::default_instance(), + &Int64Value::default_instance(), &ListValue::default_instance(), + &StringValue::default_instance(), &Struct::default_instance(), + &Timestamp::default_instance(), &UInt32Value::default_instance(), + &UInt64Value::default_instance(), + }; + for (const auto* wrong_type : wrong_types) { + ExpectNotWrapped(cel_value, *wrong_type); + } +} + +TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { + auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); + ExpectNotWrapped(cel_value, Any::default_instance()); } -TEST(CelProtoWrapperTest, DebugString) { +TEST_F(CelProtoWrapperTest, DebugString) { google::protobuf::Empty e; - ::google::protobuf::Arena arena; - EXPECT_EQ(CelProtoWrapper::CreateMessage(&e, &arena).DebugString(), + EXPECT_EQ(CelProtoWrapper::CreateMessage(&e, arena()).DebugString(), "Message: "); ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&list_value, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); EXPECT_EQ(value.DebugString(), "List, size: 3"); Struct value_struct; @@ -677,10 +830,12 @@ TEST(CelProtoWrapperTest, DebugString) { auto& value3 = (*value_struct.mutable_fields())["c"]; value3.set_string_value("test"); - value = CelProtoWrapper::CreateMessage(&value_struct, &arena); + value = CelProtoWrapper::CreateMessage(&value_struct, arena()); EXPECT_EQ(value.DebugString(), "Map, size: 3"); } +} // namespace + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index bcfae9d94..7081170c6 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -63,7 +63,7 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { break; } case CelValue::Type::kMessage: - if (value.MessageOrDie() == nullptr) { + if (value.IsNull()) { result->set_null_value(google::protobuf::NullValue::NULL_VALUE); } else { result->mutable_object_value()->PackFrom(*value.MessageOrDie()); diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 10f0fb3c6..53f0206cb 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -51,6 +51,7 @@ cc_test( "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", + "//testutil:util", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googletest//:gtest_main", "@com_google_protobuf//:protobuf", diff --git a/eval/tests/end_to_end_test.cc b/eval/tests/end_to_end_test.cc index f28907da9..0591b6619 100644 --- a/eval/tests/end_to_end_test.cc +++ b/eval/tests/end_to_end_test.cc @@ -1,4 +1,5 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/text_format.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -9,6 +10,7 @@ #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" +#include "testutil/util.h" #include "base/status_macros.h" namespace google { @@ -167,6 +169,56 @@ TEST(EndToEndTest, EmptyStringCompare) { EXPECT_TRUE(result.BoolOrDie()); } +TEST(EndToEndTest, NullLiteral) { + // AST CEL equivalent of "Value{null_value: NullValue.NULL_VALUE}" + constexpr char kExpr0[] = R"( + struct_expr: < + message_name: "Value" + entries: < + field_key: "null_value" + value: < + select_expr: < + operand: < + ident_expr: < + name: "NullValue" + > + > + field: "NULL_VALUE" + > + > + > + > + )"; + + Expr expr; + SourceInfo source_info; + TextFormat::ParseFromString(kExpr0, &expr); + + // Obtain CEL Expression builder. + std::unique_ptr builder = CreateCelExpressionBuilder(); + builder->set_container("google.protobuf"); + + // Builtin registration. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + auto cel_expression_status = builder->CreateExpression(&expr, &source_info); + ASSERT_OK(cel_expression_status); + + auto cel_expression = std::move(cel_expression_status.value()); + Activation activation; + Arena arena; + // Run evaluation. + auto eval_status = cel_expression->Evaluate(activation, &arena); + + ASSERT_OK(eval_status); + + google::protobuf::Value null_value; + null_value.set_null_value(protobuf::NULL_VALUE); + CelValue result = eval_status.value(); + ASSERT_TRUE(result.IsNull()); +} + } // namespace } // namespace runtime diff --git a/internal/BUILD b/internal/BUILD index 1c2e70b72..072a78592 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -178,6 +178,7 @@ cc_library( "//common:macros", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", diff --git a/internal/proto_util.cc b/internal/proto_util.cc index e33d267aa..06eb9c76b 100644 --- a/internal/proto_util.cc +++ b/internal/proto_util.cc @@ -2,6 +2,7 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/util/time_util.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/macros.h" @@ -58,6 +59,15 @@ absl::Status EncodeDuration(absl::Duration duration, return absl::OkStatus(); } +absl::StatusOr EncodeDurationToString(absl::Duration duration) { + google::protobuf::Duration d; + auto status = EncodeDuration(duration, &d); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(d); +} + absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { RETURN_IF_ERROR(Validate(time)); const int64_t s = absl::ToUnixSeconds(time); @@ -66,6 +76,15 @@ absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { return absl::OkStatus(); } +absl::StatusOr EncodeTimeToString(absl::Time time) { + google::protobuf::Timestamp t; + auto status = EncodeTime(time, &t); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(t); +} + } // namespace internal } // namespace expr } // namespace api diff --git a/internal/proto_util.h b/internal/proto_util.h index 3bb771ce3..02e2b92fa 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -6,6 +6,7 @@ #include "google/protobuf/util/message_differencer.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/time/time.h" namespace google { @@ -27,9 +28,15 @@ absl::Status ValidateDuration(absl::Duration duration); absl::Status EncodeDuration(absl::Duration duration, google::protobuf::Duration* proto); +/** Helper function to encode an absl::Duration to a JSON-formatted string. */ +absl::StatusOr EncodeDurationToString(absl::Duration duration); + /** Helper function to encode a time in a google::protobuf::Timestamp. */ absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); +/** Helper function to encode an absl::Time to a JSON-formatted string. */ +absl::StatusOr EncodeTimeToString(absl::Time time); + /** Helper function to decode a duration from a google::protobuf::Duration. */ absl::Duration DecodeDuration(const google::protobuf::Duration& proto); From f3b9b4692d8410858576b7c33d1f54bc1b4d67ce Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 31 Mar 2021 11:55:32 -0700 Subject: [PATCH 23/23] Internal build change PiperOrigin-RevId: 366083872 --- eval/public/structs/cel_proto_wrapper.cc | 49 +++++++++++++++--------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index eedb75e7a..b121ccf59 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -380,11 +380,11 @@ absl::optional MessageFromValue(const CelValue absl::optional MessageFromValue(const CelValue& value, BytesValue* wrapper) { - CelValue::BytesHolder val; - if (!value.GetValue(&val)) { + CelValue::BytesHolder view_val; + if (!value.GetValue(&view_val)) { return {}; } - wrapper->set_value(val.value()); + wrapper->set_value(view_val.value().data()); return wrapper; } @@ -444,11 +444,11 @@ absl::optional MessageFromValue(const CelValue absl::optional MessageFromValue(const CelValue& value, StringValue* wrapper) { - CelValue::StringHolder val; - if (!value.GetValue(&val)) { + CelValue::StringHolder view_val; + if (!value.GetValue(&view_val)) { return {}; } - wrapper->set_value(val.value()); + wrapper->set_value(view_val.value().data()); return wrapper; } @@ -644,74 +644,85 @@ absl::optional MessageFromValue(const CelValue absl::optional MessageFromValue(const CelValue& value, Any* any) { + // In open source, any->PackFrom() returns void rather than boolean. switch (value.type()) { case CelValue::Type::kBool: { BoolValue v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kBytes: { BytesValue v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kDouble: { DoubleValue v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kDuration: { Duration v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kInt64: { Int64Value v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kString: { StringValue v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kTimestamp: { Timestamp v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kUint64: { UInt64Value v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kList: { ListValue v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; case CelValue::Type::kMap: { Struct v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } } break; @@ -719,10 +730,12 @@ absl::optional MessageFromValue(const CelValue if (value.IsNull()) { Value v; auto msg = MessageFromValue(value, &v); - if (msg.has_value() && any->PackFrom(**msg)) { + if (msg.has_value()) { + any->PackFrom(**msg); return any; } - } else if (any->PackFrom(*(value.MessageOrDie()))) { + } else { + any->PackFrom(*(value.MessageOrDie())); return any; } } break;