From cc1c101ba37cce71f17797f91824b7afd5e47bed Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Wed, 1 May 2024 14:31:41 +0100 Subject: [PATCH 1/9] chore: exclude non-user symbols from symbol DB (#9013) We prevent non-user symbols from being collected by the symbol database to reduce the work done by the client library, as well as the size of the uploaded payloads. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [ ] Title is accurate - [ ] All changes are related to the pull request's stated goal - [ ] Description motivates each change - [ ] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [ ] Testing strategy adequately addresses listed risks - [ ] Change is maintainable (easy to change, telemetry, documentation) - [ ] Release note makes sense to a user of the library - [ ] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [ ] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/internal/packages.py | 5 ++++ ddtrace/internal/symbol_db/symbols.py | 34 ++++++++++++++---------- tests/internal/symbol_db/test_symbols.py | 24 ++++++++--------- 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/ddtrace/internal/packages.py b/ddtrace/internal/packages.py index 2d8f1c5fd1e..fcec01a463b 100644 --- a/ddtrace/internal/packages.py +++ b/ddtrace/internal/packages.py @@ -238,6 +238,11 @@ def is_third_party(path: Path) -> bool: return package.name in _third_party_packages() +@cached() +def is_user_code(path: Path) -> bool: + return not (is_stdlib(path) or is_third_party(path)) + + @cached() def is_distribution_available(name: str) -> bool: """Determine if a distribution is available in the current environment.""" diff --git a/ddtrace/internal/symbol_db/symbols.py b/ddtrace/internal/symbol_db/symbols.py index d454e9eb8f5..9f66ffa3a86 100644 --- a/ddtrace/internal/symbol_db/symbols.py +++ b/ddtrace/internal/symbol_db/symbols.py @@ -3,7 +3,7 @@ from dataclasses import field import dis from enum import Enum -import http +from http.client import HTTPResponse from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from inspect import isasyncgenfunction @@ -31,7 +31,6 @@ from ddtrace.internal.logger import get_logger from ddtrace.internal.module import BaseModuleWatchdog from ddtrace.internal.module import origin -from ddtrace.internal.packages import is_stdlib from ddtrace.internal.runtime import get_runtime_id from ddtrace.internal.safety import _isinstance from ddtrace.internal.utils.cache import cached @@ -50,10 +49,10 @@ @cached() -def is_from_stdlib(obj: t.Any) -> t.Optional[bool]: +def is_from_user_code(obj: t.Any) -> t.Optional[bool]: try: path = origin(sys.modules[object.__getattribute__(obj, "__module__")]) - return is_stdlib(path) if path is not None else None + return packages.is_user_code(path) if path is not None else None except (AttributeError, KeyError): return None @@ -182,9 +181,6 @@ def _(cls, module: ModuleType, data: ScopeData): symbols = [] scopes = [] - if is_stdlib(module_origin): - return None - for alias, child in object.__getattribute__(module, "__dict__").items(): if _isinstance(child, ModuleType): # We don't want to traverse other modules. @@ -224,7 +220,7 @@ def _(cls, obj: type, data: ScopeData): return None data.seen.add(obj) - if is_from_stdlib(obj): + if not is_from_user_code(obj): return None symbols = [] @@ -347,7 +343,7 @@ def _(cls, f: FunctionType, data: ScopeData): return None data.seen.add(f) - if is_from_stdlib(f): + if not is_from_user_code(f): return None code = f.__dd_wrapped__.__code__ if hasattr(f, "__dd_wrapped__") else f.__code__ @@ -416,7 +412,7 @@ def _(cls, pr: property, data: ScopeData): data.seen.add(pr.fget) # TODO: These names don't match what is reported by the discovery. - if pr.fget is None or is_from_stdlib(pr.fget): + if pr.fget is None or not is_from_user_code(pr.fget): return None path = func_origin(t.cast(FunctionType, pr.fget)) @@ -477,7 +473,7 @@ def to_json(self) -> dict: "scopes": [_.to_json() for _ in self._scopes], } - def upload(self) -> http.client.HTTPResponse: + def upload(self) -> HTTPResponse: body, headers = multipart( parts=[ FormData( @@ -509,14 +505,24 @@ def __len__(self) -> int: def is_module_included(module: ModuleType) -> bool: + # Check if module name matches the include patterns if symdb_config._includes_re.match(module.__name__): return True - package = packages.module_to_package(module) - if package is None: + # Check if it is user code + module_origin = origin(module) + if module_origin is None: return False - return symdb_config._includes_re.match(package.name) is not None + if packages.is_user_code(module_origin): + return True + + # Check if the package name matches the include patterns + package = packages.filename_to_package(module_origin) + if package is not None and symdb_config._includes_re.match(package.name): + return True + + return False class SymbolDatabaseUploader(BaseModuleWatchdog): diff --git a/tests/internal/symbol_db/test_symbols.py b/tests/internal/symbol_db/test_symbols.py index 4c879b63e5c..a97f6c5bcee 100644 --- a/tests/internal/symbol_db/test_symbols.py +++ b/tests/internal/symbol_db/test_symbols.py @@ -203,20 +203,11 @@ def test_symbols_upload_enabled(): assert remoteconfig_poller.get_registered("LIVE_DEBUGGING_SYMBOL_DB") is not None -@pytest.mark.subprocess( - ddtrace_run=True, - env=dict( - DD_SYMBOL_DATABASE_UPLOAD_ENABLED="1", - _DD_SYMBOL_DATABASE_FORCE_UPLOAD="1", - DD_SYMBOL_DATABASE_INCLUDES="tests.submod.stuff", - ), -) +@pytest.mark.subprocess(ddtrace_run=True, env=dict(DD_SYMBOL_DATABASE_INCLUDES="tests.submod.stuff")) def test_symbols_force_upload(): from ddtrace.internal.symbol_db.symbols import ScopeType from ddtrace.internal.symbol_db.symbols import SymbolDatabaseUploader - assert SymbolDatabaseUploader.is_installed() - contexts = [] def _upload_context(context): @@ -224,11 +215,18 @@ def _upload_context(context): SymbolDatabaseUploader._upload_context = staticmethod(_upload_context) + SymbolDatabaseUploader.install() + + def get_scope(contexts, name): + for context in (_.to_json() for _ in contexts): + for scope in context["scopes"]: + if scope["name"] == name: + return scope + raise ValueError(f"Scope {name} not found in {contexts}") + import tests.submod.stuff # noqa import tests.submod.traced_stuff # noqa - (context,) = contexts - - (scope,) = context.to_json()["scopes"] + scope = get_scope(contexts, "tests.submod.stuff") assert scope["scope_type"] == ScopeType.MODULE assert scope["name"] == "tests.submod.stuff" From c8b907b32408728e23dffed73945ed5dc6318c2e Mon Sep 17 00:00:00 2001 From: Emmett Butler <723615+emmettbutler@users.noreply.github.com> Date: Wed, 1 May 2024 07:14:27 -0700 Subject: [PATCH 2/9] chore(botocore): abstract away propagation header extraction code (#9087) This change adds a layer of abstraction between the botocore integration and the extraction of distributed tracing information from request data by using the Core API, increasing the separation of concerns between instrumentation and products. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [ ] Title is accurate - [ ] All changes are related to the pull request's stated goal - [ ] Description motivates each change - [ ] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [ ] Testing strategy adequately addresses listed risks - [ ] Change is maintainable (easy to change, telemetry, documentation) - [ ] Release note makes sense to a user of the library - [ ] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [ ] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/_trace/trace_handlers.py | 45 +++++++++++++-- ddtrace/_trace/utils.py | 41 ++++++++++++++ ddtrace/contrib/botocore/patch.py | 18 +++--- ddtrace/contrib/botocore/services/kinesis.py | 37 +++++++----- ddtrace/contrib/botocore/services/sqs.py | 34 +++++++---- .../botocore/services/stepfunctions.py | 8 ++- ddtrace/contrib/botocore/utils.py | 56 +++---------------- ddtrace/internal/datastreams/botocore.py | 4 +- tests/contrib/botocore/test.py | 2 +- 9 files changed, 154 insertions(+), 91 deletions(-) diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index f439f87784a..f67eca90453 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -6,9 +6,10 @@ from typing import List from typing import Optional -from ddtrace import config from ddtrace._trace.span import Span +from ddtrace._trace.utils import extract_DD_context_from_messages from ddtrace._trace.utils import set_botocore_patched_api_call_span_tags as set_patched_api_call_span_tags +from ddtrace._trace.utils import set_botocore_response_metadata_tags from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY from ddtrace.constants import SPAN_KIND from ddtrace.constants import SPAN_MEASURED_KEY @@ -107,6 +108,9 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) - trace_utils.activate_distributed_headers( tracer, int_config=distributed_headers_config, request_headers=ctx["distributed_headers"] ) + distributed_context = ctx.get_item("distributed_context", traverse=True) + if distributed_context and not call_trace: + span_kwargs["child_of"] = distributed_context span_kwargs.update(kwargs) span = (tracer.trace if call_trace else tracer.start_span)(ctx["span_name"], **span_kwargs) for tk, tv in ctx.get_item("tags", dict()).items(): @@ -569,20 +573,20 @@ def _on_botocore_patched_api_call_started(ctx): span.start_ns = start_ns -def _on_botocore_patched_api_call_exception(ctx, response, exception_type, set_response_metadata_tags): +def _on_botocore_patched_api_call_exception(ctx, response, exception_type, is_error_code_fn): span = ctx.get_item(ctx.get_item("call_key")) # `ClientError.response` contains the result, so we can still grab response metadata - set_response_metadata_tags(span, response) + set_botocore_response_metadata_tags(span, response, is_error_code_fn=is_error_code_fn) # If we have a status code, and the status code is not an error, # then ignore the exception being raised status_code = span.get_tag(http.STATUS_CODE) - if status_code and not config.botocore.operations[span.resource].is_error_code(int(status_code)): + if status_code and not is_error_code_fn(int(status_code)): span._ignore_exception(exception_type) -def _on_botocore_patched_api_call_success(ctx, response, set_response_metadata_tags): - set_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response) +def _on_botocore_patched_api_call_success(ctx, response): + set_botocore_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response) def _on_botocore_trace_context_injection_prepared( @@ -682,6 +686,31 @@ def _on_botocore_bedrock_process_response( span.finish() +def _on_botocore_sqs_recvmessage_post( + ctx: core.ExecutionContext, _, result: Dict, propagate: bool, message_parser: Callable +) -> None: + if result is not None and "Messages" in result and len(result["Messages"]) >= 1: + ctx.set_item("message_received", True) + if propagate: + ctx.set_safe("distributed_context", extract_DD_context_from_messages(result["Messages"], message_parser)) + + +def _on_botocore_kinesis_getrecords_post( + ctx: core.ExecutionContext, + _, + __, + ___, + ____, + result, + propagate: bool, + message_parser: Callable, +): + if result is not None and "Records" in result and len(result["Records"]) >= 1: + ctx.set_item("message_received", True) + if propagate: + ctx.set_item("distributed_context", extract_DD_context_from_messages(result["Records"], message_parser)) + + def _on_redis_async_command_post(span, rowcount): if rowcount is not None: span.set_metric(db.ROWCOUNT, rowcount) @@ -727,10 +756,14 @@ def listen(): core.on("botocore.patched_stepfunctions_api_call.started", _on_botocore_patched_api_call_started) core.on("botocore.patched_stepfunctions_api_call.exception", _on_botocore_patched_api_call_exception) core.on("botocore.stepfunctions.update_messages", _on_botocore_update_messages) + core.on("botocore.eventbridge.update_messages", _on_botocore_update_messages) + core.on("botocore.client_context.update_messages", _on_botocore_update_messages) core.on("botocore.patched_bedrock_api_call.started", _on_botocore_patched_bedrock_api_call_started) core.on("botocore.patched_bedrock_api_call.exception", _on_botocore_patched_bedrock_api_call_exception) core.on("botocore.patched_bedrock_api_call.success", _on_botocore_patched_bedrock_api_call_success) core.on("botocore.bedrock.process_response", _on_botocore_bedrock_process_response) + core.on("botocore.sqs.ReceiveMessage.post", _on_botocore_sqs_recvmessage_post) + core.on("botocore.kinesis.GetRecords.post", _on_botocore_kinesis_getrecords_post) core.on("redis.async_command.post", _on_redis_async_command_post) for context_name in ( diff --git a/ddtrace/_trace/utils.py b/ddtrace/_trace/utils.py index 0e1a9364582..44bef3bbf23 100644 --- a/ddtrace/_trace/utils.py +++ b/ddtrace/_trace/utils.py @@ -1,3 +1,8 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional + from ddtrace import Span from ddtrace import config from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY @@ -5,8 +10,10 @@ from ddtrace.constants import SPAN_MEASURED_KEY from ddtrace.ext import SpanKind from ddtrace.ext import aws +from ddtrace.ext import http from ddtrace.internal.constants import COMPONENT from ddtrace.internal.utils.formats import deep_getattr +from ddtrace.propagation.http import HTTPPropagator def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, endpoint_name, operation): @@ -39,3 +46,37 @@ def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, # set analytics sample rate span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.botocore.get_analytics_sample_rate()) + + +def set_botocore_response_metadata_tags( + span: Span, result: Dict[str, Any], is_error_code_fn: Optional[Callable] = None +) -> None: + if not result or not result.get("ResponseMetadata"): + return + response_meta = result["ResponseMetadata"] + + if "HTTPStatusCode" in response_meta: + status_code = response_meta["HTTPStatusCode"] + span.set_tag(http.STATUS_CODE, status_code) + + # Mark this span as an error if requested + if is_error_code_fn is not None and is_error_code_fn(int(status_code)): + span.error = 1 + + if "RetryAttempts" in response_meta: + span.set_tag("retry_attempts", response_meta["RetryAttempts"]) + + if "RequestId" in response_meta: + span.set_tag_str("aws.requestid", response_meta["RequestId"]) + + +def extract_DD_context_from_messages(messages, extract_from_message: Callable): + ctx = None + if len(messages) >= 1: + message = messages[0] + context_json = extract_from_message(message) + if context_json is not None: + child_of = HTTPPropagator.extract(context_json) + if child_of.trace_id is not None: + ctx = child_of + return ctx diff --git a/ddtrace/contrib/botocore/patch.py b/ddtrace/contrib/botocore/patch.py index e0bcc3f317f..b4f1a5265ea 100644 --- a/ddtrace/contrib/botocore/patch.py +++ b/ddtrace/contrib/botocore/patch.py @@ -39,9 +39,8 @@ from .services.sqs import update_messages as inject_trace_to_sqs_or_sns_message from .services.stepfunctions import patched_stepfunction_api_call from .services.stepfunctions import update_stepfunction_input -from .utils import inject_trace_to_client_context -from .utils import inject_trace_to_eventbridge_detail -from .utils import set_response_metadata_tags +from .utils import update_client_context +from .utils import update_eventbridge_detail _PATCHED_SUBMODULES = set() # type: Set[str] @@ -175,11 +174,11 @@ def prep_context_injection(ctx, endpoint_name, operation, trace_operation, param schematization_function = schematize_cloud_messaging_operation if endpoint_name == "lambda" and operation == "Invoke": - injection_function = inject_trace_to_client_context + injection_function = update_client_context schematization_function = schematize_cloud_faas_operation cloud_service = "lambda" if endpoint_name == "events" and operation == "PutEvents": - injection_function = inject_trace_to_eventbridge_detail + injection_function = update_eventbridge_detail cloud_service = "events" if endpoint_name == "sns" and "Publish" in operation: injection_function = inject_trace_to_sqs_or_sns_message @@ -224,9 +223,14 @@ def patched_api_call_fallback(original_func, instance, args, kwargs, function_va except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx["instrumented_api_call"].resource].is_error_code, + ], ) raise else: - core.dispatch("botocore.patched_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_api_call.success", [ctx, result]) return result diff --git a/ddtrace/contrib/botocore/services/kinesis.py b/ddtrace/contrib/botocore/services/kinesis.py index 412f0b0c27f..858f011410f 100644 --- a/ddtrace/contrib/botocore/services/kinesis.py +++ b/ddtrace/contrib/botocore/services/kinesis.py @@ -17,9 +17,8 @@ from ....internal.logger import get_logger from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import extract_DD_context +from ..utils import extract_DD_json from ..utils import get_kinesis_data_object -from ..utils import set_response_metadata_tags log = get_logger(__name__) @@ -74,13 +73,14 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var endpoint_name = function_vars.get("endpoint_name") operation = function_vars.get("operation") - message_received = False is_getrecords_call = False getrecords_error = None - child_of = None start_ns = None result = None + parent_ctx: core.ExecutionContext = core.ExecutionContext( + "botocore.patched_sqs_api_call.propagated", + ) if operation == "GetRecords": try: start_ns = time_ns() @@ -95,15 +95,20 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var time_estimate = record.get("ApproximateArrivalTimestamp", datetime.now()).timestamp() core.dispatch( f"botocore.{endpoint_name}.{operation}.post", - [params, time_estimate, data_obj.get("_datadog"), record], + [ + parent_ctx, + params, + time_estimate, + data_obj.get("_datadog"), + record, + result, + config.botocore.propagation_enabled, + extract_DD_json, + ], ) except Exception as e: getrecords_error = e - if result is not None and "Records" in result and len(result["Records"]) >= 1: - message_received = True - if config.botocore.propagation_enabled: - child_of = extract_DD_context(result["Records"]) if endpoint_name == "kinesis" and operation in {"PutRecord", "PutRecords"}: span_name = schematize_cloud_messaging_operation( @@ -116,7 +121,7 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var span_name = trace_operation stream_arn = params.get("StreamARN", params.get("StreamName", "")) function_is_not_getrecords = not is_getrecords_call - received_message_when_polling = is_getrecords_call and message_received + received_message_when_polling = is_getrecords_call and parent_ctx.get_item("message_received") instrument_empty_poll_calls = config.botocore.empty_poll_enabled should_instrument = ( received_message_when_polling or instrument_empty_poll_calls or function_is_not_getrecords or getrecords_error @@ -126,6 +131,7 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var if should_instrument: with core.context_with_data( "botocore.patched_kinesis_api_call", + parent=parent_ctx, instance=instance, args=args, params=params, @@ -136,7 +142,6 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var pin=pin, span_name=span_name, span_type=SpanTypes.HTTP, - child_of=child_of if child_of is not None else pin.tracer.context_provider.active(), activate=True, func_run=is_getrecords_call, start_ns=start_ns, @@ -158,15 +163,21 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var if getrecords_error: raise getrecords_error - core.dispatch("botocore.patched_kinesis_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_kinesis_api_call.success", [ctx, result]) return result except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_kinesis_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise + parent_ctx.end() elif is_getrecords_call: if getrecords_error: raise getrecords_error diff --git a/ddtrace/contrib/botocore/services/sqs.py b/ddtrace/contrib/botocore/services/sqs.py index 37080c85d70..25de175853a 100644 --- a/ddtrace/contrib/botocore/services/sqs.py +++ b/ddtrace/contrib/botocore/services/sqs.py @@ -7,8 +7,6 @@ import botocore.exceptions from ddtrace import config -from ddtrace.contrib.botocore.utils import extract_DD_context -from ddtrace.contrib.botocore.utils import set_response_metadata_tags from ddtrace.ext import SpanTypes from ddtrace.internal import core from ddtrace.internal.logger import get_logger @@ -16,6 +14,8 @@ from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.schema.span_attribute_schema import SpanDirection +from ..utils import extract_DD_json + log = get_logger(__name__) MAX_INJECTION_DATA_ATTRIBUTES = 10 @@ -83,16 +83,19 @@ def _ensure_datadog_messageattribute_enabled(params): def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): + with core.context_with_data("botocore.patched_sqs_api_call.propagated") as parent_ctx: + return _patched_sqs_api_call(parent_ctx, original_func, instance, args, kwargs, function_vars) + + +def _patched_sqs_api_call(parent_ctx, original_func, instance, args, kwargs, function_vars): params = function_vars.get("params") trace_operation = function_vars.get("trace_operation") pin = function_vars.get("pin") endpoint_name = function_vars.get("endpoint_name") operation = function_vars.get("operation") - message_received = False func_has_run = False func_run_err = None - child_of = None result = None if operation == "ReceiveMessage": @@ -103,16 +106,15 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): core.dispatch(f"botocore.{endpoint_name}.{operation}.pre", [params]) # run the function to extract possible parent context before creating ExecutionContext result = original_func(*args, **kwargs) - core.dispatch(f"botocore.{endpoint_name}.{operation}.post", [params, result]) + core.dispatch( + f"botocore.{endpoint_name}.{operation}.post", + [parent_ctx, params, result, config.botocore.propagation_enabled, extract_DD_json], + ) except Exception as e: func_run_err = e - if result is not None and "Messages" in result and len(result["Messages"]) >= 1: - message_received = True - if config.botocore.propagation_enabled: - child_of = extract_DD_context(result["Messages"]) function_is_not_recvmessage = not func_has_run - received_message_when_polling = func_has_run and message_received + received_message_when_polling = func_has_run and parent_ctx.get_item("message_received") instrument_empty_poll_calls = config.botocore.empty_poll_enabled should_instrument = ( received_message_when_polling or instrument_empty_poll_calls or function_is_not_recvmessage or func_run_err @@ -133,9 +135,12 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): else: call_name = trace_operation + child_of = parent_ctx.get_item("distributed_context") + if should_instrument: with core.context_with_data( "botocore.patched_sqs_api_call", + parent=parent_ctx, span_name=call_name, service=schematize_service_name("{}.{}".format(pin.service, endpoint_name)), span_type=SpanTypes.HTTP, @@ -161,7 +166,7 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): result = original_func(*args, **kwargs) core.dispatch(f"botocore.{endpoint_name}.{operation}.post", [params, result]) - core.dispatch("botocore.patched_sqs_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_sqs_api_call.success", [ctx, result]) if func_run_err: raise func_run_err @@ -169,7 +174,12 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_sqs_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise elif func_has_run: diff --git a/ddtrace/contrib/botocore/services/stepfunctions.py b/ddtrace/contrib/botocore/services/stepfunctions.py index d611f664a48..16213f2e3ed 100644 --- a/ddtrace/contrib/botocore/services/stepfunctions.py +++ b/ddtrace/contrib/botocore/services/stepfunctions.py @@ -12,7 +12,6 @@ from ....internal.schema import SpanDirection from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import set_response_metadata_tags log = get_logger(__name__) @@ -81,6 +80,11 @@ def patched_stepfunction_api_call(original_func, instance, args, kwargs: Dict, f except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_stepfunctions_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise diff --git a/ddtrace/contrib/botocore/utils.py b/ddtrace/contrib/botocore/utils.py index ead47ace10c..5804a4e1a36 100644 --- a/ddtrace/contrib/botocore/utils.py +++ b/ddtrace/contrib/botocore/utils.py @@ -8,13 +8,11 @@ from typing import Optional from typing import Tuple -from ddtrace import Span from ddtrace import config +from ddtrace.internal import core from ddtrace.internal.core import ExecutionContext -from ...ext import http from ...internal.logger import get_logger -from ...propagation.http import HTTPPropagator log = get_logger(__name__) @@ -66,11 +64,7 @@ def get_kinesis_data_object(data: str) -> Tuple[str, Optional[Dict[str, Any]]]: return None, None -def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: - """ - Inject trace headers into the EventBridge record if the record's Detail object contains a JSON string - Max size per event is 256KB (https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-putevent-size.html) - """ +def update_eventbridge_detail(ctx: ExecutionContext) -> None: params = ctx["params"] if "Entries" not in params: log.warning("Unable to inject context. The Event Bridge event had no Entries.") @@ -86,8 +80,7 @@ def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: continue detail["_datadog"] = {} - span = ctx[ctx["call_key"]] - HTTPPropagator.inject(span.context, detail["_datadog"]) + core.dispatch("botocore.eventbridge.update_messages", [ctx, None, None, detail["_datadog"], None]) detail_json = json.dumps(detail) # check if detail size will exceed max size with headers @@ -99,12 +92,11 @@ def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: entry["Detail"] = detail_json -def inject_trace_to_client_context(ctx): +def update_client_context(ctx: ExecutionContext) -> None: trace_headers = {} - span = ctx[ctx["call_key"]] - params = ctx["params"] - HTTPPropagator.inject(span.context, trace_headers) + core.dispatch("botocore.client_context.update_messages", [ctx, None, None, trace_headers, None]) client_context_object = {} + params = ctx["params"] if "ClientContext" in params: try: client_context_json = base64.b64decode(params["ClientContext"]).decode("utf-8") @@ -131,39 +123,7 @@ def modify_client_context(client_context_object, trace_headers): client_context_object["custom"] = trace_headers -def set_response_metadata_tags(span: Span, result: Dict[str, Any]) -> None: - if not result or not result.get("ResponseMetadata"): - return - response_meta = result["ResponseMetadata"] - - if "HTTPStatusCode" in response_meta: - status_code = response_meta["HTTPStatusCode"] - span.set_tag(http.STATUS_CODE, status_code) - - # Mark this span as an error if requested - if config.botocore.operations[span.resource].is_error_code(int(status_code)): - span.error = 1 - - if "RetryAttempts" in response_meta: - span.set_tag("retry_attempts", response_meta["RetryAttempts"]) - - if "RequestId" in response_meta: - span.set_tag_str("aws.requestid", response_meta["RequestId"]) - - -def extract_DD_context(messages): - ctx = None - if len(messages) >= 1: - message = messages[0] - context_json = extract_trace_context_json(message) - if context_json is not None: - child_of = HTTPPropagator.extract(context_json) - if child_of.trace_id is not None: - ctx = child_of - return ctx - - -def extract_trace_context_json(message): +def extract_DD_json(message): context_json = None try: if message and message.get("Type") == "Notification": @@ -200,7 +160,7 @@ def extract_trace_context_json(message): if "Body" in message: try: body = json.loads(message["Body"]) - return extract_trace_context_json(body) + return extract_DD_json(body) except ValueError: log.debug("Unable to parse AWS message body.") except Exception: diff --git a/ddtrace/internal/datastreams/botocore.py b/ddtrace/internal/datastreams/botocore.py index 1f1b79aee80..ec004f1ff9a 100644 --- a/ddtrace/internal/datastreams/botocore.py +++ b/ddtrace/internal/datastreams/botocore.py @@ -172,7 +172,7 @@ def get_datastreams_context(message): return context_json -def handle_sqs_receive(params, result): +def handle_sqs_receive(_, params, result, *args): from . import data_streams_processor as processor queue_name = get_queue_name(params) @@ -206,7 +206,7 @@ def record_data_streams_path_for_kinesis_stream(params, time_estimate, context_j ) -def handle_kinesis_receive(params, time_estimate, context_json, record): +def handle_kinesis_receive(_, params, time_estimate, context_json, record, *args): try: record_data_streams_path_for_kinesis_stream(params, time_estimate, context_json, record) except Exception: diff --git a/tests/contrib/botocore/test.py b/tests/contrib/botocore/test.py index 8709964db6b..aa9627169a6 100644 --- a/tests/contrib/botocore/test.py +++ b/tests/contrib/botocore/test.py @@ -312,7 +312,7 @@ def test_s3_client(self): @mock_s3 def test_s3_head_404_default(self): """ - By default we attach exception information to s3 HeadObject + By default we do not attach exception information to s3 HeadObject API calls with a 404 response """ s3 = self.session.create_client("s3", region_name="us-west-2") From 357cb3b858d46234ab4e2eb1d422a156964b7e60 Mon Sep 17 00:00:00 2001 From: erikayasuda <153395705+erikayasuda@users.noreply.github.com> Date: Wed, 1 May 2024 16:33:28 -0400 Subject: [PATCH 3/9] fix(redis): added back tracer_utils_redis with deprecation warn (#9145) ## Checklist Adds back and deprecates old `tracer_utils_redis` module with public method and variables. - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --------- Co-authored-by: Emmett Butler <723615+emmettbutler@users.noreply.github.com> --- ddtrace/contrib/trace_utils_redis.py | 18 ++++++++++++++++++ tests/.suitespec.json | 1 + 2 files changed, 19 insertions(+) create mode 100644 ddtrace/contrib/trace_utils_redis.py diff --git a/ddtrace/contrib/trace_utils_redis.py b/ddtrace/contrib/trace_utils_redis.py new file mode 100644 index 00000000000..8df16c3ce4d --- /dev/null +++ b/ddtrace/contrib/trace_utils_redis.py @@ -0,0 +1,18 @@ +from ddtrace.contrib.redis_utils import determine_row_count +from ddtrace.contrib.redis_utils import stringify_cache_args +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate + + +deprecate( + "The ddtrace.contrib.trace_utils_redis module is deprecated and will be removed.", + message="A new interface will be provided by the ddtrace.contrib.redis_utils module", + category=DDTraceDeprecationWarning, +) + + +format_command_args = stringify_cache_args + + +def determine_row_count(redis_command, span, result): # noqa: F811 + determine_row_count(redis_command=redis_command, result=result) diff --git a/tests/.suitespec.json b/tests/.suitespec.json index 7e6f1512ec4..143ef63a62e 100644 --- a/tests/.suitespec.json +++ b/tests/.suitespec.json @@ -141,6 +141,7 @@ "ddtrace/contrib/yaaredis/*", "ddtrace/_trace/utils_redis.py", "ddtrace/contrib/redis_utils.py", + "ddtrace/contrib/trace_utils_redis.py", "ddtrace/ext/redis.py" ], "mongo": [ From 7a55b3ef290e824a2584fcbb88e4d205b50d488e Mon Sep 17 00:00:00 2001 From: Juanjo Alvarez Martinez Date: Thu, 2 May 2024 14:30:26 +0200 Subject: [PATCH 4/9] chore: add memcheck tests for the new splitter aspects (#9146) ## Description - Add memcheck fixtures for the new splitter aspects that were recently merged. - Add comments to make them easier to follow. ## Checklist - [X] Change(s) are motivated and described in the PR description - [X] Testing strategy is described if automated tests are not included in the PR - [X] Risks are described (performance impact, potential for breakage, maintainability) - [X] Change is maintainable (easy to change, telemetry, documentation) - [X] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [X] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [X] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [X] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --------- Signed-off-by: Juanjo Alvarez --- .../appsec/iast/fixtures/propagation_path.py | 50 +++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/tests/appsec/iast/fixtures/propagation_path.py b/tests/appsec/iast/fixtures/propagation_path.py index b4f6616bc27..44b4d2aafee 100644 --- a/tests/appsec/iast/fixtures/propagation_path.py +++ b/tests/appsec/iast/fixtures/propagation_path.py @@ -3,6 +3,7 @@ make some changes """ import os +import sys ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -114,25 +115,68 @@ def propagation_path_5_prop(origin_string1, tainted_string_2): def propagation_memory_check(origin_string1, tainted_string_2): + import os.path + if type(origin_string1) is str: string1 = str(origin_string1) # 1 Range else: string1 = str(origin_string1, encoding="utf-8") # 1 Range + # string1 = taintsource if type(tainted_string_2) is str: string2 = str(tainted_string_2) # 1 Range else: string2 = str(tainted_string_2, encoding="utf-8") # 1 Range + # string2 = taintsource2 string3 = string1 + string2 # 2 Ranges + # taintsource1taintsource2 string4 = "-".join([string3, string3, string3]) # 6 Ranges + # taintsource1taintsource2-taintsource1taintsource2-taintsource1taintsource2 string5 = string4[0 : (len(string4) - 1)] + # taintsource1taintsource2-taintsource1taintsource2-taintsource1taintsource string6 = string5.title() + # Taintsource1Taintsource2-Taintsource1Taintsource2-Taintsource1Taintsource string7 = string6.upper() + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE string8 = "%s_notainted" % string7 - string9 = "notainted_{}".format(string8) + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string9 = "notainted#{}".format(string8) + # notainted#TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string10 = string9.split("#")[1] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string11 = "notainted#{}".format(string10) + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string12 = string11.rsplit("#")[1] + string13 = string12 + "\n" + "notainted" + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted\nnotainted + string14 = string13.splitlines()[0] # string14 = string12 + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string15 = os.path.join("foo", "bar", string14) + # /foo/bar/TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string16 = os.path.split(string15)[1] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string17 = string16 + ".jpg" + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted.jpg + string18 = os.path.splitext(string17)[0] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string19 = os.path.join(os.sep + string18, "nottainted_notdir") + # /TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted/nottainted_notdir + string20 = os.path.dirname(string19) + # /TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string21 = os.path.basename(string20) + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + + if sys.version_info >= (3, 12): + string22 = os.sep + string21 + # /TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string23 = os.path.splitroot(string22)[2] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + else: + string23 = string21 + try: # label propagation_memory_check - m = open(ROOT_DIR + "/" + string9 + ".txt") + m = open(ROOT_DIR + "/" + string23 + ".txt") _ = m.read() except Exception: pass - return string9 + return string23 From 5f9e15db0ca51385299e997c58163031b553227e Mon Sep 17 00:00:00 2001 From: Emmett Butler <723615+emmettbutler@users.noreply.github.com> Date: Thu, 2 May 2024 06:44:59 -0700 Subject: [PATCH 5/9] ci: mark some flaky tests (#9144) This change marks these two ephemeral CI failures as flaky: https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/60664/workflows/e6f953d9-721e-4a27-8064-5209f7ac3a15/jobs/3805647 last touched by @zarirhamza https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/60685/workflows/c25bae04-8827-4893-ac67-f239b2226774/jobs/3806940 last touched by @erikayasuda It also adds verbose output to the `appsec_iast` test suite to aid in figuring out which test or tests generates this ephemeral segfault: https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/60686/workflows/21d8cecd-9345-4e16-9157-71425bd65ccb/jobs/3807035 ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- riotfile.py | 2 +- tests/contrib/celery/test_integration.py | 2 ++ tests/internal/test_tracer_flare.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/riotfile.py b/riotfile.py index d1ee65dfa21..e52d53f37b1 100644 --- a/riotfile.py +++ b/riotfile.py @@ -133,7 +133,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION): Venv( name="appsec_iast", pys=select_pys(), - command="pytest {cmdargs} tests/appsec/iast/", + command="pytest -v {cmdargs} tests/appsec/iast/", pkgs={ "requests": latest, "pycryptodome": latest, diff --git a/tests/contrib/celery/test_integration.py b/tests/contrib/celery/test_integration.py index 21e716b8193..09c30f2c2ef 100644 --- a/tests/contrib/celery/test_integration.py +++ b/tests/contrib/celery/test_integration.py @@ -17,6 +17,7 @@ import ddtrace.internal.forksafe as forksafe from ddtrace.propagation.http import HTTPPropagator from tests.opentracer.utils import init_tracer +from tests.utils import flaky from ...utils import override_global_config from .base import CeleryBaseTestCase @@ -209,6 +210,7 @@ def fn_task_parameters(user, force_logout=False): assert run_span.get_tag("component") == "celery" assert run_span.get_tag("span.kind") == "consumer" + @flaky(1722529274) def test_fn_task_delay(self): # using delay shorthand must preserve arguments @self.app.task diff --git a/tests/internal/test_tracer_flare.py b/tests/internal/test_tracer_flare.py index 7051190e17d..560dcdc1ddd 100644 --- a/tests/internal/test_tracer_flare.py +++ b/tests/internal/test_tracer_flare.py @@ -13,6 +13,7 @@ from ddtrace.internal.flare import Flare from ddtrace.internal.flare import FlareSendRequest from ddtrace.internal.logger import get_logger +from tests.utils import flaky DEBUG_LEVEL_INT = logging.DEBUG @@ -118,6 +119,7 @@ def handle_agent_task(): for p in processes: p.join() + @flaky(1722529274) def test_multiple_process_partial_failure(self): """ Validte that even if the tracer flare fails for one process, we should From a2b1dbb90f883db3af07915f666dd7bab6a47605 Mon Sep 17 00:00:00 2001 From: William Conti <58711692+wconti27@users.noreply.github.com> Date: Thu, 2 May 2024 11:53:34 -0400 Subject: [PATCH 6/9] chore(dbm): add peer service precursor tag to sql injection (#9052) # Description We are adding the following tags to the DBM SQL injection comment in all tracers: - ddh: 'peer.hostname': hostname (or IP) of the db server the client is connecting to (ALREADY EXISTS IN PYTHON) - dddps: 'peer.db.name': database namespace (ALREADY EXISTS IN PYTHON) - ddprs: 'peer.service': only set if user explicitly tags the span with `peer.service` ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/propagation/_database_monitoring.py | 7 +++++ tests/contrib/aiomysql/test_aiomysql.py | 29 +++++++++++++++++++-- tests/contrib/shared_tests.py | 15 +++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/ddtrace/propagation/_database_monitoring.py b/ddtrace/propagation/_database_monitoring.py index 4210e4cbec6..4002f864dd8 100644 --- a/ddtrace/propagation/_database_monitoring.py +++ b/ddtrace/propagation/_database_monitoring.py @@ -23,6 +23,7 @@ DBM_DATABASE_SERVICE_NAME_KEY = "dddbs" DBM_PEER_HOSTNAME_KEY = "ddh" DBM_PEER_DB_NAME_KEY = "dddb" +DBM_PEER_SERVICE_KEY = "ddprs" DBM_ENVIRONMENT_KEY = "dde" DBM_VERSION_KEY = "ddpv" DBM_TRACE_PARENT_KEY = "traceparent" @@ -56,12 +57,14 @@ def __init__( sql_injector=default_sql_injector, peer_hostname_tag="out.host", peer_db_name_tag="db.name", + peer_service_tag="peer.service", ): self.sql_pos = sql_pos self.sql_kw = sql_kw self.sql_injector = sql_injector self.peer_hostname_tag = peer_hostname_tag self.peer_db_name_tag = peer_db_name_tag + self.peer_service_tag = peer_service_tag def inject(self, dbspan, args, kwargs): # run sampling before injection to propagate correct sampling priority @@ -114,6 +117,10 @@ def _get_dbm_comment(self, db_span): if peer_hostname: dbm_tags[DBM_PEER_HOSTNAME_KEY] = peer_hostname + peer_service = db_span.get_tag(self.peer_service_tag) + if peer_service: + dbm_tags[DBM_PEER_SERVICE_KEY] = peer_service + if dbm_config.propagation_mode == "full": db_span.set_tag_str(DBM_TRACE_INJECTED_TAG, "true") dbm_tags[DBM_TRACE_PARENT_KEY] = db_span.context._traceparent diff --git a/tests/contrib/aiomysql/test_aiomysql.py b/tests/contrib/aiomysql/test_aiomysql.py index 35e0a7e09c6..2247b2dba6f 100644 --- a/tests/contrib/aiomysql/test_aiomysql.py +++ b/tests/contrib/aiomysql/test_aiomysql.py @@ -230,7 +230,9 @@ class AioMySQLTestCase(AsyncioTestCase): TEST_SERVICE = "mysql" conn = None - async def _get_conn_tracer(self): + async def _get_conn_tracer(self, tags=None): + tags = tags if tags is not None else {} + if not self.conn: self.conn = await aiomysql.connect(**AIOMYSQL_CONFIG) assert not self.conn.closed @@ -239,7 +241,7 @@ async def _get_conn_tracer(self): assert pin # Customize the service # we have to apply it on the existing one since new one won't inherit `app` - pin.clone(tracer=self.tracer).onto(self.conn) + pin.clone(tracer=self.tracer, tags={**tags, **pin.tags}).onto(self.conn) return self.conn, self.tracer @@ -429,3 +431,26 @@ async def test_aiomysql_dbm_propagation_comment_peer_service_enabled(self): await shared_tests._test_dbm_propagation_comment_peer_service_enabled( config=AIOMYSQL_CONFIG, cursor=cursor, wrapped_instance=cursor.__wrapped__ ) + + @mark_asyncio + @AsyncioTestCase.run_in_subprocess( + env_overrides=dict( + DD_DBM_PROPAGATION_MODE="service", + DD_SERVICE="orders-app", + DD_ENV="staging", + DD_VERSION="v7343437-d7ac743", + DD_TRACE_SPAN_ATTRIBUTE_SCHEMA="v1", + ) + ) + async def test_aiomysql_dbm_propagation_comment_with_peer_service_tag(self): + """tests if dbm comment is set in mysql""" + conn, tracer = await self._get_conn_tracer({"peer.service": "peer_service_name"}) + cursor = await conn.cursor() + cursor.__wrapped__ = mock.AsyncMock() + + await shared_tests._test_dbm_propagation_comment_with_peer_service_tag( + config=AIOMYSQL_CONFIG, + cursor=cursor, + wrapped_instance=cursor.__wrapped__, + peer_service_name="peer_service_name", + ) diff --git a/tests/contrib/shared_tests.py b/tests/contrib/shared_tests.py index 2ccb319551f..97d1df32cfa 100644 --- a/tests/contrib/shared_tests.py +++ b/tests/contrib/shared_tests.py @@ -94,3 +94,18 @@ async def _test_dbm_propagation_comment_peer_service_enabled(config, cursor, wra await _test_execute(dbm_comment, cursor, wrapped_instance) if execute_many: await _test_execute_many(dbm_comment, cursor, wrapped_instance) + + +async def _test_dbm_propagation_comment_with_peer_service_tag( + config, cursor, wrapped_instance, peer_service_name, execute_many=True +): + """tests if dbm comment is set in mysql""" + db_name = config["db"] + + dbm_comment = ( + f"/*dddb='{db_name}',dddbs='test',dde='staging',ddh='127.0.0.1',ddprs='{peer_service_name}',ddps='orders-app'," + "ddpv='v7343437-d7ac743'*/ " + ) + await _test_execute(dbm_comment, cursor, wrapped_instance) + if execute_many: + await _test_execute_many(dbm_comment, cursor, wrapped_instance) From 01fbf9127180794392cc7dbe16acd610087dacf6 Mon Sep 17 00:00:00 2001 From: Christophe Papazian <114495376+christophe-papazian@users.noreply.github.com> Date: Thu, 2 May 2024 18:19:01 +0200 Subject: [PATCH 7/9] chore(asm): add support for blocking request in rasp flask (#9147) Add support for blocking web requests from anywhere using exploit prevention in all ASM supported frameworks. # Motivation Exploit Prevention, a new ASM feature, must be able to block a request from anywhere in the customer code, bypassing all remaining customer code to the end of the request # Content - Add a BlockingException in the tracer internal, deriving from BaseException to avoid any "catch all exception" mechanism in the code - Add support in Django, FastAPI and Flask to properly catch and manage BlockingException - Add a failsafe mechanism in asm_request_context to ensure that no BlockingException will be propagated outside of the outermost context. Also ensure that the exception is only thrown inside an asm context. - Add unit tests for all the frameworks to test blocking requests and for asm_request_context A specific release note will be added once exploit prevention will be enabled by default. APPSEC-52972 ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- .github/CODEOWNERS | 1 + ddtrace/appsec/_asm_request_context.py | 8 ++ ddtrace/appsec/_common_module_patches.py | 16 ++- ddtrace/contrib/asgi/middleware.py | 4 + ddtrace/contrib/django/patch.py | 10 +- ddtrace/contrib/wsgi/wsgi.py | 22 ++-- ddtrace/internal/_exceptions.py | 5 + tests/.suitespec.json | 1 + tests/appsec/appsec/rules-rasp-blocking.json | 106 ++++++++++++++++++ .../appsec/appsec/test_asm_request_context.py | 17 +++ .../appsec/contrib_appsec/django_app/urls.py | 3 + .../appsec/contrib_appsec/fastapi_app/app.py | 3 + tests/appsec/contrib_appsec/flask_app/app.py | 4 + tests/appsec/contrib_appsec/utils.py | 37 +++++- tests/appsec/rules.py | 1 + 15 files changed, 218 insertions(+), 20 deletions(-) create mode 100644 ddtrace/internal/_exceptions.py create mode 100644 tests/appsec/appsec/rules-rasp-blocking.json diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 23022df324d..01e3effa28e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -63,6 +63,7 @@ ddtrace/appsec/ @DataDog/asm-python ddtrace/settings/asm.py @DataDog/asm-python ddtrace/contrib/subprocess/ @DataDog/asm-python ddtrace/contrib/flask_login/ @DataDog/asm-python +ddtrace/internal/_exceptions.py @DataDog/asm-python tests/appsec/ @DataDog/asm-python tests/contrib/dbapi/test_dbapi_appsec.py @DataDog/asm-python tests/contrib/subprocess @DataDog/asm-python diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 654e06a29e5..ec88464cabe 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -20,6 +20,7 @@ from ddtrace.appsec._iast._utils import _is_iast_enabled from ddtrace.appsec._utils import get_triggers from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.constants import REQUEST_PATH_PARAMS from ddtrace.internal.logger import get_logger from ddtrace.settings.asm import config as asm_config @@ -140,6 +141,7 @@ def __init__(self): env = ASM_Environment(True) self._id = _DataHandler.main_id + self._root = not in_context() self.active = True self.execution_context = core.ExecutionContext(__name__, **{"asm_env": env}) @@ -393,6 +395,12 @@ def asm_request_context_manager( if resources is not None: try: yield resources + except BlockingException as e: + # ensure that the BlockingRequest that is never raised outside a context + # is also never propagated outside the context + core.set_item(WAF_CONTEXT_NAMES.BLOCKED, e.args[0]) + if not resources._root: + raise finally: _end_context(resources) else: diff --git a/ddtrace/appsec/_common_module_patches.py b/ddtrace/appsec/_common_module_patches.py index 69c2610cab5..71d2fa59b5b 100644 --- a/ddtrace/appsec/_common_module_patches.py +++ b/ddtrace/appsec/_common_module_patches.py @@ -8,7 +8,9 @@ from typing import Callable from typing import Dict +from ddtrace.appsec._constants import WAF_CONTEXT_NAMES from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.logger import get_logger from ddtrace.settings.asm import config as asm_config from ddtrace.vendor.wrapt import FunctionWrapper @@ -49,6 +51,7 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._asm_request_context import is_blocked from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization @@ -66,7 +69,9 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs crop_trace="wrapped_open_CFDDB7ABBA9081B6", rule_type=EXPLOIT_PREVENTION.TYPE.LFI, ) - # DEV: Next part of the exploit prevention feature: add block here + if is_blocked(): + raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "lfi", filename) + return original_open_callable(*args, **kwargs) @@ -82,6 +87,7 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._asm_request_context import is_blocked from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization @@ -98,7 +104,8 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs crop_trace="wrapped_open_ED4CF71136E15EBF", rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, ) - # DEV: Next part of the exploit prevention feature: add block here + if is_blocked(): + raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url) return original_open_callable(*args, **kwargs) @@ -115,6 +122,7 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._asm_request_context import is_blocked from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization @@ -129,7 +137,9 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, crop_trace="wrapped_request_D8CB81E472AF98A2", rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, ) - # DEV: Next part of the exploit prevention feature: add block here + if is_blocked(): + raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url) + return original_request_callable(*args, **kwargs) diff --git a/ddtrace/contrib/asgi/middleware.py b/ddtrace/contrib/asgi/middleware.py index 70388af0de5..21061cf63fe 100644 --- a/ddtrace/contrib/asgi/middleware.py +++ b/ddtrace/contrib/asgi/middleware.py @@ -13,6 +13,7 @@ from ddtrace.ext import SpanKind from ddtrace.ext import SpanTypes from ddtrace.ext import http +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.compat import is_valid_ip from ddtrace.internal.constants import COMPONENT from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED @@ -288,6 +289,9 @@ async def wrapped_blocked_send(message): try: core.dispatch("asgi.start_request", ("asgi",)) return await self.app(scope, receive, wrapped_send) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) except trace_utils.InterruptException: return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) except Exception as exc: diff --git a/ddtrace/contrib/django/patch.py b/ddtrace/contrib/django/patch.py index 0f4e2318c89..670e94fe1ba 100644 --- a/ddtrace/contrib/django/patch.py +++ b/ddtrace/contrib/django/patch.py @@ -22,6 +22,7 @@ from ddtrace.ext import http from ddtrace.ext import sql as sqlx from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.compat import Iterable from ddtrace.internal.compat import maybe_stringify from ddtrace.internal.constants import COMPONENT @@ -467,7 +468,7 @@ def traced_get_response(django, pin, func, instance, args, kwargs): def blocked_response(): from django.http import HttpResponse - block_config = core.get_item(HTTP_REQUEST_BLOCKED) + block_config = core.get_item(HTTP_REQUEST_BLOCKED) or {} desired_type = block_config.get("type", "auto") status = block_config.get("status_code", 403) if desired_type == "none": @@ -510,7 +511,12 @@ def blocked_response(): response = blocked_response() return response - response = func(*args, **kwargs) + try: + response = func(*args, **kwargs) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + response = blocked_response() + return response if core.get_item(HTTP_REQUEST_BLOCKED): response = blocked_response() diff --git a/ddtrace/contrib/wsgi/wsgi.py b/ddtrace/contrib/wsgi/wsgi.py index 1714bdfa1a1..aff74e3b0a0 100644 --- a/ddtrace/contrib/wsgi/wsgi.py +++ b/ddtrace/contrib/wsgi/wsgi.py @@ -24,6 +24,7 @@ from ddtrace.contrib import trace_utils from ddtrace.ext import SpanKind from ddtrace.ext import SpanTypes +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.constants import COMPONENT from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED from ddtrace.internal.logger import get_logger @@ -109,15 +110,6 @@ def __call__(self, environ: Iterable, start_response: Callable) -> wrapt.ObjectP call_key="req_span", ) as ctx: ctx.set_item("wsgi.construct_url", construct_url) - if core.get_item(HTTP_REQUEST_BLOCKED): - result = core.dispatch_with_results("wsgi.block.started", (ctx, construct_url)).status_headers_content - if result: - status, headers, content = result.value - else: - status, headers, content = 403, [], "" - start_response(str(status), headers) - closing_iterable = [content] - not_blocked = False def blocked_view(): result = core.dispatch_with_results("wsgi.block.started", (ctx, construct_url)).status_headers_content @@ -127,12 +119,24 @@ def blocked_view(): status, headers, content = 403, [], "" return content, status, headers + if core.get_item(HTTP_REQUEST_BLOCKED): + content, status, headers = blocked_view() + start_response(str(status), headers) + closing_iterable = [content] + not_blocked = False + core.dispatch("wsgi.block_decided", (blocked_view,)) if not_blocked: core.dispatch("wsgi.request.prepare", (ctx, start_response)) try: closing_iterable = self.app(environ, ctx.get_item("intercept_start_response")) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + content, status, headers = blocked_view() + start_response(str(status), headers) + closing_iterable = [content] + core.dispatch("wsgi.app.exception", (ctx,)) except BaseException: core.dispatch("wsgi.app.exception", (ctx,)) raise diff --git a/ddtrace/internal/_exceptions.py b/ddtrace/internal/_exceptions.py new file mode 100644 index 00000000000..01e45d2b063 --- /dev/null +++ b/ddtrace/internal/_exceptions.py @@ -0,0 +1,5 @@ +class BlockingException(BaseException): + """ + Exception raised when a request is blocked by ASM + It derives from BaseException to avoid being caught by the general Exception handler + """ diff --git a/tests/.suitespec.json b/tests/.suitespec.json index 143ef63a62e..e1d036b4581 100644 --- a/tests/.suitespec.json +++ b/tests/.suitespec.json @@ -34,6 +34,7 @@ ], "core": [ "ddtrace/internal/__init__.py", + "ddtrace/internal/_exceptions.py", "ddtrace/internal/_rand.pyi", "ddtrace/internal/_rand.pyx", "ddtrace/internal/_stdint.h", diff --git a/tests/appsec/appsec/rules-rasp-blocking.json b/tests/appsec/appsec/rules-rasp-blocking.json new file mode 100644 index 00000000000..21604f13e04 --- /dev/null +++ b/tests/appsec/appsec/rules-rasp-blocking.json @@ -0,0 +1,106 @@ +{ + "version": "2.1", + "metadata": { + "rules_version": "rules_rasp" + }, + "rules": [ + { + "id": "rasp-930-100", + "name": "Local file inclusion exploit", + "tags": { + "type": "lfi", + "category": "vulnerability_trigger", + "cwe": "22", + "capec": "1000/255/153/126", + "confidence": "0", + "module": "rasp" + }, + "conditions": [ + { + "parameters": { + "resource": [ + { + "address": "server.io.fs.file" + } + ], + "params": [ + { + "address": "server.request.query" + }, + { + "address": "server.request.body" + }, + { + "address": "server.request.path_params" + }, + { + "address": "grpc.server.request.message" + }, + { + "address": "graphql.server.all_resolvers" + }, + { + "address": "graphql.server.resolver" + } + ] + }, + "operator": "lfi_detector" + } + ], + "transformers": [], + "on_match": [ + "stack_trace", + "block" + ] + }, + { + "id": "rasp-934-100", + "name": "Server-side request forgery exploit", + "tags": { + "type": "ssrf", + "category": "vulnerability_trigger", + "cwe": "918", + "capec": "1000/225/115/664", + "confidence": "0", + "module": "rasp" + }, + "conditions": [ + { + "parameters": { + "resource": [ + { + "address": "server.io.net.url" + } + ], + "params": [ + { + "address": "server.request.query" + }, + { + "address": "server.request.body" + }, + { + "address": "server.request.path_params" + }, + { + "address": "grpc.server.request.message" + }, + { + "address": "graphql.server.all_resolvers" + }, + { + "address": "graphql.server.resolver" + } + ] + }, + "operator": "ssrf_detector" + } + ], + "transformers": [], + "on_match": [ + "stack_trace", + "block" + ] + } + ] +} \ No newline at end of file diff --git a/tests/appsec/appsec/test_asm_request_context.py b/tests/appsec/appsec/test_asm_request_context.py index b6e3a6da9c2..487401f00ed 100644 --- a/tests/appsec/appsec/test_asm_request_context.py +++ b/tests/appsec/appsec/test_asm_request_context.py @@ -1,6 +1,7 @@ import pytest from ddtrace.appsec import _asm_request_context +from ddtrace.internal._exceptions import BlockingException from tests.utils import override_global_config @@ -94,3 +95,19 @@ def test_asm_request_context_manager(): assert _asm_request_context.get_headers() == {} assert _asm_request_context.get_value("callbacks", "block") is None assert not _asm_request_context.get_headers_case_sensitive() + + +def test_blocking_exception_correctly_propagated(): + with override_global_config({"_asm_enabled": True}): + with _asm_request_context.asm_request_context_manager(): + witness = 0 + with _asm_request_context.asm_request_context_manager(): + witness = 1 + raise BlockingException({}, "rule", "type", "value") + # should be skipped by exception + witness = 3 + # should be also skipped by exception + witness = 4 + # no more exception there + # ensure that the exception was raised and caught at the end of the last context manager + assert witness == 1 diff --git a/tests/appsec/contrib_appsec/django_app/urls.py b/tests/appsec/contrib_appsec/django_app/urls.py index d8c45b4cb2e..a297f18fab3 100644 --- a/tests/appsec/contrib_appsec/django_app/urls.py +++ b/tests/appsec/contrib_appsec/django_app/urls.py @@ -71,6 +71,7 @@ def rasp(request, endpoint: str): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HttpResponse("<\br>\n".join(res)) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -98,7 +99,9 @@ def rasp(request, endpoint: str): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HttpResponse("<\\br>\n".join(res)) + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HttpResponse(f"Unknown endpoint: {endpoint}") diff --git a/tests/appsec/contrib_appsec/fastapi_app/app.py b/tests/appsec/contrib_appsec/fastapi_app/app.py index 820c25ce47a..5111fb6a218 100644 --- a/tests/appsec/contrib_appsec/fastapi_app/app.py +++ b/tests/appsec/contrib_appsec/fastapi_app/app.py @@ -128,6 +128,7 @@ async def rasp(endpoint: str, request: Request): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\br>\n".join(res)) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -155,7 +156,9 @@ async def rasp(endpoint: str, request: Request): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\\br>\n".join(res)) + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HTMLResponse(f"Unknown endpoint: {endpoint}") return app diff --git a/tests/appsec/contrib_appsec/flask_app/app.py b/tests/appsec/contrib_appsec/flask_app/app.py index 0ecb3784ddb..8997c3fa0e6 100644 --- a/tests/appsec/contrib_appsec/flask_app/app.py +++ b/tests/appsec/contrib_appsec/flask_app/app.py @@ -72,6 +72,7 @@ def rasp(endpoint: str): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -99,6 +100,7 @@ def rasp(endpoint: str): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "shell": res = ["shell endpoint"] @@ -112,5 +114,7 @@ def rasp(endpoint: str): res.append(f"cmd stdout: {f.stdout.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return f"Unknown endpoint: {endpoint}" diff --git a/tests/appsec/contrib_appsec/utils.py b/tests/appsec/contrib_appsec/utils.py index 1a193b47a04..41d4f5e2b8d 100644 --- a/tests/appsec/contrib_appsec/utils.py +++ b/tests/appsec/contrib_appsec/utils.py @@ -1183,8 +1183,26 @@ def test_stream_response( ) ], ) + @pytest.mark.parametrize( + ("rule_file", "blocking"), + [ + (rules.RULES_EXPLOIT_PREVENTION, False), + (rules.RULES_EXPLOIT_PREVENTION_BLOCKING, True), + ], + ) def test_exploit_prevention( - self, interface, root_span, get_tag, asm_enabled, ep_enabled, endpoint, parameters, rule, top_functions + self, + interface, + root_span, + get_tag, + asm_enabled, + ep_enabled, + endpoint, + parameters, + rule, + top_functions, + rule_file, + blocking, ): from unittest.mock import patch as mock_patch @@ -1198,16 +1216,18 @@ def test_exploit_prevention( try: patch_requests() with override_global_config(dict(_asm_enabled=asm_enabled, _ep_enabled=ep_enabled)), override_env( - dict(DD_APPSEC_RULES=rules.RULES_EXPLOIT_PREVENTION) + dict(DD_APPSEC_RULES=rule_file) ), mock_patch("ddtrace.internal.telemetry.metrics_namespaces.MetricNamespace.add_metric") as mocked: patch_common_modules() self.update_tracer(interface) response = interface.client.get(f"/rasp/{endpoint}/?{parameters}") - assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert self.body(response).startswith(f"{endpoint} endpoint") + code = 403 if blocking and asm_enabled and ep_enabled else 200 + assert self.status(response) == code + assert get_tag(http.STATUS_CODE) == str(code) + if code == 200: + assert self.body(response).startswith(f"{endpoint} endpoint") if asm_enabled and ep_enabled: - self.check_rules_triggered([rule] * 2, root_span) + self.check_rules_triggered([rule] * (1 if blocking else 2), root_span) assert self.check_for_stack_trace(root_span) for trace in self.check_for_stack_trace(root_span): assert "frames" in trace @@ -1229,9 +1249,14 @@ def test_exploit_prevention( "appsec.rasp.rule.eval", (("rule_type", endpoint), ("waf_version", DDWAF_VERSION)), ) in telemetry_calls + if blocking: + assert get_tag("rasp.request.done") is None + else: + assert get_tag("rasp.request.done") == endpoint else: assert get_triggers(root_span()) is None assert self.check_for_stack_trace(root_span) == [] + assert get_tag("rasp.request.done") == endpoint finally: unpatch_common_modules() unpatch_requests() diff --git a/tests/appsec/rules.py b/tests/appsec/rules.py index 83c2adb1981..d4aa4119062 100644 --- a/tests/appsec/rules.py +++ b/tests/appsec/rules.py @@ -11,6 +11,7 @@ RULES_SRB_METHOD = os.path.join(ROOT_DIR, "rules-suspicious-requests-get.json") RULES_BAD_VERSION = os.path.join(ROOT_DIR, "rules-bad_version.json") RULES_EXPLOIT_PREVENTION = os.path.join(ROOT_DIR, "rules-rasp.json") +RULES_EXPLOIT_PREVENTION_BLOCKING = os.path.join(ROOT_DIR, "rules-rasp-blocking.json") RESPONSE_CUSTOM_JSON = os.path.join(ROOT_DIR, "response-custom.json") RESPONSE_CUSTOM_HTML = os.path.join(ROOT_DIR, "response-custom.html") From faedc3553903bda9f6d9349732cd09ea7021fc73 Mon Sep 17 00:00:00 2001 From: kyle Date: Fri, 3 May 2024 01:23:09 +0200 Subject: [PATCH 8/9] chore(telemetry): add item for instrumentation config id (#8783) When enabling library injection remotely through the UI, we'd like to show which services have been instrumented as a result. To do this we are proposing to submit the remote configuration ID that was used to instrument the service. [](https://datadoghq.atlassian.net/browse/APMON-887) ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/internal/telemetry/writer.py | 9 +++++++++ tests/telemetry/test_writer.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/ddtrace/internal/telemetry/writer.py b/ddtrace/internal/telemetry/writer.py index 836daa5da74..8c9ebf6f1f7 100644 --- a/ddtrace/internal/telemetry/writer.py +++ b/ddtrace/internal/telemetry/writer.py @@ -422,6 +422,14 @@ def _app_started_event(self, register_app_shutdown=True): if register_app_shutdown: atexit.register(self.app_shutdown) + inst_config_id_entry = ("instrumentation_config_id", "", "default") + if "DD_INSTRUMENTATION_CONFIG_ID" in os.environ: + inst_config_id_entry = ( + "instrumentation_config_id", + os.environ["DD_INSTRUMENTATION_CONFIG_ID"], + "env_var", + ) + self.add_configurations( [ self._telemetry_entry("_trace_enabled"), @@ -435,6 +443,7 @@ def _app_started_event(self, register_app_shutdown=True): self._telemetry_entry("trace_http_header_tags"), self._telemetry_entry("tags"), self._telemetry_entry("_tracing_enabled"), + inst_config_id_entry, (TELEMETRY_STARTUP_LOGS_ENABLED, config._startup_logs_enabled, "unknown"), (TELEMETRY_DYNAMIC_INSTRUMENTATION_ENABLED, di_config.enabled, "unknown"), (TELEMETRY_EXCEPTION_DEBUGGING_ENABLED, ed_config.enabled, "unknown"), diff --git a/tests/telemetry/test_writer.py b/tests/telemetry/test_writer.py index 18699170152..c25482e849e 100644 --- a/tests/telemetry/test_writer.py +++ b/tests/telemetry/test_writer.py @@ -146,6 +146,7 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): {"name": "logs_injection_enabled", "origin": "default", "value": "false"}, {"name": "trace_tags", "origin": "default", "value": ""}, {"name": "tracing_enabled", "origin": "default", "value": "true"}, + {"name": "instrumentation_config_id", "origin": "default", "value": ""}, ], key=lambda x: x["name"], ), @@ -229,6 +230,7 @@ def test_app_started_event_configuration_override( env["DD_TRACE_WRITER_INTERVAL_SECONDS"] = "30" env["DD_TRACE_WRITER_REUSE_CONNECTIONS"] = "True" env["DD_TAGS"] = "team:apm,component:web" + env["DD_INSTRUMENTATION_CONFIG_ID"] = "abcedf123" env[env_var] = value file = tmpdir.join("moon_ears.json") @@ -314,6 +316,7 @@ def test_app_started_event_configuration_override( {"name": "trace_header_tags", "origin": "default", "value": ""}, {"name": "trace_tags", "origin": "env_var", "value": "team:apm,component:web"}, {"name": "tracing_enabled", "origin": "env_var", "value": "false"}, + {"name": "instrumentation_config_id", "origin": "env_var", "value": "abcedf123"}, ], key=lambda x: x["name"], ) From d10e081f8cd9e921fe004154a7b7fa12bb76e18e Mon Sep 17 00:00:00 2001 From: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> Date: Fri, 3 May 2024 14:29:29 +0200 Subject: [PATCH 9/9] feat(llmobs): add retrieval and embedding spans (#9134) This PR adds support for submitting embedding and retrieval type spans for LLM Observability, both via `LLMObs.{retrieval/embedding}` and `@ddtrace.llmobs.decorators.{retrieval/embedding}`. Additionally, this PR adds a public helper class `ddtrace.llmobs.utils.Documents` for users to create SDK-compatible input/output annotation objects for Embedding/Retrieval spans. Embedding spans require a model name to be set, and also optionally accepts model provider values (will default to `custom`). Embedding spans can be annotated with input strings, dictionaries, or a list of dictionaries, which will be cast as `Documents` when submitted to LLMObs. Embedding spans can be annotated with output strings or any JSON serializable value. Retrieval spans can be annotated with input strings or any JSON serializable value. Retrieval spans can also be annotated with output strings, dictionaries, or a list of dictionaries, which will be cast as `Documents` when submitted to LLMObs. This PR also introduces a class of type `ddtrace.llmobs.utils.Documents`, which can be used to convert arguments to be tagged as input/output documents. The `Documents` TypedDict object can contain the following fields: - `name`: str - `id`: str - `text`: str - `score`: int/float ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/llmobs/_constants.py | 2 + ddtrace/llmobs/_llmobs.py | 114 +++++++++++++++ ddtrace/llmobs/_trace_processor.py | 8 +- ddtrace/llmobs/decorators.py | 65 +++++---- ddtrace/llmobs/utils.py | 34 +++++ tests/llmobs/test_llmobs_decorators.py | 101 ++++++++++++-- tests/llmobs/test_llmobs_service.py | 183 ++++++++++++++++++++++++- tests/llmobs/test_utils.py | 48 +++++++ 8 files changed, 512 insertions(+), 43 deletions(-) diff --git a/ddtrace/llmobs/_constants.py b/ddtrace/llmobs/_constants.py index fa92a3ed566..9d04fa68cbf 100644 --- a/ddtrace/llmobs/_constants.py +++ b/ddtrace/llmobs/_constants.py @@ -8,9 +8,11 @@ MODEL_NAME = "_ml_obs.meta.model_name" MODEL_PROVIDER = "_ml_obs.meta.model_provider" +INPUT_DOCUMENTS = "_ml_obs.meta.input.documents" INPUT_MESSAGES = "_ml_obs.meta.input.messages" INPUT_VALUE = "_ml_obs.meta.input.value" INPUT_PARAMETERS = "_ml_obs.meta.input.parameters" +OUTPUT_DOCUMENTS = "_ml_obs.meta.output.documents" OUTPUT_MESSAGES = "_ml_obs.meta.output.messages" OUTPUT_VALUE = "_ml_obs.meta.output.value" diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 411c68e84af..d72aa983fe5 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -12,6 +12,7 @@ from ddtrace.internal import atexit from ddtrace.internal.logger import get_logger from ddtrace.internal.service import Service +from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES from ddtrace.llmobs._constants import INPUT_PARAMETERS from ddtrace.llmobs._constants import INPUT_VALUE @@ -20,6 +21,7 @@ from ddtrace.llmobs._constants import ML_APP from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE from ddtrace.llmobs._constants import SESSION_ID @@ -30,6 +32,7 @@ from ddtrace.llmobs._utils import _get_session_id from ddtrace.llmobs._writer import LLMObsEvalMetricWriter from ddtrace.llmobs._writer import LLMObsSpanWriter +from ddtrace.llmobs.utils import Documents from ddtrace.llmobs.utils import ExportedLLMObsSpan from ddtrace.llmobs.utils import Messages @@ -270,6 +273,64 @@ def workflow( return None return cls._instance._start_span("workflow", name=name, session_id=session_id, ml_app=ml_app) + @classmethod + def embedding( + cls, + model_name: str, + name: Optional[str] = None, + model_provider: Optional[str] = None, + session_id: Optional[str] = None, + ml_app: Optional[str] = None, + ) -> Optional[Span]: + """ + Trace a call to an embedding model or function to create an embedding. + + :param str model_name: The name of the invoked embedding model. + :param str name: The name of the traced operation. If not provided, a default value of "embedding" will be set. + :param str model_provider: The name of the invoked LLM provider (ex: openai, bedrock). + If not provided, a default value of "custom" will be set. + :param str session_id: The ID of the underlying user session. Required for tracking sessions. + :param str ml_app: The name of the ML application that the agent is orchestrating. If not provided, the default + value DD_LLMOBS_APP_NAME will be set. + + :returns: The Span object representing the traced operation. + """ + if cls.enabled is False or cls._instance is None: + log.warning("LLMObs.embedding() cannot be used while LLMObs is disabled.") + return None + if not model_name: + log.warning("model_name must be the specified name of the invoked model.") + return None + if model_provider is None: + model_provider = "custom" + return cls._instance._start_span( + "embedding", + name, + model_name=model_name, + model_provider=model_provider, + session_id=session_id, + ml_app=ml_app, + ) + + @classmethod + def retrieval( + cls, name: Optional[str] = None, session_id: Optional[str] = None, ml_app: Optional[str] = None + ) -> Optional[Span]: + """ + Trace a vector search operation involving a list of documents being returned from an external knowledge base. + + :param str name: The name of the traced operation. If not provided, a default value of "workflow" will be set. + :param str session_id: The ID of the underlying user session. Required for tracking sessions. + :param str ml_app: The name of the ML application that the agent is orchestrating. If not provided, the default + value DD_LLMOBS_APP_NAME will be set. + + :returns: The Span object representing the traced operation. + """ + if cls.enabled is False or cls._instance is None: + log.warning("LLMObs.retrieval() cannot be used while LLMObs is disabled.") + return None + return cls._instance._start_span("retrieval", name=name, session_id=session_id, ml_app=ml_app) + @classmethod def annotate( cls, @@ -290,10 +351,15 @@ def annotate( :param input_data: A single input string, dictionary, or a list of dictionaries based on the span kind: - llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."}, or a list of dictionaries with the same signature. + - embedding spans: accepts a string, list of strings, or a dictionary of form + {"text": "...", ...} or a list of dictionaries with the same signature. - other: any JSON serializable type. :param output_data: A single output string, dictionary, or a list of dictionaries based on the span kind: - llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."}, or a list of dictionaries with the same signature. + - retrieval spans: a dictionary containing any of the key value pairs + {"name": str, "id": str, "text": str, "score": float}, + or a list of dictionaries with the same signature. - other: any JSON serializable type. :param parameters: (DEPRECATED) Dictionary of JSON serializable key-value pairs to set as input parameters. :param metadata: Dictionary of JSON serializable key-value metadata pairs relevant to the input/output operation @@ -327,6 +393,10 @@ def annotate( if input_data or output_data: if span_kind == "llm": cls._tag_llm_io(span, input_messages=input_data, output_messages=output_data) + elif span_kind == "embedding": + cls._tag_embedding_io(span, input_documents=input_data, output_text=output_data) + elif span_kind == "retrieval": + cls._tag_retrieval_io(span, input_text=input_data, output_documents=output_data) else: cls._tag_text_io(span, input_value=input_data, output_value=output_data) if metadata is not None: @@ -371,6 +441,50 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None): except (TypeError, AttributeError): log.warning("Failed to parse output messages.", exc_info=True) + @classmethod + def _tag_embedding_io(cls, span, input_documents=None, output_text=None): + """Tags input documents and output text for embedding-kind spans. + Will be mapped to span's `meta.{input,output}.text` fields. + """ + if input_documents is not None: + try: + if not isinstance(input_documents, Documents): + input_documents = Documents(input_documents) + if input_documents.documents: + span.set_tag_str(INPUT_DOCUMENTS, json.dumps(input_documents.documents)) + except (TypeError, AttributeError): + log.warning("Failed to parse input documents.", exc_info=True) + if output_text is not None: + if isinstance(output_text, str): + span.set_tag_str(OUTPUT_VALUE, output_text) + else: + try: + span.set_tag_str(OUTPUT_VALUE, json.dumps(output_text)) + except TypeError: + log.warning("Failed to parse output text. Output text must be JSON serializable.") + + @classmethod + def _tag_retrieval_io(cls, span, input_text=None, output_documents=None): + """Tags input text and output documents for retrieval-kind spans. + Will be mapped to span's `meta.{input,output}.text` fields. + """ + if input_text is not None: + if isinstance(input_text, str): + span.set_tag_str(INPUT_VALUE, input_text) + else: + try: + span.set_tag_str(INPUT_VALUE, json.dumps(input_text)) + except TypeError: + log.warning("Failed to parse input text. Input text must be JSON serializable.") + if output_documents is not None: + try: + if not isinstance(output_documents, Documents): + output_documents = Documents(output_documents) + if output_documents.documents: + span.set_tag_str(OUTPUT_DOCUMENTS, json.dumps(output_documents.documents)) + except (TypeError, AttributeError): + log.warning("Failed to parse output documents.", exc_info=True) + @classmethod def _tag_text_io(cls, span, input_value=None, output_value=None): """Tags input/output values for non-LLM kind spans. diff --git a/ddtrace/llmobs/_trace_processor.py b/ddtrace/llmobs/_trace_processor.py index f95b2637be0..ac07cf1d484 100644 --- a/ddtrace/llmobs/_trace_processor.py +++ b/ddtrace/llmobs/_trace_processor.py @@ -15,6 +15,7 @@ from ddtrace.ext import SpanTypes from ddtrace.internal.logger import get_logger from ddtrace.internal.utils.formats import asbool +from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES from ddtrace.llmobs._constants import INPUT_PARAMETERS from ddtrace.llmobs._constants import INPUT_VALUE @@ -23,6 +24,7 @@ from ddtrace.llmobs._constants import ML_APP from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE from ddtrace.llmobs._constants import SESSION_ID @@ -65,7 +67,7 @@ def _llmobs_span_event(self, span: Span) -> Dict[str, Any]: """Span event object structure.""" span_kind = span._meta.pop(SPAN_KIND) meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}} - if span_kind == "llm" and span.get_tag(MODEL_NAME) is not None: + if span_kind in ("llm", "embedding") and span.get_tag(MODEL_NAME) is not None: meta["model_name"] = span._meta.pop(MODEL_NAME) meta["model_provider"] = span._meta.pop(MODEL_PROVIDER, "custom").lower() if span.get_tag(METADATA) is not None: @@ -78,8 +80,12 @@ def _llmobs_span_event(self, span: Span) -> Dict[str, Any]: meta["input"]["value"] = span._meta.pop(INPUT_VALUE) if span_kind == "llm" and span.get_tag(OUTPUT_MESSAGES) is not None: meta["output"]["messages"] = json.loads(span._meta.pop(OUTPUT_MESSAGES)) + if span_kind == "embedding" and span.get_tag(INPUT_DOCUMENTS) is not None: + meta["input"]["documents"] = json.loads(span._meta.pop(INPUT_DOCUMENTS)) if span.get_tag(OUTPUT_VALUE) is not None: meta["output"]["value"] = span._meta.pop(OUTPUT_VALUE) + if span_kind == "retrieval" and span.get_tag(OUTPUT_DOCUMENTS) is not None: + meta["output"]["documents"] = json.loads(span._meta.pop(OUTPUT_DOCUMENTS)) if span.error: meta[ERROR_MSG] = span.get_tag(ERROR_MSG) meta[ERROR_STACK] = span.get_tag(ERROR_STACK) diff --git a/ddtrace/llmobs/decorators.py b/ddtrace/llmobs/decorators.py index 1cb18620ea4..cdb9dd9762e 100644 --- a/ddtrace/llmobs/decorators.py +++ b/ddtrace/llmobs/decorators.py @@ -9,34 +9,42 @@ log = get_logger(__name__) -def llm( - model_name: str, - model_provider: Optional[str] = None, - name: Optional[str] = None, - session_id: Optional[str] = None, - ml_app: Optional[str] = None, -): - def inner(func): - @wraps(func) - def wrapper(*args, **kwargs): - if not LLMObs.enabled or LLMObs._instance is None: - log.warning("LLMObs.llm() cannot be used while LLMObs is disabled.") - return func(*args, **kwargs) - span_name = name - if span_name is None: - span_name = func.__name__ - with LLMObs.llm( - model_name=model_name, - model_provider=model_provider, - name=span_name, - session_id=session_id, - ml_app=ml_app, - ): - return func(*args, **kwargs) +def _model_decorator(operation_kind): + def decorator( + model_name: str, + original_func: Optional[Callable] = None, + model_provider: Optional[str] = None, + name: Optional[str] = None, + session_id: Optional[str] = None, + ml_app: Optional[str] = None, + ): + def inner(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not LLMObs.enabled or LLMObs._instance is None: + log.warning("LLMObs.%s() cannot be used while LLMObs is disabled.", operation_kind) + return func(*args, **kwargs) + traced_model_name = model_name + if traced_model_name is None: + raise TypeError("model_name is required for LLMObs.{}()".format(operation_kind)) + span_name = name + if span_name is None: + span_name = func.__name__ + traced_operation = getattr(LLMObs, operation_kind, "llm") + with traced_operation( + model_name=model_name, + model_provider=model_provider, + name=span_name, + session_id=session_id, + ml_app=ml_app, + ): + return func(*args, **kwargs) - return wrapper + return wrapper + + return inner - return inner + return decorator def _llmobs_decorator(operation_kind): @@ -50,7 +58,7 @@ def inner(func): @wraps(func) def wrapper(*args, **kwargs): if not LLMObs.enabled or LLMObs._instance is None: - log.warning("LLMObs.{}() cannot be used while LLMObs is disabled.", operation_kind) + log.warning("LLMObs.%s() cannot be used while LLMObs is disabled.", operation_kind) return func(*args, **kwargs) span_name = name if span_name is None: @@ -68,7 +76,10 @@ def wrapper(*args, **kwargs): return decorator +llm = _model_decorator("llm") +embedding = _model_decorator("embedding") workflow = _llmobs_decorator("workflow") task = _llmobs_decorator("task") tool = _llmobs_decorator("tool") +retrieval = _llmobs_decorator("retrieval") agent = _llmobs_decorator("agent") diff --git a/ddtrace/llmobs/utils.py b/ddtrace/llmobs/utils.py index 1fbb7305c36..cbb1f97d4f6 100644 --- a/ddtrace/llmobs/utils.py +++ b/ddtrace/llmobs/utils.py @@ -16,6 +16,7 @@ ExportedLLMObsSpan = TypedDict("ExportedLLMObsSpan", {"span_id": str, "trace_id": str}) +Document = TypedDict("Document", {"name": str, "id": str, "text": str, "score": float}, total=False) Message = TypedDict("Message", {"content": str, "role": str}, total=False) @@ -40,3 +41,36 @@ def __init__(self, messages: Union[List[Dict[str, str]], Dict[str, str], str]): if not isinstance(role, str): raise TypeError("Message role must be a string, and one of .") self.messages.append(Message(content=content, role=role)) + + +class Documents: + def __init__(self, documents: Union[List[Dict[str, str]], Dict[str, str], str]): + self.documents = [] + if not isinstance(documents, list): + documents = [documents] # type: ignore[list-item] + for document in documents: + if isinstance(document, str): + self.documents.append(Document(text=document)) + continue + elif not isinstance(document, dict): + raise TypeError("documents must be a string, dictionary, or list of dictionaries.") + document_text = document.get("text") + document_name = document.get("name") + document_id = document.get("id") + document_score = document.get("score") + if not isinstance(document_text, str): + raise TypeError("Document text must be a string.") + formatted_document = Document(text=document_text) + if document_name: + if not isinstance(document_name, str): + raise TypeError("document name must be a string.") + formatted_document["name"] = document_name + if document_id: + if not isinstance(document_id, str): + raise TypeError("document id must be a string.") + formatted_document["id"] = document_id + if document_score: + if not isinstance(document_score, (int, float)): + raise TypeError("document score must be an integer or float.") + formatted_document["score"] = document_score + self.documents.append(formatted_document) diff --git a/tests/llmobs/test_llmobs_decorators.py b/tests/llmobs/test_llmobs_decorators.py index f106c9db51b..31ecfbf37e1 100644 --- a/tests/llmobs/test_llmobs_decorators.py +++ b/tests/llmobs/test_llmobs_decorators.py @@ -2,7 +2,9 @@ import pytest from ddtrace.llmobs.decorators import agent +from ddtrace.llmobs.decorators import embedding from ddtrace.llmobs.decorators import llm +from ddtrace.llmobs.decorators import retrieval from ddtrace.llmobs.decorators import task from ddtrace.llmobs.decorators import tool from ddtrace.llmobs.decorators import workflow @@ -17,17 +19,28 @@ def mock_logs(): def test_llm_decorator_with_llmobs_disabled_logs_warning(LLMObs, mock_logs): - @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") - def f(): - pass + for decorator_name, decorator in (("llm", llm), ("embedding", embedding)): - LLMObs.disable() - f() - mock_logs.warning.assert_called_with("LLMObs.llm() cannot be used while LLMObs is disabled.") + @decorator( + model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id" + ) + def f(): + pass + + LLMObs.disable() + f() + mock_logs.warning.assert_called_with("LLMObs.%s() cannot be used while LLMObs is disabled.", decorator_name) + mock_logs.reset_mock() def test_non_llm_decorator_with_llmobs_disabled_logs_warning(LLMObs, mock_logs): - for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool), ("agent", agent)]: + for decorator_name, decorator in ( + ("task", task), + ("workflow", workflow), + ("tool", tool), + ("agent", agent), + ("retrieval", retrieval), + ): @decorator(name="test_function", session_id="test_session_id") def f(): @@ -35,7 +48,7 @@ def f(): LLMObs.disable() f() - mock_logs.warning.assert_called_with("LLMObs.{}() cannot be used while LLMObs is disabled.", decorator_name) + mock_logs.warning.assert_called_with("LLMObs.%s() cannot be used while LLMObs is disabled.", decorator_name) mock_logs.reset_mock() @@ -73,6 +86,64 @@ def f(): ) +def test_embedding_decorator(LLMObs, mock_llmobs_span_writer): + @embedding( + model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id" + ) + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, "embedding", model_name="test_model", model_provider="test_provider", session_id="test_session_id" + ) + ) + + +def test_embedding_decorator_no_model_name_raises_error(LLMObs): + with pytest.raises(TypeError): + + @embedding(model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + pass + + +def test_embedding_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): + @embedding(model_name="test_model") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event(span, "embedding", model_name="test_model", model_provider="custom") + ) + + +def test_retrieval_decorator(LLMObs, mock_llmobs_span_writer): + @retrieval(name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "retrieval", session_id="test_session_id") + ) + + +def test_retrieval_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): + @retrieval() + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "retrieval")) + + def test_task_decorator(LLMObs, mock_llmobs_span_writer): @task(name="test_function", session_id="test_session_id") def f(): @@ -265,7 +336,13 @@ def f(): def test_non_llm_decorators_no_args(LLMObs, mock_llmobs_span_writer): """Test that using the decorators without any arguments, i.e. @tool, works the same as @tool(...).""" - for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool)]: + for decorator_name, decorator in [ + ("task", task), + ("workflow", workflow), + ("tool", tool), + ("agent", agent), + ("retrieval", retrieval), + ]: @decorator def f(): @@ -314,12 +391,14 @@ def g(): ) ) - @agent(ml_app="test_ml_app") + @embedding(model_name="test_model", ml_app="test_ml_app") def h(): pass h() span = LLMObs._instance.tracer.pop()[0] mock_llmobs_span_writer.enqueue.assert_called_with( - _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_ml_app"}) + _expected_llmobs_llm_span_event( + span, "embedding", model_name="test_model", model_provider="custom", tags={"ml_app": "test_ml_app"} + ) ) diff --git a/tests/llmobs/test_llmobs_service.py b/tests/llmobs/test_llmobs_service.py index dfaef69c146..4b9153de1d5 100644 --- a/tests/llmobs/test_llmobs_service.py +++ b/tests/llmobs/test_llmobs_service.py @@ -4,6 +4,7 @@ import pytest from ddtrace.llmobs import LLMObs as llmobs_service +from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES from ddtrace.llmobs._constants import INPUT_PARAMETERS from ddtrace.llmobs._constants import INPUT_VALUE @@ -11,6 +12,7 @@ from ddtrace.llmobs._constants import METRICS from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE from ddtrace.llmobs._constants import SESSION_ID @@ -214,6 +216,42 @@ def test_agent_span(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) +def test_embedding_span_no_model_raises_error(LLMObs): + with pytest.raises(TypeError): + with LLMObs.embedding(name="test_embedding", model_provider="test_provider"): + pass + + +def test_embedding_span_empty_model_name_logs_warning(LLMObs, mock_logs): + _ = LLMObs.embedding(model_name="", name="test_embedding", model_provider="test_provider") + mock_logs.warning.assert_called_once_with("model_name must be the specified name of the invoked model.") + + +def test_embedding_default_model_provider_set_to_custom(LLMObs): + with LLMObs.embedding(model_name="test_model", name="test_embedding") as span: + assert span.name == "test_embedding" + assert span.resource == "embedding" + assert span.span_type == "llm" + assert span.get_tag(SPAN_KIND) == "embedding" + assert span.get_tag(MODEL_NAME) == "test_model" + assert span.get_tag(MODEL_PROVIDER) == "custom" + + +def test_embedding_span(LLMObs, mock_llmobs_span_writer): + with LLMObs.embedding(model_name="test_model", name="test_embedding", model_provider="test_provider") as span: + assert span.name == "test_embedding" + assert span.resource == "embedding" + assert span.span_type == "llm" + assert span.get_tag(SPAN_KIND) == "embedding" + assert span.get_tag(MODEL_NAME) == "test_model" + assert span.get_tag(MODEL_PROVIDER) == "test_provider" + assert span.get_tag(SESSION_ID) == "{:x}".format(span.trace_id) + + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event(span, "embedding", model_name="test_model", model_provider="test_provider") + ) + + def test_annotate_while_disabled_logs_warning(LLMObs, mock_logs): LLMObs.disable() LLMObs.annotate(parameters={"test": "test"}) @@ -306,6 +344,9 @@ def test_annotate_input_string(LLMObs): with LLMObs.agent() as agent_span: LLMObs.annotate(span=agent_span, input_data="test_input") assert agent_span.get_tag(INPUT_VALUE) == "test_input" + with LLMObs.retrieval() as retrieval_span: + LLMObs.annotate(span=retrieval_span, input_data="test_input") + assert retrieval_span.get_tag(INPUT_VALUE) == "test_input" def test_annotate_input_serializable_value(LLMObs): @@ -321,6 +362,9 @@ def test_annotate_input_serializable_value(LLMObs): with LLMObs.agent() as agent_span: LLMObs.annotate(span=agent_span, input_data="test_input") assert agent_span.get_tag(INPUT_VALUE) == "test_input" + with LLMObs.retrieval() as retrieval_span: + LLMObs.annotate(span=retrieval_span, input_data=[0, 1, 2, 3, 4]) + assert retrieval_span.get_tag(INPUT_VALUE) == "[0, 1, 2, 3, 4]" def test_annotate_input_value_wrong_type(LLMObs, mock_logs): @@ -352,10 +396,130 @@ def test_llmobs_annotate_incorrect_message_content_type_raises_warning(LLMObs, m mock_logs.warning.assert_called_once_with("Failed to parse output messages.", exc_info=True) +def test_annotate_document_str(LLMObs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data="test_document_text") + documents = json.loads(span.get_tag(INPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data="test_document_text") + documents = json.loads(span.get_tag(OUTPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + + +def test_annotate_document_dict(LLMObs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data={"text": "test_document_text"}) + documents = json.loads(span.get_tag(INPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data={"text": "test_document_text"}) + documents = json.loads(span.get_tag(OUTPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + + +def test_annotate_document_list(LLMObs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate( + span=span, + input_data=[{"text": "test_document_text"}, {"text": "text", "name": "name", "score": 0.9, "id": "id"}], + ) + documents = json.loads(span.get_tag(INPUT_DOCUMENTS)) + assert documents + assert len(documents) == 2 + assert documents[0]["text"] == "test_document_text" + assert documents[1]["text"] == "text" + assert documents[1]["name"] == "name" + assert documents[1]["id"] == "id" + assert documents[1]["score"] == 0.9 + with LLMObs.retrieval() as span: + LLMObs.annotate( + span=span, + output_data=[{"text": "test_document_text"}, {"text": "text", "name": "name", "score": 0.9, "id": "id"}], + ) + documents = json.loads(span.get_tag(OUTPUT_DOCUMENTS)) + assert documents + assert len(documents) == 2 + assert documents[0]["text"] == "test_document_text" + assert documents[1]["text"] == "text" + assert documents[1]["name"] == "name" + assert documents[1]["id"] == "id" + assert documents[1]["score"] == 0.9 + + +def test_annotate_incorrect_document_type_raises_warning(LLMObs, mock_logs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data={"text": 123}) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=123) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=Unserializable()) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=[{"score": 0.9, "id": "id", "name": "name"}]) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=123) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=Unserializable()) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + + +def test_annotate_document_no_text_raises_warning(LLMObs, mock_logs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=[{"score": 0.9, "id": "id", "name": "name"}]) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=[{"score": 0.9, "id": "id", "name": "name"}]) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + + +def test_annotate_incorrect_document_field_type_raises_warning(LLMObs, mock_logs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=[{"text": "test_document_text", "score": "0.9"}]) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate( + span=span, input_data=[{"text": "text", "id": 123, "score": "0.9", "name": ["h", "e", "l", "l", "o"]}] + ) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=[{"text": "test_document_text", "score": "0.9"}]) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate( + span=span, output_data=[{"text": "text", "id": 123, "score": "0.9", "name": ["h", "e", "l", "l", "o"]}] + ) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + + def test_annotate_output_string(LLMObs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, output_data="test_output") assert json.loads(llm_span.get_tag(OUTPUT_MESSAGES)) == [{"content": "test_output"}] + with LLMObs.embedding(model_name="test_model") as embedding_span: + LLMObs.annotate(span=embedding_span, output_data="test_output") + assert embedding_span.get_tag(OUTPUT_VALUE) == "test_output" with LLMObs.task() as task_span: LLMObs.annotate(span=task_span, output_data="test_output") assert task_span.get_tag(OUTPUT_VALUE) == "test_output" @@ -371,6 +535,9 @@ def test_annotate_output_string(LLMObs): def test_annotate_output_serializable_value(LLMObs): + with LLMObs.embedding(model_name="test_model") as embedding_span: + LLMObs.annotate(span=embedding_span, output_data=[[0, 1, 2, 3], [4, 5, 6, 7]]) + assert embedding_span.get_tag(OUTPUT_VALUE) == "[[0, 1, 2, 3], [4, 5, 6, 7]]" with LLMObs.task() as task_span: LLMObs.annotate(span=task_span, output_data=["test_output"]) assert task_span.get_tag(OUTPUT_VALUE) == '["test_output"]' @@ -465,13 +632,11 @@ def test_ml_app_override(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "task", tags={"ml_app": "test_app"}) ) - with LLMObs.tool(name="test_tool", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "tool", tags={"ml_app": "test_app"}) ) - with LLMObs.llm(model_name="model_name", name="test_llm", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( @@ -479,18 +644,28 @@ def test_ml_app_override(LLMObs, mock_llmobs_span_writer): span, "llm", model_name="model_name", model_provider="custom", tags={"ml_app": "test_app"} ) ) - + with LLMObs.embedding(model_name="model_name", name="test_embedding", ml_app="test_app") as span: + pass + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, "embedding", model_name="model_name", model_provider="custom", tags={"ml_app": "test_app"} + ) + ) with LLMObs.workflow(name="test_workflow", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "workflow", tags={"ml_app": "test_app"}) ) - with LLMObs.agent(name="test_agent", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_app"}) ) + with LLMObs.retrieval(name="test_retrieval", ml_app="test_app") as span: + pass + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "retrieval", tags={"ml_app": "test_app"}) + ) def test_export_span_llmobs_not_enabled_raises_warning(LLMObs, mock_logs): diff --git a/tests/llmobs/test_utils.py b/tests/llmobs/test_utils.py index 26241b90b07..41ae6bee95c 100644 --- a/tests/llmobs/test_utils.py +++ b/tests/llmobs/test_utils.py @@ -1,5 +1,6 @@ import pytest +from ddtrace.llmobs.utils import Documents from ddtrace.llmobs.utils import Messages @@ -55,3 +56,50 @@ def test_messages_with_no_role_is_ok(): """Test that a message with no role is ok and returns a message with only content.""" messages = Messages([{"content": "hello"}, {"content": "world"}]) assert messages.messages == [{"content": "hello"}, {"content": "world"}] + + +def test_documents_with_string(): + documents = Documents("hello") + assert documents.documents == [{"text": "hello"}] + + +def test_documents_with_dict(): + documents = Documents({"text": "hello", "name": "doc1", "id": "123", "score": 0.5}) + assert len(documents.documents) == 1 + assert documents.documents == [{"text": "hello", "name": "doc1", "id": "123", "score": 0.5}] + + +def test_documents_with_list_of_dicts(): + documents = Documents([{"text": "hello", "name": "doc1", "id": "123", "score": 0.5}, {"text": "world"}]) + assert len(documents.documents) == 2 + assert documents.documents[0] == {"text": "hello", "name": "doc1", "id": "123", "score": 0.5} + assert documents.documents[1] == {"text": "world"} + + +def test_documents_with_incorrect_type(): + with pytest.raises(TypeError): + Documents(123) + with pytest.raises(TypeError): + Documents(Unserializable()) + with pytest.raises(TypeError): + Documents(None) + + +def test_documents_dictionary_no_text_value(): + with pytest.raises(TypeError): + Documents([{"text": None}]) + with pytest.raises(TypeError): + Documents([{"name": "doc1", "id": "123", "score": 0.5}]) + + +def test_documents_dictionary_with_incorrect_value_types(): + with pytest.raises(TypeError): + Documents([{"text": 123}]) + with pytest.raises(TypeError): + Documents([{"text": [1, 2, 3]}]) + with pytest.raises(TypeError): + Documents([{"text": "hello", "id": 123}]) + with pytest.raises(TypeError): + Documents({"text": "hello", "name": {"key": "value"}}) + with pytest.raises(TypeError): + Documents([{"text": "hello", "score": "123"}])