diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a2cd2e1ff53..071dde14005 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -3,10 +3,12 @@ stages: - deploy - benchmarks - benchmarks-pr-comment + - macrobenchmarks include: - remote: https://gitlab-templates.ddbuild.io/apm/packaging.yml - local: ".gitlab/benchmarks.yml" + - local: ".gitlab/macrobenchmarks.yml" variables: DOWNSTREAM_BRANCH: diff --git a/.gitlab/macrobenchmarks.yml b/.gitlab/macrobenchmarks.yml new file mode 100644 index 00000000000..16cf2b3b9be --- /dev/null +++ b/.gitlab/macrobenchmarks.yml @@ -0,0 +1,86 @@ +variables: + BASE_CI_IMAGE: 486234852809.dkr.ecr.us-east-1.amazonaws.com/ci/benchmarking-platform:dd-trace-py-macrobenchmarks + +.macrobenchmarks: + stage: macrobenchmarks + needs: [] + tags: ["runner:apm-k8s-same-cpu"] + timeout: 1h + rules: + - if: $CI_PIPELINE_SOURCE == "schedule" + when: always + - when: manual + ## Next step, enable: + # - if: $CI_COMMIT_REF_NAME == "main" + # when: always + # If you have a problem with Gitlab cache, see Troubleshooting section in Benchmarking Platform docs + image: $BENCHMARKS_CI_IMAGE + script: | + git clone --branch python/macrobenchmarks https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.ddbuild.io/DataDog/benchmarking-platform platform && cd platform + if [ "$BP_PYTHON_SCENARIO_DIR" == "flask-realworld" ]; then + bp-runner bp-runner.flask-realworld.yml --debug + else + bp-runner bp-runner.simple.yml --debug + fi + artifacts: + name: "artifacts" + when: always + paths: + - platform/artifacts/ + expire_in: 3 months + variables: + # Benchmark's env variables. Modify to tweak benchmark parameters. + DD_TRACE_DEBUG: "false" + DD_RUNTIME_METRICS_ENABLED: "true" + DD_REMOTE_CONFIGURATION_ENABLED: "false" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "false" + + K6_OPTIONS_NORMAL_OPERATION_RATE: 40 + K6_OPTIONS_NORMAL_OPERATION_DURATION: 5m + K6_OPTIONS_NORMAL_OPERATION_GRACEFUL_STOP: 1m + K6_OPTIONS_NORMAL_OPERATION_PRE_ALLOCATED_VUS: 4 + K6_OPTIONS_NORMAL_OPERATION_MAX_VUS: 4 + + K6_OPTIONS_HIGH_LOAD_RATE: 500 + K6_OPTIONS_HIGH_LOAD_DURATION: 1m + K6_OPTIONS_HIGH_LOAD_GRACEFUL_STOP: 30s + K6_OPTIONS_HIGH_LOAD_PRE_ALLOCATED_VUS: 4 + K6_OPTIONS_HIGH_LOAD_MAX_VUS: 4 + + # Gitlab and BP specific env vars. Do not modify. + FF_USE_LEGACY_KUBERNETES_EXECUTION_STRATEGY: "true" + + # Workaround: Currently we're not running the benchmarks on every PR, but GitHub still shows them as pending. + # By marking the benchmarks as allow_failure, this should go away. (This workaround should be removed once the + # benchmarks get changed to run on every PR) + allow_failure: true + +macrobenchmarks: + extends: .macrobenchmarks + parallel: + matrix: + - DD_BENCHMARKS_CONFIGURATION: baseline + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + DD_REMOTE_CONFIGURATION_ENABLED: "false" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "true" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + DD_REMOTE_CONFIGURATION_ENABLED: "false" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "false" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + DD_REMOTE_CONFIGURATION_ENABLED: "true" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "true" diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index f439f87784a..f67eca90453 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -6,9 +6,10 @@ from typing import List from typing import Optional -from ddtrace import config from ddtrace._trace.span import Span +from ddtrace._trace.utils import extract_DD_context_from_messages from ddtrace._trace.utils import set_botocore_patched_api_call_span_tags as set_patched_api_call_span_tags +from ddtrace._trace.utils import set_botocore_response_metadata_tags from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY from ddtrace.constants import SPAN_KIND from ddtrace.constants import SPAN_MEASURED_KEY @@ -107,6 +108,9 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) - trace_utils.activate_distributed_headers( tracer, int_config=distributed_headers_config, request_headers=ctx["distributed_headers"] ) + distributed_context = ctx.get_item("distributed_context", traverse=True) + if distributed_context and not call_trace: + span_kwargs["child_of"] = distributed_context span_kwargs.update(kwargs) span = (tracer.trace if call_trace else tracer.start_span)(ctx["span_name"], **span_kwargs) for tk, tv in ctx.get_item("tags", dict()).items(): @@ -569,20 +573,20 @@ def _on_botocore_patched_api_call_started(ctx): span.start_ns = start_ns -def _on_botocore_patched_api_call_exception(ctx, response, exception_type, set_response_metadata_tags): +def _on_botocore_patched_api_call_exception(ctx, response, exception_type, is_error_code_fn): span = ctx.get_item(ctx.get_item("call_key")) # `ClientError.response` contains the result, so we can still grab response metadata - set_response_metadata_tags(span, response) + set_botocore_response_metadata_tags(span, response, is_error_code_fn=is_error_code_fn) # If we have a status code, and the status code is not an error, # then ignore the exception being raised status_code = span.get_tag(http.STATUS_CODE) - if status_code and not config.botocore.operations[span.resource].is_error_code(int(status_code)): + if status_code and not is_error_code_fn(int(status_code)): span._ignore_exception(exception_type) -def _on_botocore_patched_api_call_success(ctx, response, set_response_metadata_tags): - set_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response) +def _on_botocore_patched_api_call_success(ctx, response): + set_botocore_response_metadata_tags(ctx.get_item(ctx.get_item("call_key")), response) def _on_botocore_trace_context_injection_prepared( @@ -682,6 +686,31 @@ def _on_botocore_bedrock_process_response( span.finish() +def _on_botocore_sqs_recvmessage_post( + ctx: core.ExecutionContext, _, result: Dict, propagate: bool, message_parser: Callable +) -> None: + if result is not None and "Messages" in result and len(result["Messages"]) >= 1: + ctx.set_item("message_received", True) + if propagate: + ctx.set_safe("distributed_context", extract_DD_context_from_messages(result["Messages"], message_parser)) + + +def _on_botocore_kinesis_getrecords_post( + ctx: core.ExecutionContext, + _, + __, + ___, + ____, + result, + propagate: bool, + message_parser: Callable, +): + if result is not None and "Records" in result and len(result["Records"]) >= 1: + ctx.set_item("message_received", True) + if propagate: + ctx.set_item("distributed_context", extract_DD_context_from_messages(result["Records"], message_parser)) + + def _on_redis_async_command_post(span, rowcount): if rowcount is not None: span.set_metric(db.ROWCOUNT, rowcount) @@ -727,10 +756,14 @@ def listen(): core.on("botocore.patched_stepfunctions_api_call.started", _on_botocore_patched_api_call_started) core.on("botocore.patched_stepfunctions_api_call.exception", _on_botocore_patched_api_call_exception) core.on("botocore.stepfunctions.update_messages", _on_botocore_update_messages) + core.on("botocore.eventbridge.update_messages", _on_botocore_update_messages) + core.on("botocore.client_context.update_messages", _on_botocore_update_messages) core.on("botocore.patched_bedrock_api_call.started", _on_botocore_patched_bedrock_api_call_started) core.on("botocore.patched_bedrock_api_call.exception", _on_botocore_patched_bedrock_api_call_exception) core.on("botocore.patched_bedrock_api_call.success", _on_botocore_patched_bedrock_api_call_success) core.on("botocore.bedrock.process_response", _on_botocore_bedrock_process_response) + core.on("botocore.sqs.ReceiveMessage.post", _on_botocore_sqs_recvmessage_post) + core.on("botocore.kinesis.GetRecords.post", _on_botocore_kinesis_getrecords_post) core.on("redis.async_command.post", _on_redis_async_command_post) for context_name in ( diff --git a/ddtrace/_trace/tracer.py b/ddtrace/_trace/tracer.py index 4f359a42d93..7259cdf16a0 100644 --- a/ddtrace/_trace/tracer.py +++ b/ddtrace/_trace/tracer.py @@ -230,7 +230,8 @@ def __init__( self.enabled = config._tracing_enabled self.context_provider = context_provider or DefaultContextProvider() - self._user_sampler: Optional[BaseSampler] = None + # _user_sampler is the backup in case we need to revert from remote config to local + self._user_sampler: Optional[BaseSampler] = DatadogSampler() self._sampler: BaseSampler = DatadogSampler() self._dogstatsd_url = agent.get_stats_url() if dogstatsd_url is None else dogstatsd_url self._compute_stats = config._trace_compute_stats @@ -286,7 +287,7 @@ def __init__( self._shutdown_lock = RLock() self._new_process = False - config._subscribe(["_trace_sample_rate"], self._on_global_config_update) + config._subscribe(["_trace_sample_rate", "_trace_sampling_rules"], self._on_global_config_update) config._subscribe(["logs_injection"], self._on_global_config_update) config._subscribe(["tags"], self._on_global_config_update) config._subscribe(["_tracing_enabled"], self._on_global_config_update) @@ -1125,19 +1126,10 @@ def _is_span_internal(span): def _on_global_config_update(self, cfg, items): # type: (Config, List) -> None - if "_trace_sample_rate" in items: - # Reset the user sampler if one exists - if cfg._get_source("_trace_sample_rate") != "remote_config" and self._user_sampler: - self._sampler = self._user_sampler - return - - if cfg._get_source("_trace_sample_rate") != "default": - sample_rate = cfg._trace_sample_rate - else: - sample_rate = None - sampler = DatadogSampler(default_sample_rate=sample_rate) - self._sampler = sampler + # sampling configs always come as a pair + if "_trace_sample_rate" in items and "_trace_sampling_rules" in items: + self._handle_sampler_update(cfg) if "tags" in items: self._tags = cfg.tags.copy() @@ -1160,3 +1152,42 @@ def _on_global_config_update(self, cfg, items): from ddtrace.contrib.logging import unpatch unpatch() + + def _handle_sampler_update(self, cfg): + # type: (Config) -> None + if ( + cfg._get_source("_trace_sample_rate") != "remote_config" + and cfg._get_source("_trace_sampling_rules") != "remote_config" + and self._user_sampler + ): + # if we get empty configs from rc for both sample rate and rules, we should revert to the user sampler + self.sampler = self._user_sampler + return + + if cfg._get_source("_trace_sample_rate") != "remote_config" and self._user_sampler: + try: + sample_rate = self._user_sampler.default_sample_rate # type: ignore[attr-defined] + except AttributeError: + log.debug("Custom non-DatadogSampler is being used, cannot pull default sample rate") + sample_rate = None + elif cfg._get_source("_trace_sample_rate") != "default": + sample_rate = cfg._trace_sample_rate + else: + sample_rate = None + + if cfg._get_source("_trace_sampling_rules") != "remote_config" and self._user_sampler: + try: + sampling_rules = self._user_sampler.rules # type: ignore[attr-defined] + # we need to chop off the default_sample_rate rule so the new sample_rate can be applied + sampling_rules = sampling_rules[:-1] + except AttributeError: + log.debug("Custom non-DatadogSampler is being used, cannot pull sampling rules") + sampling_rules = None + elif cfg._get_source("_trace_sampling_rules") != "default": + sampling_rules = DatadogSampler._parse_rules_from_str(cfg._trace_sampling_rules) + else: + sampling_rules = None + + sampler = DatadogSampler(rules=sampling_rules, default_sample_rate=sample_rate) + + self._sampler = sampler diff --git a/ddtrace/_trace/utils.py b/ddtrace/_trace/utils.py index 0e1a9364582..44bef3bbf23 100644 --- a/ddtrace/_trace/utils.py +++ b/ddtrace/_trace/utils.py @@ -1,3 +1,8 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional + from ddtrace import Span from ddtrace import config from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY @@ -5,8 +10,10 @@ from ddtrace.constants import SPAN_MEASURED_KEY from ddtrace.ext import SpanKind from ddtrace.ext import aws +from ddtrace.ext import http from ddtrace.internal.constants import COMPONENT from ddtrace.internal.utils.formats import deep_getattr +from ddtrace.propagation.http import HTTPPropagator def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, endpoint_name, operation): @@ -39,3 +46,37 @@ def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, # set analytics sample rate span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.botocore.get_analytics_sample_rate()) + + +def set_botocore_response_metadata_tags( + span: Span, result: Dict[str, Any], is_error_code_fn: Optional[Callable] = None +) -> None: + if not result or not result.get("ResponseMetadata"): + return + response_meta = result["ResponseMetadata"] + + if "HTTPStatusCode" in response_meta: + status_code = response_meta["HTTPStatusCode"] + span.set_tag(http.STATUS_CODE, status_code) + + # Mark this span as an error if requested + if is_error_code_fn is not None and is_error_code_fn(int(status_code)): + span.error = 1 + + if "RetryAttempts" in response_meta: + span.set_tag("retry_attempts", response_meta["RetryAttempts"]) + + if "RequestId" in response_meta: + span.set_tag_str("aws.requestid", response_meta["RequestId"]) + + +def extract_DD_context_from_messages(messages, extract_from_message: Callable): + ctx = None + if len(messages) >= 1: + message = messages[0] + context_json = extract_from_message(message) + if context_json is not None: + child_of = HTTPPropagator.extract(context_json) + if child_of.trace_id is not None: + ctx = child_of + return ctx diff --git a/ddtrace/contrib/botocore/patch.py b/ddtrace/contrib/botocore/patch.py index e0bcc3f317f..b4f1a5265ea 100644 --- a/ddtrace/contrib/botocore/patch.py +++ b/ddtrace/contrib/botocore/patch.py @@ -39,9 +39,8 @@ from .services.sqs import update_messages as inject_trace_to_sqs_or_sns_message from .services.stepfunctions import patched_stepfunction_api_call from .services.stepfunctions import update_stepfunction_input -from .utils import inject_trace_to_client_context -from .utils import inject_trace_to_eventbridge_detail -from .utils import set_response_metadata_tags +from .utils import update_client_context +from .utils import update_eventbridge_detail _PATCHED_SUBMODULES = set() # type: Set[str] @@ -175,11 +174,11 @@ def prep_context_injection(ctx, endpoint_name, operation, trace_operation, param schematization_function = schematize_cloud_messaging_operation if endpoint_name == "lambda" and operation == "Invoke": - injection_function = inject_trace_to_client_context + injection_function = update_client_context schematization_function = schematize_cloud_faas_operation cloud_service = "lambda" if endpoint_name == "events" and operation == "PutEvents": - injection_function = inject_trace_to_eventbridge_detail + injection_function = update_eventbridge_detail cloud_service = "events" if endpoint_name == "sns" and "Publish" in operation: injection_function = inject_trace_to_sqs_or_sns_message @@ -224,9 +223,14 @@ def patched_api_call_fallback(original_func, instance, args, kwargs, function_va except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx["instrumented_api_call"].resource].is_error_code, + ], ) raise else: - core.dispatch("botocore.patched_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_api_call.success", [ctx, result]) return result diff --git a/ddtrace/contrib/botocore/services/kinesis.py b/ddtrace/contrib/botocore/services/kinesis.py index 412f0b0c27f..858f011410f 100644 --- a/ddtrace/contrib/botocore/services/kinesis.py +++ b/ddtrace/contrib/botocore/services/kinesis.py @@ -17,9 +17,8 @@ from ....internal.logger import get_logger from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import extract_DD_context +from ..utils import extract_DD_json from ..utils import get_kinesis_data_object -from ..utils import set_response_metadata_tags log = get_logger(__name__) @@ -74,13 +73,14 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var endpoint_name = function_vars.get("endpoint_name") operation = function_vars.get("operation") - message_received = False is_getrecords_call = False getrecords_error = None - child_of = None start_ns = None result = None + parent_ctx: core.ExecutionContext = core.ExecutionContext( + "botocore.patched_sqs_api_call.propagated", + ) if operation == "GetRecords": try: start_ns = time_ns() @@ -95,15 +95,20 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var time_estimate = record.get("ApproximateArrivalTimestamp", datetime.now()).timestamp() core.dispatch( f"botocore.{endpoint_name}.{operation}.post", - [params, time_estimate, data_obj.get("_datadog"), record], + [ + parent_ctx, + params, + time_estimate, + data_obj.get("_datadog"), + record, + result, + config.botocore.propagation_enabled, + extract_DD_json, + ], ) except Exception as e: getrecords_error = e - if result is not None and "Records" in result and len(result["Records"]) >= 1: - message_received = True - if config.botocore.propagation_enabled: - child_of = extract_DD_context(result["Records"]) if endpoint_name == "kinesis" and operation in {"PutRecord", "PutRecords"}: span_name = schematize_cloud_messaging_operation( @@ -116,7 +121,7 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var span_name = trace_operation stream_arn = params.get("StreamARN", params.get("StreamName", "")) function_is_not_getrecords = not is_getrecords_call - received_message_when_polling = is_getrecords_call and message_received + received_message_when_polling = is_getrecords_call and parent_ctx.get_item("message_received") instrument_empty_poll_calls = config.botocore.empty_poll_enabled should_instrument = ( received_message_when_polling or instrument_empty_poll_calls or function_is_not_getrecords or getrecords_error @@ -126,6 +131,7 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var if should_instrument: with core.context_with_data( "botocore.patched_kinesis_api_call", + parent=parent_ctx, instance=instance, args=args, params=params, @@ -136,7 +142,6 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var pin=pin, span_name=span_name, span_type=SpanTypes.HTTP, - child_of=child_of if child_of is not None else pin.tracer.context_provider.active(), activate=True, func_run=is_getrecords_call, start_ns=start_ns, @@ -158,15 +163,21 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var if getrecords_error: raise getrecords_error - core.dispatch("botocore.patched_kinesis_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_kinesis_api_call.success", [ctx, result]) return result except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_kinesis_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise + parent_ctx.end() elif is_getrecords_call: if getrecords_error: raise getrecords_error diff --git a/ddtrace/contrib/botocore/services/sqs.py b/ddtrace/contrib/botocore/services/sqs.py index 37080c85d70..25de175853a 100644 --- a/ddtrace/contrib/botocore/services/sqs.py +++ b/ddtrace/contrib/botocore/services/sqs.py @@ -7,8 +7,6 @@ import botocore.exceptions from ddtrace import config -from ddtrace.contrib.botocore.utils import extract_DD_context -from ddtrace.contrib.botocore.utils import set_response_metadata_tags from ddtrace.ext import SpanTypes from ddtrace.internal import core from ddtrace.internal.logger import get_logger @@ -16,6 +14,8 @@ from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.schema.span_attribute_schema import SpanDirection +from ..utils import extract_DD_json + log = get_logger(__name__) MAX_INJECTION_DATA_ATTRIBUTES = 10 @@ -83,16 +83,19 @@ def _ensure_datadog_messageattribute_enabled(params): def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): + with core.context_with_data("botocore.patched_sqs_api_call.propagated") as parent_ctx: + return _patched_sqs_api_call(parent_ctx, original_func, instance, args, kwargs, function_vars) + + +def _patched_sqs_api_call(parent_ctx, original_func, instance, args, kwargs, function_vars): params = function_vars.get("params") trace_operation = function_vars.get("trace_operation") pin = function_vars.get("pin") endpoint_name = function_vars.get("endpoint_name") operation = function_vars.get("operation") - message_received = False func_has_run = False func_run_err = None - child_of = None result = None if operation == "ReceiveMessage": @@ -103,16 +106,15 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): core.dispatch(f"botocore.{endpoint_name}.{operation}.pre", [params]) # run the function to extract possible parent context before creating ExecutionContext result = original_func(*args, **kwargs) - core.dispatch(f"botocore.{endpoint_name}.{operation}.post", [params, result]) + core.dispatch( + f"botocore.{endpoint_name}.{operation}.post", + [parent_ctx, params, result, config.botocore.propagation_enabled, extract_DD_json], + ) except Exception as e: func_run_err = e - if result is not None and "Messages" in result and len(result["Messages"]) >= 1: - message_received = True - if config.botocore.propagation_enabled: - child_of = extract_DD_context(result["Messages"]) function_is_not_recvmessage = not func_has_run - received_message_when_polling = func_has_run and message_received + received_message_when_polling = func_has_run and parent_ctx.get_item("message_received") instrument_empty_poll_calls = config.botocore.empty_poll_enabled should_instrument = ( received_message_when_polling or instrument_empty_poll_calls or function_is_not_recvmessage or func_run_err @@ -133,9 +135,12 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): else: call_name = trace_operation + child_of = parent_ctx.get_item("distributed_context") + if should_instrument: with core.context_with_data( "botocore.patched_sqs_api_call", + parent=parent_ctx, span_name=call_name, service=schematize_service_name("{}.{}".format(pin.service, endpoint_name)), span_type=SpanTypes.HTTP, @@ -161,7 +166,7 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): result = original_func(*args, **kwargs) core.dispatch(f"botocore.{endpoint_name}.{operation}.post", [params, result]) - core.dispatch("botocore.patched_sqs_api_call.success", [ctx, result, set_response_metadata_tags]) + core.dispatch("botocore.patched_sqs_api_call.success", [ctx, result]) if func_run_err: raise func_run_err @@ -169,7 +174,12 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_sqs_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise elif func_has_run: diff --git a/ddtrace/contrib/botocore/services/stepfunctions.py b/ddtrace/contrib/botocore/services/stepfunctions.py index d611f664a48..16213f2e3ed 100644 --- a/ddtrace/contrib/botocore/services/stepfunctions.py +++ b/ddtrace/contrib/botocore/services/stepfunctions.py @@ -12,7 +12,6 @@ from ....internal.schema import SpanDirection from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import set_response_metadata_tags log = get_logger(__name__) @@ -81,6 +80,11 @@ def patched_stepfunction_api_call(original_func, instance, args, kwargs: Dict, f except botocore.exceptions.ClientError as e: core.dispatch( "botocore.patched_stepfunctions_api_call.exception", - [ctx, e.response, botocore.exceptions.ClientError, set_response_metadata_tags], + [ + ctx, + e.response, + botocore.exceptions.ClientError, + config.botocore.operations[ctx[ctx["call_key"]].resource].is_error_code, + ], ) raise diff --git a/ddtrace/contrib/botocore/utils.py b/ddtrace/contrib/botocore/utils.py index ead47ace10c..5804a4e1a36 100644 --- a/ddtrace/contrib/botocore/utils.py +++ b/ddtrace/contrib/botocore/utils.py @@ -8,13 +8,11 @@ from typing import Optional from typing import Tuple -from ddtrace import Span from ddtrace import config +from ddtrace.internal import core from ddtrace.internal.core import ExecutionContext -from ...ext import http from ...internal.logger import get_logger -from ...propagation.http import HTTPPropagator log = get_logger(__name__) @@ -66,11 +64,7 @@ def get_kinesis_data_object(data: str) -> Tuple[str, Optional[Dict[str, Any]]]: return None, None -def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: - """ - Inject trace headers into the EventBridge record if the record's Detail object contains a JSON string - Max size per event is 256KB (https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-putevent-size.html) - """ +def update_eventbridge_detail(ctx: ExecutionContext) -> None: params = ctx["params"] if "Entries" not in params: log.warning("Unable to inject context. The Event Bridge event had no Entries.") @@ -86,8 +80,7 @@ def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: continue detail["_datadog"] = {} - span = ctx[ctx["call_key"]] - HTTPPropagator.inject(span.context, detail["_datadog"]) + core.dispatch("botocore.eventbridge.update_messages", [ctx, None, None, detail["_datadog"], None]) detail_json = json.dumps(detail) # check if detail size will exceed max size with headers @@ -99,12 +92,11 @@ def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: entry["Detail"] = detail_json -def inject_trace_to_client_context(ctx): +def update_client_context(ctx: ExecutionContext) -> None: trace_headers = {} - span = ctx[ctx["call_key"]] - params = ctx["params"] - HTTPPropagator.inject(span.context, trace_headers) + core.dispatch("botocore.client_context.update_messages", [ctx, None, None, trace_headers, None]) client_context_object = {} + params = ctx["params"] if "ClientContext" in params: try: client_context_json = base64.b64decode(params["ClientContext"]).decode("utf-8") @@ -131,39 +123,7 @@ def modify_client_context(client_context_object, trace_headers): client_context_object["custom"] = trace_headers -def set_response_metadata_tags(span: Span, result: Dict[str, Any]) -> None: - if not result or not result.get("ResponseMetadata"): - return - response_meta = result["ResponseMetadata"] - - if "HTTPStatusCode" in response_meta: - status_code = response_meta["HTTPStatusCode"] - span.set_tag(http.STATUS_CODE, status_code) - - # Mark this span as an error if requested - if config.botocore.operations[span.resource].is_error_code(int(status_code)): - span.error = 1 - - if "RetryAttempts" in response_meta: - span.set_tag("retry_attempts", response_meta["RetryAttempts"]) - - if "RequestId" in response_meta: - span.set_tag_str("aws.requestid", response_meta["RequestId"]) - - -def extract_DD_context(messages): - ctx = None - if len(messages) >= 1: - message = messages[0] - context_json = extract_trace_context_json(message) - if context_json is not None: - child_of = HTTPPropagator.extract(context_json) - if child_of.trace_id is not None: - ctx = child_of - return ctx - - -def extract_trace_context_json(message): +def extract_DD_json(message): context_json = None try: if message and message.get("Type") == "Notification": @@ -200,7 +160,7 @@ def extract_trace_context_json(message): if "Body" in message: try: body = json.loads(message["Body"]) - return extract_trace_context_json(body) + return extract_DD_json(body) except ValueError: log.debug("Unable to parse AWS message body.") except Exception: diff --git a/ddtrace/contrib/trace_utils_redis.py b/ddtrace/contrib/trace_utils_redis.py new file mode 100644 index 00000000000..8df16c3ce4d --- /dev/null +++ b/ddtrace/contrib/trace_utils_redis.py @@ -0,0 +1,18 @@ +from ddtrace.contrib.redis_utils import determine_row_count +from ddtrace.contrib.redis_utils import stringify_cache_args +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate + + +deprecate( + "The ddtrace.contrib.trace_utils_redis module is deprecated and will be removed.", + message="A new interface will be provided by the ddtrace.contrib.redis_utils module", + category=DDTraceDeprecationWarning, +) + + +format_command_args = stringify_cache_args + + +def determine_row_count(redis_command, span, result): # noqa: F811 + determine_row_count(redis_command=redis_command, result=result) diff --git a/ddtrace/internal/constants.py b/ddtrace/internal/constants.py index 566bec75dad..50b8e1280e4 100644 --- a/ddtrace/internal/constants.py +++ b/ddtrace/internal/constants.py @@ -90,7 +90,9 @@ class _PRIORITY_CATEGORY: USER = "user" - RULE = "rule" + RULE_DEF = "rule_default" + RULE_CUSTOMER = "rule_customer" + RULE_DYNAMIC = "rule_dynamic" AUTO = "auto" DEFAULT = "default" @@ -99,7 +101,9 @@ class _PRIORITY_CATEGORY: # used to simplify code that selects sampling priority based on many factors _CATEGORY_TO_PRIORITIES = { _PRIORITY_CATEGORY.USER: (USER_KEEP, USER_REJECT), - _PRIORITY_CATEGORY.RULE: (USER_KEEP, USER_REJECT), + _PRIORITY_CATEGORY.RULE_DEF: (USER_KEEP, USER_REJECT), + _PRIORITY_CATEGORY.RULE_CUSTOMER: (USER_KEEP, USER_REJECT), + _PRIORITY_CATEGORY.RULE_DYNAMIC: (USER_KEEP, USER_REJECT), _PRIORITY_CATEGORY.AUTO: (AUTO_KEEP, AUTO_REJECT), _PRIORITY_CATEGORY.DEFAULT: (AUTO_KEEP, AUTO_REJECT), } diff --git a/ddtrace/internal/datastreams/botocore.py b/ddtrace/internal/datastreams/botocore.py index 1f1b79aee80..ec004f1ff9a 100644 --- a/ddtrace/internal/datastreams/botocore.py +++ b/ddtrace/internal/datastreams/botocore.py @@ -172,7 +172,7 @@ def get_datastreams_context(message): return context_json -def handle_sqs_receive(params, result): +def handle_sqs_receive(_, params, result, *args): from . import data_streams_processor as processor queue_name = get_queue_name(params) @@ -206,7 +206,7 @@ def record_data_streams_path_for_kinesis_stream(params, time_estimate, context_j ) -def handle_kinesis_receive(params, time_estimate, context_json, record): +def handle_kinesis_receive(_, params, time_estimate, context_json, record, *args): try: record_data_streams_path_for_kinesis_stream(params, time_estimate, context_json, record) except Exception: diff --git a/ddtrace/internal/flare.py b/ddtrace/internal/flare.py index 9a11223b221..7cf850e7656 100644 --- a/ddtrace/internal/flare.py +++ b/ddtrace/internal/flare.py @@ -1,4 +1,5 @@ import binascii +import dataclasses import io import json import logging @@ -7,9 +8,7 @@ import pathlib import shutil import tarfile -from typing import Any from typing import Dict -from typing import List from typing import Optional from typing import Tuple @@ -19,7 +18,7 @@ from ddtrace.internal.utils.http import get_connection -TRACER_FLARE_DIRECTORY = pathlib.Path("tracer_flare") +TRACER_FLARE_DIRECTORY = "tracer_flare" TRACER_FLARE_TAR = pathlib.Path("tracer_flare.tar") TRACER_FLARE_ENDPOINT = "/tracer_flare/v1" TRACER_FLARE_FILE_HANDLER_NAME = "tracer_flare_file_handler" @@ -29,111 +28,99 @@ log = get_logger(__name__) +@dataclasses.dataclass +class FlareSendRequest: + case_id: str + hostname: str + email: str + source: str = "tracer_python" + + class Flare: - def __init__(self, timeout_sec: int = DEFAULT_TIMEOUT_SECONDS): - self.original_log_level = 0 # NOTSET - self.timeout = timeout_sec + def __init__(self, timeout_sec: int = DEFAULT_TIMEOUT_SECONDS, flare_dir: str = TRACER_FLARE_DIRECTORY): + self.original_log_level: int = logging.NOTSET + self.timeout: int = timeout_sec + self.flare_dir: pathlib.Path = pathlib.Path(flare_dir) self.file_handler: Optional[RotatingFileHandler] = None - def prepare(self, configs: List[dict]): + def prepare(self, log_level: str): """ Update configurations to start sending tracer logs to a file to be sent in a flare later. """ - if not os.path.exists(TRACER_FLARE_DIRECTORY): - try: - os.makedirs(TRACER_FLARE_DIRECTORY) - log.info("Tracer logs will now be sent to the %s directory", TRACER_FLARE_DIRECTORY) - except Exception as e: - log.error("Failed to create %s directory: %s", TRACER_FLARE_DIRECTORY, e) - return - for agent_config in configs: - # AGENT_CONFIG is currently being used for multiple purposes - # We only want to prepare for a tracer flare if the config name - # starts with 'flare-log-level' - if not agent_config.get("name", "").startswith("flare-log-level"): - return + try: + self.flare_dir.mkdir(exist_ok=True) + except Exception as e: + log.error("Failed to create %s directory: %s", self.flare_dir, e) + return + + flare_log_level_int = logging.getLevelName(log_level) + if type(flare_log_level_int) != int: + raise TypeError("Invalid log level provided: %s", log_level) - # Validate the flare log level - flare_log_level = agent_config.get("config", {}).get("log_level").upper() - flare_log_level_int = logging.getLevelName(flare_log_level) - if type(flare_log_level_int) != int: - raise TypeError("Invalid log level provided: %s", flare_log_level_int) - - ddlogger = get_logger("ddtrace") - pid = os.getpid() - flare_file_path = TRACER_FLARE_DIRECTORY / pathlib.Path(f"tracer_python_{pid}.log") - self.original_log_level = ddlogger.level - - # Set the logger level to the more verbose between original and flare - # We do this valid_original_level check because if the log level is NOTSET, the value is 0 - # which is the minimum value. In this case, we just want to use the flare level, but still - # retain the original state as NOTSET/0 - valid_original_level = 100 if self.original_log_level == 0 else self.original_log_level - logger_level = min(valid_original_level, flare_log_level_int) - ddlogger.setLevel(logger_level) - self.file_handler = _add_file_handler( - ddlogger, flare_file_path.__str__(), flare_log_level, TRACER_FLARE_FILE_HANDLER_NAME - ) - - # Create and add config file - self._generate_config_file(pid) - - def send(self, configs: List[Any]): + ddlogger = get_logger("ddtrace") + pid = os.getpid() + flare_file_path = self.flare_dir / f"tracer_python_{pid}.log" + self.original_log_level = ddlogger.level + + # Set the logger level to the more verbose between original and flare + # We do this valid_original_level check because if the log level is NOTSET, the value is 0 + # which is the minimum value. In this case, we just want to use the flare level, but still + # retain the original state as NOTSET/0 + valid_original_level = ( + logging.CRITICAL if self.original_log_level == logging.NOTSET else self.original_log_level + ) + logger_level = min(valid_original_level, flare_log_level_int) + ddlogger.setLevel(logger_level) + self.file_handler = _add_file_handler( + ddlogger, flare_file_path.__str__(), flare_log_level_int, TRACER_FLARE_FILE_HANDLER_NAME + ) + + # Create and add config file + self._generate_config_file(pid) + + def send(self, flare_send_req: FlareSendRequest): """ Revert tracer flare configurations back to original state before sending the flare. """ - for agent_task in configs: - # AGENT_TASK is currently being used for multiple purposes - # We only want to generate the tracer flare if the task_type is - # 'tracer_flare' - if type(agent_task) != dict or agent_task.get("task_type") != "tracer_flare": - continue - args = agent_task.get("args", {}) - - self.revert_configs() - - # We only want the flare to be sent once, even if there are - # multiple tracer instances - lock_path = TRACER_FLARE_DIRECTORY / TRACER_FLARE_LOCK - if not os.path.exists(lock_path): - try: - open(lock_path, "w").close() - except Exception as e: - log.error("Failed to create %s file", lock_path) - raise e - data = { - "case_id": args.get("case_id"), - "source": "tracer_python", - "hostname": args.get("hostname"), - "email": args.get("user_handle"), - } - try: - client = get_connection(config._trace_agent_url, timeout=self.timeout) - headers, body = self._generate_payload(data) - client.request("POST", TRACER_FLARE_ENDPOINT, body, headers) - response = client.getresponse() - if response.status == 200: - log.info("Successfully sent the flare") - else: - log.error( - "Upload failed with %s status code:(%s) %s", - response.status, - response.reason, - response.read().decode(), - ) - except Exception as e: - log.error("Failed to send tracer flare") - raise e - finally: - client.close() - # Clean up files regardless of success/failure - self.clean_up_files() - return + self.revert_configs() + + # We only want the flare to be sent once, even if there are + # multiple tracer instances + lock_path = self.flare_dir / TRACER_FLARE_LOCK + if not os.path.exists(lock_path): + try: + open(lock_path, "w").close() + except Exception as e: + log.error("Failed to create %s file", lock_path) + raise e + try: + client = get_connection(config._trace_agent_url, timeout=self.timeout) + headers, body = self._generate_payload(flare_send_req.__dict__) + client.request("POST", TRACER_FLARE_ENDPOINT, body, headers) + response = client.getresponse() + if response.status == 200: + log.info("Successfully sent the flare to Zendesk ticket %s", flare_send_req.case_id) + else: + log.error( + "Tracer flare upload to Zendesk ticket %s failed with %s status code:(%s) %s", + flare_send_req.case_id, + response.status, + response.reason, + response.read().decode(), + ) + except Exception as e: + log.error("Failed to send tracer flare to Zendesk ticket %s", flare_send_req.case_id) + raise e + finally: + client.close() + # Clean up files regardless of success/failure + self.clean_up_files() + return def _generate_config_file(self, pid: int): - config_file = TRACER_FLARE_DIRECTORY / pathlib.Path(f"tracer_config_{pid}.json") + config_file = self.flare_dir / f"tracer_config_{pid}.json" try: with open(config_file, "w") as f: tracer_configs = { @@ -162,8 +149,7 @@ def revert_configs(self): def _generate_payload(self, params: Dict[str, str]) -> Tuple[dict, bytes]: tar_stream = io.BytesIO() with tarfile.open(fileobj=tar_stream, mode="w") as tar: - for file_name in os.listdir(TRACER_FLARE_DIRECTORY): - flare_file_name = TRACER_FLARE_DIRECTORY / pathlib.Path(file_name) + for flare_file_name in self.flare_dir.iterdir(): tar.add(flare_file_name) tar_stream.seek(0) @@ -197,6 +183,6 @@ def _get_valid_logger_level(self, flare_log_level: int) -> int: def clean_up_files(self): try: - shutil.rmtree(TRACER_FLARE_DIRECTORY) + shutil.rmtree(self.flare_dir) except Exception as e: log.warning("Failed to clean up tracer flare files: %s", e) diff --git a/ddtrace/internal/packages.py b/ddtrace/internal/packages.py index 2d8f1c5fd1e..fcec01a463b 100644 --- a/ddtrace/internal/packages.py +++ b/ddtrace/internal/packages.py @@ -238,6 +238,11 @@ def is_third_party(path: Path) -> bool: return package.name in _third_party_packages() +@cached() +def is_user_code(path: Path) -> bool: + return not (is_stdlib(path) or is_third_party(path)) + + @cached() def is_distribution_available(name: str) -> bool: """Determine if a distribution is available in the current environment.""" diff --git a/ddtrace/internal/remoteconfig/client.py b/ddtrace/internal/remoteconfig/client.py index d21081c1d94..c2768e57bc6 100644 --- a/ddtrace/internal/remoteconfig/client.py +++ b/ddtrace/internal/remoteconfig/client.py @@ -75,6 +75,7 @@ class Capabilities(enum.IntFlag): APM_TRACING_HTTP_HEADER_TAGS = 1 << 14 APM_TRACING_CUSTOM_TAGS = 1 << 15 APM_TRACING_ENABLED = 1 << 19 + APM_TRACING_SAMPLE_RULES = 1 << 29 class RemoteConfigError(Exception): @@ -382,6 +383,7 @@ def _build_payload(self, state): | Capabilities.APM_TRACING_HTTP_HEADER_TAGS | Capabilities.APM_TRACING_CUSTOM_TAGS | Capabilities.APM_TRACING_ENABLED + | Capabilities.APM_TRACING_SAMPLE_RULES ) return dict( client=dict( diff --git a/ddtrace/internal/sampling.py b/ddtrace/internal/sampling.py index 0d5aa1a2784..267c575e8a5 100644 --- a/ddtrace/internal/sampling.py +++ b/ddtrace/internal/sampling.py @@ -62,6 +62,16 @@ class SamplingMechanism(object): REMOTE_RATE_USER = 6 REMOTE_RATE_DATADOG = 7 SPAN_SAMPLING_RULE = 8 + REMOTE_USER_RULE = 11 + REMOTE_DYNAMIC_RULE = 12 + + +class PriorityCategory(object): + DEFAULT = "default" + AUTO = "auto" + RULE_DEFAULT = "rule_default" + RULE_CUSTOMER = "rule_customer" + RULE_DYNAMIC = "rule_dynamic" # Use regex to validate trace tag value @@ -278,11 +288,17 @@ def is_single_span_sampled(span): def _set_sampling_tags(span, sampled, sample_rate, priority_category): # type: (Span, bool, float, str) -> None mechanism = SamplingMechanism.TRACE_SAMPLING_RULE - if priority_category == "rule": + if priority_category == PriorityCategory.RULE_DEFAULT: + span.set_metric(SAMPLING_RULE_DECISION, sample_rate) + if priority_category == PriorityCategory.RULE_CUSTOMER: + span.set_metric(SAMPLING_RULE_DECISION, sample_rate) + mechanism = SamplingMechanism.REMOTE_USER_RULE + if priority_category == PriorityCategory.RULE_DYNAMIC: span.set_metric(SAMPLING_RULE_DECISION, sample_rate) - elif priority_category == "default": + mechanism = SamplingMechanism.REMOTE_DYNAMIC_RULE + elif priority_category == PriorityCategory.DEFAULT: mechanism = SamplingMechanism.DEFAULT - elif priority_category == "auto": + elif priority_category == PriorityCategory.AUTO: mechanism = SamplingMechanism.AGENT_RATE span.set_metric(SAMPLING_AGENT_DECISION, sample_rate) priorities = _CATEGORY_TO_PRIORITIES[priority_category] diff --git a/ddtrace/internal/symbol_db/symbols.py b/ddtrace/internal/symbol_db/symbols.py index d454e9eb8f5..9f66ffa3a86 100644 --- a/ddtrace/internal/symbol_db/symbols.py +++ b/ddtrace/internal/symbol_db/symbols.py @@ -3,7 +3,7 @@ from dataclasses import field import dis from enum import Enum -import http +from http.client import HTTPResponse from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from inspect import isasyncgenfunction @@ -31,7 +31,6 @@ from ddtrace.internal.logger import get_logger from ddtrace.internal.module import BaseModuleWatchdog from ddtrace.internal.module import origin -from ddtrace.internal.packages import is_stdlib from ddtrace.internal.runtime import get_runtime_id from ddtrace.internal.safety import _isinstance from ddtrace.internal.utils.cache import cached @@ -50,10 +49,10 @@ @cached() -def is_from_stdlib(obj: t.Any) -> t.Optional[bool]: +def is_from_user_code(obj: t.Any) -> t.Optional[bool]: try: path = origin(sys.modules[object.__getattribute__(obj, "__module__")]) - return is_stdlib(path) if path is not None else None + return packages.is_user_code(path) if path is not None else None except (AttributeError, KeyError): return None @@ -182,9 +181,6 @@ def _(cls, module: ModuleType, data: ScopeData): symbols = [] scopes = [] - if is_stdlib(module_origin): - return None - for alias, child in object.__getattribute__(module, "__dict__").items(): if _isinstance(child, ModuleType): # We don't want to traverse other modules. @@ -224,7 +220,7 @@ def _(cls, obj: type, data: ScopeData): return None data.seen.add(obj) - if is_from_stdlib(obj): + if not is_from_user_code(obj): return None symbols = [] @@ -347,7 +343,7 @@ def _(cls, f: FunctionType, data: ScopeData): return None data.seen.add(f) - if is_from_stdlib(f): + if not is_from_user_code(f): return None code = f.__dd_wrapped__.__code__ if hasattr(f, "__dd_wrapped__") else f.__code__ @@ -416,7 +412,7 @@ def _(cls, pr: property, data: ScopeData): data.seen.add(pr.fget) # TODO: These names don't match what is reported by the discovery. - if pr.fget is None or is_from_stdlib(pr.fget): + if pr.fget is None or not is_from_user_code(pr.fget): return None path = func_origin(t.cast(FunctionType, pr.fget)) @@ -477,7 +473,7 @@ def to_json(self) -> dict: "scopes": [_.to_json() for _ in self._scopes], } - def upload(self) -> http.client.HTTPResponse: + def upload(self) -> HTTPResponse: body, headers = multipart( parts=[ FormData( @@ -509,14 +505,24 @@ def __len__(self) -> int: def is_module_included(module: ModuleType) -> bool: + # Check if module name matches the include patterns if symdb_config._includes_re.match(module.__name__): return True - package = packages.module_to_package(module) - if package is None: + # Check if it is user code + module_origin = origin(module) + if module_origin is None: return False - return symdb_config._includes_re.match(package.name) is not None + if packages.is_user_code(module_origin): + return True + + # Check if the package name matches the include patterns + package = packages.filename_to_package(module_origin) + if package is not None and symdb_config._includes_re.match(package.name): + return True + + return False class SymbolDatabaseUploader(BaseModuleWatchdog): diff --git a/ddtrace/internal/telemetry/writer.py b/ddtrace/internal/telemetry/writer.py index 06cea670c39..836daa5da74 100644 --- a/ddtrace/internal/telemetry/writer.py +++ b/ddtrace/internal/telemetry/writer.py @@ -83,7 +83,6 @@ from .constants import TELEMETRY_TRACE_PEER_SERVICE_MAPPING from .constants import TELEMETRY_TRACE_REMOVE_INTEGRATION_SERVICE_NAMES_ENABLED from .constants import TELEMETRY_TRACE_SAMPLING_LIMIT -from .constants import TELEMETRY_TRACE_SAMPLING_RULES from .constants import TELEMETRY_TRACE_SPAN_ATTRIBUTE_SCHEMA from .constants import TELEMETRY_TRACE_WRITER_BUFFER_SIZE_BYTES from .constants import TELEMETRY_TRACE_WRITER_INTERVAL_SECONDS @@ -386,6 +385,9 @@ def _telemetry_entry(self, cfg_name: str) -> Tuple[str, str, _ConfigSource]: elif cfg_name == "_trace_sample_rate": name = "trace_sample_rate" value = str(item.value()) + elif cfg_name == "_trace_sampling_rules": + name = "trace_sampling_rules" + value = str(item.value()) elif cfg_name == "logs_injection": name = "logs_injection_enabled" value = "true" if item.value() else "false" @@ -428,6 +430,7 @@ def _app_started_event(self, register_app_shutdown=True): self._telemetry_entry("_sca_enabled"), self._telemetry_entry("_dsm_enabled"), self._telemetry_entry("_trace_sample_rate"), + self._telemetry_entry("_trace_sampling_rules"), self._telemetry_entry("logs_injection"), self._telemetry_entry("trace_http_header_tags"), self._telemetry_entry("tags"), @@ -462,7 +465,6 @@ def _app_started_event(self, register_app_shutdown=True): (TELEMETRY_TRACE_SAMPLING_LIMIT, config._trace_rate_limit, "unknown"), (TELEMETRY_SPAN_SAMPLING_RULES, config._sampling_rules, "unknown"), (TELEMETRY_SPAN_SAMPLING_RULES_FILE, config._sampling_rules_file, "unknown"), - (TELEMETRY_TRACE_SAMPLING_RULES, config._trace_sampling_rules, "unknown"), (TELEMETRY_PRIORITY_SAMPLING, config._priority_sampling, "unknown"), (TELEMETRY_PARTIAL_FLUSH_ENABLED, config._partial_flush_enabled, "unknown"), (TELEMETRY_PARTIAL_FLUSH_MIN_SPANS, config._partial_flush_min_spans, "unknown"), diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index b7cf05d8beb..411c68e84af 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -3,6 +3,7 @@ from typing import Any from typing import Dict from typing import Optional +from typing import Union import ddtrace from ddtrace import Span @@ -29,6 +30,7 @@ from ddtrace.llmobs._utils import _get_session_id from ddtrace.llmobs._writer import LLMObsEvalMetricWriter from ddtrace.llmobs._writer import LLMObsSpanWriter +from ddtrace.llmobs.utils import ExportedLLMObsSpan from ddtrace.llmobs.utils import Messages @@ -107,6 +109,32 @@ def disable(cls) -> None: cls.enabled = False log.debug("%s disabled", cls.__name__) + @classmethod + def export_span(cls, span: Optional[Span] = None) -> Optional[ExportedLLMObsSpan]: + """Returns a simple representation of a span to export its span and trace IDs. + If no span is provided, the current active LLMObs-type span will be used. + """ + if cls.enabled is False or cls._instance is None: + log.warning("LLMObs.export_span() requires LLMObs to be enabled.") + return None + if span: + try: + if span.span_type != SpanTypes.LLM: + log.warning("Span must be an LLMObs-generated span.") + return None + return ExportedLLMObsSpan(span_id=str(span.span_id), trace_id="{:x}".format(span.trace_id)) + except (TypeError, AttributeError): + log.warning("Failed to export span. Span must be a valid Span object.") + return None + span = cls._instance.tracer.current_span() + if span is None: + log.warning("No span provided and no active LLMObs-generated span found.") + return None + if span.span_type != SpanTypes.LLM: + log.warning("Span must be an LLMObs-generated span.") + return None + return ExportedLLMObsSpan(span_id=str(span.span_id), trace_id="{:x}".format(span.trace_id)) + def _start_span( self, operation_kind: str, @@ -281,10 +309,10 @@ def annotate( if span is None: span = cls._instance.tracer.current_span() if span is None: - log.warning("No span provided and no active span found.") + log.warning("No span provided and no active LLMObs-generated span found.") return if span.span_type != SpanTypes.LLM: - log.warning("Span must be an LLM-type span.") + log.warning("Span must be an LLMObs-generated span.") return if span.finished: log.warning("Cannot annotate a finished span.") @@ -402,3 +430,56 @@ def _tag_metrics(span: Span, metrics: Dict[str, Any]) -> None: span.set_tag_str(METRICS, json.dumps(metrics)) except TypeError: log.warning("Failed to parse span metrics. Metric key-value pairs must be JSON serializable.") + + @classmethod + def submit_evaluation( + cls, + span_context: Dict[str, str], + label: str, + metric_type: str, + value: Union[str, int, float], + ) -> None: + """ + Submits a custom evaluation metric for a given span ID and trace ID. + + :param span_context: A dictionary containing the span_id and trace_id of interest. + :param str label: The name of the evaluation metric. + :param str metric_type: The type of the evaluation metric. One of "categorical", "numerical", and "score". + :param value: The value of the evaluation metric. + Must be a string (categorical), integer (numerical/score), or float (numerical/score). + """ + if cls.enabled is False or cls._instance is None or cls._instance._llmobs_eval_metric_writer is None: + log.warning("LLMObs.submit_evaluation() requires LLMObs to be enabled.") + return + if not isinstance(span_context, dict): + log.warning( + "span_context must be a dictionary containing both span_id and trace_id keys. " + "LLMObs.export_span() can be used to generate this dictionary from a given span." + ) + return + span_id = span_context.get("span_id") + trace_id = span_context.get("trace_id") + if not (span_id and trace_id): + log.warning("span_id and trace_id must both be specified for the given evaluation metric to be submitted.") + return + if not label: + log.warning("label must be the specified name of the evaluation metric.") + return + if not metric_type or metric_type.lower() not in ("categorical", "numerical", "score"): + log.warning("metric_type must be one of 'categorical', 'numerical', or 'score'.") + return + if metric_type == "categorical" and not isinstance(value, str): + log.warning("value must be a string for a categorical metric.") + return + if metric_type in ("numerical", "score") and not isinstance(value, (int, float)): + log.warning("value must be an integer or float for a numerical/score metric.") + return + cls._instance._llmobs_eval_metric_writer.enqueue( + { + "span_id": span_id, + "trace_id": trace_id, + "label": str(label), + "metric_type": metric_type.lower(), + "{}_value".format(metric_type): value, + } + ) diff --git a/ddtrace/llmobs/_writer.py b/ddtrace/llmobs/_writer.py index 8380f861f0c..a90251fd6c4 100644 --- a/ddtrace/llmobs/_writer.py +++ b/ddtrace/llmobs/_writer.py @@ -66,7 +66,7 @@ def __init__(self, site: str, api_key: str, interval: float, timeout: float) -> def start(self, *args, **kwargs): super(BaseLLMObsWriter, self).start() - logger.debug("started %r to %r", (self.__class__.__name__, self._url)) + logger.debug("started %r to %r", self.__class__.__name__, self._url) atexit.register(self.on_shutdown) def on_shutdown(self): @@ -76,7 +76,7 @@ def _enqueue(self, event: Union[LLMObsSpanEvent, LLMObsEvaluationMetricEvent]) - with self._lock: if len(self._buffer) >= self._buffer_limit: logger.warning( - "%r event buffer full (limit is %d), dropping event", (self.__class__.__name__, self._buffer_limit) + "%r event buffer full (limit is %d), dropping event", self.__class__.__name__, self._buffer_limit ) return self._buffer.append(event) @@ -92,7 +92,7 @@ def periodic(self) -> None: try: enc_llm_events = json.dumps(data) except TypeError: - logger.error("failed to encode %d LLMObs %s events", (len(events), self._event_type), exc_info=True) + logger.error("failed to encode %d LLMObs %s events", len(events), self._event_type, exc_info=True) return conn = httplib.HTTPSConnection(self._intake, 443, timeout=self._timeout) try: @@ -101,19 +101,17 @@ def periodic(self) -> None: if resp.status >= 300: logger.error( "failed to send %d LLMObs %s events to %s, got response code %d, status: %s", - ( - len(events), - self._event_type, - self._url, - resp.status, - resp.read(), - ), + len(events), + self._event_type, + self._url, + resp.status, + resp.read(), ) else: - logger.debug("sent %d LLMObs %s events to %s", (len(events), self._event_type, self._url)) + logger.debug("sent %d LLMObs %s events to %s", len(events), self._event_type, self._url) except Exception: logger.error( - "failed to send %d LLMObs %s events to %s", (len(events), self._event_type, self._intake), exc_info=True + "failed to send %d LLMObs %s events to %s", len(events), self._event_type, self._intake, exc_info=True ) finally: conn.close() diff --git a/ddtrace/llmobs/utils.py b/ddtrace/llmobs/utils.py index 997a26c0b85..1fbb7305c36 100644 --- a/ddtrace/llmobs/utils.py +++ b/ddtrace/llmobs/utils.py @@ -15,6 +15,7 @@ log = get_logger(__name__) +ExportedLLMObsSpan = TypedDict("ExportedLLMObsSpan", {"span_id": str, "trace_id": str}) Message = TypedDict("Message", {"content": str, "role": str}, total=False) diff --git a/ddtrace/sampler.py b/ddtrace/sampler.py index 69cc58c73d7..fe558c1f426 100644 --- a/ddtrace/sampler.py +++ b/ddtrace/sampler.py @@ -23,6 +23,8 @@ from .settings import _config as ddconfig +PROVENANCE_ORDER = ["customer", "dynamic", "default"] + try: from json.decoder import JSONDecodeError except ImportError: @@ -158,7 +160,7 @@ def _choose_priority_category(self, sampler): elif isinstance(sampler, _AgentRateSampler): return _PRIORITY_CATEGORY.AUTO else: - return _PRIORITY_CATEGORY.RULE + return _PRIORITY_CATEGORY.RULE_DEF def _make_sampling_decision(self, span): # type: (Span) -> Tuple[bool, BaseSampler] @@ -204,7 +206,7 @@ class DatadogSampler(RateByServiceSampler): per second. """ - __slots__ = ("limiter", "rules") + __slots__ = ("limiter", "rules", "default_sample_rate") NO_RATE_LIMIT = -1 # deprecate and remove the DEFAULT_RATE_LIMIT field from DatadogSampler @@ -228,7 +230,7 @@ def __init__( """ # Use default sample rate of 1.0 super(DatadogSampler, self).__init__() - + self.default_sample_rate = default_sample_rate if default_sample_rate is None: if ddconfig._get_source("_trace_sample_rate") != "default": default_sample_rate = float(ddconfig._trace_sample_rate) @@ -239,7 +241,7 @@ def __init__( if rules is None: env_sampling_rules = ddconfig._trace_sampling_rules if env_sampling_rules: - rules = self._parse_rules_from_env_variable(env_sampling_rules) + rules = self._parse_rules_from_str(env_sampling_rules) else: rules = [] self.rules = rules @@ -268,7 +270,8 @@ def __str__(self): __repr__ = __str__ - def _parse_rules_from_env_variable(self, rules): + @staticmethod + def _parse_rules_from_str(rules): # type: (str) -> List[SamplingRule] sampling_rules = [] try: @@ -283,13 +286,22 @@ def _parse_rules_from_env_variable(self, rules): name = rule.get("name", SamplingRule.NO_RULE) resource = rule.get("resource", SamplingRule.NO_RULE) tags = rule.get("tags", SamplingRule.NO_RULE) + provenance = rule.get("provenance", "default") try: sampling_rule = SamplingRule( - sample_rate=sample_rate, service=service, name=name, resource=resource, tags=tags + sample_rate=sample_rate, + service=service, + name=name, + resource=resource, + tags=tags, + provenance=provenance, ) except ValueError as e: raise ValueError("Error creating sampling rule {}: {}".format(json.dumps(rule), e)) sampling_rules.append(sampling_rule) + + # Sort the sampling_rules list using a lambda function as the key + sampling_rules = sorted(sampling_rules, key=lambda rule: PROVENANCE_ORDER.index(rule.provenance)) return sampling_rules def sample(self, span): @@ -320,7 +332,13 @@ def sample(self, span): def _choose_priority_category_with_rule(self, rule, sampler): # type: (Optional[SamplingRule], BaseSampler) -> str if rule: - return _PRIORITY_CATEGORY.RULE + provenance = rule.provenance + if provenance == "customer": + return _PRIORITY_CATEGORY.RULE_CUSTOMER + if provenance == "dynamic": + return _PRIORITY_CATEGORY.RULE_DYNAMIC + return _PRIORITY_CATEGORY.RULE_DEF + if self.limiter._has_been_configured: return _PRIORITY_CATEGORY.USER return super(DatadogSampler, self)._choose_priority_category(sampler) diff --git a/ddtrace/sampling_rule.py b/ddtrace/sampling_rule.py index aecf03de5ab..72ab1574277 100644 --- a/ddtrace/sampling_rule.py +++ b/ddtrace/sampling_rule.py @@ -34,6 +34,7 @@ def __init__( name=NO_RULE, # type: Any resource=NO_RULE, # type: Any tags=NO_RULE, # type: Any + provenance="default", # type: str ): # type: (...) -> None """ @@ -83,6 +84,7 @@ def __init__( self.service = self.choose_matcher(service) self.name = self.choose_matcher(name) self.resource = self.choose_matcher(resource) + self.provenance = provenance @property def sample_rate(self): @@ -236,13 +238,14 @@ def choose_matcher(self, prop): return GlobMatcher(prop) if prop != SamplingRule.NO_RULE else SamplingRule.NO_RULE def __repr__(self): - return "{}(sample_rate={!r}, service={!r}, name={!r}, resource={!r}, tags={!r})".format( + return "{}(sample_rate={!r}, service={!r}, name={!r}, resource={!r}, tags={!r}, provenance={!r})".format( self.__class__.__name__, self.sample_rate, self._no_rule_or_self(self.service), self._no_rule_or_self(self.name), self._no_rule_or_self(self.resource), self._no_rule_or_self(self.tags), + self.provenance, ) __str__ = __repr__ diff --git a/ddtrace/settings/config.py b/ddtrace/settings/config.py index c49abc83bae..fc0083d6222 100644 --- a/ddtrace/settings/config.py +++ b/ddtrace/settings/config.py @@ -1,4 +1,5 @@ from copy import deepcopy +import json import os import re import sys @@ -282,6 +283,11 @@ def _default_config(): default=1.0, envs=[("DD_TRACE_SAMPLE_RATE", float)], ), + "_trace_sampling_rules": _ConfigItem( + name="trace_sampling_rules", + default=lambda: "", + envs=[("DD_TRACE_SAMPLING_RULES", str)], + ), "logs_injection": _ConfigItem( name="logs_injection", default=False, @@ -384,7 +390,6 @@ def __init__(self): self._startup_logs_enabled = asbool(os.getenv("DD_TRACE_STARTUP_LOGS", False)) self._trace_rate_limit = int(os.getenv("DD_TRACE_RATE_LIMIT", default=DEFAULT_SAMPLING_RATE_LIMIT)) - self._trace_sampling_rules = os.getenv("DD_TRACE_SAMPLING_RULES") self._partial_flush_enabled = asbool(os.getenv("DD_TRACE_PARTIAL_FLUSH_ENABLED", default=True)) self._partial_flush_min_spans = int(os.getenv("DD_TRACE_PARTIAL_FLUSH_MIN_SPANS", default=300)) self._priority_sampling = asbool(os.getenv("DD_PRIORITY_SAMPLING", default=True)) @@ -562,7 +567,6 @@ def __init__(self): def __getattr__(self, name) -> Any: if name in self._config: return self._config[name].value() - if name not in self._integration_configs: self._integration_configs[name] = IntegrationConfig(self, name) @@ -753,6 +757,14 @@ def _handle_remoteconfig(self, data, test_tracer=None): if "tracing_sampling_rate" in lib_config: base_rc_config["_trace_sample_rate"] = lib_config["tracing_sampling_rate"] + if "tracing_sampling_rules" in lib_config: + trace_sampling_rules = lib_config["tracing_sampling_rules"] + if trace_sampling_rules: + # returns None if no rules + trace_sampling_rules = self.convert_rc_trace_sampling_rules(trace_sampling_rules) + if trace_sampling_rules: + base_rc_config["_trace_sampling_rules"] = trace_sampling_rules + if "log_injection_enabled" in lib_config: base_rc_config["logs_injection"] = lib_config["log_injection_enabled"] @@ -802,3 +814,60 @@ def enable_remote_configuration(self): remoteconfig_poller.register("APM_TRACING", remoteconfig_pubsub) remoteconfig_poller.register("AGENT_CONFIG", remoteconfig_pubsub) remoteconfig_poller.register("AGENT_TASK", remoteconfig_pubsub) + + def _remove_invalid_rules(self, rc_rules: List) -> List: + """Remove invalid sampling rules from the given list""" + # loop through list of dictionaries, if a dictionary doesn't have certain attributes, remove it + for rule in rc_rules: + if ( + ("service" not in rule and "name" not in rule and "resource" not in rule and "tags" not in rule) + or "sample_rate" not in rule + or "provenance" not in rule + ): + log.debug("Invalid sampling rule from remoteconfig found, rule will be removed: %s", rule) + rc_rules.remove(rule) + + return rc_rules + + def _tags_to_dict(self, tags: List[Dict]): + """ + Converts a list of tag dictionaries to a single dictionary. + """ + if isinstance(tags, list): + return {tag["key"]: tag["value_glob"] for tag in tags} + return tags + + def convert_rc_trace_sampling_rules(self, rc_rules: List[Dict[str, Any]]) -> Optional[str]: + """Example of an incoming rule: + [ + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [ + { + "key": "care_about", + "value_glob": "yes" + }, + { + "key": "region", + "value_glob": "us-*" + } + ] + } + ] + + Example of a converted rule: + '[{"sample_rate":1.0,"service":"my-service","resource":"*","name":"web.request","tags":{"care_about":"yes","region":"us-*"},provenance":"customer"}]' + """ + rc_rules = self._remove_invalid_rules(rc_rules) + for rule in rc_rules: + tags = rule.get("tags") + if tags: + rule["tags"] = self._tags_to_dict(tags) + if rc_rules: + return json.dumps(rc_rules) + else: + return None diff --git a/tests/.suitespec.json b/tests/.suitespec.json index 7e6f1512ec4..143ef63a62e 100644 --- a/tests/.suitespec.json +++ b/tests/.suitespec.json @@ -141,6 +141,7 @@ "ddtrace/contrib/yaaredis/*", "ddtrace/_trace/utils_redis.py", "ddtrace/contrib/redis_utils.py", + "ddtrace/contrib/trace_utils_redis.py", "ddtrace/ext/redis.py" ], "mongo": [ diff --git a/tests/contrib/botocore/test.py b/tests/contrib/botocore/test.py index 8709964db6b..aa9627169a6 100644 --- a/tests/contrib/botocore/test.py +++ b/tests/contrib/botocore/test.py @@ -312,7 +312,7 @@ def test_s3_client(self): @mock_s3 def test_s3_head_404_default(self): """ - By default we attach exception information to s3 HeadObject + By default we do not attach exception information to s3 HeadObject API calls with a 404 response """ s3 = self.session.create_client("s3", region_name="us-west-2") diff --git a/tests/integration/test_debug.py b/tests/integration/test_debug.py index 0d486355c26..cf5520dcb7c 100644 --- a/tests/integration/test_debug.py +++ b/tests/integration/test_debug.py @@ -336,7 +336,8 @@ def test_startup_logs_sampling_rules(): f = debug.collect(tracer) assert f.get("sampler_rules") == [ - "SamplingRule(sample_rate=1.0, service='NO_RULE', name='NO_RULE', resource='NO_RULE', tags='NO_RULE')" + "SamplingRule(sample_rate=1.0, service='NO_RULE', name='NO_RULE', resource='NO_RULE'," + " tags='NO_RULE', provenance='default')" ] sampler = ddtrace.sampler.DatadogSampler( @@ -346,7 +347,8 @@ def test_startup_logs_sampling_rules(): f = debug.collect(tracer) assert f.get("sampler_rules") == [ - "SamplingRule(sample_rate=1.0, service='xyz', name='abc', resource='NO_RULE', tags='NO_RULE')" + "SamplingRule(sample_rate=1.0, service='xyz', name='abc', resource='NO_RULE'," + " tags='NO_RULE', provenance='default')" ] diff --git a/tests/internal/remoteconfig/test_remoteconfig.py b/tests/internal/remoteconfig/test_remoteconfig.py index deaa2790bde..feb83b775d6 100644 --- a/tests/internal/remoteconfig/test_remoteconfig.py +++ b/tests/internal/remoteconfig/test_remoteconfig.py @@ -10,6 +10,7 @@ from mock.mock import ANY import pytest +from ddtrace import config from ddtrace.internal.remoteconfig._connectors import PublisherSubscriberConnector from ddtrace.internal.remoteconfig._publishers import RemoteConfigPublisherMergeDicts from ddtrace.internal.remoteconfig._pubsub import PubSub @@ -20,6 +21,8 @@ from ddtrace.internal.remoteconfig.worker import RemoteConfigPoller from ddtrace.internal.remoteconfig.worker import remoteconfig_poller from ddtrace.internal.service import ServiceStatus +from ddtrace.sampler import DatadogSampler +from ddtrace.sampling_rule import SamplingRule from tests.internal.test_utils_version import _assert_and_get_version_agent_format from tests.utils import override_global_config @@ -428,3 +431,161 @@ def test_rc_default_products_registered(): assert bool(remoteconfig_poller._client._products.get("APM_TRACING")) == rc_enabled assert bool(remoteconfig_poller._client._products.get("AGENT_CONFIG")) == rc_enabled assert bool(remoteconfig_poller._client._products.get("AGENT_TASK")) == rc_enabled + + +@pytest.mark.parametrize( + "rc_rules,expected_config_rules,expected_sampling_rules", + [ + ( + [ # Test with all fields + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "dynamic", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"service": "my-service", "name": "web.request", "resource": "*", "provenance": "dynamic",' + ' "sample_rate": 1.0, "tags": {"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name="web.request", + resource="*", + tags={"care_about": "yes", "region": "us-*"}, + provenance="dynamic", + ) + ], + ), + ( # Test with no service + [ + { + "name": "web.request", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"name": "web.request", "resource": "*", "provenance": "customer", "sample_rate": 1.0, "tags": ' + '{"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service=SamplingRule.NO_RULE, + name="web.request", + resource="*", + tags={"care_about": "yes", "region": "us-*"}, + provenance="customer", + ) + ], + ), + ( + # Test with no tags + [ + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + } + ], + '[{"service": "my-service", "name": "web.request", "resource": "*", "provenance": ' + '"customer", "sample_rate": 1.0}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name="web.request", + resource="*", + tags=SamplingRule.NO_RULE, + provenance="customer", + ) + ], + ), + ( + # Test with no resource + [ + { + "service": "my-service", + "name": "web.request", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"service": "my-service", "name": "web.request", "provenance": "customer", "sample_rate": 1.0, "tags":' + ' {"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name="web.request", + resource=SamplingRule.NO_RULE, + tags={"care_about": "yes", "region": "us-*"}, + provenance="customer", + ) + ], + ), + ( + # Test with no name + [ + { + "service": "my-service", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"service": "my-service", "resource": "*", "provenance": "customer", "sample_rate": 1.0, "tags":' + ' {"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name=SamplingRule.NO_RULE, + resource="*", + tags={"care_about": "yes", "region": "us-*"}, + provenance="customer", + ) + ], + ), + ( + # Test with no sample rate + [ + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "customer", + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + None, + None, + ), + ( + # Test with no service, name, resource, tags + [ + { + "provenance": "customer", + "sample_rate": 1.0, + } + ], + None, + None, + ), + ], +) +def test_trace_sampling_rules_conversion(rc_rules, expected_config_rules, expected_sampling_rules): + trace_sampling_rules = config.convert_rc_trace_sampling_rules(rc_rules) + + assert trace_sampling_rules == expected_config_rules + if trace_sampling_rules is not None: + parsed_rules = DatadogSampler._parse_rules_from_str(trace_sampling_rules) + assert parsed_rules == expected_sampling_rules diff --git a/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py b/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py index ad6a9e4436c..760fa4a2e7b 100644 --- a/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py +++ b/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py @@ -18,7 +18,7 @@ def _expected_payload( rc_client, - capabilities="CPAA", # this was gathered by running the test and observing the payload + capabilities="IAjwAA==", # this was gathered by running the test and observing the payload has_errors=False, targets_version=0, backend_client_state=None, diff --git a/tests/internal/symbol_db/test_symbols.py b/tests/internal/symbol_db/test_symbols.py index 4c879b63e5c..a97f6c5bcee 100644 --- a/tests/internal/symbol_db/test_symbols.py +++ b/tests/internal/symbol_db/test_symbols.py @@ -203,20 +203,11 @@ def test_symbols_upload_enabled(): assert remoteconfig_poller.get_registered("LIVE_DEBUGGING_SYMBOL_DB") is not None -@pytest.mark.subprocess( - ddtrace_run=True, - env=dict( - DD_SYMBOL_DATABASE_UPLOAD_ENABLED="1", - _DD_SYMBOL_DATABASE_FORCE_UPLOAD="1", - DD_SYMBOL_DATABASE_INCLUDES="tests.submod.stuff", - ), -) +@pytest.mark.subprocess(ddtrace_run=True, env=dict(DD_SYMBOL_DATABASE_INCLUDES="tests.submod.stuff")) def test_symbols_force_upload(): from ddtrace.internal.symbol_db.symbols import ScopeType from ddtrace.internal.symbol_db.symbols import SymbolDatabaseUploader - assert SymbolDatabaseUploader.is_installed() - contexts = [] def _upload_context(context): @@ -224,11 +215,18 @@ def _upload_context(context): SymbolDatabaseUploader._upload_context = staticmethod(_upload_context) + SymbolDatabaseUploader.install() + + def get_scope(contexts, name): + for context in (_.to_json() for _ in contexts): + for scope in context["scopes"]: + if scope["name"] == name: + return scope + raise ValueError(f"Scope {name} not found in {contexts}") + import tests.submod.stuff # noqa import tests.submod.traced_stuff # noqa - (context,) = contexts - - (scope,) = context.to_json()["scopes"] + scope = get_scope(contexts, "tests.submod.stuff") assert scope["scope_type"] == ScopeType.MODULE assert scope["name"] == "tests.submod.stuff" diff --git a/tests/internal/test_settings.py b/tests/internal/test_settings.py index faea554f489..4bae14bfef9 100644 --- a/tests/internal/test_settings.py +++ b/tests/internal/test_settings.py @@ -256,6 +256,216 @@ def test_remoteconfig_sampling_rate_user(run_python_code_in_subprocess): assert status == 0, err.decode("utf-8") +def test_remoteconfig_sampling_rules(run_python_code_in_subprocess): + env = os.environ.copy() + env.update({"DD_TRACE_SAMPLING_RULES": '[{"sample_rate":0.1, "name":"test"}]'}) + + out, err, status, _ = run_python_code_in_subprocess( + """ +from ddtrace import config, tracer +from ddtrace.sampler import DatadogSampler +from tests.internal.test_settings import _base_rc_config, _deleted_rc_config + +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.1 +assert span.get_tag("_dd.p.dm") == "-3" + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "test", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.2, + } + ]})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.2 +assert span.get_tag("_dd.p.dm") == "-11" + +config._handle_remoteconfig(_base_rc_config({})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.1 + +custom_sampler = DatadogSampler(DatadogSampler._parse_rules_from_str('[{"sample_rate":0.3, "name":"test"}]')) +tracer.configure(sampler=custom_sampler) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "test", + "resource": "*", + "provenance": "dynamic", + "sample_rate": 0.4, + } + ]})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.4 +assert span.get_tag("_dd.p.dm") == "-12" + +config._handle_remoteconfig(_base_rc_config({})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "ok", + "name": "test", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.4, + } + ]})) +with tracer.trace(service="ok", name="test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.4 +assert span.get_tag("_dd.p.dm") == "-11" + +config._handle_remoteconfig(_deleted_rc_config()) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + + """, + env=env, + ) + assert status == 0, err.decode("utf-8") + + +def test_remoteconfig_sample_rate_and_rules(run_python_code_in_subprocess): + """There is complex logic regarding the interaction between setting new + sample rates and rules with remote config. + """ + env = os.environ.copy() + env.update({"DD_TRACE_SAMPLING_RULES": '[{"sample_rate":0.9, "name":"rules"}]'}) + env.update({"DD_TRACE_SAMPLE_RATE": "0.8"}) + + out, err, status, _ = run_python_code_in_subprocess( + """ +from ddtrace import config, tracer +from ddtrace.sampler import DatadogSampler +from tests.internal.test_settings import _base_rc_config, _deleted_rc_config + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "rules", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.7, + } + ]})) + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.7 +assert span.get_tag("_dd.p.dm") == "-11" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rate": 0.2})) + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.2 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rate": 0.3})) + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({})) + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "rules_dynamic", + "resource": "*", + "provenance": "dynamic", + "sample_rate": 0.1, + }, + { + "service": "*", + "name": "rules_customer", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.6, + } + ]})) + +with tracer.trace("rules_dynamic") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.1 +assert span.get_tag("_dd.p.dm") == "-12" + +with tracer.trace("rules_customer") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.6 +assert span.get_tag("_dd.p.dm") == "-11" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + """, + env=env, + ) + assert status == 0, err.decode("utf-8") + + def test_remoteconfig_custom_tags(run_python_code_in_subprocess): env = os.environ.copy() env.update({"DD_TAGS": "team:apm"}) diff --git a/tests/internal/test_tracer_flare.py b/tests/internal/test_tracer_flare.py index 35f38674e67..7051190e17d 100644 --- a/tests/internal/test_tracer_flare.py +++ b/tests/internal/test_tracer_flare.py @@ -2,13 +2,16 @@ from logging import Logger import multiprocessing import os +import pathlib from typing import Optional import unittest from unittest import mock +import uuid from ddtrace.internal.flare import TRACER_FLARE_DIRECTORY from ddtrace.internal.flare import TRACER_FLARE_FILE_HANDLER_NAME from ddtrace.internal.flare import Flare +from ddtrace.internal.flare import FlareSendRequest from ddtrace.internal.logger import get_logger @@ -16,25 +19,17 @@ class TracerFlareTests(unittest.TestCase): - mock_agent_config = [{"name": "flare-log-level", "config": {"log_level": "DEBUG"}}] - mock_agent_task = [ - False, - { - "args": { - "case_id": "1111111", - "hostname": "myhostname", - "user_handle": "user.name@datadoghq.com", - }, - "task_type": "tracer_flare", - "uuid": "d53fc8a4-8820-47a2-aa7d-d565582feb81", - }, - ] + mock_flare_send_request = FlareSendRequest( + case_id="1111111", hostname="myhostname", email="user.name@datadoghq.com" + ) def setUp(self): - self.flare = Flare() + self.flare_uuid = uuid.uuid4() + self.flare_dir = f"{TRACER_FLARE_DIRECTORY}-{self.flare_uuid}" + self.flare = Flare(flare_dir=pathlib.Path(self.flare_dir)) self.pid = os.getpid() - self.flare_file_path = f"{TRACER_FLARE_DIRECTORY}/tracer_python_{self.pid}.log" - self.config_file_path = f"{TRACER_FLARE_DIRECTORY}/tracer_config_{self.pid}.json" + self.flare_file_path = f"{self.flare_dir}/tracer_python_{self.pid}.log" + self.config_file_path = f"{self.flare_dir}/tracer_config_{self.pid}.json" def tearDown(self): self.confirm_cleanup() @@ -53,7 +48,7 @@ def test_single_process_success(self): """ ddlogger = get_logger("ddtrace") - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") file_handler = self._get_handler() valid_logger_level = self.flare._get_valid_logger_level(DEBUG_LEVEL_INT) @@ -66,7 +61,7 @@ def test_single_process_success(self): # Sends request to testagent # This just validates the request params - self.flare.send(self.mock_agent_task) + self.flare.send(self.mock_flare_send_request) def test_single_process_partial_failure(self): """ @@ -79,7 +74,7 @@ def test_single_process_partial_failure(self): # Mock the partial failure with mock.patch("json.dump") as mock_json: mock_json.side_effect = Exception("file issue happened") - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") file_handler = self._get_handler() assert file_handler is not None @@ -89,7 +84,7 @@ def test_single_process_partial_failure(self): assert os.path.exists(self.flare_file_path) assert not os.path.exists(self.config_file_path) - self.flare.send(self.mock_agent_task) + self.flare.send(self.mock_flare_send_request) def test_multiple_process_success(self): """ @@ -99,10 +94,10 @@ def test_multiple_process_success(self): num_processes = 3 def handle_agent_config(): - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") def handle_agent_task(): - self.flare.send(self.mock_agent_task) + self.flare.send(self.mock_flare_send_request) # Create multiple processes for _ in range(num_processes): @@ -114,7 +109,7 @@ def handle_agent_task(): # Assert that each process wrote its file successfully # We double the process number because each will generate a log file and a config file - assert len(processes) * 2 == len(os.listdir(TRACER_FLARE_DIRECTORY)) + assert len(processes) * 2 == len(os.listdir(self.flare_dir)) for _ in range(num_processes): p = multiprocessing.Process(target=handle_agent_task) @@ -130,19 +125,19 @@ def test_multiple_process_partial_failure(self): """ processes = [] - def do_tracer_flare(agent_config, agent_task): - self.flare.prepare(agent_config) + def do_tracer_flare(prep_request, send_request): + self.flare.prepare(prep_request) # Assert that only one process wrote its file successfully # We check for 2 files because it will generate a log file and a config file - assert 2 == len(os.listdir(TRACER_FLARE_DIRECTORY)) - self.flare.send(agent_task) + assert 2 == len(os.listdir(self.flare_dir)) + self.flare.send(send_request) # Create successful process - p = multiprocessing.Process(target=do_tracer_flare, args=(self.mock_agent_config, self.mock_agent_task)) + p = multiprocessing.Process(target=do_tracer_flare, args=("DEBUG", self.mock_flare_send_request)) processes.append(p) p.start() # Create failing process - p = multiprocessing.Process(target=do_tracer_flare, args=(None, self.mock_agent_task)) + p = multiprocessing.Process(target=do_tracer_flare, args=(None, self.mock_flare_send_request)) processes.append(p) p.start() for p in processes: @@ -154,7 +149,7 @@ def test_no_app_logs(self): file, just the tracer logs """ app_logger = Logger(name="my-app", level=DEBUG_LEVEL_INT) - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") app_log_line = "this is an app log" app_logger.debug(app_log_line) @@ -169,5 +164,5 @@ def test_no_app_logs(self): self.flare.revert_configs() def confirm_cleanup(self): - assert not os.path.exists(TRACER_FLARE_DIRECTORY), f"The directory {TRACER_FLARE_DIRECTORY} still exists" + assert not self.flare.flare_dir.exists(), f"The directory {self.flare.flare_dir} still exists" assert self._get_handler() is None, "File handler was not removed" diff --git a/tests/llmobs/_utils.py b/tests/llmobs/_utils.py index 3678a1392fb..2cb1456ccc5 100644 --- a/tests/llmobs/_utils.py +++ b/tests/llmobs/_utils.py @@ -182,3 +182,16 @@ def _get_llmobs_parent_id(span: Span): if parent.span_type == SpanTypes.LLM: return str(parent.span_id) parent = parent._parent + + +def _expected_llmobs_eval_metric_event( + span_id, trace_id, metric_type, label, categorical_value=None, score_value=None, numerical_value=None +): + eval_metric_event = {"span_id": span_id, "trace_id": trace_id, "metric_type": metric_type, "label": label} + if categorical_value is not None: + eval_metric_event["categorical_value"] = categorical_value + if score_value is not None: + eval_metric_event["score_value"] = score_value + if numerical_value is not None: + eval_metric_event["numerical_value"] = numerical_value + return eval_metric_event diff --git a/tests/llmobs/conftest.py b/tests/llmobs/conftest.py index a722c863c38..a0bc2daaec2 100644 --- a/tests/llmobs/conftest.py +++ b/tests/llmobs/conftest.py @@ -37,6 +37,16 @@ def mock_llmobs_span_writer(): patcher.stop() +@pytest.fixture +def mock_llmobs_eval_metric_writer(): + patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsEvalMetricWriter") + LLMObsEvalMetricWriterMock = patcher.start() + m = mock.MagicMock() + LLMObsEvalMetricWriterMock.return_value = m + yield m + patcher.stop() + + @pytest.fixture def mock_writer_logs(): with mock.patch("ddtrace.llmobs._writer.logger") as m: @@ -54,7 +64,7 @@ def default_global_config(): @pytest.fixture -def LLMObs(mock_llmobs_span_writer, ddtrace_global_config): +def LLMObs(mock_llmobs_span_writer, mock_llmobs_eval_metric_writer, ddtrace_global_config): global_config = default_global_config() global_config.update(ddtrace_global_config) with override_global_config(global_config): diff --git a/tests/llmobs/test_llmobs_eval_metric_writer.py b/tests/llmobs/test_llmobs_eval_metric_writer.py index 2f9368c8bd8..984f8645feb 100644 --- a/tests/llmobs/test_llmobs_eval_metric_writer.py +++ b/tests/llmobs/test_llmobs_eval_metric_writer.py @@ -45,9 +45,7 @@ def _numerical_metric_event(): def test_writer_start(mock_writer_logs): llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=1000, timeout=1) llmobs_eval_metric_writer.start() - mock_writer_logs.debug.assert_has_calls( - [mock.call("started %r to %r", ("LLMObsEvalMetricWriter", INTAKE_ENDPOINT))] - ) + mock_writer_logs.debug.assert_has_calls([mock.call("started %r to %r", "LLMObsEvalMetricWriter", INTAKE_ENDPOINT)]) def test_buffer_limit(mock_writer_logs): @@ -55,7 +53,7 @@ def test_buffer_limit(mock_writer_logs): for _ in range(1001): llmobs_eval_metric_writer.enqueue({}) mock_writer_logs.warning.assert_called_with( - "%r event buffer full (limit is %d), dropping event", ("LLMObsEvalMetricWriter", 1000) + "%r event buffer full (limit is %d), dropping event", "LLMObsEvalMetricWriter", 1000 ) @@ -69,13 +67,11 @@ def test_send_metric_bad_api_key(mock_writer_logs): llmobs_eval_metric_writer.periodic() mock_writer_logs.error.assert_called_with( "failed to send %d LLMObs %s events to %s, got response code %d, status: %s", - ( - 1, - "evaluation_metric", - INTAKE_ENDPOINT, - 403, - b'{"status":"error","code":403,"errors":["Forbidden"],"statuspage":"http://status.datadoghq.com","twitter":"http://twitter.com/datadogops","email":"support@datadoghq.com"}', # noqa - ), + 1, + "evaluation_metric", + INTAKE_ENDPOINT, + 403, + b'{"status":"error","code":403,"errors":["Forbidden"],"statuspage":"http://status.datadoghq.com","twitter":"http://twitter.com/datadogops","email":"support@datadoghq.com"}', # noqa ) @@ -86,7 +82,7 @@ def test_send_categorical_metric(mock_writer_logs): llmobs_eval_metric_writer.enqueue(_categorical_metric_event()) llmobs_eval_metric_writer.periodic() mock_writer_logs.debug.assert_has_calls( - [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + [mock.call("sent %d LLMObs %s events to %s", 1, "evaluation_metric", INTAKE_ENDPOINT)] ) @@ -97,7 +93,7 @@ def test_send_numerical_metric(mock_writer_logs): llmobs_eval_metric_writer.enqueue(_numerical_metric_event()) llmobs_eval_metric_writer.periodic() mock_writer_logs.debug.assert_has_calls( - [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + [mock.call("sent %d LLMObs %s events to %s", 1, "evaluation_metric", INTAKE_ENDPOINT)] ) @@ -108,7 +104,7 @@ def test_send_score_metric(mock_writer_logs): llmobs_eval_metric_writer.enqueue(_score_metric_event()) llmobs_eval_metric_writer.periodic() mock_writer_logs.debug.assert_has_calls( - [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + [mock.call("sent %d LLMObs %s events to %s", 1, "evaluation_metric", INTAKE_ENDPOINT)] ) @@ -121,13 +117,13 @@ def test_send_timed_events(mock_writer_logs): llmobs_eval_metric_writer.enqueue(_score_metric_event()) time.sleep(0.1) mock_writer_logs.debug.assert_has_calls( - [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + [mock.call("sent %d LLMObs %s events to %s", 1, "evaluation_metric", INTAKE_ENDPOINT)] ) mock_writer_logs.reset_mock() llmobs_eval_metric_writer.enqueue(_numerical_metric_event()) time.sleep(0.1) mock_writer_logs.debug.assert_has_calls( - [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + [mock.call("sent %d LLMObs %s events to %s", 1, "evaluation_metric", INTAKE_ENDPOINT)] ) @@ -141,7 +137,7 @@ def test_send_multiple_events(mock_writer_logs): llmobs_eval_metric_writer.enqueue(_numerical_metric_event()) time.sleep(0.1) mock_writer_logs.debug.assert_has_calls( - [mock.call("sent %d LLMObs %s events to %s", (2, "evaluation_metric", INTAKE_ENDPOINT))] + [mock.call("sent %d LLMObs %s events to %s", 2, "evaluation_metric", INTAKE_ENDPOINT)] ) diff --git a/tests/llmobs/test_llmobs_service.py b/tests/llmobs/test_llmobs_service.py index 88941a275e2..dfaef69c146 100644 --- a/tests/llmobs/test_llmobs_service.py +++ b/tests/llmobs/test_llmobs_service.py @@ -17,6 +17,7 @@ from ddtrace.llmobs._constants import SPAN_KIND from ddtrace.llmobs._constants import TAGS from ddtrace.llmobs._llmobs import LLMObsTraceProcessor +from tests.llmobs._utils import _expected_llmobs_eval_metric_event from tests.llmobs._utils import _expected_llmobs_llm_span_event from tests.llmobs._utils import _expected_llmobs_non_llm_span_event from tests.utils import DummyTracer @@ -33,7 +34,7 @@ def mock_logs(): yield mock_logs -def test_llmobs_service_enable(): +def test_service_enable(): with override_global_config(dict(_dd_api_key="", _llmobs_ml_app="")): dummy_tracer = DummyTracer() llmobs_service.enable(tracer=dummy_tracer) @@ -45,7 +46,7 @@ def test_llmobs_service_enable(): llmobs_service.disable() -def test_llmobs_service_disable(): +def test_service_disable(): with override_global_config(dict(_dd_api_key="", _llmobs_ml_app="")): dummy_tracer = DummyTracer() llmobs_service.enable(tracer=dummy_tracer) @@ -54,7 +55,7 @@ def test_llmobs_service_disable(): assert llmobs_service.enabled is False -def test_llmobs_service_enable_no_api_key(): +def test_service_enable_no_api_key(): with override_global_config(dict(_dd_api_key="", _llmobs_ml_app="")): dummy_tracer = DummyTracer() with pytest.raises(ValueError): @@ -64,7 +65,7 @@ def test_llmobs_service_enable_no_api_key(): assert llmobs_service.enabled is False -def test_llmobs_service_enable_no_ml_app_specified(): +def test_service_enable_no_ml_app_specified(): with override_global_config(dict(_dd_api_key="", _llmobs_ml_app="")): dummy_tracer = DummyTracer() with pytest.raises(ValueError): @@ -74,7 +75,7 @@ def test_llmobs_service_enable_no_ml_app_specified(): assert llmobs_service.enabled is False -def test_llmobs_service_enable_already_enabled(mock_logs): +def test_service_enable_already_enabled(mock_logs): with override_global_config(dict(_dd_api_key="", _llmobs_ml_app="")): dummy_tracer = DummyTracer() llmobs_service.enable(tracer=dummy_tracer) @@ -88,7 +89,7 @@ def test_llmobs_service_enable_already_enabled(mock_logs): mock_logs.debug.assert_has_calls([mock.call("%s already enabled", "LLMObs")]) -def test_llmobs_start_span_while_disabled_logs_warning(LLMObs, mock_logs): +def test_start_span_while_disabled_logs_warning(LLMObs, mock_logs): LLMObs.disable() _ = LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") mock_logs.warning.assert_called_once_with("LLMObs.llm() cannot be used while LLMObs is disabled.") @@ -106,7 +107,7 @@ def test_llmobs_start_span_while_disabled_logs_warning(LLMObs, mock_logs): mock_logs.warning.assert_called_once_with("LLMObs.agent() cannot be used while LLMObs is disabled.") -def test_llmobs_start_span_uses_kind_as_default_name(LLMObs): +def test_start_span_uses_kind_as_default_name(LLMObs): with LLMObs.llm(model_name="test_model", model_provider="test_provider") as span: assert span.name == "llm" with LLMObs.tool() as span: @@ -119,7 +120,7 @@ def test_llmobs_start_span_uses_kind_as_default_name(LLMObs): assert span.name == "agent" -def test_llmobs_start_span_with_session_id(LLMObs): +def test_start_span_with_session_id(LLMObs): with LLMObs.llm(model_name="test_model", session_id="test_session_id") as span: assert span.get_tag(SESSION_ID) == "test_session_id" with LLMObs.tool(session_id="test_session_id") as span: @@ -132,7 +133,7 @@ def test_llmobs_start_span_with_session_id(LLMObs): assert span.get_tag(SESSION_ID) == "test_session_id" -def test_llmobs_session_id_becomes_top_level_field(LLMObs, mock_llmobs_span_writer): +def test_session_id_becomes_top_level_field(LLMObs, mock_llmobs_span_writer): session_id = "test_session_id" with LLMObs.task(session_id=session_id) as span: pass @@ -141,7 +142,7 @@ def test_llmobs_session_id_becomes_top_level_field(LLMObs, mock_llmobs_span_writ ) -def test_llmobs_llm_span(LLMObs, mock_llmobs_span_writer): +def test_llm_span(LLMObs, mock_llmobs_span_writer): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: assert span.name == "test_llm_call" assert span.resource == "llm" @@ -156,18 +157,18 @@ def test_llmobs_llm_span(LLMObs, mock_llmobs_span_writer): ) -def test_llmobs_llm_span_no_model_raises_error(LLMObs, mock_logs): +def test_llm_span_no_model_raises_error(LLMObs, mock_logs): with pytest.raises(TypeError): with LLMObs.llm(name="test_llm_call", model_provider="test_provider"): pass -def test_llmobs_llm_span_empty_model_name_logs_warning(LLMObs, mock_logs): +def test_llm_span_empty_model_name_logs_warning(LLMObs, mock_logs): _ = LLMObs.llm(model_name="", name="test_llm_call", model_provider="test_provider") mock_logs.warning.assert_called_once_with("model_name must be the specified name of the invoked model.") -def test_llmobs_default_model_provider_set_to_custom(LLMObs): +def test_default_model_provider_set_to_custom(LLMObs): with LLMObs.llm(model_name="test_model", name="test_llm_call") as span: assert span.name == "test_llm_call" assert span.resource == "llm" @@ -177,7 +178,7 @@ def test_llmobs_default_model_provider_set_to_custom(LLMObs): assert span.get_tag(MODEL_PROVIDER) == "custom" -def test_llmobs_tool_span(LLMObs, mock_llmobs_span_writer): +def test_tool_span(LLMObs, mock_llmobs_span_writer): with LLMObs.tool(name="test_tool") as span: assert span.name == "test_tool" assert span.resource == "tool" @@ -186,7 +187,7 @@ def test_llmobs_tool_span(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) -def test_llmobs_task_span(LLMObs, mock_llmobs_span_writer): +def test_task_span(LLMObs, mock_llmobs_span_writer): with LLMObs.task(name="test_task") as span: assert span.name == "test_task" assert span.resource == "task" @@ -195,7 +196,7 @@ def test_llmobs_task_span(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) -def test_llmobs_workflow_span(LLMObs, mock_llmobs_span_writer): +def test_workflow_span(LLMObs, mock_llmobs_span_writer): with LLMObs.workflow(name="test_workflow") as span: assert span.name == "test_workflow" assert span.resource == "workflow" @@ -204,7 +205,7 @@ def test_llmobs_workflow_span(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) -def test_llmobs_agent_span(LLMObs, mock_llmobs_span_writer): +def test_agent_span(LLMObs, mock_llmobs_span_writer): with LLMObs.agent(name="test_agent") as span: assert span.name == "test_agent" assert span.resource == "agent" @@ -213,32 +214,32 @@ def test_llmobs_agent_span(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) -def test_llmobs_annotate_while_disabled_logs_warning(LLMObs, mock_logs): +def test_annotate_while_disabled_logs_warning(LLMObs, mock_logs): LLMObs.disable() LLMObs.annotate(parameters={"test": "test"}) mock_logs.warning.assert_called_once_with("LLMObs.annotate() cannot be used while LLMObs is disabled.") -def test_llmobs_annotate_no_active_span_logs_warning(LLMObs, mock_logs): +def test_annotate_no_active_span_logs_warning(LLMObs, mock_logs): LLMObs.annotate(parameters={"test": "test"}) - mock_logs.warning.assert_called_once_with("No span provided and no active span found.") + mock_logs.warning.assert_called_once_with("No span provided and no active LLMObs-generated span found.") -def test_llmobs_annotate_non_llm_span_logs_warning(LLMObs, mock_logs): +def test_annotate_non_llm_span_logs_warning(LLMObs, mock_logs): dummy_tracer = DummyTracer() with dummy_tracer.trace("root") as non_llmobs_span: LLMObs.annotate(span=non_llmobs_span, parameters={"test": "test"}) - mock_logs.warning.assert_called_once_with("Span must be an LLM-type span.") + mock_logs.warning.assert_called_once_with("Span must be an LLMObs-generated span.") -def test_llmobs_annotate_finished_span_does_nothing(LLMObs, mock_logs): +def test_annotate_finished_span_does_nothing(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: pass LLMObs.annotate(span=span, parameters={"test": "test"}) mock_logs.warning.assert_called_once_with("Cannot annotate a finished span.") -def test_llmobs_annotate_parameters(LLMObs, mock_logs): +def test_annotate_parameters(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: LLMObs.annotate(span=span, parameters={"temperature": 0.9, "max_tokens": 50}) assert json.loads(span.get_tag(INPUT_PARAMETERS)) == {"temperature": 0.9, "max_tokens": 50} @@ -247,13 +248,13 @@ def test_llmobs_annotate_parameters(LLMObs, mock_logs): ) -def test_llmobs_annotate_metadata(LLMObs): +def test_annotate_metadata(LLMObs): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: LLMObs.annotate(span=span, metadata={"temperature": 0.5, "max_tokens": 20, "top_k": 10, "n": 3}) assert json.loads(span.get_tag(METADATA)) == {"temperature": 0.5, "max_tokens": 20, "top_k": 10, "n": 3} -def test_llmobs_annotate_metadata_wrong_type(LLMObs, mock_logs): +def test_annotate_metadata_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: LLMObs.annotate(span=span, metadata="wrong_metadata") assert span.get_tag(METADATA) is None @@ -267,13 +268,13 @@ def test_llmobs_annotate_metadata_wrong_type(LLMObs, mock_logs): ) -def test_llmobs_annotate_tag(LLMObs): +def test_annotate_tag(LLMObs): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: LLMObs.annotate(span=span, tags={"test_tag_name": "test_tag_value", "test_numeric_tag": 10}) assert json.loads(span.get_tag(TAGS)) == {"test_tag_name": "test_tag_value", "test_numeric_tag": 10} -def test_llmobs_annotate_tag_wrong_type(LLMObs, mock_logs): +def test_annotate_tag_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: LLMObs.annotate(span=span, tags=12345) assert span.get_tag(TAGS) is None @@ -289,7 +290,7 @@ def test_llmobs_annotate_tag_wrong_type(LLMObs, mock_logs): ) -def test_llmobs_annotate_input_string(LLMObs): +def test_annotate_input_string(LLMObs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, input_data="test_input") assert json.loads(llm_span.get_tag(INPUT_MESSAGES)) == [{"content": "test_input"}] @@ -307,7 +308,7 @@ def test_llmobs_annotate_input_string(LLMObs): assert agent_span.get_tag(INPUT_VALUE) == "test_input" -def test_llmobs_annotate_input_serializable_value(LLMObs): +def test_annotate_input_serializable_value(LLMObs): with LLMObs.task() as task_span: LLMObs.annotate(span=task_span, input_data=["test_input"]) assert task_span.get_tag(INPUT_VALUE) == '["test_input"]' @@ -322,20 +323,20 @@ def test_llmobs_annotate_input_serializable_value(LLMObs): assert agent_span.get_tag(INPUT_VALUE) == "test_input" -def test_llmobs_annotate_input_value_wrong_type(LLMObs, mock_logs): +def test_annotate_input_value_wrong_type(LLMObs, mock_logs): with LLMObs.workflow() as llm_span: LLMObs.annotate(span=llm_span, input_data=Unserializable()) assert llm_span.get_tag(INPUT_VALUE) is None mock_logs.warning.assert_called_once_with("Failed to parse input value. Input value must be JSON serializable.") -def test_llmobs_annotate_input_llm_message(LLMObs): +def test_annotate_input_llm_message(LLMObs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, input_data=[{"content": "test_input", "role": "human"}]) assert json.loads(llm_span.get_tag(INPUT_MESSAGES)) == [{"content": "test_input", "role": "human"}] -def test_llmobs_annotate_input_llm_message_wrong_type(LLMObs, mock_logs): +def test_annotate_input_llm_message_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, input_data=[{"content": Unserializable()}]) assert llm_span.get_tag(INPUT_MESSAGES) is None @@ -351,7 +352,7 @@ def test_llmobs_annotate_incorrect_message_content_type_raises_warning(LLMObs, m mock_logs.warning.assert_called_once_with("Failed to parse output messages.", exc_info=True) -def test_llmobs_annotate_output_string(LLMObs): +def test_annotate_output_string(LLMObs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, output_data="test_output") assert json.loads(llm_span.get_tag(OUTPUT_MESSAGES)) == [{"content": "test_output"}] @@ -369,7 +370,7 @@ def test_llmobs_annotate_output_string(LLMObs): assert agent_span.get_tag(OUTPUT_VALUE) == "test_output" -def test_llmobs_annotate_output_serializable_value(LLMObs): +def test_annotate_output_serializable_value(LLMObs): with LLMObs.task() as task_span: LLMObs.annotate(span=task_span, output_data=["test_output"]) assert task_span.get_tag(OUTPUT_VALUE) == '["test_output"]' @@ -384,7 +385,7 @@ def test_llmobs_annotate_output_serializable_value(LLMObs): assert agent_span.get_tag(OUTPUT_VALUE) == "test_output" -def test_llmobs_annotate_output_value_wrong_type(LLMObs, mock_logs): +def test_annotate_output_value_wrong_type(LLMObs, mock_logs): with LLMObs.workflow() as llm_span: LLMObs.annotate(span=llm_span, output_data=Unserializable()) assert llm_span.get_tag(OUTPUT_VALUE) is None @@ -393,26 +394,26 @@ def test_llmobs_annotate_output_value_wrong_type(LLMObs, mock_logs): ) -def test_llmobs_annotate_output_llm_message(LLMObs): +def test_annotate_output_llm_message(LLMObs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, output_data=[{"content": "test_output", "role": "human"}]) assert json.loads(llm_span.get_tag(OUTPUT_MESSAGES)) == [{"content": "test_output", "role": "human"}] -def test_llmobs_annotate_output_llm_message_wrong_type(LLMObs, mock_logs): +def test_annotate_output_llm_message_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, output_data=[{"content": Unserializable()}]) assert llm_span.get_tag(OUTPUT_MESSAGES) is None mock_logs.warning.assert_called_once_with("Failed to parse output messages.", exc_info=True) -def test_llmobs_annotate_metrics(LLMObs): +def test_annotate_metrics(LLMObs): with LLMObs.llm(model_name="test_model") as span: LLMObs.annotate(span=span, metrics={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}) assert json.loads(span.get_tag(METRICS)) == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} -def test_llmobs_annotate_metrics_wrong_type(LLMObs, mock_logs): +def test_annotate_metrics_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, metrics=12345) assert llm_span.get_tag(METRICS) is None @@ -426,7 +427,7 @@ def test_llmobs_annotate_metrics_wrong_type(LLMObs, mock_logs): ) -def test_llmobs_span_error_sets_error(LLMObs, mock_llmobs_span_writer): +def test_span_error_sets_error(LLMObs, mock_llmobs_span_writer): with pytest.raises(ValueError): with LLMObs.llm(model_name="test_model", model_provider="test_model_provider") as span: raise ValueError("test error message") @@ -446,7 +447,7 @@ def test_llmobs_span_error_sets_error(LLMObs, mock_llmobs_span_writer): "ddtrace_global_config", [dict(version="1.2.3", env="test_env", service="test_service", _llmobs_ml_app="test_app_name")], ) -def test_llmobs_tags(ddtrace_global_config, LLMObs, mock_llmobs_span_writer, monkeypatch): +def test_tags(ddtrace_global_config, LLMObs, mock_llmobs_span_writer, monkeypatch): with LLMObs.task(name="test_task") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( @@ -458,7 +459,7 @@ def test_llmobs_tags(ddtrace_global_config, LLMObs, mock_llmobs_span_writer, mon ) -def test_llmobs_ml_app_override(LLMObs, mock_llmobs_span_writer): +def test_ml_app_override(LLMObs, mock_llmobs_span_writer): with LLMObs.task(name="test_task", ml_app="test_app") as span: pass mock_llmobs_span_writer.enqueue.assert_called_with( @@ -490,3 +491,185 @@ def test_llmobs_ml_app_override(LLMObs, mock_llmobs_span_writer): mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_app"}) ) + + +def test_export_span_llmobs_not_enabled_raises_warning(LLMObs, mock_logs): + LLMObs.disable() + LLMObs.export_span() + mock_logs.warning.assert_called_once_with("LLMObs.export_span() requires LLMObs to be enabled.") + + +def test_export_span_specified_span_is_incorrect_type_raises_warning(LLMObs, mock_logs): + LLMObs.export_span(span="asd") + mock_logs.warning.assert_called_once_with("Failed to export span. Span must be a valid Span object.") + + +def test_export_span_specified_span_is_not_llmobs_span_raises_warning(LLMObs, mock_logs): + with DummyTracer().trace("non_llmobs_span") as span: + LLMObs.export_span(span=span) + mock_logs.warning.assert_called_once_with("Span must be an LLMObs-generated span.") + + +def test_export_span_specified_span_returns_span_context(LLMObs): + with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: + span_context = LLMObs.export_span(span=span) + assert span_context is not None + assert span_context["span_id"] == str(span.span_id) + assert span_context["trace_id"] == "{:x}".format(span.trace_id) + + +def test_export_span_no_specified_span_no_active_span_raises_warning(LLMObs, mock_logs): + LLMObs.export_span() + mock_logs.warning.assert_called_once_with("No span provided and no active LLMObs-generated span found.") + + +def test_export_span_active_span_not_llmobs_span_raises_warning(LLMObs, mock_logs): + with LLMObs._instance.tracer.trace("non_llmobs_span"): + LLMObs.export_span() + mock_logs.warning.assert_called_once_with("Span must be an LLMObs-generated span.") + + +def test_export_span_no_specified_span_returns_exported_active_span(LLMObs): + with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: + span_context = LLMObs.export_span() + assert span_context is not None + assert span_context["span_id"] == str(span.span_id) + assert span_context["trace_id"] == "{:x}".format(span.trace_id) + + +def test_submit_evaluation_llmobs_disabled_raises_warning(LLMObs, mock_logs): + LLMObs.disable() + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="toxicity", metric_type="categorical", value="high" + ) + mock_logs.warning.assert_called_once_with("LLMObs.submit_evaluation() requires LLMObs to be enabled.") + + +def test_submit_evaluation_span_context_incorrect_type_raises_warning(LLMObs, mock_logs): + LLMObs.submit_evaluation(span_context="asd", label="toxicity", metric_type="categorical", value="high") + mock_logs.warning.assert_called_once_with( + "span_context must be a dictionary containing both span_id and trace_id keys. " + "LLMObs.export_span() can be used to generate this dictionary from a given span." + ) + + +def test_submit_evaluation_empty_span_or_trace_id_raises_warning(LLMObs, mock_logs): + LLMObs.submit_evaluation( + span_context={"trace_id": "456"}, label="toxicity", metric_type="categorical", value="high" + ) + mock_logs.warning.assert_called_once_with( + "span_id and trace_id must both be specified for the given evaluation metric to be submitted." + ) + mock_logs.reset_mock() + LLMObs.submit_evaluation(span_context={"span_id": "456"}, label="toxicity", metric_type="categorical", value="high") + mock_logs.warning.assert_called_once_with( + "span_id and trace_id must both be specified for the given evaluation metric to be submitted." + ) + + +def test_submit_evaluation_empty_label_raises_warning(LLMObs, mock_logs): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="", metric_type="categorical", value="high" + ) + mock_logs.warning.assert_called_once_with("label must be the specified name of the evaluation metric.") + + +def test_submit_evaluation_incorrect_metric_type_raises_warning(LLMObs, mock_logs): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="toxicity", metric_type="wrong", value="high" + ) + mock_logs.warning.assert_called_once_with("metric_type must be one of 'categorical', 'numerical', or 'score'.") + mock_logs.reset_mock() + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="toxicity", metric_type="", value="high" + ) + mock_logs.warning.assert_called_once_with("metric_type must be one of 'categorical', 'numerical', or 'score'.") + + +def test_submit_evaluation_incorrect_numerical_value_type_raises_warning(LLMObs, mock_logs): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="token_count", metric_type="numerical", value="high" + ) + mock_logs.warning.assert_called_once_with("value must be an integer or float for a numerical/score metric.") + + +def test_submit_evaluation_incorrect_score_value_type_raises_warning(LLMObs, mock_logs): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="token_count", metric_type="score", value="high" + ) + mock_logs.warning.assert_called_once_with("value must be an integer or float for a numerical/score metric.") + + +def test_submit_evaluation_enqueues_writer_with_categorical_metric(LLMObs, mock_llmobs_eval_metric_writer): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="toxicity", metric_type="categorical", value="high" + ) + mock_llmobs_eval_metric_writer.enqueue.assert_called_with( + _expected_llmobs_eval_metric_event( + span_id="123", trace_id="456", label="toxicity", metric_type="categorical", categorical_value="high" + ) + ) + mock_llmobs_eval_metric_writer.reset_mock() + with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: + LLMObs.submit_evaluation( + span_context=LLMObs.export_span(span), label="toxicity", metric_type="categorical", value="high" + ) + mock_llmobs_eval_metric_writer.enqueue.assert_called_with( + _expected_llmobs_eval_metric_event( + span_id=str(span.span_id), + trace_id="{:x}".format(span.trace_id), + label="toxicity", + metric_type="categorical", + categorical_value="high", + ) + ) + + +def test_submit_evaluation_enqueues_writer_with_score_metric(LLMObs, mock_llmobs_eval_metric_writer): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="sentiment", metric_type="score", value=0.9 + ) + mock_llmobs_eval_metric_writer.enqueue.assert_called_with( + _expected_llmobs_eval_metric_event( + span_id="123", trace_id="456", label="sentiment", metric_type="score", score_value=0.9 + ) + ) + mock_llmobs_eval_metric_writer.reset_mock() + with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: + LLMObs.submit_evaluation( + span_context=LLMObs.export_span(span), label="sentiment", metric_type="score", value=0.9 + ) + mock_llmobs_eval_metric_writer.enqueue.assert_called_with( + _expected_llmobs_eval_metric_event( + span_id=str(span.span_id), + trace_id="{:x}".format(span.trace_id), + label="sentiment", + metric_type="score", + score_value=0.9, + ) + ) + + +def test_submit_evaluation_enqueues_writer_with_numerical_metric(LLMObs, mock_llmobs_eval_metric_writer): + LLMObs.submit_evaluation( + span_context={"span_id": "123", "trace_id": "456"}, label="token_count", metric_type="numerical", value=35 + ) + mock_llmobs_eval_metric_writer.enqueue.assert_called_with( + _expected_llmobs_eval_metric_event( + span_id="123", trace_id="456", label="token_count", metric_type="numerical", numerical_value=35 + ) + ) + mock_llmobs_eval_metric_writer.reset_mock() + with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: + LLMObs.submit_evaluation( + span_context=LLMObs.export_span(span), label="token_count", metric_type="numerical", value=35 + ) + mock_llmobs_eval_metric_writer.enqueue.assert_called_with( + _expected_llmobs_eval_metric_event( + span_id=str(span.span_id), + trace_id="{:x}".format(span.trace_id), + label="token_count", + metric_type="numerical", + numerical_value=35, + ) + ) diff --git a/tests/llmobs/test_llmobs_span_writer.py b/tests/llmobs/test_llmobs_span_writer.py index 7032acad45f..4fc96ff5118 100644 --- a/tests/llmobs/test_llmobs_span_writer.py +++ b/tests/llmobs/test_llmobs_span_writer.py @@ -85,7 +85,7 @@ def _chat_completion_event(): def test_writer_start(mock_writer_logs): llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key="asdf", interval=1000, timeout=1) llmobs_span_writer.start() - mock_writer_logs.debug.assert_has_calls([mock.call("started %r to %r", ("LLMObsSpanWriter", INTAKE_ENDPOINT))]) + mock_writer_logs.debug.assert_has_calls([mock.call("started %r to %r", "LLMObsSpanWriter", INTAKE_ENDPOINT)]) def test_buffer_limit(mock_writer_logs): @@ -93,7 +93,7 @@ def test_buffer_limit(mock_writer_logs): for _ in range(1001): llmobs_span_writer.enqueue({}) mock_writer_logs.warning.assert_called_with( - "%r event buffer full (limit is %d), dropping event", ("LLMObsSpanWriter", 1000) + "%r event buffer full (limit is %d), dropping event", "LLMObsSpanWriter", 1000 ) @@ -103,7 +103,7 @@ def test_send_completion_event(mock_writer_logs): llmobs_span_writer.start() llmobs_span_writer.enqueue(_completion_event()) llmobs_span_writer.periodic() - mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", 1, "span", INTAKE_ENDPOINT)]) @pytest.mark.vcr_logs @@ -112,7 +112,7 @@ def test_send_chat_completion_event(mock_writer_logs): llmobs_span_writer.start() llmobs_span_writer.enqueue(_chat_completion_event()) llmobs_span_writer.periodic() - mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", 1, "span", INTAKE_ENDPOINT)]) @pytest.mark.vcr_logs @@ -123,13 +123,11 @@ def test_send_completion_bad_api_key(mock_writer_logs): llmobs_span_writer.periodic() mock_writer_logs.error.assert_called_with( "failed to send %d LLMObs %s events to %s, got response code %d, status: %s", - ( - 1, - "span", - INTAKE_ENDPOINT, - 403, - b'{"errors":[{"status":"403","title":"Forbidden","detail":"API key is invalid"}]}', - ), + 1, + "span", + INTAKE_ENDPOINT, + 403, + b'{"errors":[{"status":"403","title":"Forbidden","detail":"API key is invalid"}]}', ) @@ -141,11 +139,11 @@ def test_send_timed_events(mock_writer_logs): llmobs_span_writer.enqueue(_completion_event()) time.sleep(0.1) - mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", 1, "span", INTAKE_ENDPOINT)]) mock_writer_logs.reset_mock() llmobs_span_writer.enqueue(_chat_completion_event()) time.sleep(0.1) - mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", 1, "span", INTAKE_ENDPOINT)]) @pytest.mark.vcr_logs @@ -157,7 +155,7 @@ def test_send_multiple_events(mock_writer_logs): llmobs_span_writer.enqueue(_completion_event()) llmobs_span_writer.enqueue(_chat_completion_event()) time.sleep(0.1) - mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (2, "span", INTAKE_ENDPOINT))]) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", 2, "span", INTAKE_ENDPOINT)]) def test_send_on_exit(mock_writer_logs, run_python_code_in_subprocess): diff --git a/tests/telemetry/test_writer.py b/tests/telemetry/test_writer.py index fbc56869cb6..18699170152 100644 --- a/tests/telemetry/test_writer.py +++ b/tests/telemetry/test_writer.py @@ -128,7 +128,6 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): {"name": "DD_TRACE_PROPAGATION_STYLE_INJECT", "origin": "unknown", "value": "datadog,tracecontext"}, {"name": "DD_TRACE_RATE_LIMIT", "origin": "unknown", "value": 100}, {"name": "DD_TRACE_REMOVE_INTEGRATION_SERVICE_NAMES_ENABLED", "origin": "unknown", "value": False}, - {"name": "DD_TRACE_SAMPLING_RULES", "origin": "unknown", "value": None}, {"name": "DD_TRACE_SPAN_ATTRIBUTE_SCHEMA", "origin": "unknown", "value": "v0"}, {"name": "DD_TRACE_STARTUP_LOGS", "origin": "unknown", "value": False}, {"name": "DD_TRACE_WRITER_BUFFER_SIZE_BYTES", "origin": "unknown", "value": 20 << 20}, @@ -142,6 +141,7 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): {"name": "data_streams_enabled", "origin": "default", "value": "false"}, {"name": "appsec_enabled", "origin": "default", "value": "false"}, {"name": "trace_sample_rate", "origin": "default", "value": "1.0"}, + {"name": "trace_sampling_rules", "origin": "default", "value": ""}, {"name": "trace_header_tags", "origin": "default", "value": ""}, {"name": "logs_injection_enabled", "origin": "default", "value": "false"}, {"name": "trace_tags", "origin": "default", "value": ""}, @@ -292,11 +292,6 @@ def test_app_started_event_configuration_override( {"name": "DD_TRACE_PROPAGATION_STYLE_INJECT", "origin": "unknown", "value": "tracecontext"}, {"name": "DD_TRACE_RATE_LIMIT", "origin": "unknown", "value": 50}, {"name": "DD_TRACE_REMOVE_INTEGRATION_SERVICE_NAMES_ENABLED", "origin": "unknown", "value": True}, - { - "name": "DD_TRACE_SAMPLING_RULES", - "origin": "unknown", - "value": '[{"sample_rate":1.0,"service":"xyz","name":"abc"}]', - }, {"name": "DD_TRACE_SPAN_ATTRIBUTE_SCHEMA", "origin": "unknown", "value": "v1"}, {"name": "DD_TRACE_STARTUP_LOGS", "origin": "unknown", "value": True}, {"name": "DD_TRACE_WRITER_BUFFER_SIZE_BYTES", "origin": "unknown", "value": 1000}, @@ -310,6 +305,11 @@ def test_app_started_event_configuration_override( {"name": "data_streams_enabled", "origin": "env_var", "value": "true"}, {"name": "appsec_enabled", "origin": "env_var", "value": "true"}, {"name": "trace_sample_rate", "origin": "env_var", "value": "0.5"}, + { + "name": "trace_sampling_rules", + "origin": "env_var", + "value": '[{"sample_rate":1.0,"service":"xyz","name":"abc"}]', + }, {"name": "logs_injection_enabled", "origin": "env_var", "value": "true"}, {"name": "trace_header_tags", "origin": "default", "value": ""}, {"name": "trace_tags", "origin": "env_var", "value": "team:apm,component:web"},