diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 23022df324d..01e3effa28e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -63,6 +63,7 @@ ddtrace/appsec/ @DataDog/asm-python ddtrace/settings/asm.py @DataDog/asm-python ddtrace/contrib/subprocess/ @DataDog/asm-python ddtrace/contrib/flask_login/ @DataDog/asm-python +ddtrace/internal/_exceptions.py @DataDog/asm-python tests/appsec/ @DataDog/asm-python tests/contrib/dbapi/test_dbapi_appsec.py @DataDog/asm-python tests/contrib/subprocess @DataDog/asm-python diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index f439f87784a..f67eca90453 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -6,9 +6,10 @@ from typing import List from typing import Optional -from ddtrace import config from ddtrace._trace.span import Span +from ddtrace._trace.utils import extract_DD_context_from_messages from ddtrace._trace.utils import set_botocore_patched_api_call_span_tags as set_patched_api_call_span_tags +from ddtrace._trace.utils import set_botocore_response_metadata_tags from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY from ddtrace.constants import SPAN_KIND from ddtrace.constants import SPAN_MEASURED_KEY @@ -107,6 +108,9 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) - trace_utils.activate_distributed_headers( tracer, int_config=distributed_headers_config, request_headers=ctx["distributed_headers"] ) + distributed_context = ctx.get_item("distributed_context", traverse=True) + if distributed_context and not call_trace: + span_kwargs["child_of"] = distributed_context span_kwargs.update(kwargs) span = (tracer.trace if call_trace else tracer.start_span)(ctx["span_name"], **span_kwargs) for tk, tv in ctx.get_item("tags", dict()).items(): @@ -569,20 +573,20 @@ def _on_botocore_patched_api_call_started(ctx): span.start_ns = start_ns -def _on_botocore_patched_api_call_exception(ctx, response, exception_type, set_response_metadata_tags): +def _on_botocore_patched_api_call_exception(ctx, response, exception_type, is_error_code_fn): span = ctx.get_item(ctx.get_item("call_key")) # `ClientError.response` contains the result, so we can still grab response metadata - set_response_metadata_tags(span, response) + set_botocore_response_metadata_tags(span, response, is_error_code_fn=is_error_code_fn) # If we have a status code, and the status code is not an error, # then ignore the exception being raised status_code = span.get_tag(http.STATUS_CODE) - if status_code and not config.botocore.operations[span.resource].is_error_code(int(status_code)): + if status_code and not is_error_code_fn(int(status_code)): span._ignore_exception(exception_type) -def _on_botocore_patched_api_call_success(ctx, response, set_response_metadata_tags): - set_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response) +def _on_botocore_patched_api_call_success(ctx, response): + set_botocore_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response) def _on_botocore_trace_context_injection_prepared( @@ -682,6 +686,31 @@ def _on_botocore_bedrock_process_response( span.finish() +def _on_botocore_sqs_recvmessage_post( + ctx: core.ExecutionContext, _, result: Dict, propagate: bool, message_parser: Callable +) -> None: + if result is not None and "Messages" in result and len(result["Messages"]) >= 1: + ctx.set_item("message_received", True) + if propagate: + ctx.set_safe("distributed_context", extract_DD_context_from_messages(result["Messages"], message_parser)) + + +def _on_botocore_kinesis_getrecords_post( + ctx: core.ExecutionContext, + _, + __, + ___, + ____, + result, + propagate: bool, + message_parser: Callable, +): + if result is not None and "Records" in result and len(result["Records"]) >= 1: + ctx.set_item("message_received", True) + if propagate: + ctx.set_item("distributed_context", extract_DD_context_from_messages(result["Records"], message_parser)) + + def _on_redis_async_command_post(span, rowcount): if rowcount is not None: span.set_metric(db.ROWCOUNT, rowcount) @@ -727,10 +756,14 @@ def listen(): core.on("botocore.patched_stepfunctions_api_call.started", _on_botocore_patched_api_call_started) core.on("botocore.patched_stepfunctions_api_call.exception", _on_botocore_patched_api_call_exception) core.on("botocore.stepfunctions.update_messages", _on_botocore_update_messages) + core.on("botocore.eventbridge.update_messages", _on_botocore_update_messages) + core.on("botocore.client_context.update_messages", _on_botocore_update_messages) core.on("botocore.patched_bedrock_api_call.started", _on_botocore_patched_bedrock_api_call_started) core.on("botocore.patched_bedrock_api_call.exception", _on_botocore_patched_bedrock_api_call_exception) core.on("botocore.patched_bedrock_api_call.success", _on_botocore_patched_bedrock_api_call_success) core.on("botocore.bedrock.process_response", _on_botocore_bedrock_process_response) + core.on("botocore.sqs.ReceiveMessage.post", _on_botocore_sqs_recvmessage_post) + core.on("botocore.kinesis.GetRecords.post", _on_botocore_kinesis_getrecords_post) core.on("redis.async_command.post", _on_redis_async_command_post) for context_name in ( diff --git a/ddtrace/_trace/utils.py b/ddtrace/_trace/utils.py index 0e1a9364582..44bef3bbf23 100644 --- a/ddtrace/_trace/utils.py +++ b/ddtrace/_trace/utils.py @@ -1,3 +1,8 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional + from ddtrace import Span from ddtrace import config from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY @@ -5,8 +10,10 @@ from ddtrace.constants import SPAN_MEASURED_KEY from ddtrace.ext import SpanKind from ddtrace.ext import aws +from ddtrace.ext import http from ddtrace.internal.constants import COMPONENT from ddtrace.internal.utils.formats import deep_getattr +from ddtrace.propagation.http import HTTPPropagator def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, endpoint_name, operation): @@ -39,3 +46,37 @@ def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, # set analytics sample rate span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.botocore.get_analytics_sample_rate()) + + +def set_botocore_response_metadata_tags( + span: Span, result: Dict[str, Any], is_error_code_fn: Optional[Callable] = None +) -> None: + if not result or not result.get("ResponseMetadata"): + return + response_meta = result["ResponseMetadata"] + + if "HTTPStatusCode" in response_meta: + status_code = response_meta["HTTPStatusCode"] + span.set_tag(http.STATUS_CODE, status_code) + + # Mark this span as an error if requested + if is_error_code_fn is not None and is_error_code_fn(int(status_code)): + span.error = 1 + + if "RetryAttempts" in response_meta: + span.set_tag("retry_attempts", response_meta["RetryAttempts"]) + + if "RequestId" in response_meta: + span.set_tag_str("aws.requestid", response_meta["RequestId"]) + + +def extract_DD_context_from_messages(messages, extract_from_message: Callable): + ctx = None + if len(messages) >= 1: + message = messages[0] + context_json = extract_from_message(message) + if context_json is not None: + child_of = HTTPPropagator.extract(context_json) + if child_of.trace_id is not None: + ctx = child_of + return ctx diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 654e06a29e5..ec88464cabe 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -20,6 +20,7 @@ from ddtrace.appsec._iast._utils import _is_iast_enabled from ddtrace.appsec._utils import get_triggers from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.constants import REQUEST_PATH_PARAMS from ddtrace.internal.logger import get_logger from ddtrace.settings.asm import config as asm_config @@ -140,6 +141,7 @@ def __init__(self): env = ASM_Environment(True) self._id = _DataHandler.main_id + self._root = not in_context() self.active = True self.execution_context = core.ExecutionContext(__name__, **{"asm_env": env}) @@ -393,6 +395,12 @@ def asm_request_context_manager( if resources is not None: try: yield resources + except BlockingException as e: + # ensure that the BlockingRequest that is never raised outside a context + # is also never propagated outside the context + core.set_item(WAF_CONTEXT_NAMES.BLOCKED, e.args[0]) + if not resources._root: + raise finally: _end_context(resources) else: diff --git a/ddtrace/appsec/_common_module_patches.py b/ddtrace/appsec/_common_module_patches.py index 69c2610cab5..71d2fa59b5b 100644 --- a/ddtrace/appsec/_common_module_patches.py +++ b/ddtrace/appsec/_common_module_patches.py @@ -8,7 +8,9 @@ from typing import Callable from typing import Dict +from ddtrace.appsec._constants import WAF_CONTEXT_NAMES from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.logger import get_logger from ddtrace.settings.asm import config as asm_config from ddtrace.vendor.wrapt import FunctionWrapper @@ -49,6 +51,7 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._asm_request_context import is_blocked from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization @@ -66,7 +69,9 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs crop_trace="wrapped_open_CFDDB7ABBA9081B6", rule_type=EXPLOIT_PREVENTION.TYPE.LFI, ) - # DEV: Next part of the exploit prevention feature: add block here + if is_blocked(): + raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "lfi", filename) + return original_open_callable(*args, **kwargs) @@ -82,6 +87,7 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._asm_request_context import is_blocked from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization @@ -98,7 +104,8 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs crop_trace="wrapped_open_ED4CF71136E15EBF", rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, ) - # DEV: Next part of the exploit prevention feature: add block here + if is_blocked(): + raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url) return original_open_callable(*args, **kwargs) @@ -115,6 +122,7 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._asm_request_context import is_blocked from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization @@ -129,7 +137,9 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, crop_trace="wrapped_request_D8CB81E472AF98A2", rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, ) - # DEV: Next part of the exploit prevention feature: add block here + if is_blocked(): + raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url) + return original_request_callable(*args, **kwargs) diff --git a/ddtrace/contrib/asgi/middleware.py b/ddtrace/contrib/asgi/middleware.py index 70388af0de5..21061cf63fe 100644 --- a/ddtrace/contrib/asgi/middleware.py +++ b/ddtrace/contrib/asgi/middleware.py @@ -13,6 +13,7 @@ from ddtrace.ext import SpanKind from ddtrace.ext import SpanTypes from ddtrace.ext import http +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.compat import is_valid_ip from ddtrace.internal.constants import COMPONENT from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED @@ -288,6 +289,9 @@ async def wrapped_blocked_send(message): try: core.dispatch("asgi.start_request", ("asgi",)) return await self.app(scope, receive, wrapped_send) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) except trace_utils.InterruptException: return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) except Exception as exc: diff --git a/ddtrace/contrib/botocore/patch.py b/ddtrace/contrib/botocore/patch.py index e0bcc3f317f..b4f1a5265ea 100644 --- a/ddtrace/contrib/botocore/patch.py +++ b/ddtrace/contrib/botocore/patch.py @@ -39,9 +39,8 @@ from .services.sqs import update_messages as inject_trace_to_sqs_or_sns_message from .services.stepfunctions import patched_stepfunction_api_call from .services.stepfunctions import update_stepfunction_input -from .utils import inject_trace_to_client_context -from .utils import inject_trace_to_eventbridge_detail -from .utils import set_response_metadata_tags +from .utils import update_client_context +from .utils import update_eventbridge_detail _PATCHED_SUBMODULES = set() # type: Set[str] @@ -175,11 +174,11 @@ def prep_context_injection(ctx, endpoint_name, operation, trace_operation, param schematization_function = schematize_cloud_messaging_operation if endpoint_name == "lambda" and operation == "Invoke": - injection_function = inject_trace_to_client_context + injection_function = update_client_context schematization_function = schematize_cloud_faas_operation cloud_service = "lambda" if endpoint_name == "events" and operation == "PutEvents": - injection_function = inject_trace_to_eventbridge_detail + injection_function = update_eventbridge_detail cloud_service = "events" if endpoint_name == "sns" and "Publish" in operation: injection_function = inject_trace_to_sqs_or_sns_message @@ -224,9 +223,14 @@ def patched_api_call_fallback(original_func, instance, args, kwargs, function_va except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx["instrumented_api_call"].resource].is_error_code, + ], ) raise else: - core.dispatch("botocore.patched_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_api_call.success", [ctx, result]) return result diff --git a/ddtrace/contrib/botocore/services/kinesis.py b/ddtrace/contrib/botocore/services/kinesis.py index 412f0b0c27f..858f011410f 100644 --- a/ddtrace/contrib/botocore/services/kinesis.py +++ b/ddtrace/contrib/botocore/services/kinesis.py @@ -17,9 +17,8 @@ from ....internal.logger import get_logger from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import extract_DD_context +from ..utils import extract_DD_json from ..utils import get_kinesis_data_object -from ..utils import set_response_metadata_tags log = get_logger(__name__) @@ -74,13 +73,14 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var endpoint_name = function_vars.get("endpoint_name") operation = function_vars.get("operation") - message_received = False is_getrecords_call = False getrecords_error = None - child_of = None start_ns = None result = None + parent_ctx: core.ExecutionContext = core.ExecutionContext( + "botocore.patched_sqs_api_call.propagated", + ) if operation == "GetRecords": try: start_ns = time_ns() @@ -95,15 +95,20 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var time_estimate = record.get("ApproximateArrivalTimestamp", datetime.now()).timestamp() core.dispatch( f"botocore.{endpoint_name}.{operation}.post", - [params, time_estimate, data_obj.get("_datadog"), record], + [ + parent_ctx, + params, + time_estimate, + data_obj.get("_datadog"), + record, + result, + config.botocore.propagation_enabled, + extract_DD_json, + ], ) except Exception as e: getrecords_error = e - if result is not None and "Records" in result and len(result["Records"]) >= 1: - message_received = True - if config.botocore.propagation_enabled: - child_of = extract_DD_context(result["Records"]) if endpoint_name == "kinesis" and operation in {"PutRecord", "PutRecords"}: span_name = schematize_cloud_messaging_operation( @@ -116,7 +121,7 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var span_name = trace_operation stream_arn = params.get("StreamARN", params.get("StreamName", "")) function_is_not_getrecords = not is_getrecords_call - received_message_when_polling = is_getrecords_call and message_received + received_message_when_polling = is_getrecords_call and parent_ctx.get_item("message_received") instrument_empty_poll_calls = config.botocore.empty_poll_enabled should_instrument = ( received_message_when_polling or instrument_empty_poll_calls or function_is_not_getrecords or getrecords_error @@ -126,6 +131,7 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var if should_instrument: with core.context_with_data( "botocore.patched_kinesis_api_call", + parent=parent_ctx, instance=instance, args=args, params=params, @@ -136,7 +142,6 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var pin=pin, span_name=span_name, span_type=SpanTypes.HTTP, - child_of=child_of if child_of is not None else pin.tracer.context_provider.active(), activate=True, func_run=is_getrecords_call, start_ns=start_ns, @@ -158,15 +163,21 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var if getrecords_error: raise getrecords_error - core.dispatch("botocore.patched_kinesis_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_kinesis_api_call.success", [ctx, result]) return result except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_kinesis_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise + parent_ctx.end() elif is_getrecords_call: if getrecords_error: raise getrecords_error diff --git a/ddtrace/contrib/botocore/services/sqs.py b/ddtrace/contrib/botocore/services/sqs.py index 37080c85d70..25de175853a 100644 --- a/ddtrace/contrib/botocore/services/sqs.py +++ b/ddtrace/contrib/botocore/services/sqs.py @@ -7,8 +7,6 @@ import botocore.exceptions from ddtrace import config -from ddtrace.contrib.botocore.utils import extract_DD_context -from ddtrace.contrib.botocore.utils import set_response_metadata_tags from ddtrace.ext import SpanTypes from ddtrace.internal import core from ddtrace.internal.logger import get_logger @@ -16,6 +14,8 @@ from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.schema.span_attribute_schema import SpanDirection +from ..utils import extract_DD_json + log = get_logger(__name__) MAX_INJECTION_DATA_ATTRIBUTES = 10 @@ -83,16 +83,19 @@ def _ensure_datadog_messageattribute_enabled(params): def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): + with core.context_with_data("botocore.patched_sqs_api_call.propagated") as parent_ctx: + return _patched_sqs_api_call(parent_ctx, original_func, instance, args, kwargs, function_vars) + + +def _patched_sqs_api_call(parent_ctx, original_func, instance, args, kwargs, function_vars): params = function_vars.get("params") trace_operation = function_vars.get("trace_operation") pin = function_vars.get("pin") endpoint_name = function_vars.get("endpoint_name") operation = function_vars.get("operation") - message_received = False func_has_run = False func_run_err = None - child_of = None result = None if operation == "ReceiveMessage": @@ -103,16 +106,15 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): core.dispatch(f"botocore.{endpoint_name}.{operation}.pre", [params]) # run the function to extract possible parent context before creating ExecutionContext result = original_func(*args, **kwargs) - core.dispatch(f"botocore.{endpoint_name}.{operation}.post", [params, result]) + core.dispatch( + f"botocore.{endpoint_name}.{operation}.post", + [parent_ctx, params, result, config.botocore.propagation_enabled, extract_DD_json], + ) except Exception as e: func_run_err = e - if result is not None and "Messages" in result and len(result["Messages"]) >= 1: - message_received = True - if config.botocore.propagation_enabled: - child_of = extract_DD_context(result["Messages"]) function_is_not_recvmessage = not func_has_run - received_message_when_polling = func_has_run and message_received + received_message_when_polling = func_has_run and parent_ctx.get_item("message_received") instrument_empty_poll_calls = config.botocore.empty_poll_enabled should_instrument = ( received_message_when_polling or instrument_empty_poll_calls or function_is_not_recvmessage or func_run_err @@ -133,9 +135,12 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): else: call_name = trace_operation + child_of = parent_ctx.get_item("distributed_context") + if should_instrument: with core.context_with_data( "botocore.patched_sqs_api_call", + parent=parent_ctx, span_name=call_name, service=schematize_service_name("{}.{}".format(pin.service, endpoint_name)), span_type=SpanTypes.HTTP, @@ -161,7 +166,7 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): result = original_func(*args, **kwargs) core.dispatch(f"botocore.{endpoint_name}.{operation}.post", [params, result]) - core.dispatch("botocore.patched_sqs_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_sqs_api_call.success", [ctx, result]) if func_run_err: raise func_run_err @@ -169,7 +174,12 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_sqs_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise elif func_has_run: diff --git a/ddtrace/contrib/botocore/services/stepfunctions.py b/ddtrace/contrib/botocore/services/stepfunctions.py index d611f664a48..16213f2e3ed 100644 --- a/ddtrace/contrib/botocore/services/stepfunctions.py +++ b/ddtrace/contrib/botocore/services/stepfunctions.py @@ -12,7 +12,6 @@ from ....internal.schema import SpanDirection from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import set_response_metadata_tags log = get_logger(__name__) @@ -81,6 +80,11 @@ def patched_stepfunction_api_call(original_func, instance, args, kwargs: Dict, f except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_stepfunctions_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise diff --git a/ddtrace/contrib/botocore/utils.py b/ddtrace/contrib/botocore/utils.py index ead47ace10c..5804a4e1a36 100644 --- a/ddtrace/contrib/botocore/utils.py +++ b/ddtrace/contrib/botocore/utils.py @@ -8,13 +8,11 @@ from typing import Optional from typing import Tuple -from ddtrace import Span from ddtrace import config +from ddtrace.internal import core from ddtrace.internal.core import ExecutionContext -from ...ext import http from ...internal.logger import get_logger -from ...propagation.http import HTTPPropagator log = get_logger(__name__) @@ -66,11 +64,7 @@ def get_kinesis_data_object(data: str) -> Tuple[str, Optional[Dict[str, Any]]]: return None, None -def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: - """ - Inject trace headers into the EventBridge record if the record's Detail object contains a JSON string - Max size per event is 256KB (https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-putevent-size.html) - """ +def update_eventbridge_detail(ctx: ExecutionContext) -> None: params = ctx["params"] if "Entries" not in params: log.warning("Unable to inject context. The Event Bridge event had no Entries.") @@ -86,8 +80,7 @@ def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: continue detail["_datadog"] = {} - span = ctx[ctx["call_key"]] - HTTPPropagator.inject(span.context, detail["_datadog"]) + core.dispatch("botocore.eventbridge.update_messages", [ctx, None, None, detail["_datadog"], None]) detail_json = json.dumps(detail) # check if detail size will exceed max size with headers @@ -99,12 +92,11 @@ def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: entry["Detail"] = detail_json -def inject_trace_to_client_context(ctx): +def update_client_context(ctx: ExecutionContext) -> None: trace_headers = {} - span = ctx[ctx["call_key"]] - params = ctx["params"] - HTTPPropagator.inject(span.context, trace_headers) + core.dispatch("botocore.client_context.update_messages", [ctx, None, None, trace_headers, None]) client_context_object = {} + params = ctx["params"] if "ClientContext" in params: try: client_context_json = base64.b64decode(params["ClientContext"]).decode("utf-8") @@ -131,39 +123,7 @@ def modify_client_context(client_context_object, trace_headers): client_context_object["custom"] = trace_headers -def set_response_metadata_tags(span: Span, result: Dict[str, Any]) -> None: - if not result or not result.get("ResponseMetadata"): - return - response_meta = result["ResponseMetadata"] - - if "HTTPStatusCode" in response_meta: - status_code = response_meta["HTTPStatusCode"] - span.set_tag(http.STATUS_CODE, status_code) - - # Mark this span as an error if requested - if config.botocore.operations[span.resource].is_error_code(int(status_code)): - span.error = 1 - - if "RetryAttempts" in response_meta: - span.set_tag("retry_attempts", response_meta["RetryAttempts"]) - - if "RequestId" in response_meta: - span.set_tag_str("aws.requestid", response_meta["RequestId"]) - - -def extract_DD_context(messages): - ctx = None - if len(messages) >= 1: - message = messages[0] - context_json = extract_trace_context_json(message) - if context_json is not None: - child_of = HTTPPropagator.extract(context_json) - if child_of.trace_id is not None: - ctx = child_of - return ctx - - -def extract_trace_context_json(message): +def extract_DD_json(message): context_json = None try: if message and message.get("Type") == "Notification": @@ -200,7 +160,7 @@ def extract_trace_context_json(message): if "Body" in message: try: body = json.loads(message["Body"]) - return extract_trace_context_json(body) + return extract_DD_json(body) except ValueError: log.debug("Unable to parse AWS message body.") except Exception: diff --git a/ddtrace/contrib/django/patch.py b/ddtrace/contrib/django/patch.py index 0f4e2318c89..670e94fe1ba 100644 --- a/ddtrace/contrib/django/patch.py +++ b/ddtrace/contrib/django/patch.py @@ -22,6 +22,7 @@ from ddtrace.ext import http from ddtrace.ext import sql as sqlx from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.compat import Iterable from ddtrace.internal.compat import maybe_stringify from ddtrace.internal.constants import COMPONENT @@ -467,7 +468,7 @@ def traced_get_response(django, pin, func, instance, args, kwargs): def blocked_response(): from django.http import HttpResponse - block_config = core.get_item(HTTP_REQUEST_BLOCKED) + block_config = core.get_item(HTTP_REQUEST_BLOCKED) or {} desired_type = block_config.get("type", "auto") status = block_config.get("status_code", 403) if desired_type == "none": @@ -510,7 +511,12 @@ def blocked_response(): response = blocked_response() return response - response = func(*args, **kwargs) + try: + response = func(*args, **kwargs) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + response = blocked_response() + return response if core.get_item(HTTP_REQUEST_BLOCKED): response = blocked_response() diff --git a/ddtrace/contrib/trace_utils_redis.py b/ddtrace/contrib/trace_utils_redis.py new file mode 100644 index 00000000000..8df16c3ce4d --- /dev/null +++ b/ddtrace/contrib/trace_utils_redis.py @@ -0,0 +1,18 @@ +from ddtrace.contrib.redis_utils import determine_row_count +from ddtrace.contrib.redis_utils import stringify_cache_args +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate + + +deprecate( + "The ddtrace.contrib.trace_utils_redis module is deprecated and will be removed.", + message="A new interface will be provided by the ddtrace.contrib.redis_utils module", + category=DDTraceDeprecationWarning, +) + + +format_command_args = stringify_cache_args + + +def determine_row_count(redis_command, span, result): # noqa: F811 + determine_row_count(redis_command=redis_command, result=result) diff --git a/ddtrace/contrib/wsgi/wsgi.py b/ddtrace/contrib/wsgi/wsgi.py index 1714bdfa1a1..aff74e3b0a0 100644 --- a/ddtrace/contrib/wsgi/wsgi.py +++ b/ddtrace/contrib/wsgi/wsgi.py @@ -24,6 +24,7 @@ from ddtrace.contrib import trace_utils from ddtrace.ext import SpanKind from ddtrace.ext import SpanTypes +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.constants import COMPONENT from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED from ddtrace.internal.logger import get_logger @@ -109,15 +110,6 @@ def __call__(self, environ: Iterable, start_response: Callable) -> wrapt.ObjectP call_key="req_span", ) as ctx: ctx.set_item("wsgi.construct_url", construct_url) - if core.get_item(HTTP_REQUEST_BLOCKED): - result = core.dispatch_with_results("wsgi.block.started", (ctx, construct_url)).status_headers_content - if result: - status, headers, content = result.value - else: - status, headers, content = 403, [], "" - start_response(str(status), headers) - closing_iterable = [content] - not_blocked = False def blocked_view(): result = core.dispatch_with_results("wsgi.block.started", (ctx, construct_url)).status_headers_content @@ -127,12 +119,24 @@ def blocked_view(): status, headers, content = 403, [], "" return content, status, headers + if core.get_item(HTTP_REQUEST_BLOCKED): + content, status, headers = blocked_view() + start_response(str(status), headers) + closing_iterable = [content] + not_blocked = False + core.dispatch("wsgi.block_decided", (blocked_view,)) if not_blocked: core.dispatch("wsgi.request.prepare", (ctx, start_response)) try: closing_iterable = self.app(environ, ctx.get_item("intercept_start_response")) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + content, status, headers = blocked_view() + start_response(str(status), headers) + closing_iterable = [content] + core.dispatch("wsgi.app.exception", (ctx,)) except BaseException: core.dispatch("wsgi.app.exception", (ctx,)) raise diff --git a/ddtrace/internal/_exceptions.py b/ddtrace/internal/_exceptions.py new file mode 100644 index 00000000000..01e45d2b063 --- /dev/null +++ b/ddtrace/internal/_exceptions.py @@ -0,0 +1,5 @@ +class BlockingException(BaseException): + """ + Exception raised when a request is blocked by ASM + It derives from BaseException to avoid being caught by the general Exception handler + """ diff --git a/ddtrace/internal/datastreams/botocore.py b/ddtrace/internal/datastreams/botocore.py index 1f1b79aee80..ec004f1ff9a 100644 --- a/ddtrace/internal/datastreams/botocore.py +++ b/ddtrace/internal/datastreams/botocore.py @@ -172,7 +172,7 @@ def get_datastreams_context(message): return context_json -def handle_sqs_receive(params, result): +def handle_sqs_receive(_, params, result, *args): from . import data_streams_processor as processor queue_name = get_queue_name(params) @@ -206,7 +206,7 @@ def record_data_streams_path_for_kinesis_stream(params, time_estimate, context_j ) -def handle_kinesis_receive(params, time_estimate, context_json, record): +def handle_kinesis_receive(_, params, time_estimate, context_json, record, *args): try: record_data_streams_path_for_kinesis_stream(params, time_estimate, context_json, record) except Exception: diff --git a/ddtrace/internal/packages.py b/ddtrace/internal/packages.py index 2d8f1c5fd1e..fcec01a463b 100644 --- a/ddtrace/internal/packages.py +++ b/ddtrace/internal/packages.py @@ -238,6 +238,11 @@ def is_third_party(path: Path) -> bool: return package.name in _third_party_packages() +@cached() +def is_user_code(path: Path) -> bool: + return not (is_stdlib(path) or is_third_party(path)) + + @cached() def is_distribution_available(name: str) -> bool: """Determine if a distribution is available in the current environment.""" diff --git a/ddtrace/internal/symbol_db/symbols.py b/ddtrace/internal/symbol_db/symbols.py index d454e9eb8f5..9f66ffa3a86 100644 --- a/ddtrace/internal/symbol_db/symbols.py +++ b/ddtrace/internal/symbol_db/symbols.py @@ -3,7 +3,7 @@ from dataclasses import field import dis from enum import Enum -import http +from http.client import HTTPResponse from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from inspect import isasyncgenfunction @@ -31,7 +31,6 @@ from ddtrace.internal.logger import get_logger from ddtrace.internal.module import BaseModuleWatchdog from ddtrace.internal.module import origin -from ddtrace.internal.packages import is_stdlib from ddtrace.internal.runtime import get_runtime_id from ddtrace.internal.safety import _isinstance from ddtrace.internal.utils.cache import cached @@ -50,10 +49,10 @@ @cached() -def is_from_stdlib(obj: t.Any) -> t.Optional[bool]: +def is_from_user_code(obj: t.Any) -> t.Optional[bool]: try: path = origin(sys.modules[object.__getattribute__(obj, "__module__")]) - return is_stdlib(path) if path is not None else None + return packages.is_user_code(path) if path is not None else None except (AttributeError, KeyError): return None @@ -182,9 +181,6 @@ def _(cls, module: ModuleType, data: ScopeData): symbols = [] scopes = [] - if is_stdlib(module_origin): - return None - for alias, child in object.__getattribute__(module, "__dict__").items(): if _isinstance(child, ModuleType): # We don't want to traverse other modules. @@ -224,7 +220,7 @@ def _(cls, obj: type, data: ScopeData): return None data.seen.add(obj) - if is_from_stdlib(obj): + if not is_from_user_code(obj): return None symbols = [] @@ -347,7 +343,7 @@ def _(cls, f: FunctionType, data: ScopeData): return None data.seen.add(f) - if is_from_stdlib(f): + if not is_from_user_code(f): return None code = f.__dd_wrapped__.__code__ if hasattr(f, "__dd_wrapped__") else f.__code__ @@ -416,7 +412,7 @@ def _(cls, pr: property, data: ScopeData): data.seen.add(pr.fget) # TODO: These names don't match what is reported by the discovery. - if pr.fget is None or is_from_stdlib(pr.fget): + if pr.fget is None or not is_from_user_code(pr.fget): return None path = func_origin(t.cast(FunctionType, pr.fget)) @@ -477,7 +473,7 @@ def to_json(self) -> dict: "scopes": [_.to_json() for _ in self._scopes], } - def upload(self) -> http.client.HTTPResponse: + def upload(self) -> HTTPResponse: body, headers = multipart( parts=[ FormData( @@ -509,14 +505,24 @@ def __len__(self) -> int: def is_module_included(module: ModuleType) -> bool: + # Check if module name matches the include patterns if symdb_config._includes_re.match(module.__name__): return True - package = packages.module_to_package(module) - if package is None: + # Check if it is user code + module_origin = origin(module) + if module_origin is None: return False - return symdb_config._includes_re.match(package.name) is not None + if packages.is_user_code(module_origin): + return True + + # Check if the package name matches the include patterns + package = packages.filename_to_package(module_origin) + if package is not None and symdb_config._includes_re.match(package.name): + return True + + return False class SymbolDatabaseUploader(BaseModuleWatchdog): diff --git a/ddtrace/internal/telemetry/writer.py b/ddtrace/internal/telemetry/writer.py index 836daa5da74..8c9ebf6f1f7 100644 --- a/ddtrace/internal/telemetry/writer.py +++ b/ddtrace/internal/telemetry/writer.py @@ -422,6 +422,14 @@ def _app_started_event(self, register_app_shutdown=True): if register_app_shutdown: atexit.register(self.app_shutdown) + inst_config_id_entry = ("instrumentation_config_id", "", "default") + if "DD_INSTRUMENTATION_CONFIG_ID" in os.environ: + inst_config_id_entry = ( + "instrumentation_config_id", + os.environ["DD_INSTRUMENTATION_CONFIG_ID"], + "env_var", + ) + self.add_configurations( [ self._telemetry_entry("_trace_enabled"), @@ -435,6 +443,7 @@ def _app_started_event(self, register_app_shutdown=True): self._telemetry_entry("trace_http_header_tags"), self._telemetry_entry("tags"), self._telemetry_entry("_tracing_enabled"), + inst_config_id_entry, (TELEMETRY_STARTUP_LOGS_ENABLED, config._startup_logs_enabled, "unknown"), (TELEMETRY_DYNAMIC_INSTRUMENTATION_ENABLED, di_config.enabled, "unknown"), (TELEMETRY_EXCEPTION_DEBUGGING_ENABLED, ed_config.enabled, "unknown"), diff --git a/ddtrace/llmobs/_constants.py b/ddtrace/llmobs/_constants.py index fa92a3ed566..9d04fa68cbf 100644 --- a/ddtrace/llmobs/_constants.py +++ b/ddtrace/llmobs/_constants.py @@ -8,9 +8,11 @@ MODEL_NAME = "_ml_obs.meta.model_name" MODEL_PROVIDER = "_ml_obs.meta.model_provider" +INPUT_DOCUMENTS = "_ml_obs.meta.input.documents" INPUT_MESSAGES = "_ml_obs.meta.input.messages" INPUT_VALUE = "_ml_obs.meta.input.value" INPUT_PARAMETERS = "_ml_obs.meta.input.parameters" +OUTPUT_DOCUMENTS = "_ml_obs.meta.output.documents" OUTPUT_MESSAGES = "_ml_obs.meta.output.messages" OUTPUT_VALUE = "_ml_obs.meta.output.value" diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 411c68e84af..d72aa983fe5 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -12,6 +12,7 @@ from ddtrace.internal import atexit from ddtrace.internal.logger import get_logger from ddtrace.internal.service import Service +from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES from ddtrace.llmobs._constants import INPUT_PARAMETERS from ddtrace.llmobs._constants import INPUT_VALUE @@ -20,6 +21,7 @@ from ddtrace.llmobs._constants import ML_APP from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE from ddtrace.llmobs._constants import SESSION_ID @@ -30,6 +32,7 @@ from ddtrace.llmobs._utils import _get_session_id from ddtrace.llmobs._writer import LLMObsEvalMetricWriter from ddtrace.llmobs._writer import LLMObsSpanWriter +from ddtrace.llmobs.utils import Documents from ddtrace.llmobs.utils import ExportedLLMObsSpan from ddtrace.llmobs.utils import Messages @@ -270,6 +273,64 @@ def workflow( return None return cls._instance._start_span("workflow", name=name, session_id=session_id, ml_app=ml_app) + @classmethod + def embedding( + cls, + model_name: str, + name: Optional[str] = None, + model_provider: Optional[str] = None, + session_id: Optional[str] = None, + ml_app: Optional[str] = None, + ) -> Optional[Span]: + """ + Trace a call to an embedding model or function to create an embedding. + + :param str model_name: The name of the invoked embedding model. + :param str name: The name of the traced operation. If not provided, a default value of "embedding" will be set. + :param str model_provider: The name of the invoked LLM provider (ex: openai, bedrock). + If not provided, a default value of "custom" will be set. + :param str session_id: The ID of the underlying user session. Required for tracking sessions. + :param str ml_app: The name of the ML application that the agent is orchestrating. If not provided, the default + value DD_LLMOBS_APP_NAME will be set. + + :returns: The Span object representing the traced operation. + """ + if cls.enabled is False or cls._instance is None: + log.warning("LLMObs.embedding() cannot be used while LLMObs is disabled.") + return None + if not model_name: + log.warning("model_name must be the specified name of the invoked model.") + return None + if model_provider is None: + model_provider = "custom" + return cls._instance._start_span( + "embedding", + name, + model_name=model_name, + model_provider=model_provider, + session_id=session_id, + ml_app=ml_app, + ) + + @classmethod + def retrieval( + cls, name: Optional[str] = None, session_id: Optional[str] = None, ml_app: Optional[str] = None + ) -> Optional[Span]: + """ + Trace a vector search operation involving a list of documents being returned from an external knowledge base. + + :param str name: The name of the traced operation. If not provided, a default value of "workflow" will be set. + :param str session_id: The ID of the underlying user session. Required for tracking sessions. + :param str ml_app: The name of the ML application that the agent is orchestrating. If not provided, the default + value DD_LLMOBS_APP_NAME will be set. + + :returns: The Span object representing the traced operation. + """ + if cls.enabled is False or cls._instance is None: + log.warning("LLMObs.retrieval() cannot be used while LLMObs is disabled.") + return None + return cls._instance._start_span("retrieval", name=name, session_id=session_id, ml_app=ml_app) + @classmethod def annotate( cls, @@ -290,10 +351,15 @@ def annotate( :param input_data: A single input string, dictionary, or a list of dictionaries based on the span kind: - llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."}, or a list of dictionaries with the same signature. + - embedding spans: accepts a string, list of strings, or a dictionary of form + {"text": "...", ...} or a list of dictionaries with the same signature. - other: any JSON serializable type. :param output_data: A single output string, dictionary, or a list of dictionaries based on the span kind: - llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."}, or a list of dictionaries with the same signature. + - retrieval spans: a dictionary containing any of the key value pairs + {"name": str, "id": str, "text": str, "score": float}, + or a list of dictionaries with the same signature. - other: any JSON serializable type. :param parameters: (DEPRECATED) Dictionary of JSON serializable key-value pairs to set as input parameters. :param metadata: Dictionary of JSON serializable key-value metadata pairs relevant to the input/output operation @@ -327,6 +393,10 @@ def annotate( if input_data or output_data: if span_kind == "llm": cls._tag_llm_io(span, input_messages=input_data, output_messages=output_data) + elif span_kind == "embedding": + cls._tag_embedding_io(span, input_documents=input_data, output_text=output_data) + elif span_kind == "retrieval": + cls._tag_retrieval_io(span, input_text=input_data, output_documents=output_data) else: cls._tag_text_io(span, input_value=input_data, output_value=output_data) if metadata is not None: @@ -371,6 +441,50 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None): except (TypeError, AttributeError): log.warning("Failed to parse output messages.", exc_info=True) + @classmethod + def _tag_embedding_io(cls, span, input_documents=None, output_text=None): + """Tags input documents and output text for embedding-kind spans. + Will be mapped to span's `meta.{input,output}.text` fields. + """ + if input_documents is not None: + try: + if not isinstance(input_documents, Documents): + input_documents = Documents(input_documents) + if input_documents.documents: + span.set_tag_str(INPUT_DOCUMENTS, json.dumps(input_documents.documents)) + except (TypeError, AttributeError): + log.warning("Failed to parse input documents.", exc_info=True) + if output_text is not None: + if isinstance(output_text, str): + span.set_tag_str(OUTPUT_VALUE, output_text) + else: + try: + span.set_tag_str(OUTPUT_VALUE, json.dumps(output_text)) + except TypeError: + log.warning("Failed to parse output text. Output text must be JSON serializable.") + + @classmethod + def _tag_retrieval_io(cls, span, input_text=None, output_documents=None): + """Tags input text and output documents for retrieval-kind spans. + Will be mapped to span's `meta.{input,output}.text` fields. + """ + if input_text is not None: + if isinstance(input_text, str): + span.set_tag_str(INPUT_VALUE, input_text) + else: + try: + span.set_tag_str(INPUT_VALUE, json.dumps(input_text)) + except TypeError: + log.warning("Failed to parse input text. Input text must be JSON serializable.") + if output_documents is not None: + try: + if not isinstance(output_documents, Documents): + output_documents = Documents(output_documents) + if output_documents.documents: + span.set_tag_str(OUTPUT_DOCUMENTS, json.dumps(output_documents.documents)) + except (TypeError, AttributeError): + log.warning("Failed to parse output documents.", exc_info=True) + @classmethod def _tag_text_io(cls, span, input_value=None, output_value=None): """Tags input/output values for non-LLM kind spans. diff --git a/ddtrace/llmobs/_trace_processor.py b/ddtrace/llmobs/_trace_processor.py index f95b2637be0..ac07cf1d484 100644 --- a/ddtrace/llmobs/_trace_processor.py +++ b/ddtrace/llmobs/_trace_processor.py @@ -15,6 +15,7 @@ from ddtrace.ext import SpanTypes from ddtrace.internal.logger import get_logger from ddtrace.internal.utils.formats import asbool +from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES from ddtrace.llmobs._constants import INPUT_PARAMETERS from ddtrace.llmobs._constants import INPUT_VALUE @@ -23,6 +24,7 @@ from ddtrace.llmobs._constants import ML_APP from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE from ddtrace.llmobs._constants import SESSION_ID @@ -65,7 +67,7 @@ def _llmobs_span_event(self, span: Span) -> Dict[str, Any]: """Span event object structure.""" span_kind = span._meta.pop(SPAN_KIND) meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}} - if span_kind == "llm" and span.get_tag(MODEL_NAME) is not None: + if span_kind in ("llm", "embedding") and span.get_tag(MODEL_NAME) is not None: meta["model_name"] = span._meta.pop(MODEL_NAME) meta["model_provider"] = span._meta.pop(MODEL_PROVIDER, "custom").lower() if span.get_tag(METADATA) is not None: @@ -78,8 +80,12 @@ def _llmobs_span_event(self, span: Span) -> Dict[str, Any]: meta["input"]["value"] = span._meta.pop(INPUT_VALUE) if span_kind == "llm" and span.get_tag(OUTPUT_MESSAGES) is not None: meta["output"]["messages"] = json.loads(span._meta.pop(OUTPUT_MESSAGES)) + if span_kind == "embedding" and span.get_tag(INPUT_DOCUMENTS) is not None: + meta["input"]["documents"] = json.loads(span._meta.pop(INPUT_DOCUMENTS)) if span.get_tag(OUTPUT_VALUE) is not None: meta["output"]["value"] = span._meta.pop(OUTPUT_VALUE) + if span_kind == "retrieval" and span.get_tag(OUTPUT_DOCUMENTS) is not None: + meta["output"]["documents"] = json.loads(span._meta.pop(OUTPUT_DOCUMENTS)) if span.error: meta[ERROR_MSG] = span.get_tag(ERROR_MSG) meta[ERROR_STACK] = span.get_tag(ERROR_STACK) diff --git a/ddtrace/llmobs/decorators.py b/ddtrace/llmobs/decorators.py index 1cb18620ea4..cdb9dd9762e 100644 --- a/ddtrace/llmobs/decorators.py +++ b/ddtrace/llmobs/decorators.py @@ -9,34 +9,42 @@ log = get_logger(__name__) -def llm( - model_name: str, - model_provider: Optional[str] = None, - name: Optional[str] = None, - session_id: Optional[str] = None, - ml_app: Optional[str] = None, -): - def inner(func): - @wraps(func) - def wrapper(*args, **kwargs): - if not LLMObs.enabled or LLMObs._instance is None: - log.warning("LLMObs.llm() cannot be used while LLMObs is disabled.") - return func(*args, **kwargs) - span_name = name - if span_name is None: - span_name = func.__name__ - with LLMObs.llm( - model_name=model_name, - model_provider=model_provider, - name=span_name, - session_id=session_id, - ml_app=ml_app, - ): - return func(*args, **kwargs) +def _model_decorator(operation_kind): + def decorator( + model_name: str, + original_func: Optional[Callable] = None, + model_provider: Optional[str] = None, + name: Optional[str] = None, + session_id: Optional[str] = None, + ml_app: Optional[str] = None, + ): + def inner(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not LLMObs.enabled or LLMObs._instance is None: + log.warning("LLMObs.%s() cannot be used while LLMObs is disabled.", operation_kind) + return func(*args, **kwargs) + traced_model_name = model_name + if traced_model_name is None: + raise TypeError("model_name is required for LLMObs.{}()".format(operation_kind)) + span_name = name + if span_name is None: + span_name = func.__name__ + traced_operation = getattr(LLMObs, operation_kind, "llm") + with traced_operation( + model_name=model_name, + model_provider=model_provider, + name=span_name, + session_id=session_id, + ml_app=ml_app, + ): + return func(*args, **kwargs) - return wrapper + return wrapper + + return inner - return inner + return decorator def _llmobs_decorator(operation_kind): @@ -50,7 +58,7 @@ def inner(func): @wraps(func) def wrapper(*args, **kwargs): if not LLMObs.enabled or LLMObs._instance is None: - log.warning("LLMObs.{}() cannot be used while LLMObs is disabled.", operation_kind) + log.warning("LLMObs.%s() cannot be used while LLMObs is disabled.", operation_kind) return func(*args, **kwargs) span_name = name if span_name is None: @@ -68,7 +76,10 @@ def wrapper(*args, **kwargs): return decorator +llm = _model_decorator("llm") +embedding = _model_decorator("embedding") workflow = _llmobs_decorator("workflow") task = _llmobs_decorator("task") tool = _llmobs_decorator("tool") +retrieval = _llmobs_decorator("retrieval") agent = _llmobs_decorator("agent") diff --git a/ddtrace/llmobs/utils.py b/ddtrace/llmobs/utils.py index 1fbb7305c36..cbb1f97d4f6 100644 --- a/ddtrace/llmobs/utils.py +++ b/ddtrace/llmobs/utils.py @@ -16,6 +16,7 @@ ExportedLLMObsSpan = TypedDict("ExportedLLMObsSpan", {"span_id": str, "trace_id": str}) +Document = TypedDict("Document", {"name": str, "id": str, "text": str, "score": float}, total=False) Message = TypedDict("Message", {"content": str, "role": str}, total=False) @@ -40,3 +41,36 @@ def __init__(self, messages: Union[List[Dict[str, str]], Dict[str, str], str]): if not isinstance(role, str): raise TypeError("Message role must be a string, and one of .") self.messages.append(Message(content=content, role=role)) + + +class Documents: + def __init__(self, documents: Union[List[Dict[str, str]], Dict[str, str], str]): + self.documents = [] + if not isinstance(documents, list): + documents = [documents] # type: ignore[list-item] + for document in documents: + if isinstance(document, str): + self.documents.append(Document(text=document)) + continue + elif not isinstance(document, dict): + raise TypeError("documents must be a string, dictionary, or list of dictionaries.") + document_text = document.get("text") + document_name = document.get("name") + document_id = document.get("id") + document_score = document.get("score") + if not isinstance(document_text, str): + raise TypeError("Document text must be a string.") + formatted_document = Document(text=document_text) + if document_name: + if not isinstance(document_name, str): + raise TypeError("document name must be a string.") + formatted_document["name"] = document_name + if document_id: + if not isinstance(document_id, str): + raise TypeError("document id must be a string.") + formatted_document["id"] = document_id + if document_score: + if not isinstance(document_score, (int, float)): + raise TypeError("document score must be an integer or float.") + formatted_document["score"] = document_score + self.documents.append(formatted_document) diff --git a/ddtrace/propagation/_database_monitoring.py b/ddtrace/propagation/_database_monitoring.py index 4210e4cbec6..4002f864dd8 100644 --- a/ddtrace/propagation/_database_monitoring.py +++ b/ddtrace/propagation/_database_monitoring.py @@ -23,6 +23,7 @@ DBM_DATABASE_SERVICE_NAME_KEY = "dddbs" DBM_PEER_HOSTNAME_KEY = "ddh" DBM_PEER_DB_NAME_KEY = "dddb" +DBM_PEER_SERVICE_KEY = "ddprs" DBM_ENVIRONMENT_KEY = "dde" DBM_VERSION_KEY = "ddpv" DBM_TRACE_PARENT_KEY = "traceparent" @@ -56,12 +57,14 @@ def __init__( sql_injector=default_sql_injector, peer_hostname_tag="out.host", peer_db_name_tag="db.name", + peer_service_tag="peer.service", ): self.sql_pos = sql_pos self.sql_kw = sql_kw self.sql_injector = sql_injector self.peer_hostname_tag = peer_hostname_tag self.peer_db_name_tag = peer_db_name_tag + self.peer_service_tag = peer_service_tag def inject(self, dbspan, args, kwargs): # run sampling before injection to propagate correct sampling priority @@ -114,6 +117,10 @@ def _get_dbm_comment(self, db_span): if peer_hostname: dbm_tags[DBM_PEER_HOSTNAME_KEY] = peer_hostname + peer_service = db_span.get_tag(self.peer_service_tag) + if peer_service: + dbm_tags[DBM_PEER_SERVICE_KEY] = peer_service + if dbm_config.propagation_mode == "full": db_span.set_tag_str(DBM_TRACE_INJECTED_TAG, "true") dbm_tags[DBM_TRACE_PARENT_KEY] = db_span.context._traceparent diff --git a/riotfile.py b/riotfile.py index d1ee65dfa21..e52d53f37b1 100644 --- a/riotfile.py +++ b/riotfile.py @@ -133,7 +133,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION): Venv( name="appsec_iast", pys=select_pys(), - command="pytest {cmdargs} tests/appsec/iast/", + command="pytest -v {cmdargs} tests/appsec/iast/", pkgs={ "requests": latest, "pycryptodome": latest, diff --git a/tests/.suitespec.json b/tests/.suitespec.json index 7e6f1512ec4..e1d036b4581 100644 --- a/tests/.suitespec.json +++ b/tests/.suitespec.json @@ -34,6 +34,7 @@ ], "core": [ "ddtrace/internal/__init__.py", + "ddtrace/internal/_exceptions.py", "ddtrace/internal/_rand.pyi", "ddtrace/internal/_rand.pyx", "ddtrace/internal/_stdint.h", @@ -141,6 +142,7 @@ "ddtrace/contrib/yaaredis/*", "ddtrace/_trace/utils_redis.py", "ddtrace/contrib/redis_utils.py", + "ddtrace/contrib/trace_utils_redis.py", "ddtrace/ext/redis.py" ], "mongo": [ diff --git a/tests/appsec/appsec/rules-rasp-blocking.json b/tests/appsec/appsec/rules-rasp-blocking.json new file mode 100644 index 00000000000..21604f13e04 --- /dev/null +++ b/tests/appsec/appsec/rules-rasp-blocking.json @@ -0,0 +1,106 @@ +{ + "version": "2.1", + "metadata": { + "rules_version": "rules_rasp" + }, + "rules": [ + { + "id": "rasp-930-100", + "name": "Local file inclusion exploit", + "tags": { + "type": "lfi", + "category": "vulnerability_trigger", + "cwe": "22", + "capec": "1000/255/153/126", + "confidence": "0", + "module": "rasp" + }, + "conditions": [ + { + "parameters": { + "resource": [ + { + "address": "server.io.fs.file" + } + ], + "params": [ + { + "address": "server.request.query" + }, + { + "address": "server.request.body" + }, + { + "address": "server.request.path_params" + }, + { + "address": "grpc.server.request.message" + }, + { + "address": "graphql.server.all_resolvers" + }, + { + "address": "graphql.server.resolver" + } + ] + }, + "operator": "lfi_detector" + } + ], + "transformers": [], + "on_match": [ + "stack_trace", + "block" + ] + }, + { + "id": "rasp-934-100", + "name": "Server-side request forgery exploit", + "tags": { + "type": "ssrf", + "category": "vulnerability_trigger", + "cwe": "918", + "capec": "1000/225/115/664", + "confidence": "0", + "module": "rasp" + }, + "conditions": [ + { + "parameters": { + "resource": [ + { + "address": "server.io.net.url" + } + ], + "params": [ + { + "address": "server.request.query" + }, + { + "address": "server.request.body" + }, + { + "address": "server.request.path_params" + }, + { + "address": "grpc.server.request.message" + }, + { + "address": "graphql.server.all_resolvers" + }, + { + "address": "graphql.server.resolver" + } + ] + }, + "operator": "ssrf_detector" + } + ], + "transformers": [], + "on_match": [ + "stack_trace", + "block" + ] + } + ] +} \ No newline at end of file diff --git a/tests/appsec/appsec/test_asm_request_context.py b/tests/appsec/appsec/test_asm_request_context.py index b6e3a6da9c2..487401f00ed 100644 --- a/tests/appsec/appsec/test_asm_request_context.py +++ b/tests/appsec/appsec/test_asm_request_context.py @@ -1,6 +1,7 @@ import pytest from ddtrace.appsec import _asm_request_context +from ddtrace.internal._exceptions import BlockingException from tests.utils import override_global_config @@ -94,3 +95,19 @@ def test_asm_request_context_manager(): assert _asm_request_context.get_headers() == {} assert _asm_request_context.get_value("callbacks", "block") is None assert not _asm_request_context.get_headers_case_sensitive() + + +def test_blocking_exception_correctly_propagated(): + with override_global_config({"_asm_enabled": True}): + with _asm_request_context.asm_request_context_manager(): + witness = 0 + with _asm_request_context.asm_request_context_manager(): + witness = 1 + raise BlockingException({}, "rule", "type", "value") + # should be skipped by exception + witness = 3 + # should be also skipped by exception + witness = 4 + # no more exception there + # ensure that the exception was raised and caught at the end of the last context manager + assert witness == 1 diff --git a/tests/appsec/contrib_appsec/django_app/urls.py b/tests/appsec/contrib_appsec/django_app/urls.py index d8c45b4cb2e..a297f18fab3 100644 --- a/tests/appsec/contrib_appsec/django_app/urls.py +++ b/tests/appsec/contrib_appsec/django_app/urls.py @@ -71,6 +71,7 @@ def rasp(request, endpoint: str): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HttpResponse("<\br>\n".join(res)) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -98,7 +99,9 @@ def rasp(request, endpoint: str): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HttpResponse("<\\br>\n".join(res)) + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HttpResponse(f"Unknown endpoint: {endpoint}") diff --git a/tests/appsec/contrib_appsec/fastapi_app/app.py b/tests/appsec/contrib_appsec/fastapi_app/app.py index 820c25ce47a..5111fb6a218 100644 --- a/tests/appsec/contrib_appsec/fastapi_app/app.py +++ b/tests/appsec/contrib_appsec/fastapi_app/app.py @@ -128,6 +128,7 @@ async def rasp(endpoint: str, request: Request): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\br>\n".join(res)) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -155,7 +156,9 @@ async def rasp(endpoint: str, request: Request): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\\br>\n".join(res)) + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return HTMLResponse(f"Unknown endpoint: {endpoint}") return app diff --git a/tests/appsec/contrib_appsec/flask_app/app.py b/tests/appsec/contrib_appsec/flask_app/app.py index 0ecb3784ddb..8997c3fa0e6 100644 --- a/tests/appsec/contrib_appsec/flask_app/app.py +++ b/tests/appsec/contrib_appsec/flask_app/app.py @@ -72,6 +72,7 @@ def rasp(endpoint: str): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -99,6 +100,7 @@ def rasp(endpoint: str): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "shell": res = ["shell endpoint"] @@ -112,5 +114,7 @@ def rasp(endpoint: str): res.append(f"cmd stdout: {f.stdout.read()}") except Exception as e: res.append(f"Error: {e}") + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) + tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) return f"Unknown endpoint: {endpoint}" diff --git a/tests/appsec/contrib_appsec/utils.py b/tests/appsec/contrib_appsec/utils.py index 1a193b47a04..41d4f5e2b8d 100644 --- a/tests/appsec/contrib_appsec/utils.py +++ b/tests/appsec/contrib_appsec/utils.py @@ -1183,8 +1183,26 @@ def test_stream_response( ) ], ) + @pytest.mark.parametrize( + ("rule_file", "blocking"), + [ + (rules.RULES_EXPLOIT_PREVENTION, False), + (rules.RULES_EXPLOIT_PREVENTION_BLOCKING, True), + ], + ) def test_exploit_prevention( - self, interface, root_span, get_tag, asm_enabled, ep_enabled, endpoint, parameters, rule, top_functions + self, + interface, + root_span, + get_tag, + asm_enabled, + ep_enabled, + endpoint, + parameters, + rule, + top_functions, + rule_file, + blocking, ): from unittest.mock import patch as mock_patch @@ -1198,16 +1216,18 @@ def test_exploit_prevention( try: patch_requests() with override_global_config(dict(_asm_enabled=asm_enabled, _ep_enabled=ep_enabled)), override_env( - dict(DD_APPSEC_RULES=rules.RULES_EXPLOIT_PREVENTION) + dict(DD_APPSEC_RULES=rule_file) ), mock_patch("ddtrace.internal.telemetry.metrics_namespaces.MetricNamespace.add_metric") as mocked: patch_common_modules() self.update_tracer(interface) response = interface.client.get(f"/rasp/{endpoint}/?{parameters}") - assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert self.body(response).startswith(f"{endpoint} endpoint") + code = 403 if blocking and asm_enabled and ep_enabled else 200 + assert self.status(response) == code + assert get_tag(http.STATUS_CODE) == str(code) + if code == 200: + assert self.body(response).startswith(f"{endpoint} endpoint") if asm_enabled and ep_enabled: - self.check_rules_triggered([rule] * 2, root_span) + self.check_rules_triggered([rule] * (1 if blocking else 2), root_span) assert self.check_for_stack_trace(root_span) for trace in self.check_for_stack_trace(root_span): assert "frames" in trace @@ -1229,9 +1249,14 @@ def test_exploit_prevention( "appsec.rasp.rule.eval", (("rule_type", endpoint), ("waf_version", DDWAF_VERSION)), ) in telemetry_calls + if blocking: + assert get_tag("rasp.request.done") is None + else: + assert get_tag("rasp.request.done") == endpoint else: assert get_triggers(root_span()) is None assert self.check_for_stack_trace(root_span) == [] + assert get_tag("rasp.request.done") == endpoint finally: unpatch_common_modules() unpatch_requests() diff --git a/tests/appsec/iast/fixtures/propagation_path.py b/tests/appsec/iast/fixtures/propagation_path.py index b4f6616bc27..44b4d2aafee 100644 --- a/tests/appsec/iast/fixtures/propagation_path.py +++ b/tests/appsec/iast/fixtures/propagation_path.py @@ -3,6 +3,7 @@ make some changes """ import os +import sys ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -114,25 +115,68 @@ def propagation_path_5_prop(origin_string1, tainted_string_2): def propagation_memory_check(origin_string1, tainted_string_2): + import os.path + if type(origin_string1) is str: string1 = str(origin_string1) # 1 Range else: string1 = str(origin_string1, encoding="utf-8") # 1 Range + # string1 = taintsource if type(tainted_string_2) is str: string2 = str(tainted_string_2) # 1 Range else: string2 = str(tainted_string_2, encoding="utf-8") # 1 Range + # string2 = taintsource2 string3 = string1 + string2 # 2 Ranges + # taintsource1taintsource2 string4 = "-".join([string3, string3, string3]) # 6 Ranges + # taintsource1taintsource2-taintsource1taintsource2-taintsource1taintsource2 string5 = string4[0 : (len(string4) - 1)] + # taintsource1taintsource2-taintsource1taintsource2-taintsource1taintsource string6 = string5.title() + # Taintsource1Taintsource2-Taintsource1Taintsource2-Taintsource1Taintsource string7 = string6.upper() + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE string8 = "%s_notainted" % string7 - string9 = "notainted_{}".format(string8) + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string9 = "notainted#{}".format(string8) + # notainted#TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string10 = string9.split("#")[1] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string11 = "notainted#{}".format(string10) + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string12 = string11.rsplit("#")[1] + string13 = string12 + "\n" + "notainted" + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted\nnotainted + string14 = string13.splitlines()[0] # string14 = string12 + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string15 = os.path.join("foo", "bar", string14) + # /foo/bar/TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string16 = os.path.split(string15)[1] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string17 = string16 + ".jpg" + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted.jpg + string18 = os.path.splitext(string17)[0] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string19 = os.path.join(os.sep + string18, "nottainted_notdir") + # /TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted/nottainted_notdir + string20 = os.path.dirname(string19) + # /TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string21 = os.path.basename(string20) + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + + if sys.version_info >= (3, 12): + string22 = os.sep + string21 + # /TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + string23 = os.path.splitroot(string22)[2] + # TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE2-TAINTSOURCE1TAINTSOURCE_notainted + else: + string23 = string21 + try: # label propagation_memory_check - m = open(ROOT_DIR + "/" + string9 + ".txt") + m = open(ROOT_DIR + "/" + string23 + ".txt") _ = m.read() except Exception: pass - return string9 + return string23 diff --git a/tests/appsec/rules.py b/tests/appsec/rules.py index 83c2adb1981..d4aa4119062 100644 --- a/tests/appsec/rules.py +++ b/tests/appsec/rules.py @@ -11,6 +11,7 @@ RULES_SRB_METHOD = os.path.join(ROOT_DIR, "rules-suspicious-requests-get.json") RULES_BAD_VERSION = os.path.join(ROOT_DIR, "rules-bad_version.json") RULES_EXPLOIT_PREVENTION = os.path.join(ROOT_DIR, "rules-rasp.json") +RULES_EXPLOIT_PREVENTION_BLOCKING = os.path.join(ROOT_DIR, "rules-rasp-blocking.json") RESPONSE_CUSTOM_JSON = os.path.join(ROOT_DIR, "response-custom.json") RESPONSE_CUSTOM_HTML = os.path.join(ROOT_DIR, "response-custom.html") diff --git a/tests/contrib/aiomysql/test_aiomysql.py b/tests/contrib/aiomysql/test_aiomysql.py index 35e0a7e09c6..2247b2dba6f 100644 --- a/tests/contrib/aiomysql/test_aiomysql.py +++ b/tests/contrib/aiomysql/test_aiomysql.py @@ -230,7 +230,9 @@ class AioMySQLTestCase(AsyncioTestCase): TEST_SERVICE = "mysql" conn = None - async def _get_conn_tracer(self): + async def _get_conn_tracer(self, tags=None): + tags = tags if tags is not None else {} + if not self.conn: self.conn = await aiomysql.connect(**AIOMYSQL_CONFIG) assert not self.conn.closed @@ -239,7 +241,7 @@ async def _get_conn_tracer(self): assert pin # Customize the service # we have to apply it on the existing one since new one won't inherit `app` - pin.clone(tracer=self.tracer).onto(self.conn) + pin.clone(tracer=self.tracer, tags={**tags, **pin.tags}).onto(self.conn) return self.conn, self.tracer @@ -429,3 +431,26 @@ async def test_aiomysql_dbm_propagation_comment_peer_service_enabled(self): await shared_tests._test_dbm_propagation_comment_peer_service_enabled( config=AIOMYSQL_CONFIG, cursor=cursor, wrapped_instance=cursor.__wrapped__ ) + + @mark_asyncio + @AsyncioTestCase.run_in_subprocess( + env_overrides=dict( + DD_DBM_PROPAGATION_MODE="service", + DD_SERVICE="orders-app", + DD_ENV="staging", + DD_VERSION="v7343437-d7ac743", + DD_TRACE_SPAN_ATTRIBUTE_SCHEMA="v1", + ) + ) + async def test_aiomysql_dbm_propagation_comment_with_peer_service_tag(self): + """tests if dbm comment is set in mysql""" + conn, tracer = await self._get_conn_tracer({"peer.service": "peer_service_name"}) + cursor = await conn.cursor() + cursor.__wrapped__ = mock.AsyncMock() + + await shared_tests._test_dbm_propagation_comment_with_peer_service_tag( + config=AIOMYSQL_CONFIG, + cursor=cursor, + wrapped_instance=cursor.__wrapped__, + peer_service_name="peer_service_name", + ) diff --git a/tests/contrib/botocore/test.py b/tests/contrib/botocore/test.py index 8709964db6b..aa9627169a6 100644 --- a/tests/contrib/botocore/test.py +++ b/tests/contrib/botocore/test.py @@ -312,7 +312,7 @@ def test_s3_client(self): @mock_s3 def test_s3_head_404_default(self): """ - By default we attach exception information to s3 HeadObject + By default we do not attach exception information to s3 HeadObject API calls with a 404 response """ s3 = self.session.create_client("s3", region_name="us-west-2") diff --git a/tests/contrib/celery/test_integration.py b/tests/contrib/celery/test_integration.py index 21e716b8193..09c30f2c2ef 100644 --- a/tests/contrib/celery/test_integration.py +++ b/tests/contrib/celery/test_integration.py @@ -17,6 +17,7 @@ import ddtrace.internal.forksafe as forksafe from ddtrace.propagation.http import HTTPPropagator from tests.opentracer.utils import init_tracer +from tests.utils import flaky from ...utils import override_global_config from .base import CeleryBaseTestCase @@ -209,6 +210,7 @@ def fn_task_parameters(user, force_logout=False): assert run_span.get_tag("component") == "celery" assert run_span.get_tag("span.kind") == "consumer" + @flaky(1722529274) def test_fn_task_delay(self): # using delay shorthand must preserve arguments @self.app.task diff --git a/tests/contrib/shared_tests.py b/tests/contrib/shared_tests.py index 2ccb319551f..97d1df32cfa 100644 --- a/tests/contrib/shared_tests.py +++ b/tests/contrib/shared_tests.py @@ -94,3 +94,18 @@ async def _test_dbm_propagation_comment_peer_service_enabled(config, cursor, wra await _test_execute(dbm_comment, cursor, wrapped_instance) if execute_many: await _test_execute_many(dbm_comment, cursor, wrapped_instance) + + +async def _test_dbm_propagation_comment_with_peer_service_tag( + config, cursor, wrapped_instance, peer_service_name, execute_many=True +): + """tests if dbm comment is set in mysql""" + db_name = config["db"] + + dbm_comment = ( + f"/*dddb='{db_name}',dddbs='test',dde='staging',ddh='127.0.0.1',ddprs='{peer_service_name}',ddps='orders-app'," + "ddpv='v7343437-d7ac743'*/ " + ) + await _test_execute(dbm_comment, cursor, wrapped_instance) + if execute_many: + await _test_execute_many(dbm_comment, cursor, wrapped_instance) diff --git a/tests/internal/symbol_db/test_symbols.py b/tests/internal/symbol_db/test_symbols.py index 4c879b63e5c..a97f6c5bcee 100644 --- a/tests/internal/symbol_db/test_symbols.py +++ b/tests/internal/symbol_db/test_symbols.py @@ -203,20 +203,11 @@ def test_symbols_upload_enabled(): assert remoteconfig_poller.get_registered("LIVE_DEBUGGING_SYMBOL_DB") is not None -@pytest.mark.subprocess( - ddtrace_run=True, - env=dict( - DD_SYMBOL_DATABASE_UPLOAD_ENABLED="1", - _DD_SYMBOL_DATABASE_FORCE_UPLOAD="1", - DD_SYMBOL_DATABASE_INCLUDES="tests.submod.stuff", - ), -) +@pytest.mark.subprocess(ddtrace_run=True, env=dict(DD_SYMBOL_DATABASE_INCLUDES="tests.submod.stuff")) def test_symbols_force_upload(): from ddtrace.internal.symbol_db.symbols import ScopeType from ddtrace.internal.symbol_db.symbols import SymbolDatabaseUploader - assert SymbolDatabaseUploader.is_installed() - contexts = [] def _upload_context(context): @@ -224,11 +215,18 @@ def _upload_context(context): SymbolDatabaseUploader._upload_context = staticmethod(_upload_context) + SymbolDatabaseUploader.install() + + def get_scope(contexts, name): + for context in (_.to_json() for _ in contexts): + for scope in context["scopes"]: + if scope["name"] == name: + return scope + raise ValueError(f"Scope {name} not found in {contexts}") + import tests.submod.stuff # noqa import tests.submod.traced_stuff # noqa - (context,) = contexts - - (scope,) = context.to_json()["scopes"] + scope = get_scope(contexts, "tests.submod.stuff") assert scope["scope_type"] == ScopeType.MODULE assert scope["name"] == "tests.submod.stuff" diff --git a/tests/internal/test_tracer_flare.py b/tests/internal/test_tracer_flare.py index 7051190e17d..560dcdc1ddd 100644 --- a/tests/internal/test_tracer_flare.py +++ b/tests/internal/test_tracer_flare.py @@ -13,6 +13,7 @@ from ddtrace.internal.flare import Flare from ddtrace.internal.flare import FlareSendRequest from ddtrace.internal.logger import get_logger +from tests.utils import flaky DEBUG_LEVEL_INT = logging.DEBUG @@ -118,6 +119,7 @@ def handle_agent_task(): for p in processes: p.join() + @flaky(1722529274) def test_multiple_process_partial_failure(self): """ Validte that even if the tracer flare fails for one process, we should diff --git a/tests/llmobs/test_llmobs_decorators.py b/tests/llmobs/test_llmobs_decorators.py index f106c9db51b..31ecfbf37e1 100644 --- a/tests/llmobs/test_llmobs_decorators.py +++ b/tests/llmobs/test_llmobs_decorators.py @@ -2,7 +2,9 @@ import pytest from ddtrace.llmobs.decorators import agent +from ddtrace.llmobs.decorators import embedding from ddtrace.llmobs.decorators import llm +from ddtrace.llmobs.decorators import retrieval from ddtrace.llmobs.decorators import task from ddtrace.llmobs.decorators import tool from ddtrace.llmobs.decorators import workflow @@ -17,17 +19,28 @@ def mock_logs(): def test_llm_decorator_with_llmobs_disabled_logs_warning(LLMObs, mock_logs): - @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") - def f(): - pass + for decorator_name, decorator in (("llm", llm), ("embedding", embedding)): - LLMObs.disable() - f() - mock_logs.warning.assert_called_with("LLMObs.llm() cannot be used while LLMObs is disabled.") + @decorator( + model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id" + ) + def f(): + pass + + LLMObs.disable() + f() + mock_logs.warning.assert_called_with("LLMObs.%s() cannot be used while LLMObs is disabled.", decorator_name) + mock_logs.reset_mock() def test_non_llm_decorator_with_llmobs_disabled_logs_warning(LLMObs, mock_logs): - for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool), ("agent", agent)]: + for decorator_name, decorator in ( + ("task", task), + ("workflow", workflow), + ("tool", tool), + ("agent", agent), + ("retrieval", retrieval), + ): @decorator(name="test_function", session_id="test_session_id") def f(): @@ -35,7 +48,7 @@ def f(): LLMObs.disable() f() - mock_logs.warning.assert_called_with("LLMObs.{}() cannot be used while LLMObs is disabled.", decorator_name) + mock_logs.warning.assert_called_with("LLMObs.%s() cannot be used while LLMObs is disabled.", decorator_name) mock_logs.reset_mock() @@ -73,6 +86,64 @@ def f(): ) +def test_embedding_decorator(LLMObs, mock_llmobs_span_writer): + @embedding( + model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id" + ) + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, "embedding", model_name="test_model", model_provider="test_provider", session_id="test_session_id" + ) + ) + + +def test_embedding_decorator_no_model_name_raises_error(LLMObs): + with pytest.raises(TypeError): + + @embedding(model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + pass + + +def test_embedding_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): + @embedding(model_name="test_model") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event(span, "embedding", model_name="test_model", model_provider="custom") + ) + + +def test_retrieval_decorator(LLMObs, mock_llmobs_span_writer): + @retrieval(name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "retrieval", session_id="test_session_id") + ) + + +def test_retrieval_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): + @retrieval() + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "retrieval")) + + def test_task_decorator(LLMObs, mock_llmobs_span_writer): @task(name="test_function", session_id="test_session_id") def f(): @@ -265,7 +336,13 @@ def f(): def test_non_llm_decorators_no_args(LLMObs, mock_llmobs_span_writer): """Test that using the decorators without any arguments, i.e. @tool, works the same as @tool(...).""" - for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool)]: + for decorator_name, decorator in [ + ("task", task), + ("workflow", workflow), + ("tool", tool), + ("agent", agent), + ("retrieval", retrieval), + ]: @decorator def f(): @@ -314,12 +391,14 @@ def g(): ) ) - @agent(ml_app="test_ml_app") + @embedding(model_name="test_model", ml_app="test_ml_app") def h(): pass h() span = LLMObs._instance.tracer.pop()[0] mock_llmobs_span_writer.enqueue.assert_called_with( - _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_ml_app"}) + _expected_llmobs_llm_span_event( + span, "embedding", model_name="test_model", model_provider="custom", tags={"ml_app": "test_ml_app"} + ) ) diff --git a/tests/llmobs/test_llmobs_service.py b/tests/llmobs/test_llmobs_service.py index dfaef69c146..4b9153de1d5 100644 --- a/tests/llmobs/test_llmobs_service.py +++ b/tests/llmobs/test_llmobs_service.py @@ -4,6 +4,7 @@ import pytest from ddtrace.llmobs import LLMObs as llmobs_service +from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES from ddtrace.llmobs._constants import INPUT_PARAMETERS from ddtrace.llmobs._constants import INPUT_VALUE @@ -11,6 +12,7 @@ from ddtrace.llmobs._constants import METRICS from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE from ddtrace.llmobs._constants import SESSION_ID @@ -214,6 +216,42 @@ def test_agent_span(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) +def test_embedding_span_no_model_raises_error(LLMObs): + with pytest.raises(TypeError): + with LLMObs.embedding(name="test_embedding", model_provider="test_provider"): + pass + + +def test_embedding_span_empty_model_name_logs_warning(LLMObs, mock_logs): + _ = LLMObs.embedding(model_name="", name="test_embedding", model_provider="test_provider") + mock_logs.warning.assert_called_once_with("model_name must be the specified name of the invoked model.") + + +def test_embedding_default_model_provider_set_to_custom(LLMObs): + with LLMObs.embedding(model_name="test_model", name="test_embedding") as span: + assert span.name == "test_embedding" + assert span.resource == "embedding" + assert span.span_type == "llm" + assert span.get_tag(SPAN_KIND) == "embedding" + assert span.get_tag(MODEL_NAME) == "test_model" + assert span.get_tag(MODEL_PROVIDER) == "custom" + + +def test_embedding_span(LLMObs, mock_llmobs_span_writer): + with LLMObs.embedding(model_name="test_model", name="test_embedding", model_provider="test_provider") as span: + assert span.name == "test_embedding" + assert span.resource == "embedding" + assert span.span_type == "llm" + assert span.get_tag(SPAN_KIND) == "embedding" + assert span.get_tag(MODEL_NAME) == "test_model" + assert span.get_tag(MODEL_PROVIDER) == "test_provider" + assert span.get_tag(SESSION_ID) == "{:x}".format(span.trace_id) + + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event(span, "embedding", model_name="test_model", model_provider="test_provider") + ) + + def test_annotate_while_disabled_logs_warning(LLMObs, mock_logs): LLMObs.disable() LLMObs.annotate(parameters={"test": "test"}) @@ -306,6 +344,9 @@ def test_annotate_input_string(LLMObs): with LLMObs.agent() as agent_span: LLMObs.annotate(span=agent_span, input_data="test_input") assert agent_span.get_tag(INPUT_VALUE) == "test_input" + with LLMObs.retrieval() as retrieval_span: + LLMObs.annotate(span=retrieval_span, input_data="test_input") + assert retrieval_span.get_tag(INPUT_VALUE) == "test_input" def test_annotate_input_serializable_value(LLMObs): @@ -321,6 +362,9 @@ def test_annotate_input_serializable_value(LLMObs): with LLMObs.agent() as agent_span: LLMObs.annotate(span=agent_span, input_data="test_input") assert agent_span.get_tag(INPUT_VALUE) == "test_input" + with LLMObs.retrieval() as retrieval_span: + LLMObs.annotate(span=retrieval_span, input_data=[0, 1, 2, 3, 4]) + assert retrieval_span.get_tag(INPUT_VALUE) == "[0, 1, 2, 3, 4]" def test_annotate_input_value_wrong_type(LLMObs, mock_logs): @@ -352,10 +396,130 @@ def test_llmobs_annotate_incorrect_message_content_type_raises_warning(LLMObs, m mock_logs.warning.assert_called_once_with("Failed to parse output messages.", exc_info=True) +def test_annotate_document_str(LLMObs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data="test_document_text") + documents = json.loads(span.get_tag(INPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data="test_document_text") + documents = json.loads(span.get_tag(OUTPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + + +def test_annotate_document_dict(LLMObs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data={"text": "test_document_text"}) + documents = json.loads(span.get_tag(INPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data={"text": "test_document_text"}) + documents = json.loads(span.get_tag(OUTPUT_DOCUMENTS)) + assert documents + assert len(documents) == 1 + assert documents[0]["text"] == "test_document_text" + + +def test_annotate_document_list(LLMObs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate( + span=span, + input_data=[{"text": "test_document_text"}, {"text": "text", "name": "name", "score": 0.9, "id": "id"}], + ) + documents = json.loads(span.get_tag(INPUT_DOCUMENTS)) + assert documents + assert len(documents) == 2 + assert documents[0]["text"] == "test_document_text" + assert documents[1]["text"] == "text" + assert documents[1]["name"] == "name" + assert documents[1]["id"] == "id" + assert documents[1]["score"] == 0.9 + with LLMObs.retrieval() as span: + LLMObs.annotate( + span=span, + output_data=[{"text": "test_document_text"}, {"text": "text", "name": "name", "score": 0.9, "id": "id"}], + ) + documents = json.loads(span.get_tag(OUTPUT_DOCUMENTS)) + assert documents + assert len(documents) == 2 + assert documents[0]["text"] == "test_document_text" + assert documents[1]["text"] == "text" + assert documents[1]["name"] == "name" + assert documents[1]["id"] == "id" + assert documents[1]["score"] == 0.9 + + +def test_annotate_incorrect_document_type_raises_warning(LLMObs, mock_logs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data={"text": 123}) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=123) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=Unserializable()) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=[{"score": 0.9, "id": "id", "name": "name"}]) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=123) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=Unserializable()) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + + +def test_annotate_document_no_text_raises_warning(LLMObs, mock_logs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=[{"score": 0.9, "id": "id", "name": "name"}]) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=[{"score": 0.9, "id": "id", "name": "name"}]) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + + +def test_annotate_incorrect_document_field_type_raises_warning(LLMObs, mock_logs): + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate(span=span, input_data=[{"text": "test_document_text", "score": "0.9"}]) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.embedding(model_name="test_model") as span: + LLMObs.annotate( + span=span, input_data=[{"text": "text", "id": 123, "score": "0.9", "name": ["h", "e", "l", "l", "o"]}] + ) + mock_logs.warning.assert_called_once_with("Failed to parse input documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate(span=span, output_data=[{"text": "test_document_text", "score": "0.9"}]) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + mock_logs.reset_mock() + with LLMObs.retrieval() as span: + LLMObs.annotate( + span=span, output_data=[{"text": "text", "id": 123, "score": "0.9", "name": ["h", "e", "l", "l", "o"]}] + ) + mock_logs.warning.assert_called_once_with("Failed to parse output documents.", exc_info=True) + + def test_annotate_output_string(LLMObs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, output_data="test_output") assert json.loads(llm_span.get_tag(OUTPUT_MESSAGES)) == [{"content": "test_output"}] + with LLMObs.embedding(model_name="test_model") as embedding_span: + LLMObs.annotate(span=embedding_span, output_data="test_output") + assert embedding_span.get_tag(OUTPUT_VALUE) == "test_output" with LLMObs.task() as task_span: LLMObs.annotate(span=task_span, output_data="test_output") assert task_span.get_tag(OUTPUT_VALUE) == "test_output" @@ -371,6 +535,9 @@ def test_annotate_output_string(LLMObs): def test_annotate_output_serializable_value(LLMObs): + with LLMObs.embedding(model_name="test_model") as embedding_span: + LLMObs.annotate(span=embedding_span, output_data=[[0, 1, 2, 3], [4, 5, 6, 7]]) + assert embedding_span.get_tag(OUTPUT_VALUE) == "[[0, 1, 2, 3], [4, 5, 6, 7]]" with LLMObs.task() as task_span: LLMObs.annotate(span=task_span, output_data=["test_output"]) assert task_span.get_tag(OUTPUT_VALUE) == '["test_output"]' @@ -465,13 +632,11 @@ def test_ml_app_override(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "task", tags={"ml_app": "test_app"}) ) - with LLMObs.tool(name="test_tool", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "tool", tags={"ml_app": "test_app"}) ) - with LLMObs.llm(model_name="model_name", name="test_llm", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( @@ -479,18 +644,28 @@ def test_ml_app_override(LLMObs, mock_llmobs_span_writer): span, "llm", model_name="model_name", model_provider="custom", tags={"ml_app": "test_app"} ) ) - + with LLMObs.embedding(model_name="model_name", name="test_embedding", ml_app="test_app") as span: + pass + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, "embedding", model_name="model_name", model_provider="custom", tags={"ml_app": "test_app"} + ) + ) with LLMObs.workflow(name="test_workflow", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "workflow", tags={"ml_app": "test_app"}) ) - with LLMObs.agent(name="test_agent", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_app"}) ) + with LLMObs.retrieval(name="test_retrieval", ml_app="test_app") as span: + pass + mock_llmobs_span_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "retrieval", tags={"ml_app": "test_app"}) + ) def test_export_span_llmobs_not_enabled_raises_warning(LLMObs, mock_logs): diff --git a/tests/llmobs/test_utils.py b/tests/llmobs/test_utils.py index 26241b90b07..41ae6bee95c 100644 --- a/tests/llmobs/test_utils.py +++ b/tests/llmobs/test_utils.py @@ -1,5 +1,6 @@ import pytest +from ddtrace.llmobs.utils import Documents from ddtrace.llmobs.utils import Messages @@ -55,3 +56,50 @@ def test_messages_with_no_role_is_ok(): """Test that a message with no role is ok and returns a message with only content.""" messages = Messages([{"content": "hello"}, {"content": "world"}]) assert messages.messages == [{"content": "hello"}, {"content": "world"}] + + +def test_documents_with_string(): + documents = Documents("hello") + assert documents.documents == [{"text": "hello"}] + + +def test_documents_with_dict(): + documents = Documents({"text": "hello", "name": "doc1", "id": "123", "score": 0.5}) + assert len(documents.documents) == 1 + assert documents.documents == [{"text": "hello", "name": "doc1", "id": "123", "score": 0.5}] + + +def test_documents_with_list_of_dicts(): + documents = Documents([{"text": "hello", "name": "doc1", "id": "123", "score": 0.5}, {"text": "world"}]) + assert len(documents.documents) == 2 + assert documents.documents[0] == {"text": "hello", "name": "doc1", "id": "123", "score": 0.5} + assert documents.documents[1] == {"text": "world"} + + +def test_documents_with_incorrect_type(): + with pytest.raises(TypeError): + Documents(123) + with pytest.raises(TypeError): + Documents(Unserializable()) + with pytest.raises(TypeError): + Documents(None) + + +def test_documents_dictionary_no_text_value(): + with pytest.raises(TypeError): + Documents([{"text": None}]) + with pytest.raises(TypeError): + Documents([{"name": "doc1", "id": "123", "score": 0.5}]) + + +def test_documents_dictionary_with_incorrect_value_types(): + with pytest.raises(TypeError): + Documents([{"text": 123}]) + with pytest.raises(TypeError): + Documents([{"text": [1, 2, 3]}]) + with pytest.raises(TypeError): + Documents([{"text": "hello", "id": 123}]) + with pytest.raises(TypeError): + Documents({"text": "hello", "name": {"key": "value"}}) + with pytest.raises(TypeError): + Documents([{"text": "hello", "score": "123"}]) diff --git a/tests/telemetry/test_writer.py b/tests/telemetry/test_writer.py index 18699170152..c25482e849e 100644 --- a/tests/telemetry/test_writer.py +++ b/tests/telemetry/test_writer.py @@ -146,6 +146,7 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): {"name": "logs_injection_enabled", "origin": "default", "value": "false"}, {"name": "trace_tags", "origin": "default", "value": ""}, {"name": "tracing_enabled", "origin": "default", "value": "true"}, + {"name": "instrumentation_config_id", "origin": "default", "value": ""}, ], key=lambda x: x["name"], ), @@ -229,6 +230,7 @@ def test_app_started_event_configuration_override( env["DD_TRACE_WRITER_INTERVAL_SECONDS"] = "30" env["DD_TRACE_WRITER_REUSE_CONNECTIONS"] = "True" env["DD_TAGS"] = "team:apm,component:web" + env["DD_INSTRUMENTATION_CONFIG_ID"] = "abcedf123" env[env_var] = value file = tmpdir.join("moon_ears.json") @@ -314,6 +316,7 @@ def test_app_started_event_configuration_override( {"name": "trace_header_tags", "origin": "default", "value": ""}, {"name": "trace_tags", "origin": "env_var", "value": "team:apm,component:web"}, {"name": "tracing_enabled", "origin": "env_var", "value": "false"}, + {"name": "instrumentation_config_id", "origin": "env_var", "value": "abcedf123"}, ], key=lambda x: x["name"], )