diff --git a/.circleci/config.templ.yml b/.circleci/config.templ.yml index ce53bc0181c..1cf66d7a209 100644 --- a/.circleci/config.templ.yml +++ b/.circleci/config.templ.yml @@ -14,7 +14,7 @@ mysql_image: &mysql_image mysql:5.7@sha256:03b6dcedf5a2754da00e119e2cc6094ed3c88 postgres_image: &postgres_image postgres:12-alpine@sha256:c6704f41eb84be53d5977cb821bf0e5e876064b55eafef1e260c2574de40ad9a mongo_image: &mongo_image mongo:3.6@sha256:19c11a8f1064fd2bb713ef1270f79a742a184cd57d9bb922efdd2a8eca514af8 httpbin_image: &httpbin_image kennethreitz/httpbin@sha256:2c7abc4803080c22928265744410173b6fea3b898872c01c5fd0f0f9df4a59fb -vertica_image: &vertica_image sumitchawla/vertica:latest +vertica_image: &vertica_image vertica/vertica-ce:latest rabbitmq_image: &rabbitmq_image rabbitmq:3.7-alpine testagent_image: &testagent_image ghcr.io/datadog/dd-apm-test-agent/ddapm-test-agent:v1.16.0 @@ -1227,21 +1227,6 @@ jobs: snapshot: true docker_services: "httpbin_local" - vertica: - <<: *contrib_job - docker: - - image: *ddtrace_dev_image - - *testagent - - image: *vertica_image - environment: - - VP_TEST_USER=dbadmin - - VP_TEST_PASSWORD=abc123 - - VP_TEST_DATABASE=docker - steps: - - run_test: - wait: vertica - pattern: 'vertica' - wsgi: <<: *machine_executor steps: diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1c9d99baa5c..23022df324d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,13 +4,13 @@ * @DataDog/apm-core-python # Framework Integrations -ddtrace/ext/ @DataDog/apm-core-python @DataDog/apm-framework-integrations -ddtrace/contrib/ @DataDog/apm-core-python @DataDog/apm-framework-integrations -ddtrace/internal/schema/ @DataDog/apm-core-python @DataDog/apm-framework-integrations -tests/contrib/ @DataDog/apm-core-python @DataDog/apm-framework-integrations -tests/internal/peer_service @DataDog/apm-core-python @DataDog/apm-framework-integrations -tests/internal/service_name @DataDog/apm-core-python @DataDog/apm-framework-integrations -tests/contrib/grpc @DataDog/apm-framework-integrations @DataDog/asm-python +ddtrace/ext/ @DataDog/apm-core-python @DataDog/apm-idm-python +ddtrace/contrib/ @DataDog/apm-core-python @DataDog/apm-idm-python +ddtrace/internal/schema/ @DataDog/apm-core-python @DataDog/apm-idm-python +tests/contrib/ @DataDog/apm-core-python @DataDog/apm-idm-python +tests/internal/peer_service @DataDog/apm-core-python @DataDog/apm-idm-python +tests/internal/service_name @DataDog/apm-core-python @DataDog/apm-idm-python +tests/contrib/grpc @DataDog/apm-idm-python @DataDog/asm-python # Files which can be approved by anyone # DEV: This helps not requiring apm-core-python to review new files added diff --git a/.github/workflows/build_deploy.yml b/.github/workflows/build_deploy.yml index 018cc7b2ac4..4a6775f33bf 100644 --- a/.github/workflows/build_deploy.yml +++ b/.github/workflows/build_deploy.yml @@ -9,18 +9,6 @@ on: # before merging/releasing - build_deploy* pull_request: - paths: - - ".github/workflows/build_deploy.yml" - - ".github/workflows/build_python_3.yml" - - "setup.py" - - "setup.cfg" - - "pyproject.toml" - - "**.c" - - "**.h" - - "**.cpp" - - "**.hpp" - - "**.pyx" - - "ddtrace/vendor/**" release: types: - published diff --git a/.github/workflows/build_python_3.yml b/.github/workflows/build_python_3.yml index c46a704a25a..fe6e2279087 100644 --- a/.github/workflows/build_python_3.yml +++ b/.github/workflows/build_python_3.yml @@ -21,7 +21,7 @@ jobs: include: - os: ubuntu-latest archs: x86_64 i686 - - os: ubuntu-latest + - os: arm-4core-linux archs: aarch64 - os: windows-latest archs: AMD64 x86 @@ -34,17 +34,63 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v4 + if: matrix.os != 'arm-4core-linux' name: Install Python with: python-version: '3.8' + - name: Install docker and pipx + if: matrix.os == 'arm-4core-linux' + # The ARM64 Ubuntu has less things installed by default + # We need docker, pip and venv for cibuildwheel + # acl allows us to use docker in the same session + run: | + curl -fsSL https://get.docker.com -o get-docker.sh + sudo sh get-docker.sh + sudo usermod -a -G docker $USER + sudo apt install -y acl python3.10-venv python3-pip + sudo setfacl --modify user:runner:rw /var/run/docker.sock + python3 -m pip install pipx + - name: Set up QEMU - if: runner.os == 'Linux' + if: runner.os == 'Linux' && matrix.os != 'arm-4core-linux' uses: docker/setup-qemu-action@v2 with: platforms: all + - name: Build wheels arm64 + if: matrix.os == 'arm-4core-linux' + run: /home/runner/.local/bin/pipx run cibuildwheel==2.16.5 --platform linux + env: + # configure cibuildwheel to build native archs ('auto'), and some + # emulated ones + CIBW_ARCHS: ${{ matrix.archs }} + CIBW_BUILD: ${{ inputs.cibw_build }} + CIBW_SKIP: ${{ inputs.cibw_skip }} + CIBW_PRERELEASE_PYTHONS: ${{ inputs.cibw_prerelease_pythons }} + CMAKE_BUILD_PARALLEL_LEVEL: 12 + CIBW_REPAIR_WHEEL_COMMAND_LINUX: | + mkdir ./tempwheelhouse && + unzip -l {wheel} | grep '\.so' && + auditwheel repair -w ./tempwheelhouse {wheel} && + (yum install -y zip || apk add zip) && + for w in ./tempwheelhouse/*.whl; do + zip -d $w \*.c \*.cpp \*.cc \*.h \*.hpp \*.pyx + mv $w {dest_dir} + done && + rm -rf ./tempwheelhouse + CIBW_REPAIR_WHEEL_COMMAND_MACOS: | + zip -d {wheel} \*.c \*.cpp \*.cc \*.h \*.hpp \*.pyx && + delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel} + CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: + choco install -y 7zip && + 7z d -r "{wheel}" *.c *.cpp *.cc *.h *.hpp *.pyx && + move "{wheel}" "{dest_dir}" + # DEV: Uncomment to debug MacOS + # CIBW_BUILD_VERBOSITY_MACOS: 3 + - name: Build wheels + if: matrix.os != 'arm-4core-linux' uses: pypa/cibuildwheel@v2.16.5 env: # configure cibuildwheel to build native archs ('auto'), and some @@ -73,6 +119,7 @@ jobs: move "{wheel}" "{dest_dir}" # DEV: Uncomment to debug MacOS # CIBW_BUILD_VERBOSITY_MACOS: 3 + - uses: actions/upload-artifact@v3 with: path: ./wheelhouse/*.whl diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e1cf66ffb83..071dde14005 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -3,10 +3,12 @@ stages: - deploy - benchmarks - benchmarks-pr-comment + - macrobenchmarks include: - remote: https://gitlab-templates.ddbuild.io/apm/packaging.yml - local: ".gitlab/benchmarks.yml" + - local: ".gitlab/macrobenchmarks.yml" variables: DOWNSTREAM_BRANCH: @@ -30,6 +32,7 @@ package: when: on_success script: - ../.gitlab/build-deb-rpm.sh + - find . -iregex '.*\.\(deb\|rpm\)' -printf '%f\0' | xargs -0 dd-pkg lint package-arm: extends: .package-arm @@ -40,6 +43,7 @@ package-arm: when: on_success script: - ../.gitlab/build-deb-rpm.sh + - find . -iregex '.*\.\(deb\|rpm\)' -printf '%f\0' | xargs -0 dd-pkg lint .release-package: stage: deploy diff --git a/.gitlab/macrobenchmarks.yml b/.gitlab/macrobenchmarks.yml new file mode 100644 index 00000000000..16cf2b3b9be --- /dev/null +++ b/.gitlab/macrobenchmarks.yml @@ -0,0 +1,86 @@ +variables: + BASE_CI_IMAGE: 486234852809.dkr.ecr.us-east-1.amazonaws.com/ci/benchmarking-platform:dd-trace-py-macrobenchmarks + +.macrobenchmarks: + stage: macrobenchmarks + needs: [] + tags: ["runner:apm-k8s-same-cpu"] + timeout: 1h + rules: + - if: $CI_PIPELINE_SOURCE == "schedule" + when: always + - when: manual + ## Next step, enable: + # - if: $CI_COMMIT_REF_NAME == "main" + # when: always + # If you have a problem with Gitlab cache, see Troubleshooting section in Benchmarking Platform docs + image: $BENCHMARKS_CI_IMAGE + script: | + git clone --branch python/macrobenchmarks https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.ddbuild.io/DataDog/benchmarking-platform platform && cd platform + if [ "$BP_PYTHON_SCENARIO_DIR" == "flask-realworld" ]; then + bp-runner bp-runner.flask-realworld.yml --debug + else + bp-runner bp-runner.simple.yml --debug + fi + artifacts: + name: "artifacts" + when: always + paths: + - platform/artifacts/ + expire_in: 3 months + variables: + # Benchmark's env variables. Modify to tweak benchmark parameters. + DD_TRACE_DEBUG: "false" + DD_RUNTIME_METRICS_ENABLED: "true" + DD_REMOTE_CONFIGURATION_ENABLED: "false" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "false" + + K6_OPTIONS_NORMAL_OPERATION_RATE: 40 + K6_OPTIONS_NORMAL_OPERATION_DURATION: 5m + K6_OPTIONS_NORMAL_OPERATION_GRACEFUL_STOP: 1m + K6_OPTIONS_NORMAL_OPERATION_PRE_ALLOCATED_VUS: 4 + K6_OPTIONS_NORMAL_OPERATION_MAX_VUS: 4 + + K6_OPTIONS_HIGH_LOAD_RATE: 500 + K6_OPTIONS_HIGH_LOAD_DURATION: 1m + K6_OPTIONS_HIGH_LOAD_GRACEFUL_STOP: 30s + K6_OPTIONS_HIGH_LOAD_PRE_ALLOCATED_VUS: 4 + K6_OPTIONS_HIGH_LOAD_MAX_VUS: 4 + + # Gitlab and BP specific env vars. Do not modify. + FF_USE_LEGACY_KUBERNETES_EXECUTION_STRATEGY: "true" + + # Workaround: Currently we're not running the benchmarks on every PR, but GitHub still shows them as pending. + # By marking the benchmarks as allow_failure, this should go away. (This workaround should be removed once the + # benchmarks get changed to run on every PR) + allow_failure: true + +macrobenchmarks: + extends: .macrobenchmarks + parallel: + matrix: + - DD_BENCHMARKS_CONFIGURATION: baseline + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + DD_REMOTE_CONFIGURATION_ENABLED: "false" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "true" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + DD_REMOTE_CONFIGURATION_ENABLED: "false" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "false" + + - DD_BENCHMARKS_CONFIGURATION: only-tracing + BP_PYTHON_SCENARIO_DIR: flask-realworld + DDTRACE_INSTALL_VERSION: "git+https://github.com/Datadog/dd-trace-py@${CI_COMMIT_SHA}" + DD_REMOTE_CONFIGURATION_ENABLED: "true" + DD_INSTRUMENTATION_TELEMETRY_ENABLED: "true" diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c3cffbf56f..47047c4a1c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,38 @@ Changelogs for versions not listed here can be found at https://github.com/DataDog/dd-trace-py/releases +--- + +## 2.7.10 + +### Bug Fixes + +- Code Security: This fix solves an issue with fstrings where formatting was not applied to int parameters +- logging: This fix resolves an issue where `tracer.get_log_correlation_context()` incorrectly returned a 128-bit trace_id even with `DD_TRACE_128_BIT_TRACEID_LOGGING_ENABLED` set to `False` (the default), breaking log correlation. It now returns a 64-bit trace_id. +- profiling: Fixes a defect where the deprecated path to the Datadog span type was used by the profiler. + +--- + +## 2.8.3 + + +### Bug Fixes + +- Code Security: This fix solves an issue with fstrings where formatting was not applied to int parameters +- logging: This fix resolves an issue where `tracer.get_log_correlation_context()` incorrectly returned a 128-bit trace_id even with `DD_TRACE_128_BIT_TRACEID_LOGGING_ENABLED` set to `False` (the default), breaking log correlation. It now returns a 64-bit trace_id. +- profiling: Fixes a defect where the deprecated path to the Datadog span type was used by the profiler. + + +--- + +## 2.6.12 + + +### Bug Fixes + +- Code Security: This fix solves an issue with fstrings where formatting was not applied to int parameters + + --- ## 2.8.2 diff --git a/benchmarks/appsec_iast_propagation/scenario.py b/benchmarks/appsec_iast_propagation/scenario.py index 56ac67128b7..ec827b7bb21 100644 --- a/benchmarks/appsec_iast_propagation/scenario.py +++ b/benchmarks/appsec_iast_propagation/scenario.py @@ -1,8 +1,7 @@ from typing import Any # noqa:F401 import bm - -from tests.utils import override_env +from bm.utils import override_env with override_env({"DD_IAST_ENABLED": "True"}): @@ -42,7 +41,7 @@ def aspect_function(internal_loop, tainted): value = "" res = value for _ in range(internal_loop): - res = add_aspect(res, join_aspect(str.join, 1, "_", (tainted, "_", tainted))) + res = add_aspect(res, join_aspect("_".join, 1, "_", (tainted, "_", tainted))) value = res res = add_aspect(res, tainted) value = res diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index 9ef08f7c3b6..f439f87784a 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -8,6 +8,7 @@ from ddtrace import config from ddtrace._trace.span import Span +from ddtrace._trace.utils import set_botocore_patched_api_call_span_tags as set_patched_api_call_span_tags from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY from ddtrace.constants import SPAN_KIND from ddtrace.constants import SPAN_MEASURED_KEY @@ -551,9 +552,8 @@ def _on_django_after_request_headers_post( def _on_botocore_patched_api_call_started(ctx): - callback = ctx.get_item("context_started_callback") span = ctx.get_item(ctx.get_item("call_key")) - callback( + set_patched_api_call_span_tags( span, ctx.get_item("instance"), ctx.get_item("args"), @@ -589,7 +589,6 @@ def _on_botocore_trace_context_injection_prepared( ctx, cloud_service, schematization_function, injection_function, trace_operation ): endpoint_name = ctx.get_item("endpoint_name") - params = ctx.get_item("params") if cloud_service is not None: span = ctx.get_item(ctx["call_key"]) inject_kwargs = dict(endpoint_service=endpoint_name) if cloud_service == "sns" else dict() @@ -597,7 +596,7 @@ def _on_botocore_trace_context_injection_prepared( if endpoint_name != "lambda": schematize_kwargs["direction"] = SpanDirection.OUTBOUND try: - injection_function(ctx, params, span, **inject_kwargs) + injection_function(ctx, **inject_kwargs) span.name = schematization_function(trace_operation, **schematize_kwargs) except Exception: log.warning("Unable to inject trace context", exc_info=True) @@ -633,8 +632,9 @@ def _on_botocore_patched_bedrock_api_call_exception(ctx, exc_info): span = ctx[ctx["call_key"]] span.set_exc_info(*exc_info) prompt = ctx["prompt"] + model_name = ctx["model_name"] integration = ctx["bedrock_integration"] - if integration.is_pc_sampled_llmobs(span): + if integration.is_pc_sampled_llmobs(span) and "embed" not in model_name: integration.llmobs_set_tags(span, formatted_response=None, prompt=prompt, err=True) span.finish() @@ -656,6 +656,7 @@ def _on_botocore_bedrock_process_response( ) -> None: text = formatted_response["text"] span = ctx[ctx["call_key"]] + model_name = ctx["model_name"] if should_set_choice_ids: for i in range(len(text)): span.set_tag_str("bedrock.response.choices.{}.id".format(i), str(body["generations"][i]["id"])) @@ -663,6 +664,10 @@ def _on_botocore_bedrock_process_response( if metadata is not None: for k, v in metadata.items(): span.set_tag_str("bedrock.{}".format(k), str(v)) + if "embed" in model_name: + span.set_metric("bedrock.response.embedding_length", len(formatted_response["text"][0])) + span.finish() + return for i in range(len(formatted_response["text"])): if integration.is_pc_sampled_span(span): span.set_tag_str( @@ -677,6 +682,11 @@ def _on_botocore_bedrock_process_response( span.finish() +def _on_redis_async_command_post(span, rowcount): + if rowcount is not None: + span.set_metric(db.ROWCOUNT, rowcount) + + def listen(): core.on("wsgi.block.started", _wsgi_make_block_content, "status_headers_content") core.on("asgi.block.started", _asgi_make_block_content, "status_headers_content") @@ -721,6 +731,7 @@ def listen(): core.on("botocore.patched_bedrock_api_call.exception", _on_botocore_patched_bedrock_api_call_exception) core.on("botocore.patched_bedrock_api_call.success", _on_botocore_patched_bedrock_api_call_success) core.on("botocore.bedrock.process_response", _on_botocore_bedrock_process_response) + core.on("redis.async_command.post", _on_redis_async_command_post) for context_name in ( "flask.call", diff --git a/ddtrace/_trace/tracer.py b/ddtrace/_trace/tracer.py index 4f359a42d93..7259cdf16a0 100644 --- a/ddtrace/_trace/tracer.py +++ b/ddtrace/_trace/tracer.py @@ -230,7 +230,8 @@ def __init__( self.enabled = config._tracing_enabled self.context_provider = context_provider or DefaultContextProvider() - self._user_sampler: Optional[BaseSampler] = None + # _user_sampler is the backup in case we need to revert from remote config to local + self._user_sampler: Optional[BaseSampler] = DatadogSampler() self._sampler: BaseSampler = DatadogSampler() self._dogstatsd_url = agent.get_stats_url() if dogstatsd_url is None else dogstatsd_url self._compute_stats = config._trace_compute_stats @@ -286,7 +287,7 @@ def __init__( self._shutdown_lock = RLock() self._new_process = False - config._subscribe(["_trace_sample_rate"], self._on_global_config_update) + config._subscribe(["_trace_sample_rate", "_trace_sampling_rules"], self._on_global_config_update) config._subscribe(["logs_injection"], self._on_global_config_update) config._subscribe(["tags"], self._on_global_config_update) config._subscribe(["_tracing_enabled"], self._on_global_config_update) @@ -1125,19 +1126,10 @@ def _is_span_internal(span): def _on_global_config_update(self, cfg, items): # type: (Config, List) -> None - if "_trace_sample_rate" in items: - # Reset the user sampler if one exists - if cfg._get_source("_trace_sample_rate") != "remote_config" and self._user_sampler: - self._sampler = self._user_sampler - return - - if cfg._get_source("_trace_sample_rate") != "default": - sample_rate = cfg._trace_sample_rate - else: - sample_rate = None - sampler = DatadogSampler(default_sample_rate=sample_rate) - self._sampler = sampler + # sampling configs always come as a pair + if "_trace_sample_rate" in items and "_trace_sampling_rules" in items: + self._handle_sampler_update(cfg) if "tags" in items: self._tags = cfg.tags.copy() @@ -1160,3 +1152,42 @@ def _on_global_config_update(self, cfg, items): from ddtrace.contrib.logging import unpatch unpatch() + + def _handle_sampler_update(self, cfg): + # type: (Config) -> None + if ( + cfg._get_source("_trace_sample_rate") != "remote_config" + and cfg._get_source("_trace_sampling_rules") != "remote_config" + and self._user_sampler + ): + # if we get empty configs from rc for both sample rate and rules, we should revert to the user sampler + self.sampler = self._user_sampler + return + + if cfg._get_source("_trace_sample_rate") != "remote_config" and self._user_sampler: + try: + sample_rate = self._user_sampler.default_sample_rate # type: ignore[attr-defined] + except AttributeError: + log.debug("Custom non-DatadogSampler is being used, cannot pull default sample rate") + sample_rate = None + elif cfg._get_source("_trace_sample_rate") != "default": + sample_rate = cfg._trace_sample_rate + else: + sample_rate = None + + if cfg._get_source("_trace_sampling_rules") != "remote_config" and self._user_sampler: + try: + sampling_rules = self._user_sampler.rules # type: ignore[attr-defined] + # we need to chop off the default_sample_rate rule so the new sample_rate can be applied + sampling_rules = sampling_rules[:-1] + except AttributeError: + log.debug("Custom non-DatadogSampler is being used, cannot pull sampling rules") + sampling_rules = None + elif cfg._get_source("_trace_sampling_rules") != "default": + sampling_rules = DatadogSampler._parse_rules_from_str(cfg._trace_sampling_rules) + else: + sampling_rules = None + + sampler = DatadogSampler(rules=sampling_rules, default_sample_rate=sample_rate) + + self._sampler = sampler diff --git a/ddtrace/_trace/utils.py b/ddtrace/_trace/utils.py new file mode 100644 index 00000000000..0e1a9364582 --- /dev/null +++ b/ddtrace/_trace/utils.py @@ -0,0 +1,41 @@ +from ddtrace import Span +from ddtrace import config +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.ext import SpanKind +from ddtrace.ext import aws +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.utils.formats import deep_getattr + + +def set_botocore_patched_api_call_span_tags(span: Span, instance, args, params, endpoint_name, operation): + span.set_tag_str(COMPONENT, config.botocore.integration_name) + # set span.kind to the type of request being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + span.set_tag(SPAN_MEASURED_KEY) + + if args: + # DEV: join is the fastest way of concatenating strings that is compatible + # across Python versions (see + # https://stackoverflow.com/questions/1316887/what-is-the-most-efficient-string-concatenation-method-in-python) + span.resource = ".".join((endpoint_name, operation.lower())) + span.set_tag("aws_service", endpoint_name) + + if params and not config.botocore["tag_no_params"]: + aws._add_api_param_span_tags(span, endpoint_name, params) + + else: + span.resource = endpoint_name + + region_name = deep_getattr(instance, "meta.region_name") + + span.set_tag_str("aws.agent", "botocore") + if operation is not None: + span.set_tag_str("aws.operation", operation) + if region_name is not None: + span.set_tag_str("aws.region", region_name) + span.set_tag_str("region", region_name) + + # set analytics sample rate + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.botocore.get_analytics_sample_rate()) diff --git a/ddtrace/_trace/utils_redis.py b/ddtrace/_trace/utils_redis.py new file mode 100644 index 00000000000..1e2d7b9b9a8 --- /dev/null +++ b/ddtrace/_trace/utils_redis.py @@ -0,0 +1,92 @@ +""" +Some utils used by the dogtrace redis integration +""" +from contextlib import contextmanager +from typing import List +from typing import Optional + +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib import trace_utils +from ddtrace.contrib.redis_utils import _extract_conn_tags +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import redis as redisx +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.schema import schematize_cache_operation +from ddtrace.internal.utils.formats import stringify_cache_args + + +format_command_args = stringify_cache_args + + +def _set_span_tags( + span, pin, config_integration, args: Optional[List], instance, query: Optional[List], is_cluster: bool = False +): + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + span.set_tag_str(COMPONENT, config_integration.integration_name) + span.set_tag_str(db.SYSTEM, redisx.APP) + span.set_tag(SPAN_MEASURED_KEY) + if query is not None: + span_name = schematize_cache_operation(redisx.RAWCMD, cache_provider=redisx.APP) # type: ignore[operator] + span.set_tag_str(span_name, query) + if pin.tags: + span.set_tags(pin.tags) + # some redis clients do not have a connection_pool attribute (ex. aioredis v1.3) + if not is_cluster and hasattr(instance, "connection_pool"): + span.set_tags(_extract_conn_tags(instance.connection_pool.connection_kwargs)) + if args is not None: + span.set_metric(redisx.ARGS_LEN, len(args)) + else: + for attr in ("command_stack", "_command_stack"): + if hasattr(instance, attr): + span.set_metric(redisx.PIPELINE_LEN, len(getattr(instance, attr))) + # set analytics sample rate if enabled + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config_integration.get_analytics_sample_rate()) + + +@contextmanager +def _trace_redis_cmd(pin, config_integration, instance, args): + query = stringify_cache_args(args, cmd_max_len=config_integration.cmd_max_length) + with pin.tracer.trace( + schematize_cache_operation(redisx.CMD, cache_provider=redisx.APP), + service=trace_utils.ext_service(pin, config_integration), + span_type=SpanTypes.REDIS, + resource=query.split(" ")[0] if config_integration.resource_only_command else query, + ) as span: + _set_span_tags(span, pin, config_integration, args, instance, query) + yield span + + +@contextmanager +def _trace_redis_execute_pipeline(pin, config_integration, cmds, instance, is_cluster=False): + cmd_string = resource = "\n".join(cmds) + if config_integration.resource_only_command: + resource = "\n".join([cmd.split(" ")[0] for cmd in cmds]) + + with pin.tracer.trace( + schematize_cache_operation(redisx.CMD, cache_provider=redisx.APP), + resource=resource, + service=trace_utils.ext_service(pin, config_integration), + span_type=SpanTypes.REDIS, + ) as span: + _set_span_tags(span, pin, config_integration, None, instance, cmd_string) + yield span + + +@contextmanager +def _trace_redis_execute_async_cluster_pipeline(pin, config_integration, cmds, instance): + cmd_string = resource = "\n".join(cmds) + if config_integration.resource_only_command: + resource = "\n".join([cmd.split(" ")[0] for cmd in cmds]) + + with pin.tracer.trace( + schematize_cache_operation(redisx.CMD, cache_provider=redisx.APP), + resource=resource, + service=trace_utils.ext_service(pin, config_integration), + span_type=SpanTypes.REDIS, + ) as span: + _set_span_tags(span, pin, config_integration, None, instance, cmd_string) + yield span diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 173027c0d10..654e06a29e5 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -13,6 +13,7 @@ from ddtrace._trace.span import Span from ddtrace.appsec import _handlers from ddtrace.appsec._constants import APPSEC +from ddtrace.appsec._constants import EXPLOIT_PREVENTION from ddtrace.appsec._constants import SPAN_DATA_NAMES from ddtrace.appsec._constants import WAF_CONTEXT_NAMES from ddtrace.appsec._ddwaf import DDWaf_result @@ -147,6 +148,12 @@ def __init__(self): "triggered": False, "timeout": False, "version": None, + "rasp": { + "called": False, + "eval": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, + "match": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, + "timeout": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, + }, } env.callbacks[_CONTEXT_CALL] = [] @@ -330,15 +337,27 @@ def asm_request_context_set( def set_waf_telemetry_results( - rules_version: Optional[str], is_triggered: bool, is_blocked: bool, is_timeout: bool + rules_version: Optional[str], + is_triggered: bool, + is_blocked: bool, + is_timeout: bool, + rule_type: Optional[str], ) -> None: result = get_value(_TELEMETRY, _TELEMETRY_WAF_RESULTS) if result is not None: - result["triggered"] |= is_triggered - result["blocked"] |= is_blocked - result["timeout"] |= is_timeout - if rules_version is not None: - result["version"] = rules_version + if rule_type is None: + # Request Blocking telemetry + result["triggered"] |= is_triggered + result["blocked"] |= is_blocked + result["timeout"] |= is_timeout + if rules_version is not None: + result["version"] = rules_version + else: + # Exploit Prevention telemetry + result["rasp"]["called"] = True + result["rasp"]["eval"][rule_type] += 1 + result["rasp"]["match"][rule_type] += int(is_triggered) + result["rasp"]["timeout"][rule_type] += int(is_timeout) def get_waf_telemetry_results() -> Optional[Dict[str, Any]]: diff --git a/ddtrace/appsec/_common_module_patches.py b/ddtrace/appsec/_common_module_patches.py index 312a88a41d4..69c2610cab5 100644 --- a/ddtrace/appsec/_common_module_patches.py +++ b/ddtrace/appsec/_common_module_patches.py @@ -3,6 +3,7 @@ import ctypes import gc +import os from typing import Any from typing import Callable from typing import Dict @@ -48,14 +49,23 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization # and shouldn't be changed at that time return original_open_callable(*args, **kwargs) - filename = args[0] if args else kwargs.get("file", None) + filename_arg = args[0] if args else kwargs.get("file", None) + try: + filename = os.fspath(filename_arg) + except Exception: + filename = "" if filename and in_context(): - call_waf_callback({"LFI_ADDRESS": filename}, crop_trace="wrapped_open_CFDDB7ABBA9081B6") + call_waf_callback( + {EXPLOIT_PREVENTION.ADDRESS.LFI: filename}, + crop_trace="wrapped_open_CFDDB7ABBA9081B6", + rule_type=EXPLOIT_PREVENTION.TYPE.LFI, + ) # DEV: Next part of the exploit prevention feature: add block here return original_open_callable(*args, **kwargs) @@ -72,6 +82,7 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization # and shouldn't be changed at that time @@ -82,7 +93,11 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs if url.__class__.__name__ == "Request": url = url.get_full_url() if isinstance(url, str): - call_waf_callback({"SSRF_ADDRESS": url}, crop_trace="wrapped_open_ED4CF71136E15EBF") + call_waf_callback( + {EXPLOIT_PREVENTION.ADDRESS.SSRF: url}, + crop_trace="wrapped_open_ED4CF71136E15EBF", + rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, + ) # DEV: Next part of the exploit prevention feature: add block here return original_open_callable(*args, **kwargs) @@ -100,6 +115,7 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, try: from ddtrace.appsec._asm_request_context import call_waf_callback from ddtrace.appsec._asm_request_context import in_context + from ddtrace.appsec._constants import EXPLOIT_PREVENTION except ImportError: # open is used during module initialization # and shouldn't be changed at that time @@ -108,7 +124,11 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, url = args[1] if len(args) > 1 else kwargs.get("url", None) if url and in_context(): if isinstance(url, str): - call_waf_callback({"SSRF_ADDRESS": url}, crop_trace="wrapped_request_D8CB81E472AF98A2") + call_waf_callback( + {EXPLOIT_PREVENTION.ADDRESS.SSRF: url}, + crop_trace="wrapped_request_D8CB81E472AF98A2", + rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, + ) # DEV: Next part of the exploit prevention feature: add block here return original_request_callable(*args, **kwargs) diff --git a/ddtrace/appsec/_constants.py b/ddtrace/appsec/_constants.py index c7a3fad3cf3..59f90a335dc 100644 --- a/ddtrace/appsec/_constants.py +++ b/ddtrace/appsec/_constants.py @@ -248,3 +248,12 @@ class EXPLOIT_PREVENTION(metaclass=Constant_Class): STACK_TRACE_ENABLED = "DD_APPSEC_STACK_TRACE_ENABLED" MAX_STACK_TRACES = "DD_APPSEC_MAX_STACK_TRACES" MAX_STACK_TRACE_DEPTH = "DD_APPSEC_MAX_STACK_TRACE_DEPTH" + + class TYPE(metaclass=Constant_Class): + LFI = "lfi" + SSRF = "ssrf" + SQLI = "sql_injection" + + class ADDRESS(metaclass=Constant_Class): + LFI = "LFI_ADDRESS" + SSRF = "SSRF_ADDRESS" diff --git a/ddtrace/appsec/_iast/_ast/visitor.py b/ddtrace/appsec/_iast/_ast/visitor.py index 3c071edb428..c726b7570b4 100644 --- a/ddtrace/appsec/_iast/_ast/visitor.py +++ b/ddtrace/appsec/_iast/_ast/visitor.py @@ -3,6 +3,7 @@ from _ast import ImportFrom import ast import copy +import os import sys from typing import Any # noqa:F401 from typing import List # noqa:F401 @@ -65,6 +66,9 @@ def __init__( "format_map": "ddtrace_aspects.format_map_aspect", "zfill": "ddtrace_aspects.zfill_aspect", "ljust": "ddtrace_aspects.ljust_aspect", + "split": "ddtrace_aspects.split_aspect", + "rsplit": "ddtrace_aspects.rsplit_aspect", + "splitlines": "ddtrace_aspects.splitlines_aspect", }, # Replacement function for indexes and ranges "slices": { @@ -73,10 +77,14 @@ def __init__( }, # Replacement functions for modules "module_functions": { - # "BytesIO": "ddtrace_aspects.stringio_aspect", - # "StringIO": "ddtrace_aspects.stringio_aspect", - # "format": "ddtrace_aspects.format_aspect", - # "format_map": "ddtrace_aspects.format_map_aspect", + "os.path": { + "basename": "ddtrace_aspects._aspect_ospathbasename", + "dirname": "ddtrace_aspects._aspect_ospathdirname", + "join": "ddtrace_aspects._aspect_ospathjoin", + "normcase": "ddtrace_aspects._aspect_ospathnormcase", + "split": "ddtrace_aspects._aspect_ospathsplit", + "splitext": "ddtrace_aspects._aspect_ospathsplitext", + } }, "operators": { ast.Add: "ddtrace_aspects.add_aspect", @@ -125,6 +133,13 @@ def __init__( }, }, } + + if sys.version_info >= (3, 12): + self._aspects_spec["module_functions"]["os.path"]["splitroot"] = "ddtrace_aspects._aspect_ospathsplitroot" + + if sys.version_info >= (3, 12) or os.name == "nt": + self._aspects_spec["module_functions"]["os.path"]["splitdrive"] = "ddtrace_aspects._aspect_ospathsplitdrive" + self._sinkpoints_spec = { "definitions_module": "ddtrace.appsec._iast.taint_sinks", "alias_module": "ddtrace_taint_sinks", @@ -492,30 +507,46 @@ def visit_Call(self, call_node): # type: (ast.Call) -> Any if self._is_string_format_with_literals(call_node): return call_node - aspect = self._aspect_methods.get(method_name) - - if aspect: - # Move the Attribute.value to 'args' - new_arg = func_member.value - call_node.args.insert(0, new_arg) - # Send 1 as flag_added_args value - call_node.args.insert(0, self._int_constant(call_node, 1)) - - # Insert None as first parameter instead of a.b.c.method - # to avoid unexpected side effects such as a.b.read(4).method - call_node.args.insert(0, self._none_constant(call_node)) + # This resolve moduleparent.modulechild.name + # TODO: use the better Hdiv method with a decorator + func_value = getattr(func_member, "value", None) + func_value_value = getattr(func_value, "value", None) if func_value else None + func_value_value_id = getattr(func_value_value, "id", None) if func_value_value else None + func_value_attr = getattr(func_value, "attr", None) if func_value else None + func_attr = getattr(func_member, "attr", None) + aspect = None + if func_value_value_id or func_attr: + if func_value_value_id and func_value_attr: + # e.g. "os.path" or "one.two.three.whatever" (all dotted previous tokens with be in the id) + key = func_value_value_id + "." + func_value_attr + elif func_value_attr: + # e.g os + key = func_attr + else: + key = None + + if key: + module_dict = self._aspect_modules.get(key, None) + aspect = module_dict.get(func_attr, None) if module_dict else None + if aspect: + # Create a new Name node for the replacement and set it as node.func + call_node.func = self._attr_node(call_node, aspect) + self.ast_modified = call_modified = True - # Create a new Name node for the replacement and set it as node.func - call_node.func = self._attr_node(call_node, aspect) - self.ast_modified = call_modified = True + if not aspect: + # Not a module symbol, check if it's a known method + aspect = self._aspect_methods.get(method_name) - elif hasattr(func_member.value, "id") or hasattr(func_member.value, "attr"): - aspect = self._aspect_modules.get(method_name, None) if aspect: - # Send 0 as flag_added_args value - call_node.args.insert(0, self._int_constant(call_node, 0)) - # Move the Function to 'args' - call_node.args.insert(0, call_node.func) + # Move the Attribute.value to 'args' + new_arg = func_member.value + call_node.args.insert(0, new_arg) + # Send 1 as flag_added_args value + call_node.args.insert(0, self._int_constant(call_node, 1)) + + # Insert None as first parameter instead of a.b.c.method + # to avoid unexpected side effects such as a.b.read(4).method + call_node.args.insert(0, self._none_constant(call_node)) # Create a new Name node for the replacement and set it as node.func call_node.func = self._attr_node(call_node, aspect) diff --git a/ddtrace/appsec/_iast/_evidence_redaction/__init__.py b/ddtrace/appsec/_iast/_evidence_redaction/__init__.py new file mode 100644 index 00000000000..195391ffab2 --- /dev/null +++ b/ddtrace/appsec/_iast/_evidence_redaction/__init__.py @@ -0,0 +1,4 @@ +from ddtrace.appsec._iast._evidence_redaction._sensitive_handler import sensitive_handler + + +sensitive_handler diff --git a/ddtrace/appsec/_iast/_evidence_redaction/_sensitive_handler.py b/ddtrace/appsec/_iast/_evidence_redaction/_sensitive_handler.py new file mode 100644 index 00000000000..b76ad6c96b1 --- /dev/null +++ b/ddtrace/appsec/_iast/_evidence_redaction/_sensitive_handler.py @@ -0,0 +1,363 @@ +import re + +from ddtrace.internal.logger import get_logger +from ddtrace.settings.asm import config as asm_config + +from ..constants import VULN_CMDI +from ..constants import VULN_HEADER_INJECTION +from ..constants import VULN_SSRF +from .command_injection_sensitive_analyzer import command_injection_sensitive_analyzer +from .header_injection_sensitive_analyzer import header_injection_sensitive_analyzer +from .url_sensitive_analyzer import url_sensitive_analyzer + + +log = get_logger(__name__) + +REDACTED_SOURCE_BUFFER = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + +class SensitiveHandler: + """ + Class responsible for handling sensitive information. + """ + + def __init__(self): + self._name_pattern = re.compile(asm_config._iast_redaction_name_pattern, re.IGNORECASE | re.MULTILINE) + self._value_pattern = re.compile(asm_config._iast_redaction_value_pattern, re.IGNORECASE | re.MULTILINE) + + self._sensitive_analyzers = { + VULN_CMDI: command_injection_sensitive_analyzer, + # SQL_INJECTION: sql_sensitive_analyzer, + VULN_SSRF: url_sensitive_analyzer, + VULN_HEADER_INJECTION: header_injection_sensitive_analyzer, + } + + @staticmethod + def _contains(range_container, range_contained): + """ + Checks if a range_container contains another range_contained. + + Args: + - range_container (dict): The container range. + - range_contained (dict): The contained range. + + Returns: + - bool: True if range_container contains range_contained, False otherwise. + """ + if range_container["start"] > range_contained["start"]: + return False + return range_container["end"] >= range_contained["end"] + + @staticmethod + def _intersects(range_a, range_b): + """ + Checks if two ranges intersect. + + Args: + - range_a (dict): First range. + - range_b (dict): Second range. + + Returns: + - bool: True if the ranges intersect, False otherwise. + """ + return range_b["start"] < range_a["end"] and range_b["end"] > range_a["start"] + + def _remove(self, range_, range_to_remove): + """ + Removes a range_to_remove from a range_. + + Args: + - range_ (dict): The range to remove from. + - range_to_remove (dict): The range to remove. + + Returns: + - list: List containing the remaining parts after removing the range_to_remove. + """ + if not self._intersects(range_, range_to_remove): + return [range_] + elif self._contains(range_to_remove, range_): + return [] + else: + result = [] + if range_to_remove["start"] > range_["start"]: + offset = range_to_remove["start"] - range_["start"] + result.append({"start": range_["start"], "end": range_["start"] + offset}) + if range_to_remove["end"] < range_["end"]: + offset = range_["end"] - range_to_remove["end"] + result.append({"start": range_to_remove["end"], "end": range_to_remove["end"] + offset}) + return result + + def is_sensible_name(self, name): + """ + Checks if a name is sensible based on the name pattern. + + Args: + - name (str): The name to check. + + Returns: + - bool: True if the name is sensible, False otherwise. + """ + return bool(self._name_pattern.search(name)) + + def is_sensible_value(self, value): + """ + Checks if a value is sensible based on the value pattern. + + Args: + - value (str): The value to check. + + Returns: + - bool: True if the value is sensible, False otherwise. + """ + return bool(self._value_pattern.search(value)) + + def is_sensible_source(self, source): + """ + Checks if a source is sensible. + + Args: + - source (dict): The source to check. + + Returns: + - bool: True if the source is sensible, False otherwise. + """ + return ( + source is not None + and source.value is not None + and (self.is_sensible_name(source.name) or self.is_sensible_value(source.value)) + ) + + def scrub_evidence(self, vulnerability_type, evidence, tainted_ranges, sources): + """ + Scrubs evidence based on the given vulnerability type. + + Args: + - vulnerability_type (str): The vulnerability type. + - evidence (dict): The evidence to scrub. + - tainted_ranges (list): List of tainted ranges. + - sources (list): List of sources. + + Returns: + - dict: The scrubbed evidence. + """ + if asm_config._iast_redaction_enabled: + sensitive_analyzer = self._sensitive_analyzers.get(vulnerability_type) + if sensitive_analyzer: + if not evidence.value: + log.debug("No evidence value found in evidence %s", evidence) + return None + sensitive_ranges = sensitive_analyzer(evidence, self._name_pattern, self._value_pattern) + return self.to_redacted_json(evidence.value, sensitive_ranges, tainted_ranges, sources) + return None + + def to_redacted_json(self, evidence_value, sensitive, tainted_ranges, sources): + """ + Converts evidence value to redacted JSON format. + + Args: + - evidence_value (str): The evidence value. + - sensitive (list): List of sensitive ranges. + - tainted_ranges (list): List of tainted ranges. + - sources (list): List of sources. + + Returns: + - dict: The redacted JSON. + """ + value_parts = [] + redacted_sources = [] + redacted_sources_context = dict() + + start = 0 + next_tainted_index = 0 + source_index = None + + next_tainted = tainted_ranges.pop(0) if tainted_ranges else None + next_sensitive = sensitive.pop(0) if sensitive else None + i = 0 + while i < len(evidence_value): + if next_tainted and next_tainted["start"] == i: + self.write_value_part(value_parts, evidence_value[start:i], source_index) + + source_index = next_tainted_index + + while next_sensitive and self._contains(next_tainted, next_sensitive): + redaction_start = next_sensitive["start"] - next_tainted["start"] + redaction_end = next_sensitive["end"] - next_tainted["start"] + if redaction_start == redaction_end: + self.write_redacted_value_part(value_parts, 0) + else: + self.redact_source( + sources, + redacted_sources, + redacted_sources_context, + source_index, + redaction_start, + redaction_end, + ) + next_sensitive = sensitive.pop(0) if sensitive else None + + if next_sensitive and self._intersects(next_sensitive, next_tainted): + redaction_start = next_sensitive["start"] - next_tainted["start"] + redaction_end = next_sensitive["end"] - next_tainted["start"] + + self.redact_source( + sources, + redacted_sources, + redacted_sources_context, + source_index, + redaction_start, + redaction_end, + ) + + entries = self._remove(next_sensitive, next_tainted) + next_sensitive = entries[0] if entries else None + + if source_index < len(sources): + if not sources[source_index].redacted and self.is_sensible_source(sources[source_index]): + redacted_sources.append(source_index) + sources[source_index].pattern = REDACTED_SOURCE_BUFFER[: len(sources[source_index].value)] + sources[source_index].redacted = True + + if source_index in redacted_sources: + part_value = evidence_value[i : i + (next_tainted["end"] - next_tainted["start"])] + + self.write_redacted_value_part( + value_parts, + len(part_value), + source_index, + part_value, + sources[source_index], + redacted_sources_context.get(source_index), + self.is_sensible_source(sources[source_index]), + ) + redacted_sources_context[source_index] = [] + else: + substring_end = min(next_tainted["end"], len(evidence_value)) + self.write_value_part( + value_parts, evidence_value[next_tainted["start"] : substring_end], source_index + ) + + start = i + (next_tainted["end"] - next_tainted["start"]) + i = start - 1 + next_tainted = tainted_ranges.pop(0) if tainted_ranges else None + next_tainted_index += 1 + source_index = None + continue + elif next_sensitive and next_sensitive["start"] == i: + self.write_value_part(value_parts, evidence_value[start:i], source_index) + if next_tainted and self._intersects(next_sensitive, next_tainted): + source_index = next_tainted_index + + redaction_start = next_sensitive["start"] - next_tainted["start"] + redaction_end = next_sensitive["end"] - next_tainted["start"] + self.redact_source( + sources, + redacted_sources, + redacted_sources_context, + next_tainted_index, + redaction_start, + redaction_end, + ) + + entries = self._remove(next_sensitive, next_tainted) + next_sensitive = entries[0] if entries else None + + length = next_sensitive["end"] - next_sensitive["start"] + self.write_redacted_value_part(value_parts, length) + + start = i + length + i = start - 1 + next_sensitive = sensitive.pop(0) if sensitive else None + continue + i += 1 + if start < len(evidence_value): + self.write_value_part(value_parts, evidence_value[start:]) + + return {"redacted_value_parts": value_parts, "redacted_sources": redacted_sources} + + def redact_source(self, sources, redacted_sources, redacted_sources_context, source_index, start, end): + if source_index is not None: + if not sources[source_index].redacted: + redacted_sources.append(source_index) + sources[source_index].pattern = REDACTED_SOURCE_BUFFER[: len(sources[source_index].value)] + sources[source_index].redacted = True + + if source_index not in redacted_sources_context.keys(): + redacted_sources_context[source_index] = [] + + redacted_sources_context[source_index].append({"start": start, "end": end}) + + def write_value_part(self, value_parts, value, source_index=None): + if value: + if source_index is not None: + value_parts.append({"value": value, "source": source_index}) + else: + value_parts.append({"value": value}) + + def write_redacted_value_part( + self, + value_parts, + length, + source_index=None, + part_value=None, + source=None, + source_redaction_context=None, + is_sensible_source=False, + ): + if source_index is not None: + placeholder = source.pattern if part_value and part_value in source.value else "*" * length + + if is_sensible_source: + value_parts.append({"redacted": True, "source": source_index, "pattern": placeholder}) + else: + _value = part_value + deduped_source_redaction_contexts = [] + + for _source_redaction_context in source_redaction_context: + if _source_redaction_context not in deduped_source_redaction_contexts: + deduped_source_redaction_contexts.append(_source_redaction_context) + + offset = 0 + for _source_redaction_context in deduped_source_redaction_contexts: + if _source_redaction_context["start"] > 0: + value_parts.append( + {"source": source_index, "value": _value[: _source_redaction_context["start"] - offset]} + ) + _value = _value[_source_redaction_context["start"] - offset :] + offset = _source_redaction_context["start"] + + sensitive_start = _source_redaction_context["start"] - offset + if sensitive_start < 0: + sensitive_start = 0 + sensitive = _value[sensitive_start : _source_redaction_context["end"] - offset] + index_of_part_value_in_pattern = source.value.find(sensitive) + pattern = ( + placeholder[index_of_part_value_in_pattern : index_of_part_value_in_pattern + len(sensitive)] + if index_of_part_value_in_pattern > -1 + else placeholder[_source_redaction_context["start"] : _source_redaction_context["end"]] + ) + + value_parts.append({"redacted": True, "source": source_index, "pattern": pattern}) + _value = _value[len(pattern) :] + offset += len(pattern) + if _value: + value_parts.append({"source": source_index, "value": _value}) + + else: + value_parts.append({"redacted": True}) + + def set_redaction_patterns(self, redaction_name_pattern=None, redaction_value_pattern=None): + if redaction_name_pattern: + try: + self._name_pattern = re.compile(redaction_name_pattern, re.IGNORECASE | re.MULTILINE) + except re.error: + log.warning("Redaction name pattern is not valid") + + if redaction_value_pattern: + try: + self._value_pattern = re.compile(redaction_value_pattern, re.IGNORECASE | re.MULTILINE) + except re.error: + log.warning("Redaction value pattern is not valid") + + +sensitive_handler = SensitiveHandler() diff --git a/ddtrace/appsec/_iast/_evidence_redaction/command_injection_sensitive_analyzer.py b/ddtrace/appsec/_iast/_evidence_redaction/command_injection_sensitive_analyzer.py new file mode 100644 index 00000000000..57dccc03db1 --- /dev/null +++ b/ddtrace/appsec/_iast/_evidence_redaction/command_injection_sensitive_analyzer.py @@ -0,0 +1,19 @@ +import re + +from ddtrace.internal.logger import get_logger + + +log = get_logger(__name__) + +_INSIDE_QUOTES_REGEXP = re.compile(r"^(?:\s*(?:sudo|doas)\s+)?\b\S+\b\s*(.*)") +COMMAND_PATTERN = r"^(?:\s*(?:sudo|doas)\s+)?\b\S+\b\s(.*)" +pattern = re.compile(COMMAND_PATTERN, re.IGNORECASE | re.MULTILINE) + + +def command_injection_sensitive_analyzer(evidence, name_pattern=None, value_pattern=None): + regex_result = pattern.search(evidence.value) + if regex_result and len(regex_result.groups()) > 0: + start = regex_result.start(1) + end = regex_result.end(1) + return [{"start": start, "end": end}] + return [] diff --git a/ddtrace/appsec/_iast/_evidence_redaction/header_injection_sensitive_analyzer.py b/ddtrace/appsec/_iast/_evidence_redaction/header_injection_sensitive_analyzer.py new file mode 100644 index 00000000000..3b254781351 --- /dev/null +++ b/ddtrace/appsec/_iast/_evidence_redaction/header_injection_sensitive_analyzer.py @@ -0,0 +1,17 @@ +from ddtrace.appsec._iast.constants import HEADER_NAME_VALUE_SEPARATOR +from ddtrace.internal.logger import get_logger + + +log = get_logger(__name__) + + +def header_injection_sensitive_analyzer(evidence, name_pattern, value_pattern): + evidence_value = evidence.value + sections = evidence_value.split(HEADER_NAME_VALUE_SEPARATOR) + header_name = sections[0] + header_value = HEADER_NAME_VALUE_SEPARATOR.join(sections[1:]) + + if name_pattern.search(header_name) or value_pattern.search(header_value): + return [{"start": len(header_name) + len(HEADER_NAME_VALUE_SEPARATOR), "end": len(evidence_value)}] + + return [] diff --git a/ddtrace/appsec/_iast/_evidence_redaction/url_sensitive_analyzer.py b/ddtrace/appsec/_iast/_evidence_redaction/url_sensitive_analyzer.py new file mode 100644 index 00000000000..04ee4ecb6c8 --- /dev/null +++ b/ddtrace/appsec/_iast/_evidence_redaction/url_sensitive_analyzer.py @@ -0,0 +1,34 @@ +import re + +from ddtrace.internal.logger import get_logger + + +log = get_logger(__name__) +AUTHORITY = r"^(?:[^:]+:)?//([^@]+)@" +QUERY_FRAGMENT = r"[?#&]([^=&;]+)=([^?#&]+)" +pattern = re.compile(f"({AUTHORITY})|({QUERY_FRAGMENT})", re.IGNORECASE | re.MULTILINE) + + +def url_sensitive_analyzer(evidence, name_pattern=None, value_pattern=None): + try: + ranges = [] + regex_result = pattern.search(evidence.value) + + while regex_result is not None: + if isinstance(regex_result.group(1), str): + end = regex_result.start() + (len(regex_result.group(0)) - 1) + start = end - len(regex_result.group(1)) + ranges.append({"start": start, "end": end}) + + if isinstance(regex_result.group(3), str): + end = regex_result.start() + len(regex_result.group(0)) + start = end - len(regex_result.group(3)) + ranges.append({"start": start, "end": end}) + + regex_result = pattern.search(evidence.value, regex_result.end()) + + return ranges + except Exception as e: + log.debug(e) + + return [] diff --git a/ddtrace/appsec/_iast/_patch_modules.py b/ddtrace/appsec/_iast/_patch_modules.py index 05a6900be83..c2186786dc3 100644 --- a/ddtrace/appsec/_iast/_patch_modules.py +++ b/ddtrace/appsec/_iast/_patch_modules.py @@ -3,6 +3,7 @@ IAST_PATCH = { "command_injection": True, + "header_injection": False, "path_traversal": True, "weak_cipher": True, "weak_hash": True, diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectSplit.cpp b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectSplit.cpp new file mode 100644 index 00000000000..5269270438d --- /dev/null +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectSplit.cpp @@ -0,0 +1,74 @@ +#include "AspectSplit.h" +#include "Initializer/Initializer.h" + +template +py::list +api_split_text(const StrType& text, const optional& separator, const optional maxsplit) +{ + TaintRangeMapType* tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + auto split = text.attr("split"); + auto split_result = split(separator, maxsplit); + auto ranges = api_get_ranges(text); + if (not ranges.empty()) { + set_ranges_on_splitted(text, ranges, split_result, tx_map, false); + } + + return split_result; +} + +template +py::list +api_rsplit_text(const StrType& text, const optional& separator, const optional maxsplit) +{ + TaintRangeMapType* tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + auto rsplit = text.attr("rsplit"); + auto split_result = rsplit(separator, maxsplit); + auto ranges = api_get_ranges(text); + if (not ranges.empty()) { + set_ranges_on_splitted(text, ranges, split_result, tx_map, false); + } + return split_result; +} + +template +py::list +api_splitlines_text(const StrType& text, bool keepends) +{ + TaintRangeMapType* tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + auto splitlines = text.attr("splitlines"); + auto split_result = splitlines(keepends); + auto ranges = api_get_ranges(text); + if (not ranges.empty()) { + set_ranges_on_splitted(text, ranges, split_result, tx_map, keepends); + } + return split_result; +} + +void +pyexport_aspect_split(py::module& m) +{ + m.def("_aspect_split", &api_split_text, "text"_a, "separator"_a = py::none(), "maxsplit"_a = -1); + m.def("_aspect_split", &api_split_text, "text"_a, "separator"_a = py::none(), "maxsplit"_a = -1); + m.def("_aspect_split", &api_split_text, "text"_a, "separator"_a = py::none(), "maxsplit"_a = -1); + m.def("_aspect_rsplit", &api_rsplit_text, "text"_a, "separator"_a = py::none(), "maxsplit"_a = -1); + m.def("_aspect_rsplit", &api_rsplit_text, "text"_a, "separator"_a = py::none(), "maxsplit"_a = -1); + m.def("_aspect_rsplit", &api_rsplit_text, "text"_a, "separator"_a = py::none(), "maxsplit"_a = -1); + // cppcheck-suppress assignBoolToPointer + m.def("_aspect_splitlines", &api_splitlines_text, "text"_a, "keepends"_a = false); + // cppcheck-suppress assignBoolToPointer + m.def("_aspect_splitlines", &api_splitlines_text, "text"_a, "keepends"_a = false); + // cppcheck-suppress assignBoolToPointer + m.def("_aspect_splitlines", &api_splitlines_text, "text"_a, "keepends"_a = false); +} \ No newline at end of file diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectSplit.h b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectSplit.h new file mode 100644 index 00000000000..5fb708e7c26 --- /dev/null +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectSplit.h @@ -0,0 +1,18 @@ +#pragma once + +#include "Helpers.h" + +template +py::list +api_split_text(const StrType& text, const optional& separator, const optional maxsplit); + +template +py::list +api_rsplit_text(const StrType& text, const optional& separator, const optional maxsplit); + +template +py::list +api_splitlines_text(const StrType& text, bool keepends); + +void +pyexport_aspect_split(py::module& m); \ No newline at end of file diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectsOsPath.cpp b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectsOsPath.cpp new file mode 100644 index 00000000000..86055cf035f --- /dev/null +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectsOsPath.cpp @@ -0,0 +1,277 @@ +#include "AspectsOsPath.h" +#include + +#include "Helpers.h" + +static bool +starts_with_separator(const py::handle& arg, const std::string& separator) +{ + std::string carg = py::cast(arg); + return carg.substr(0, 1) == separator; +} + +template +StrType +api_ospathjoin_aspect(StrType& first_part, const py::args& args) +{ + auto ospath = py::module_::import("os.path"); + auto join = ospath.attr("join"); + auto joined = join(first_part, *args); + + auto tx_map = initializer->get_tainting_map(); + if (not tx_map or tx_map->empty()) { + return joined; + } + + std::string separator = ospath.attr("sep").cast(); + auto sepsize = separator.size(); + + // Find the initial iteration point. This will be the first argument that has the separator ("/foo") + // as a first character or first_part (the first element) if no such argument is found. + auto initial_arg_pos = -1; + bool root_is_after_first = false; + for (auto& arg : args) { + if (not is_text(arg.ptr())) { + return joined; + } + + if (starts_with_separator(arg, separator)) { + root_is_after_first = true; + initial_arg_pos++; + break; + } + initial_arg_pos++; + } + + TaintRangeRefs result_ranges; + result_ranges.reserve(args.size()); + + std::vector all_ranges; + unsigned long current_offset = 0; + auto first_part_len = py::len(first_part); + + if (not root_is_after_first) { + // Get the ranges of first_part and set them to the result, skipping the first character position + // if it's a separator + bool ranges_error; + TaintRangeRefs ranges; + std::tie(ranges, ranges_error) = get_ranges(first_part.ptr(), tx_map); + if (not ranges_error and not ranges.empty()) { + for (auto& range : ranges) { + result_ranges.emplace_back(shift_taint_range(range, current_offset, first_part_len)); + } + } + + if (not first_part.is(py::str(separator))) { + current_offset = py::len(first_part); + } + + current_offset += sepsize; + initial_arg_pos = 0; + } + + unsigned long unsigned_initial_arg_pos = max(0, initial_arg_pos); + + // Now go trough the arguments and do the same + for (unsigned long i = 0; i < args.size(); i++) { + if (i >= unsigned_initial_arg_pos) { + // Set the ranges from the corresponding argument + bool ranges_error; + TaintRangeRefs ranges; + std::tie(ranges, ranges_error) = get_ranges(args[i].ptr(), tx_map); + if (not ranges_error and not ranges.empty()) { + auto len_args_i = py::len(args[i]); + for (auto& range : ranges) { + result_ranges.emplace_back(shift_taint_range(range, current_offset, len_args_i)); + } + } + current_offset += py::len(args[i]); + current_offset += sepsize; + } + } + + if (not result_ranges.empty()) { + PyObject* new_result = new_pyobject_id(joined.ptr()); + set_ranges(new_result, result_ranges, tx_map); + return py::reinterpret_steal(new_result); + } + + return joined; +} + +template +StrType +api_ospathbasename_aspect(const StrType& path) +{ + auto tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + auto ospath = py::module_::import("os.path"); + auto basename = ospath.attr("basename"); + auto basename_result = basename(path); + if (py::len(basename_result) == 0) { + return basename_result; + } + + bool ranges_error; + TaintRangeRefs ranges; + std::tie(ranges, ranges_error) = get_ranges(path.ptr(), tx_map); + if (ranges_error or ranges.empty()) { + return basename_result; + } + + // Create a fake list to call set_ranges_on_splitted on it (we are + // only interested on the last path, which is the basename result) + auto prev_path_len = py::len(path) - py::len(basename_result); + std::string filler(prev_path_len, 'X'); + py::str filler_str(filler); + py::list apply_list; + apply_list.append(filler_str); + apply_list.append(basename_result); + + set_ranges_on_splitted(path, ranges, apply_list, tx_map, false); + return apply_list[1]; +} + +template +StrType +api_ospathdirname_aspect(const StrType& path) +{ + auto tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + auto ospath = py::module_::import("os.path"); + auto dirname = ospath.attr("dirname"); + auto dirname_result = dirname(path); + if (py::len(dirname_result) == 0) { + return dirname_result; + } + + bool ranges_error; + TaintRangeRefs ranges; + std::tie(ranges, ranges_error) = get_ranges(path.ptr(), tx_map); + if (ranges_error or ranges.empty()) { + return dirname_result; + } + + // Create a fake list to call set_ranges_on_splitted on it (we are + // only interested on the first path, which is the dirname result) + auto prev_path_len = py::len(path) - py::len(dirname_result); + std::string filler(prev_path_len, 'X'); + py::str filler_str(filler); + py::list apply_list; + apply_list.append(dirname_result); + apply_list.append(filler_str); + + set_ranges_on_splitted(path, ranges, apply_list, tx_map, false); + return apply_list[0]; +} + +template +static py::tuple +_forward_to_set_ranges_on_splitted(const char* function_name, const StrType& path, bool includeseparator = false) +{ + auto tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + auto ospath = py::module_::import("os.path"); + auto function = ospath.attr(function_name); + auto function_result = function(path); + if (py::len(function_result) == 0) { + return function_result; + } + + bool ranges_error; + TaintRangeRefs ranges; + std::tie(ranges, ranges_error) = get_ranges(path.ptr(), tx_map); + if (ranges_error or ranges.empty()) { + return function_result; + } + + set_ranges_on_splitted(path, ranges, function_result, tx_map, includeseparator); + return function_result; +} + +template +py::tuple +api_ospathsplit_aspect(const StrType& path) +{ + return _forward_to_set_ranges_on_splitted("split", path); +} + +template +py::tuple +api_ospathsplitext_aspect(const StrType& path) +{ + return _forward_to_set_ranges_on_splitted("splitext", path, true); +} + +template +py::tuple +api_ospathsplitdrive_aspect(const StrType& path) +{ + return _forward_to_set_ranges_on_splitted("splitdrive", path, true); +} + +template +py::tuple +api_ospathsplitroot_aspect(const StrType& path) +{ + return _forward_to_set_ranges_on_splitted("splitroot", path, true); +} + +template +StrType +api_ospathnormcase_aspect(const StrType& path) +{ + auto tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + auto ospath = py::module_::import("os.path"); + auto normcase = ospath.attr("normcase"); + auto normcased = normcase(path); + + bool ranges_error; + TaintRangeRefs ranges; + std::tie(ranges, ranges_error) = get_ranges(path.ptr(), tx_map); + if (ranges_error or ranges.empty()) { + return normcased; + } + + TaintRangeRefs result_ranges = ranges; + PyObject* new_result = new_pyobject_id(normcased.ptr()); + if (new_result) { + set_ranges(new_result, result_ranges, tx_map); + return py::reinterpret_steal(new_result); + } + + return normcased; +} + +void +pyexport_ospath_aspects(py::module& m) +{ + m.def("_aspect_ospathjoin", &api_ospathjoin_aspect, "first_part"_a); + m.def("_aspect_ospathjoin", &api_ospathjoin_aspect, "first_part"_a); + m.def("_aspect_ospathnormcase", &api_ospathnormcase_aspect, "path"_a); + m.def("_aspect_ospathnormcase", &api_ospathnormcase_aspect, "path"_a); + m.def("_aspect_ospathbasename", &api_ospathbasename_aspect, "path"_a); + m.def("_aspect_ospathbasename", &api_ospathbasename_aspect, "path"_a); + m.def("_aspect_ospathdirname", &api_ospathdirname_aspect, "path"_a); + m.def("_aspect_ospathdirname", &api_ospathdirname_aspect, "path"_a); + m.def("_aspect_ospathsplit", &api_ospathsplit_aspect, "path"_a); + m.def("_aspect_ospathsplit", &api_ospathsplit_aspect, "path"_a); + m.def("_aspect_ospathsplitext", &api_ospathsplitext_aspect, "path"_a); + m.def("_aspect_ospathsplitext", &api_ospathsplitext_aspect, "path"_a); + m.def("_aspect_ospathsplitdrive", &api_ospathsplitdrive_aspect, "path"_a); + m.def("_aspect_ospathsplitdrive", &api_ospathsplitdrive_aspect, "path"_a); + m.def("_aspect_ospathsplitroot", &api_ospathsplitroot_aspect, "path"_a); + m.def("_aspect_ospathsplitroot", &api_ospathsplitroot_aspect, "path"_a); +} diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectsOsPath.h b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectsOsPath.h new file mode 100644 index 00000000000..48e1baf3542 --- /dev/null +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/AspectsOsPath.h @@ -0,0 +1,41 @@ +#pragma once +#include "Initializer/Initializer.h" +#include "TaintTracking/TaintRange.h" +#include "TaintTracking/TaintedObject.h" + +namespace py = pybind11; + +template +StrType +api_ospathjoin_aspect(StrType& first_part, const py::args& args); + +template +StrType +api_ospathbasename_aspect(const StrType& path); + +template +StrType +api_ospathdirname_aspect(const StrType& path); + +template +py::tuple +api_ospathsplit_aspect(const StrType& path); + +template +py::tuple +api_ospathsplitext_aspect(const StrType& path); + +template +py::tuple +api_ospathsplitdrive_aspect(const StrType& path); + +template +py::tuple +api_ospathsplitroot_aspect(const StrType& path); + +template +StrType +api_ospathnormcase_aspect(const StrType& path); + +void +pyexport_ospath_aspects(py::module& m); diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.cpp b/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.cpp index 8332a89c1f8..5384d147303 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.cpp +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.cpp @@ -1,5 +1,6 @@ #include "Helpers.h" #include "Initializer/Initializer.h" +#include #include #include @@ -327,6 +328,87 @@ _convert_escaped_text_to_taint_text(const StrType& taint_escaped_text, TaintRang return { StrType(result), ranges }; } +/** + * @brief This function takes the ranges of a string splitted (as in string.split or rsplit or os.path.split) and + * applies the ranges of the original string to the splitted parts with updated offsets. + * + * @param source_str: The original string that was splitted. + * @param source_ranges: The ranges of the original string. + * @param split_result: The splitted parts of the original string. + * @param tx_map: The taint map to apply the ranges. + * @param include_separator: If the separator should be included in the splitted parts. + */ +template +bool +set_ranges_on_splitted(const StrType& source_str, + const TaintRangeRefs& source_ranges, + const py::list& split_result, + TaintRangeMapType* tx_map, + bool include_separator) +{ + bool some_set = false; + + // Some quick shortcuts + if (source_ranges.empty() or py::len(split_result) == 0 or py::len(source_str) == 0 or not tx_map) { + return false; + } + + RANGE_START offset = 0; + std::string c_source_str = py::cast(source_str); + auto separator_increase = (int)((not include_separator)); + + for (const auto& item : split_result) { + if (not is_text(item.ptr()) or py::len(item) == 0) { + continue; + } + auto c_item = py::cast(item); + TaintRangeRefs item_ranges; + + // Find the item in the source_str. + const auto start = static_cast(c_source_str.find(c_item, offset)); + if (start == -1) { + continue; + } + const auto end = static_cast(start + c_item.length()); + + // Find what source_ranges match these positions and create a new range with the start and len updated. + for (const auto& range : source_ranges) { + auto range_end_abs = range->start + range->length; + + if (range->start < end && range_end_abs > start) { + // Create a new range with the updated start + auto new_range_start = std::max(range->start - offset, 0L); + auto new_range_length = std::min(end - start, (range->length - std::max(0L, offset - range->start))); + item_ranges.emplace_back( + initializer->allocate_taint_range(new_range_start, new_range_length, range->source)); + } + } + if (not item_ranges.empty()) { + set_ranges(item.ptr(), item_ranges, tx_map); + some_set = true; + } + + offset += py::len(item) + separator_increase; + } + + return some_set; +} + +template +bool +api_set_ranges_on_splitted(const StrType& source_str, + const TaintRangeRefs& source_ranges, + const py::list& split_result, + bool include_separator) +{ + TaintRangeMapType* tx_map = initializer->get_tainting_map(); + if (not tx_map) { + throw py::value_error(MSG_ERROR_TAINT_MAP); + } + + return set_ranges_on_splitted(source_str, source_ranges, split_result, tx_map, include_separator); +} + py::object parse_params(size_t position, const char* keyword_name, @@ -348,6 +430,27 @@ pyexport_aspect_helpers(py::module& m) m.def("common_replace", &api_common_replace, "string_method"_a, "candidate_text"_a); m.def("common_replace", &api_common_replace, "string_method"_a, "candidate_text"_a); m.def("common_replace", &api_common_replace, "string_method"_a, "candidate_text"_a); + m.def("set_ranges_on_splitted", + &api_set_ranges_on_splitted, + "source_str"_a, + "source_ranges"_a, + "split_result"_a, + // cppcheck-suppress assignBoolToPointer + "include_separator"_a = false); + m.def("set_ranges_on_splitted", + &api_set_ranges_on_splitted, + "source_str"_a, + "source_ranges"_a, + "split_result"_a, + // cppcheck-suppress assignBoolToPointer + "include_separator"_a = false); + m.def("set_ranges_on_splitted", + &api_set_ranges_on_splitted, + "source_str"_a, + "source_ranges"_a, + "split_result"_a, + // cppcheck-suppress assignBoolToPointer + "include_separator"_a = false); m.def("_all_as_formatted_evidence", &_all_as_formatted_evidence, "text"_a, diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.h b/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.h index bd672442f2c..3a8ddbf83ed 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.h +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/Helpers.h @@ -52,5 +52,20 @@ template std::tuple _convert_escaped_text_to_taint_text(const StrType& taint_escaped_text, TaintRangeRefs ranges_orig); +template +bool +set_ranges_on_splitted(const StrType& source_str, + const TaintRangeRefs& source_ranges, + const py::list& split_result, + TaintRangeMapType* tx_map, + bool include_separator = false); + +template +bool +api_set_ranges_on_splitted(const StrType& source_str, + const TaintRangeRefs& source_ranges, + const py::list& split_result, + bool include_separator = false); + void pyexport_aspect_helpers(py::module& m); diff --git a/ddtrace/appsec/_iast/_taint_tracking/Aspects/_aspects_exports.h b/ddtrace/appsec/_iast/_taint_tracking/Aspects/_aspects_exports.h index fd45b423cbb..1331af54a94 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/Aspects/_aspects_exports.h +++ b/ddtrace/appsec/_iast/_taint_tracking/Aspects/_aspects_exports.h @@ -1,5 +1,7 @@ #pragma once #include "AspectFormat.h" +#include "AspectSplit.h" +#include "AspectsOsPath.h" #include "Helpers.h" #include @@ -10,4 +12,9 @@ pyexport_m_aspect_helpers(py::module& m) pyexport_aspect_helpers(m_aspect_helpers); py::module m_aspect_format = m.def_submodule("aspect_format", "Aspect Format"); pyexport_format_aspect(m_aspect_format); + + py::module m_aspects_ospath = m.def_submodule("aspects_ospath", "Aspect os.path.join"); + pyexport_ospath_aspects(m_aspects_ospath); + py::module m_aspect_split = m.def_submodule("aspect_split", "Aspect split"); + pyexport_aspect_split(m_aspect_split); } diff --git a/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.cpp b/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.cpp index 49e3e4a78fd..4883a84213a 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.cpp +++ b/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.cpp @@ -38,7 +38,7 @@ TaintRange::get_hash() const }; TaintRangePtr -api_shift_taint_range(const TaintRangePtr& source_taint_range, RANGE_START offset, RANGE_LENGTH new_length = -1) +shift_taint_range(const TaintRangePtr& source_taint_range, RANGE_START offset, RANGE_LENGTH new_length = -1) { if (new_length == -1) { new_length = source_taint_range->length; @@ -56,7 +56,7 @@ shift_taint_ranges(const TaintRangeRefs& source_taint_ranges, RANGE_START offset new_ranges.reserve(source_taint_ranges.size()); for (const auto& trange : source_taint_ranges) { - new_ranges.emplace_back(api_shift_taint_range(trange, offset, new_length)); + new_ranges.emplace_back(shift_taint_range(trange, offset, new_length)); } return new_ranges; } @@ -243,7 +243,7 @@ get_range_by_hash(size_t range_hash, optional& taint_ranges) } TaintRangeRefs -api_get_ranges(py::object& string_input) +api_get_ranges(const py::object& string_input) { bool ranges_error; TaintRangeRefs ranges; diff --git a/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.h b/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.h index 72152200d50..d076dc59846 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.h +++ b/ddtrace/appsec/_iast/_taint_tracking/TaintTracking/TaintRange.h @@ -73,7 +73,13 @@ using TaintRangePtr = shared_ptr; using TaintRangeRefs = vector; TaintRangePtr -api_shift_taint_range(const TaintRangePtr& source_taint_range, RANGE_START offset, RANGE_LENGTH new_length); +shift_taint_range(const TaintRangePtr& source_taint_range, RANGE_START offset, RANGE_LENGTH new_length); + +inline TaintRangePtr +api_shift_taint_range(const TaintRangePtr& source_taint_range, RANGE_START offset, RANGE_LENGTH new_length) +{ + return shift_taint_range(source_taint_range, offset, new_length); +} TaintRangeRefs shift_taint_ranges(const TaintRangeRefs& source_taint_ranges, RANGE_START offset, RANGE_LENGTH new_length); @@ -91,7 +97,7 @@ py::object api_set_ranges(py::object& str, const TaintRangeRefs& ranges); TaintRangeRefs -api_get_ranges(py::object& string_input); +api_get_ranges(const py::object& string_input); void api_copy_ranges_from_strings(py::object& str_1, py::object& str_2); diff --git a/ddtrace/appsec/_iast/_taint_tracking/__init__.py b/ddtrace/appsec/_iast/_taint_tracking/__init__.py index 8e05fbefbd2..b155e7c08a9 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/__init__.py +++ b/ddtrace/appsec/_iast/_taint_tracking/__init__.py @@ -23,6 +23,18 @@ from ._native.aspect_helpers import as_formatted_evidence from ._native.aspect_helpers import common_replace from ._native.aspect_helpers import parse_params + from ._native.aspect_helpers import set_ranges_on_splitted + from ._native.aspect_split import _aspect_rsplit + from ._native.aspect_split import _aspect_split + from ._native.aspect_split import _aspect_splitlines + from ._native.aspects_ospath import _aspect_ospathbasename + from ._native.aspects_ospath import _aspect_ospathdirname + from ._native.aspects_ospath import _aspect_ospathjoin + from ._native.aspects_ospath import _aspect_ospathnormcase + from ._native.aspects_ospath import _aspect_ospathsplit + from ._native.aspects_ospath import _aspect_ospathsplitdrive + from ._native.aspects_ospath import _aspect_ospathsplitext + from ._native.aspects_ospath import _aspect_ospathsplitroot from ._native.initializer import active_map_addreses_size from ._native.initializer import create_context from ._native.initializer import debug_taint_map @@ -79,9 +91,21 @@ "str_to_origin", "origin_to_str", "common_replace", + "_aspect_ospathjoin", + "_aspect_split", + "_aspect_rsplit", + "_aspect_splitlines", + "_aspect_ospathbasename", + "_aspect_ospathdirname", + "_aspect_ospathnormcase", + "_aspect_ospathsplit", + "_aspect_ospathsplitext", + "_aspect_ospathsplitdrive", + "_aspect_ospathsplitroot", "_format_aspect", "as_formatted_evidence", "parse_params", + "set_ranges_on_splitted", "num_objects_tainted", "debug_taint_map", "iast_taint_log_error", @@ -153,12 +177,15 @@ def get_tainted_ranges(pyobject: Any) -> Tuple: def taint_ranges_as_evidence_info(pyobject: Any) -> Tuple[List[Dict[str, Union[Any, int]]], List[Source]]: + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. value_parts = [] - sources = [] + sources = list() current_pos = 0 tainted_ranges = get_tainted_ranges(pyobject) if not len(tainted_ranges): - return ([{"value": pyobject}], []) + return ([{"value": pyobject}], list()) for _range in tainted_ranges: if _range.start > current_pos: @@ -168,7 +195,10 @@ def taint_ranges_as_evidence_info(pyobject: Any) -> Tuple[List[Dict[str, Union[A sources.append(_range.source) value_parts.append( - {"value": pyobject[_range.start : _range.start + _range.length], "source": sources.index(_range.source)} + { + "value": pyobject[_range.start : _range.start + _range.length], + "source": sources.index(_range.source), + } ) current_pos = _range.start + _range.length diff --git a/ddtrace/appsec/_iast/_taint_tracking/aspects.py b/ddtrace/appsec/_iast/_taint_tracking/aspects.py index 166639b645d..cae1e07d455 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/aspects.py +++ b/ddtrace/appsec/_iast/_taint_tracking/aspects.py @@ -16,6 +16,17 @@ from .._taint_tracking import TagMappingMode from .._taint_tracking import TaintRange +from .._taint_tracking import _aspect_ospathbasename +from .._taint_tracking import _aspect_ospathdirname +from .._taint_tracking import _aspect_ospathjoin +from .._taint_tracking import _aspect_ospathnormcase +from .._taint_tracking import _aspect_ospathsplit +from .._taint_tracking import _aspect_ospathsplitdrive +from .._taint_tracking import _aspect_ospathsplitext +from .._taint_tracking import _aspect_ospathsplitroot +from .._taint_tracking import _aspect_rsplit +from .._taint_tracking import _aspect_split +from .._taint_tracking import _aspect_splitlines from .._taint_tracking import _convert_escaped_text_to_tainted_text from .._taint_tracking import _format_aspect from .._taint_tracking import are_all_text_all_ranges @@ -44,7 +55,26 @@ _join_aspect = aspects.join_aspect _slice_aspect = aspects.slice_aspect -__all__ = ["add_aspect", "str_aspect", "bytearray_extend_aspect", "decode_aspect", "encode_aspect"] +__all__ = [ + "add_aspect", + "str_aspect", + "bytearray_extend_aspect", + "decode_aspect", + "encode_aspect", + "_aspect_ospathjoin", + "_aspect_split", + "_aspect_rsplit", + "_aspect_splitlines", + "_aspect_ospathbasename", + "_aspect_ospathdirname", + "_aspect_ospathnormcase", + "_aspect_ospathsplit", + "_aspect_ospathsplitext", + "_aspect_ospathsplitdrive", + "_aspect_ospathsplitroot", +] + +# TODO: Factorize the "flags_added_args" copypasta into a decorator def add_aspect(op1, op2): @@ -57,6 +87,45 @@ def add_aspect(op1, op2): return op1 + op2 +def split_aspect(orig_function: Optional[Callable], flag_added_args: int, *args: Any, **kwargs: Any) -> str: + if orig_function: + if orig_function != builtin_str: + if flag_added_args > 0: + args = args[flag_added_args:] + return orig_function(*args, **kwargs) + try: + return _aspect_split(*args, **kwargs) + except Exception as e: + iast_taint_log_error("IAST propagation error. split_aspect. {}".format(e)) + return args[0].split(*args[1:], **kwargs) + + +def rsplit_aspect(orig_function: Optional[Callable], flag_added_args: int, *args: Any, **kwargs: Any) -> str: + if orig_function: + if orig_function != builtin_str: + if flag_added_args > 0: + args = args[flag_added_args:] + return orig_function(*args, **kwargs) + try: + return _aspect_rsplit(*args, **kwargs) + except Exception as e: + iast_taint_log_error("IAST propagation error. rsplit_aspect. {}".format(e)) + return args[0].rsplit(*args[1:], **kwargs) + + +def splitlines_aspect(orig_function: Optional[Callable], flag_added_args: int, *args: Any, **kwargs: Any) -> str: + if orig_function: + if orig_function != builtin_str: + if flag_added_args > 0: + args = args[flag_added_args:] + return orig_function(*args, **kwargs) + try: + return _aspect_splitlines(*args, **kwargs) + except Exception as e: + iast_taint_log_error("IAST propagation error. splitlines_aspect. {}".format(e)) + return args[0].splitlines(*args[1:], **kwargs) + + def str_aspect(orig_function: Optional[Callable], flag_added_args: int, *args: Any, **kwargs: Any) -> str: if orig_function: if orig_function != builtin_str: @@ -216,12 +285,14 @@ def modulo_aspect(candidate_text: Text, candidate_tuple: Any) -> Any: tag_mapping_function=TagMappingMode.Mapper, ) % tuple( - as_formatted_evidence( - parameter, - tag_mapping_function=TagMappingMode.Mapper, + ( + as_formatted_evidence( + parameter, + tag_mapping_function=TagMappingMode.Mapper, + ) + if isinstance(parameter, IAST.TEXT_TYPES) + else parameter ) - if isinstance(parameter, IAST.TEXT_TYPES) - else parameter for parameter in parameter_list ), ranges_orig=ranges_orig, @@ -382,9 +453,11 @@ def format_map_aspect( candidate_text, candidate_text_ranges, tag_mapping_function=TagMappingMode.Mapper ).format_map( { - key: as_formatted_evidence(value, tag_mapping_function=TagMappingMode.Mapper) - if isinstance(value, IAST.TEXT_TYPES) - else value + key: ( + as_formatted_evidence(value, tag_mapping_function=TagMappingMode.Mapper) + if isinstance(value, IAST.TEXT_TYPES) + else value + ) for key, value in mapping.items() } ), @@ -440,6 +513,8 @@ def format_value_aspect( else: new_text = element if not isinstance(new_text, IAST.TEXT_TYPES): + if format_spec: + return format(new_text, format_spec) return format(new_text) try: diff --git a/ddtrace/appsec/_iast/_utils.py b/ddtrace/appsec/_iast/_utils.py index e2e26e291fa..7272abb9016 100644 --- a/ddtrace/appsec/_iast/_utils.py +++ b/ddtrace/appsec/_iast/_utils.py @@ -1,11 +1,8 @@ -import json import re import string import sys from typing import TYPE_CHECKING # noqa:F401 -import attr - from ddtrace.internal.logger import get_logger from ddtrace.settings.asm import config as asm_config @@ -41,6 +38,9 @@ def _is_iast_enabled(): def _has_to_scrub(s): # type: (str) -> bool + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. global _SOURCE_NAME_SCRUB global _SOURCE_VALUE_SCRUB global _SOURCE_NUMERAL_SCRUB @@ -58,6 +58,9 @@ def _has_to_scrub(s): # type: (str) -> bool def _is_numeric(s): + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. global _SOURCE_NUMERAL_SCRUB if _SOURCE_NUMERAL_SCRUB is None: @@ -71,17 +74,26 @@ def _is_numeric(s): def _scrub(s, has_range=False): # type: (str, bool) -> str + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. if has_range: return "".join([_REPLACEMENTS[i % _LEN_REPLACEMENTS] for i in range(len(s))]) return "*" * len(s) def _is_evidence_value_parts(value): # type: (Any) -> bool + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. return isinstance(value, (set, list)) def _scrub_get_tokens_positions(text, tokens): # type: (str, Set[str]) -> List[Tuple[int, int]] + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. token_positions = [] for token in tokens: @@ -93,20 +105,6 @@ def _scrub_get_tokens_positions(text, tokens): return token_positions -def _iast_report_to_str(data): - from ._taint_tracking import OriginType - from ._taint_tracking import origin_to_str - - class OriginTypeEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, OriginType): - # if the obj is uuid, we simply return the value of uuid - return origin_to_str(obj) - return json.JSONEncoder.default(self, obj) - - return json.dumps(attr.asdict(data, filter=lambda attr, x: x is not None), cls=OriginTypeEncoder) - - def _get_patched_code(module_path, module_name): # type: (str, str) -> str """ Print the patched code to stdout, for debugging purposes. diff --git a/ddtrace/appsec/_iast/constants.py b/ddtrace/appsec/_iast/constants.py index bd9e73928ea..17981bccbcc 100644 --- a/ddtrace/appsec/_iast/constants.py +++ b/ddtrace/appsec/_iast/constants.py @@ -11,6 +11,7 @@ VULN_NO_HTTPONLY_COOKIE = "NO_HTTPONLY_COOKIE" VULN_NO_SAMESITE_COOKIE = "NO_SAMESITE_COOKIE" VULN_CMDI = "COMMAND_INJECTION" +VULN_HEADER_INJECTION = "HEADER_INJECTION" VULN_SSRF = "SSRF" VULNERABILITY_TOKEN_TYPE = Dict[int, Dict[str, Any]] @@ -21,8 +22,11 @@ EVIDENCE_WEAK_RANDOMNESS = "WEAK_RANDOMNESS" EVIDENCE_COOKIE = "COOKIE" EVIDENCE_CMDI = "COMMAND" +EVIDENCE_HEADER_INJECTION = "HEADER_INJECTION" EVIDENCE_SSRF = "SSRF" +HEADER_NAME_VALUE_SEPARATOR = ": " + MD5_DEF = "md5" SHA1_DEF = "sha1" diff --git a/ddtrace/appsec/_iast/processor.py b/ddtrace/appsec/_iast/processor.py index 8deee2a1846..8d0adffdb90 100644 --- a/ddtrace/appsec/_iast/processor.py +++ b/ddtrace/appsec/_iast/processor.py @@ -16,6 +16,7 @@ from ._metrics import _set_span_tag_iast_executed_sink from ._metrics import _set_span_tag_iast_request_tainted from ._utils import _is_iast_enabled +from .reporter import IastSpanReporter if TYPE_CHECKING: # pragma: no cover @@ -75,14 +76,14 @@ def on_span_finish(self, span): return from ._taint_tracking import reset_context # noqa: F401 - from ._utils import _iast_report_to_str span.set_metric(IAST.ENABLED, 1.0) - data = core.get_item(IAST.CONTEXT_KEY, span=span) + report_data: IastSpanReporter = core.get_item(IAST.CONTEXT_KEY, span=span) # type: ignore - if data: - span.set_tag_str(IAST.JSON, _iast_report_to_str(data)) + if report_data: + report_data.build_and_scrub_value_parts() + span.set_tag_str(IAST.JSON, report_data._to_str()) _asm_manual_keep(span) _set_metric_iast_request_tainted() diff --git a/ddtrace/appsec/_iast/reporter.py b/ddtrace/appsec/_iast/reporter.py index 5a95aa1272d..fa2cc8ae96c 100644 --- a/ddtrace/appsec/_iast/reporter.py +++ b/ddtrace/appsec/_iast/reporter.py @@ -3,17 +3,23 @@ import operator import os from typing import TYPE_CHECKING +from typing import Any +from typing import Dict from typing import List from typing import Set +from typing import Tuple import zlib import attr +from ddtrace.appsec._iast._evidence_redaction import sensitive_handler +from ddtrace.appsec._iast.constants import VULN_INSECURE_HASHING_TYPE +from ddtrace.appsec._iast.constants import VULN_WEAK_CIPHER_TYPE +from ddtrace.appsec._iast.constants import VULN_WEAK_RANDOMNESS -if TYPE_CHECKING: - import Any # noqa:F401 - import Dict # noqa:F401 - import Optional # noqa:F401 + +if TYPE_CHECKING: # pragma: no cover + from typing import Optional # noqa:F401 def _only_if_true(value): @@ -23,9 +29,8 @@ def _only_if_true(value): @attr.s(eq=False, hash=False) class Evidence(object): value = attr.ib(type=str, default=None) # type: Optional[str] - pattern = attr.ib(type=str, default=None) # type: Optional[str] - valueParts = attr.ib(type=list, default=None) # type: Optional[List[Dict[str, Any]]] - redacted = attr.ib(type=bool, default=False, converter=_only_if_true) # type: bool + _ranges = attr.ib(type=dict, default={}) # type: Any + valueParts = attr.ib(type=list, default=None) # type: Any def _valueParts_hash(self): if not self.valueParts: @@ -40,15 +45,10 @@ def _valueParts_hash(self): return _hash def __hash__(self): - return hash((self.value, self.pattern, self._valueParts_hash(), self.redacted)) + return hash((self.value, self._valueParts_hash())) def __eq__(self, other): - return ( - self.value == other.value - and self.pattern == other.pattern - and self._valueParts_hash() == other._valueParts_hash() - and self.redacted == other.redacted - ) + return self.value == other.value and self._valueParts_hash() == other._valueParts_hash() @attr.s(eq=True, hash=True) @@ -69,7 +69,7 @@ def __attrs_post_init__(self): self.hash = zlib.crc32(repr(self).encode()) -@attr.s(eq=True, hash=True) +@attr.s(eq=True, hash=False) class Source(object): origin = attr.ib(type=str) # type: str name = attr.ib(type=str) # type: str @@ -77,11 +77,163 @@ class Source(object): value = attr.ib(type=str, default=None) # type: Optional[str] pattern = attr.ib(type=str, default=None) # type: Optional[str] + def __hash__(self): + """origin & name serve as hashes. This approach aims to mitigate false positives when searching for + identical sources in a list, especially when sources undergo changes. The provided example illustrates how + two sources with different attributes could actually represent the same source. For example: + Source(origin=, name='string1', redacted=False, value="password", pattern=None) + could be the same source as the one below: + Source(origin=, name='string1', redacted=True, value=None, pattern='ab') + :return: + """ + return hash((self.origin, self.name)) + @attr.s(eq=False, hash=False) class IastSpanReporter(object): + """ + Class representing an IAST span reporter. + """ + sources = attr.ib(type=List[Source], factory=list) # type: List[Source] vulnerabilities = attr.ib(type=Set[Vulnerability], factory=set) # type: Set[Vulnerability] + _evidences_with_no_sources = [VULN_INSECURE_HASHING_TYPE, VULN_WEAK_CIPHER_TYPE, VULN_WEAK_RANDOMNESS] - def __hash__(self): + def __hash__(self) -> int: + """ + Computes the hash value of the IAST span reporter. + + Returns: + - int: Hash value. + """ return reduce(operator.xor, (hash(obj) for obj in set(self.sources) | self.vulnerabilities)) + + def taint_ranges_as_evidence_info(self, pyobject: Any) -> Tuple[List[Source], List[Dict]]: + """ + Extracts tainted ranges as evidence information. + + Args: + - pyobject (Any): Python object. + + Returns: + - Tuple[Set[Source], List[Dict]]: Set of Source objects and list of tainted ranges as dictionaries. + """ + from ddtrace.appsec._iast._taint_tracking import get_tainted_ranges + + sources = list() + tainted_ranges = get_tainted_ranges(pyobject) + tainted_ranges_to_dict = list() + if not len(tainted_ranges): + return [], [] + + for _range in tainted_ranges: + source = Source(origin=_range.source.origin, name=_range.source.name, value=_range.source.value) + if source not in sources: + sources.append(source) + + tainted_ranges_to_dict.append( + {"start": _range.start, "end": _range.start + _range.length, "length": _range.length, "source": source} + ) + return sources, tainted_ranges_to_dict + + def add_ranges_to_evidence_and_extract_sources(self, vuln): + sources, tainted_ranges_to_dict = self.taint_ranges_as_evidence_info(vuln.evidence.value) + vuln.evidence._ranges = tainted_ranges_to_dict + for source in sources: + if source not in self.sources: + self.sources = self.sources + [source] + + def _get_source_index(self, sources: List[Source], source: Source) -> int: + i = 0 + for source_ in sources: + if hash(source_) == hash(source): + return i + i += 1 + return -1 + + def build_and_scrub_value_parts(self) -> Dict[str, Any]: + """ + Builds and scrubs value parts of vulnerabilities. + + Returns: + - Dict[str, Any]: Dictionary representation of the IAST span reporter. + """ + for vuln in self.vulnerabilities: + scrubbing_result = sensitive_handler.scrub_evidence( + vuln.type, vuln.evidence, vuln.evidence._ranges, self.sources + ) + if scrubbing_result: + redacted_value_parts = scrubbing_result["redacted_value_parts"] + redacted_sources = scrubbing_result["redacted_sources"] + i = 0 + for source in self.sources: + if i in redacted_sources: + source.value = None + vuln.evidence.valueParts = redacted_value_parts + vuln.evidence.value = None + elif vuln.evidence.value is not None and vuln.type not in self._evidences_with_no_sources: + vuln.evidence.valueParts = self.get_unredacted_value_parts( + vuln.evidence.value, vuln.evidence._ranges, self.sources + ) + vuln.evidence.value = None + return self._to_dict() + + def get_unredacted_value_parts(self, evidence_value: str, ranges: List[dict], sources: List[Any]) -> List[dict]: + """ + Gets unredacted value parts of evidence. + + Args: + - evidence_value (str): Evidence value. + - ranges (List[Dict]): List of tainted ranges. + - sources (List[Any]): List of sources. + + Returns: + - List[Dict]: List of unredacted value parts. + """ + value_parts = [] + from_index = 0 + + for range_ in ranges: + if from_index < range_["start"]: + value_parts.append({"value": evidence_value[from_index : range_["start"]]}) + + source_index = self._get_source_index(sources, range_["source"]) + + value_parts.append( + {"value": evidence_value[range_["start"] : range_["end"]], "source": source_index} # type: ignore[dict-item] + ) + + from_index = range_["end"] + + if from_index < len(evidence_value): + value_parts.append({"value": evidence_value[from_index:]}) + + return value_parts + + def _to_dict(self) -> Dict[str, Any]: + """ + Converts the IAST span reporter to a dictionary. + + Returns: + - Dict[str, Any]: Dictionary representation of the IAST span reporter. + """ + return attr.asdict(self, filter=lambda attr, x: x is not None and attr.name != "_ranges") + + def _to_str(self) -> str: + """ + Converts the IAST span reporter to a JSON string. + + Returns: + - str: JSON representation of the IAST span reporter. + """ + from ._taint_tracking import OriginType + from ._taint_tracking import origin_to_str + + class OriginTypeEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, OriginType): + # if the obj is uuid, we simply return the value of uuid + return origin_to_str(obj) + return json.JSONEncoder.default(self, obj) + + return json.dumps(self._to_dict(), cls=OriginTypeEncoder) diff --git a/ddtrace/appsec/_iast/taint_sinks/_base.py b/ddtrace/appsec/_iast/taint_sinks/_base.py index 43dc1f5cb53..7cba289d644 100644 --- a/ddtrace/appsec/_iast/taint_sinks/_base.py +++ b/ddtrace/appsec/_iast/taint_sinks/_base.py @@ -19,7 +19,6 @@ from ..reporter import Evidence from ..reporter import IastSpanReporter from ..reporter import Location -from ..reporter import Source from ..reporter import Vulnerability @@ -89,35 +88,16 @@ def _prepare_report(cls, span, vulnerability_type, evidence, file_name, line_num line_number = -1 report = core.get_item(IAST.CONTEXT_KEY, span=span) + vulnerability = Vulnerability( + type=vulnerability_type, + evidence=evidence, + location=Location(path=file_name, line=line_number, spanId=span.span_id), + ) if report: - report.vulnerabilities.add( - Vulnerability( - type=vulnerability_type, - evidence=evidence, - location=Location(path=file_name, line=line_number, spanId=span.span_id), - ) - ) - + report.vulnerabilities.add(vulnerability) else: - report = IastSpanReporter( - vulnerabilities={ - Vulnerability( - type=vulnerability_type, - evidence=evidence, - location=Location(path=file_name, line=line_number, spanId=span.span_id), - ) - } - ) - if sources: - - def cast_value(value): - if isinstance(value, (bytes, bytearray)): - value_decoded = value.decode("utf-8") - else: - value_decoded = value - return value_decoded - - report.sources = [Source(origin=x.origin, name=x.name, value=cast_value(x.value)) for x in sources] + report = IastSpanReporter(vulnerabilities={vulnerability}) + report.add_ranges_to_evidence_and_extract_sources(vulnerability) if getattr(cls, "redact_report", False): redacted_report = cls._redacted_report_cache.get( @@ -130,9 +110,10 @@ def cast_value(value): return True @classmethod - def report(cls, evidence_value="", sources=None): - # type: (Union[Text|List[Dict[str, Any]]], Optional[List[Source]]) -> None + def report(cls, evidence_value="", value_parts=None, sources=None): + # type: (Any, Any, Optional[List[Any]]) -> None """Build a IastSpanReporter instance to report it in the `AppSecIastSpanProcessor` as a string JSON""" + # TODO: type of evidence_value will be Text. We wait to finish the redaction refactor. if cls.acquire_quota(): if not tracer or not hasattr(tracer, "current_root_span"): log.debug( @@ -166,11 +147,12 @@ def report(cls, evidence_value="", sources=None): if not cls.is_not_reported(file_name, line_number): return - if _is_evidence_value_parts(evidence_value): - evidence = Evidence(valueParts=evidence_value) + # TODO: This function is deprecated, but we need to migrate all vulnerabilities first before deleting it + if _is_evidence_value_parts(evidence_value) or _is_evidence_value_parts(value_parts): + evidence = Evidence(value=evidence_value, valueParts=value_parts) # Evidence is a string in weak cipher, weak hash and weak randomness elif isinstance(evidence_value, (str, bytes, bytearray)): - evidence = Evidence(value=evidence_value) + evidence = Evidence(value=evidence_value) # type: ignore else: log.debug("Unexpected evidence_value type: %s", type(evidence_value)) evidence = Evidence(value="") @@ -184,11 +166,17 @@ def report(cls, evidence_value="", sources=None): @classmethod def _extract_sensitive_tokens(cls, report): # type: (Dict[Vulnerability, str]) -> Dict[int, Dict[str, Any]] + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. log.debug("Base class VulnerabilityBase._extract_sensitive_tokens called") return {} @classmethod def _get_vulnerability_text(cls, vulnerability): + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. if vulnerability and vulnerability.evidence.value is not None: return vulnerability.evidence.value @@ -209,6 +197,9 @@ def replace_tokens( vulns_to_tokens, has_range=False, ): + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. ret = vuln.evidence.value replaced = False @@ -222,10 +213,16 @@ def replace_tokens( def _custom_edit_valueparts(cls, vuln): # Subclasses could optionally implement this to add further processing to the # vulnerability valueParts + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. return @classmethod def _redact_report(cls, report): # type: (IastSpanReporter) -> IastSpanReporter + # TODO: This function is deprecated. + # Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. if not asm_config._iast_redaction_enabled: return report @@ -239,8 +236,8 @@ def _redact_report(cls, report): # type: (IastSpanReporter) -> IastSpanReporter for source in report.sources: # Join them so we only run the regexps once for each source # joined_fields = "%s%s" % (source.name, source.value) - if _has_to_scrub(source.name) or _has_to_scrub(source.value): - scrubbed = _scrub(source.value, has_range=True) + if _has_to_scrub(source.name) or _has_to_scrub(source.value): # type: ignore + scrubbed = _scrub(source.value, has_range=True) # type: ignore already_scrubbed[source.value] = scrubbed source.redacted = True sources_values_to_scrubbed[source.value] = scrubbed @@ -252,8 +249,6 @@ def _redact_report(cls, report): # type: (IastSpanReporter) -> IastSpanReporter if vuln.evidence.value is not None: pattern, replaced = cls.replace_tokens(vuln, vulns_to_tokens, hasattr(vuln.evidence.value, "source")) if replaced: - vuln.evidence.pattern = pattern - vuln.evidence.redacted = True vuln.evidence.value = None if vuln.evidence.valueParts is None: diff --git a/ddtrace/appsec/_iast/taint_sinks/ast_taint.py b/ddtrace/appsec/_iast/taint_sinks/ast_taint.py index af8f59b15a9..57d22f63796 100644 --- a/ddtrace/appsec/_iast/taint_sinks/ast_taint.py +++ b/ddtrace/appsec/_iast/taint_sinks/ast_taint.py @@ -14,6 +14,7 @@ from typing import Callable # noqa:F401 +# TODO: we also need a native version of this function! def ast_function( func, # type: Callable flag_added_args, # type: Any diff --git a/ddtrace/appsec/_iast/taint_sinks/command_injection.py b/ddtrace/appsec/_iast/taint_sinks/command_injection.py index 0b11ffd12b0..8f123a2be4c 100644 --- a/ddtrace/appsec/_iast/taint_sinks/command_injection.py +++ b/ddtrace/appsec/_iast/taint_sinks/command_injection.py @@ -1,10 +1,7 @@ import os -import re import subprocess # nosec -from typing import TYPE_CHECKING # noqa:F401 -from typing import List # noqa:F401 -from typing import Set # noqa:F401 -from typing import Union # noqa:F401 +from typing import List +from typing import Union from ddtrace.contrib import trace_utils from ddtrace.internal import core @@ -14,30 +11,15 @@ from ..._constants import IAST_SPAN_TAGS from .. import oce from .._metrics import increment_iast_span_metric -from .._utils import _has_to_scrub -from .._utils import _scrub -from .._utils import _scrub_get_tokens_positions -from ..constants import EVIDENCE_CMDI from ..constants import VULN_CMDI +from ..processor import AppSecIastSpanProcessor from ._base import VulnerabilityBase -from ._base import _check_positions_contained - - -if TYPE_CHECKING: - from typing import Any # noqa:F401 - from typing import Dict # noqa:F401 - - from ..reporter import IastSpanReporter # noqa:F401 - from ..reporter import Vulnerability # noqa:F401 log = get_logger(__name__) -_INSIDE_QUOTES_REGEXP = re.compile(r"^(?:\s*(?:sudo|doas)\s+)?\b\S+\b\s*(.*)") - -def get_version(): - # type: () -> str +def get_version() -> str: return "" @@ -61,8 +43,7 @@ def patch(): core.dispatch("exploit.prevention.ssrf.patch.urllib") -def unpatch(): - # type: () -> None +def unpatch() -> None: trace_utils.unwrap(os, "system") trace_utils.unwrap(os, "_spawnvef") trace_utils.unwrap(subprocess.Popen, "__init__") @@ -93,151 +74,29 @@ def _iast_cmdi_subprocess_init(wrapped, instance, args, kwargs): @oce.register class CommandInjection(VulnerabilityBase): vulnerability_type = VULN_CMDI - evidence_type = EVIDENCE_CMDI - redact_report = True - - @classmethod - def report(cls, evidence_value=None, sources=None): - if isinstance(evidence_value, (str, bytes, bytearray)): - from .._taint_tracking import taint_ranges_as_evidence_info - - evidence_value, sources = taint_ranges_as_evidence_info(evidence_value) - super(CommandInjection, cls).report(evidence_value=evidence_value, sources=sources) - - @classmethod - def _extract_sensitive_tokens(cls, vulns_to_text): - # type: (Dict[Vulnerability, str]) -> Dict[int, Dict[str, Any]] - ret = {} # type: Dict[int, Dict[str, Any]] - for vuln, text in vulns_to_text.items(): - vuln_hash = hash(vuln) - ret[vuln_hash] = { - "tokens": set(_INSIDE_QUOTES_REGEXP.findall(text)), - } - ret[vuln_hash]["token_positions"] = _scrub_get_tokens_positions(text, ret[vuln_hash]["tokens"]) - - return ret - - @classmethod - def _redact_report(cls, report): # type: (IastSpanReporter) -> IastSpanReporter - if not asm_config._iast_redaction_enabled: - return report - - # See if there is a match on either any of the sources or value parts of the report - found = False - - for source in report.sources: - # Join them so we only run the regexps once for each source - joined_fields = "%s%s" % (source.name, source.value) - if _has_to_scrub(joined_fields): - found = True - break - - vulns_to_text = {} - - if not found: - # Check the evidence's value/s - for vuln in report.vulnerabilities: - vulnerability_text = cls._get_vulnerability_text(vuln) - if _has_to_scrub(vulnerability_text) or _INSIDE_QUOTES_REGEXP.match(vulnerability_text): - vulns_to_text[vuln] = vulnerability_text - found = True - break + # TODO: Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. + redact_report = False + - if not found: - return report - - if not vulns_to_text: - vulns_to_text = {vuln: cls._get_vulnerability_text(vuln) for vuln in report.vulnerabilities} - - # If we're here, some potentially sensitive information was found, we delegate on - # the specific subclass the task of extracting the variable tokens (e.g. literals inside - # quotes for SQL Injection). Note that by just having one potentially sensitive match - # we need to then scrub all the tokens, thus why we do it in two steps instead of one - vulns_to_tokens = cls._extract_sensitive_tokens(vulns_to_text) - - if not vulns_to_tokens: - return report - - all_tokens = set() # type: Set[str] - for _, value_dict in vulns_to_tokens.items(): - all_tokens.update(value_dict["tokens"]) - - # Iterate over all the sources, if one of the tokens match it, redact it - for source in report.sources: - if source.name in "".join(all_tokens) or source.value in "".join(all_tokens): - source.pattern = _scrub(source.value, has_range=True) - source.redacted = True - source.value = None - - # Same for all the evidence values - try: - for vuln in report.vulnerabilities: - # Use the initial hash directly as iteration key since the vuln itself will change - vuln_hash = hash(vuln) - if vuln.evidence.value is not None: - pattern, replaced = cls.replace_tokens( - vuln, vulns_to_tokens, hasattr(vuln.evidence.value, "source") - ) - if replaced: - vuln.evidence.pattern = pattern - vuln.evidence.redacted = True - vuln.evidence.value = None - elif vuln.evidence.valueParts is not None: - idx = 0 - new_value_parts = [] - for part in vuln.evidence.valueParts: - value = part["value"] - part_len = len(value) - part_start = idx - part_end = idx + part_len - pattern_list = [] - - for positions in vulns_to_tokens[vuln_hash]["token_positions"]: - if _check_positions_contained(positions, (part_start, part_end)): - part_scrub_start = max(positions[0] - idx, 0) - part_scrub_end = positions[1] - idx - pattern_list.append(value[:part_scrub_start] + "" + value[part_scrub_end:]) - if part.get("source", False) is not False: - source = report.sources[part["source"]] - if source.redacted: - part["redacted"] = source.redacted - part["pattern"] = source.pattern - del part["value"] - new_value_parts.append(part) - break - else: - part["value"] = "".join(pattern_list) - new_value_parts.append(part) - new_value_parts.append({"redacted": True}) - break - else: - new_value_parts.append(part) - pattern_list.append(value[part_start:part_end]) - break - - idx += part_len - vuln.evidence.valueParts = new_value_parts - except (ValueError, KeyError): - log.debug("an error occurred while redacting cmdi", exc_info=True) - return report - - -def _iast_report_cmdi(shell_args): - # type: (Union[str, List[str]]) -> None +def _iast_report_cmdi(shell_args: Union[str, List[str]]) -> None: report_cmdi = "" from .._metrics import _set_metric_iast_executed_sink - from .._taint_tracking import is_pyobject_tainted - from .._taint_tracking.aspects import join_aspect - - if isinstance(shell_args, (list, tuple)): - for arg in shell_args: - if is_pyobject_tainted(arg): - report_cmdi = join_aspect(" ".join, 1, " ", shell_args) - break - elif is_pyobject_tainted(shell_args): - report_cmdi = shell_args increment_iast_span_metric(IAST_SPAN_TAGS.TELEMETRY_EXECUTED_SINK, CommandInjection.vulnerability_type) _set_metric_iast_executed_sink(CommandInjection.vulnerability_type) - if report_cmdi: - CommandInjection.report(evidence_value=report_cmdi) + + if AppSecIastSpanProcessor.is_span_analyzed() and CommandInjection.has_quota(): + from .._taint_tracking import is_pyobject_tainted + from .._taint_tracking.aspects import join_aspect + + if isinstance(shell_args, (list, tuple)): + for arg in shell_args: + if is_pyobject_tainted(arg): + report_cmdi = join_aspect(" ".join, 1, " ", shell_args) + break + elif is_pyobject_tainted(shell_args): + report_cmdi = shell_args + + if report_cmdi: + CommandInjection.report(evidence_value=report_cmdi) diff --git a/ddtrace/appsec/_iast/taint_sinks/header_injection.py b/ddtrace/appsec/_iast/taint_sinks/header_injection.py new file mode 100644 index 00000000000..1ce8a52d5e4 --- /dev/null +++ b/ddtrace/appsec/_iast/taint_sinks/header_injection.py @@ -0,0 +1,136 @@ +import re + +from ddtrace.internal.logger import get_logger +from ddtrace.settings.asm import config as asm_config + +from ..._common_module_patches import try_unwrap +from ..._constants import IAST_SPAN_TAGS +from .. import oce +from .._metrics import _set_metric_iast_instrumented_sink +from .._metrics import increment_iast_span_metric +from .._patch import set_and_check_module_is_patched +from .._patch import set_module_unpatched +from .._patch import try_wrap_function_wrapper +from ..constants import HEADER_NAME_VALUE_SEPARATOR +from ..constants import VULN_HEADER_INJECTION +from ..processor import AppSecIastSpanProcessor +from ._base import VulnerabilityBase + + +log = get_logger(__name__) + +_HEADERS_NAME_REGEXP = re.compile( + r"(?:p(?:ass)?w(?:or)?d|pass(?:_?phrase)?|secret|(?:api_?|private_?|public_?|access_?|secret_?)key(?:_?id)?|token|consumer_?(?:id|key|secret)|sign(?:ed|ature)?|auth(?:entication|orization)?)", + re.IGNORECASE, +) +_HEADERS_VALUE_REGEXP = re.compile( + r"(?:bearer\\s+[a-z0-9\\._\\-]+|glpat-[\\w\\-]{20}|gh[opsu]_[0-9a-zA-Z]{36}|ey[I-L][\\w=\\-]+\\.ey[I-L][\\w=\\-]+(?:\\.[\\w.+/=\\-]+)?|(?:[\\-]{5}BEGIN[a-z\\s]+PRIVATE\\sKEY[\\-]{5}[^\\-]+[\\-]{5}END[a-z\\s]+PRIVATE\\sKEY[\\-]{5}|ssh-rsa\\s*[a-z0-9/\\.+]{100,}))", + re.IGNORECASE, +) + + +def get_version(): + # type: () -> str + return "" + + +def patch(): + if not asm_config._iast_enabled: + return + + if not set_and_check_module_is_patched("flask", default_attr="_datadog_header_injection_patch"): + return + if not set_and_check_module_is_patched("django", default_attr="_datadog_header_injection_patch"): + return + + try_wrap_function_wrapper( + "wsgiref.headers", + "Headers.add_header", + _iast_h, + ) + try_wrap_function_wrapper( + "wsgiref.headers", + "Headers.__setitem__", + _iast_h, + ) + try_wrap_function_wrapper( + "werkzeug.datastructures", + "Headers.set", + _iast_h, + ) + try_wrap_function_wrapper( + "werkzeug.datastructures", + "Headers.add", + _iast_h, + ) + + # Django + try_wrap_function_wrapper( + "django.http.response", + "HttpResponseBase.__setitem__", + _iast_h, + ) + try_wrap_function_wrapper( + "django.http.response", + "ResponseHeaders.__setitem__", + _iast_h, + ) + + _set_metric_iast_instrumented_sink(VULN_HEADER_INJECTION, 1) + + +def unpatch(): + # type: () -> None + try_unwrap("wsgiref.headers", "Headers.add_header") + try_unwrap("wsgiref.headers", "Headers.__setitem__") + try_unwrap("werkzeug.datastructures", "Headers.set") + try_unwrap("werkzeug.datastructures", "Headers.add") + try_unwrap("django.http.response", "HttpResponseBase.__setitem__") + try_unwrap("django.http.response", "ResponseHeaders.__setitem__") + + set_module_unpatched("flask", default_attr="_datadog_header_injection_patch") + set_module_unpatched("django", default_attr="_datadog_header_injection_patch") + + pass + + +def _iast_h(wrapped, instance, args, kwargs): + if asm_config._iast_enabled: + _iast_report_header_injection(args) + return wrapped(*args, **kwargs) + + +@oce.register +class HeaderInjection(VulnerabilityBase): + vulnerability_type = VULN_HEADER_INJECTION + # TODO: Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. + redact_report = False + + +def _iast_report_header_injection(headers_args) -> None: + headers_exclusion = { + "content-type", + "content-length", + "content-encoding", + "transfer-encoding", + "set-cookie", + "vary", + } + from .._metrics import _set_metric_iast_executed_sink + from .._taint_tracking import is_pyobject_tainted + from .._taint_tracking.aspects import add_aspect + + header_name, header_value = headers_args + for header_to_exclude in headers_exclusion: + header_name_lower = header_name.lower() + if header_name_lower == header_to_exclude or header_name_lower.startswith(header_to_exclude): + return + + increment_iast_span_metric(IAST_SPAN_TAGS.TELEMETRY_EXECUTED_SINK, HeaderInjection.vulnerability_type) + _set_metric_iast_executed_sink(HeaderInjection.vulnerability_type) + + if AppSecIastSpanProcessor.is_span_analyzed() and HeaderInjection.has_quota(): + if is_pyobject_tainted(header_name) or is_pyobject_tainted(header_value): + header_evidence = add_aspect(add_aspect(header_name, HEADER_NAME_VALUE_SEPARATOR), header_value) + HeaderInjection.report(evidence_value=header_evidence) diff --git a/ddtrace/appsec/_iast/taint_sinks/path_traversal.py b/ddtrace/appsec/_iast/taint_sinks/path_traversal.py index c7618000d05..e6fde3b40e2 100644 --- a/ddtrace/appsec/_iast/taint_sinks/path_traversal.py +++ b/ddtrace/appsec/_iast/taint_sinks/path_traversal.py @@ -8,7 +8,6 @@ from .._metrics import increment_iast_span_metric from .._patch import set_and_check_module_is_patched from .._patch import set_module_unpatched -from ..constants import EVIDENCE_PATH_TRAVERSAL from ..constants import VULN_PATH_TRAVERSAL from ..processor import AppSecIastSpanProcessor from ._base import VulnerabilityBase @@ -20,15 +19,6 @@ @oce.register class PathTraversal(VulnerabilityBase): vulnerability_type = VULN_PATH_TRAVERSAL - evidence_type = EVIDENCE_PATH_TRAVERSAL - - @classmethod - def report(cls, evidence_value=None, sources=None): - if isinstance(evidence_value, (str, bytes, bytearray)): - from .._taint_tracking import taint_ranges_as_evidence_info - - evidence_value, sources = taint_ranges_as_evidence_info(evidence_value) - super(PathTraversal, cls).report(evidence_value=evidence_value, sources=sources) def get_version(): diff --git a/ddtrace/appsec/_iast/taint_sinks/sql_injection.py b/ddtrace/appsec/_iast/taint_sinks/sql_injection.py index ee7bcfb2f8f..68d5a289c01 100644 --- a/ddtrace/appsec/_iast/taint_sinks/sql_injection.py +++ b/ddtrace/appsec/_iast/taint_sinks/sql_injection.py @@ -32,9 +32,10 @@ class SqlInjection(VulnerabilityBase): @classmethod def report(cls, evidence_value=None, sources=None): + value_parts = [] if isinstance(evidence_value, (str, bytes, bytearray)): - evidence_value, sources = taint_ranges_as_evidence_info(evidence_value) - super(SqlInjection, cls).report(evidence_value=evidence_value, sources=sources) + value_parts, sources = taint_ranges_as_evidence_info(evidence_value) + super(SqlInjection, cls).report(evidence_value=evidence_value, value_parts=value_parts, sources=sources) @classmethod def _extract_sensitive_tokens(cls, vulns_to_text): diff --git a/ddtrace/appsec/_iast/taint_sinks/ssrf.py b/ddtrace/appsec/_iast/taint_sinks/ssrf.py index f114998605a..7a070cf5425 100644 --- a/ddtrace/appsec/_iast/taint_sinks/ssrf.py +++ b/ddtrace/appsec/_iast/taint_sinks/ssrf.py @@ -1,176 +1,33 @@ -import re -from typing import Callable # noqa:F401 -from typing import Dict # noqa:F401 -from typing import Set # noqa:F401 +from typing import Callable from ddtrace.internal.logger import get_logger -from ddtrace.settings.asm import config as asm_config from ..._constants import IAST_SPAN_TAGS from .. import oce from .._metrics import increment_iast_span_metric -from .._utils import _has_to_scrub -from .._utils import _is_iast_enabled -from .._utils import _scrub -from .._utils import _scrub_get_tokens_positions -from ..constants import EVIDENCE_SSRF from ..constants import VULN_SSRF -from ..constants import VULNERABILITY_TOKEN_TYPE from ..processor import AppSecIastSpanProcessor -from ..reporter import IastSpanReporter # noqa:F401 -from ..reporter import Vulnerability from ._base import VulnerabilityBase -from ._base import _check_positions_contained log = get_logger(__name__) -_AUTHORITY_REGEXP = re.compile(r"(?:\/\/([^:@\/]+)(?::([^@\/]+))?@).*") -_QUERY_FRAGMENT_REGEXP = re.compile(r"[?#&]([^=&;]+)=(?P[^?#&]+)") - - @oce.register class SSRF(VulnerabilityBase): vulnerability_type = VULN_SSRF - evidence_type = EVIDENCE_SSRF - redact_report = True - - @classmethod - def report(cls, evidence_value=None, sources=None): - if not _is_iast_enabled(): - return - - from .._taint_tracking import taint_ranges_as_evidence_info - - if isinstance(evidence_value, (str, bytes, bytearray)): - evidence_value, sources = taint_ranges_as_evidence_info(evidence_value) - super(SSRF, cls).report(evidence_value=evidence_value, sources=sources) - - @classmethod - def _extract_sensitive_tokens(cls, vulns_to_text: Dict[Vulnerability, str]) -> VULNERABILITY_TOKEN_TYPE: - ret = {} # type: VULNERABILITY_TOKEN_TYPE - for vuln, text in vulns_to_text.items(): - vuln_hash = hash(vuln) - authority = [] - authority_found = _AUTHORITY_REGEXP.findall(text) - if authority_found: - authority = list(authority_found[0]) - query = [value for param, value in _QUERY_FRAGMENT_REGEXP.findall(text)] - ret[vuln_hash] = { - "tokens": set(authority + query), - } - ret[vuln_hash]["token_positions"] = _scrub_get_tokens_positions(text, ret[vuln_hash]["tokens"]) - - return ret - - @classmethod - def _redact_report(cls, report): # type: (IastSpanReporter) -> IastSpanReporter - if not asm_config._iast_redaction_enabled: - return report - - # See if there is a match on either any of the sources or value parts of the report - found = False - - for source in report.sources: - # Join them so we only run the regexps once for each source - joined_fields = "%s%s" % (source.name, source.value) - if _has_to_scrub(joined_fields): - found = True - break - - vulns_to_text = {} - - if not found: - # Check the evidence's value/s - for vuln in report.vulnerabilities: - vulnerability_text = cls._get_vulnerability_text(vuln) - if _has_to_scrub(vulnerability_text) or _AUTHORITY_REGEXP.match(vulnerability_text): - vulns_to_text[vuln] = vulnerability_text - found = True - break - - if not found: - return report - - if not vulns_to_text: - vulns_to_text = {vuln: cls._get_vulnerability_text(vuln) for vuln in report.vulnerabilities} - - # If we're here, some potentially sensitive information was found, we delegate on - # the specific subclass the task of extracting the variable tokens (e.g. literals inside - # quotes for SQL Injection). Note that by just having one potentially sensitive match - # we need to then scrub all the tokens, thus why we do it in two steps instead of one - vulns_to_tokens = cls._extract_sensitive_tokens(vulns_to_text) - - if not vulns_to_tokens: - return report - - all_tokens = set() # type: Set[str] - for _, value_dict in vulns_to_tokens.items(): - all_tokens.update(value_dict["tokens"]) - - # Iterate over all the sources, if one of the tokens match it, redact it - for source in report.sources: - if source.name in "".join(all_tokens) or source.value in "".join(all_tokens): - source.pattern = _scrub(source.value, has_range=True) - source.redacted = True - source.value = None - - # Same for all the evidence values - for vuln in report.vulnerabilities: - # Use the initial hash directly as iteration key since the vuln itself will change - vuln_hash = hash(vuln) - if vuln.evidence.value is not None: - pattern, replaced = cls.replace_tokens(vuln, vulns_to_tokens, hasattr(vuln.evidence.value, "source")) - if replaced: - vuln.evidence.pattern = pattern - vuln.evidence.redacted = True - vuln.evidence.value = None - elif vuln.evidence.valueParts is not None: - idx = 0 - new_value_parts = [] - for part in vuln.evidence.valueParts: - value = part["value"] - part_len = len(value) - part_start = idx - part_end = idx + part_len - pattern_list = [] - - for positions in vulns_to_tokens[vuln_hash]["token_positions"]: - if _check_positions_contained(positions, (part_start, part_end)): - part_scrub_start = max(positions[0] - idx, 0) - part_scrub_end = positions[1] - idx - pattern_list.append(value[:part_scrub_start] + "" + value[part_scrub_end:]) - if part.get("source", False) is not False: - source = report.sources[part["source"]] - if source.redacted: - part["redacted"] = source.redacted - part["pattern"] = source.pattern - del part["value"] - new_value_parts.append(part) - break - else: - part["value"] = "".join(pattern_list) - new_value_parts.append(part) - new_value_parts.append({"redacted": True}) - break - else: - new_value_parts.append(part) - pattern_list.append(value[part_start:part_end]) - break - - idx += part_len - vuln.evidence.valueParts = new_value_parts - return report + # TODO: Redaction migrated to `ddtrace.appsec._iast._evidence_redaction._sensitive_handler` but we need to migrate + # all vulnerabilities to use it first. + redact_report = False def _iast_report_ssrf(func: Callable, *args, **kwargs): - from .._metrics import _set_metric_iast_executed_sink - report_ssrf = kwargs.get("url", False) - increment_iast_span_metric(IAST_SPAN_TAGS.TELEMETRY_EXECUTED_SINK, SSRF.vulnerability_type) - _set_metric_iast_executed_sink(SSRF.vulnerability_type) if report_ssrf: + from .._metrics import _set_metric_iast_executed_sink + + _set_metric_iast_executed_sink(SSRF.vulnerability_type) + increment_iast_span_metric(IAST_SPAN_TAGS.TELEMETRY_EXECUTED_SINK, SSRF.vulnerability_type) if AppSecIastSpanProcessor.is_span_analyzed() and SSRF.has_quota(): try: from .._taint_tracking import is_pyobject_tainted diff --git a/ddtrace/appsec/_metrics.py b/ddtrace/appsec/_metrics.py index 28644978a0f..28d712cebf7 100644 --- a/ddtrace/appsec/_metrics.py +++ b/ddtrace/appsec/_metrics.py @@ -105,6 +105,21 @@ def _set_waf_request_metrics(*args): 1.0, tags=tags_request, ) + rasp = result["rasp"] + if rasp["called"]: + for t, n in [("eval", "rasp.rule.eval"), ("match", "rasp.rule.match"), ("timeout", "rasp.timeout")]: + for rule_type, value in rasp[t].items(): + if value: + telemetry.telemetry_writer.add_count_metric( + TELEMETRY_NAMESPACE_TAG_APPSEC, + n, + float(value), + tags=( + ("rule_type", rule_type), + ("waf_version", DDWAF_VERSION), + ), + ) + except Exception: log.warning("Error reporting ASM WAF requests metrics", exc_info=True) finally: diff --git a/ddtrace/appsec/_processor.py b/ddtrace/appsec/_processor.py index 4d98ab486b4..a3d0518b6a4 100644 --- a/ddtrace/appsec/_processor.py +++ b/ddtrace/appsec/_processor.py @@ -260,7 +260,12 @@ def waf_callable(custom_data=None, **kwargs): _asm_request_context.call_waf_callback({"REQUEST_HTTP_IP": None}) def _waf_action( - self, span: Span, ctx: ddwaf_context_capsule, custom_data: Optional[Dict[str, Any]] = None, **kwargs + self, + span: Span, + ctx: ddwaf_context_capsule, + custom_data: Optional[Dict[str, Any]] = None, + crop_trace: Optional[str] = None, + rule_type: Optional[str] = None, ) -> Optional[DDWaf_result]: """ Call the `WAF` with the given parameters. If `custom_data_names` is specified as @@ -327,7 +332,7 @@ def _waf_action( from ddtrace.appsec._exploit_prevention.stack_traces import report_stack stack_trace_id = parameters["stack_id"] - report_stack("exploit detected", span, kwargs.get("crop_trace"), stack_id=stack_trace_id) + report_stack("exploit detected", span, crop_trace, stack_id=stack_trace_id) for rule in waf_results.data: rule[EXPLOIT_PREVENTION.STACK_TRACE_ID] = stack_trace_id @@ -335,7 +340,11 @@ def _waf_action( log.debug("[DDAS-011-00] ASM In-App WAF returned: %s. Timeout %s", waf_results.data, waf_results.timeout) _asm_request_context.set_waf_telemetry_results( - self._ddwaf.info.version, bool(waf_results.data), bool(blocked), waf_results.timeout + self._ddwaf.info.version, + bool(waf_results.data), + bool(blocked), + waf_results.timeout, + rule_type, ) if blocked: core.set_item(WAF_CONTEXT_NAMES.BLOCKED, blocked, span=span) diff --git a/ddtrace/contrib/aioredis/patch.py b/ddtrace/contrib/aioredis/patch.py index e460211b089..f44cea456c5 100644 --- a/ddtrace/contrib/aioredis/patch.py +++ b/ddtrace/contrib/aioredis/patch.py @@ -5,6 +5,11 @@ import aioredis from ddtrace import config +from ddtrace._trace.utils_redis import _trace_redis_cmd +from ddtrace._trace.utils_redis import _trace_redis_execute_pipeline +from ddtrace.contrib.redis_utils import ROW_RETURNING_COMMANDS +from ddtrace.contrib.redis_utils import _run_redis_command_async +from ddtrace.contrib.redis_utils import determine_row_count from ddtrace.internal.constants import COMPONENT from ddtrace.internal.utils.wrappers import unwrap as _u from ddtrace.pin import Pin @@ -25,11 +30,6 @@ from ...internal.utils.formats import asbool from ...internal.utils.formats import stringify_cache_args from .. import trace_utils -from ..trace_utils_redis import ROW_RETURNING_COMMANDS -from ..trace_utils_redis import _run_redis_command_async -from ..trace_utils_redis import _trace_redis_cmd -from ..trace_utils_redis import _trace_redis_execute_pipeline -from ..trace_utils_redis import determine_row_count try: @@ -175,7 +175,7 @@ def _finish_span(future): redis_command = span.resource.split(" ")[0] future.result() if redis_command in ROW_RETURNING_COMMANDS: - determine_row_count(redis_command=redis_command, span=span, result=future.result()) + span.set_metric(db.ROWCOUNT, determine_row_count(redis_command=redis_command, result=future.result())) # CancelledError exceptions extend from BaseException as of Python 3.8, instead of usual Exception except (Exception, aioredis.CancelledError): span.set_exc_info(*sys.exc_info()) diff --git a/ddtrace/contrib/aredis/patch.py b/ddtrace/contrib/aredis/patch.py index 1c0dc8c88ff..65386e99932 100644 --- a/ddtrace/contrib/aredis/patch.py +++ b/ddtrace/contrib/aredis/patch.py @@ -3,6 +3,9 @@ import aredis from ddtrace import config +from ddtrace._trace.utils_redis import _trace_redis_cmd +from ddtrace._trace.utils_redis import _trace_redis_execute_pipeline +from ddtrace.contrib.redis_utils import _run_redis_command_async from ddtrace.vendor import wrapt from ...internal.schema import schematize_service_name @@ -11,9 +14,6 @@ from ...internal.utils.formats import stringify_cache_args from ...internal.utils.wrappers import unwrap from ...pin import Pin -from ..trace_utils_redis import _run_redis_command_async -from ..trace_utils_redis import _trace_redis_cmd -from ..trace_utils_redis import _trace_redis_execute_pipeline config._add( diff --git a/ddtrace/contrib/botocore/patch.py b/ddtrace/contrib/botocore/patch.py index ad540ee092e..e0bcc3f317f 100644 --- a/ddtrace/contrib/botocore/patch.py +++ b/ddtrace/contrib/botocore/patch.py @@ -41,7 +41,6 @@ from .services.stepfunctions import update_stepfunction_input from .utils import inject_trace_to_client_context from .utils import inject_trace_to_eventbridge_detail -from .utils import set_patched_api_call_span_tags from .utils import set_response_metadata_tags @@ -183,11 +182,7 @@ def prep_context_injection(ctx, endpoint_name, operation, trace_operation, param injection_function = inject_trace_to_eventbridge_detail cloud_service = "events" if endpoint_name == "sns" and "Publish" in operation: - injection_function = ( # noqa: E731 - lambda ctx, params, span, endpoint_service: inject_trace_to_sqs_or_sns_message( - ctx, params, endpoint_service - ) - ) + injection_function = inject_trace_to_sqs_or_sns_message cloud_service = "sns" if endpoint_name == "states" and (operation == "StartExecution" or operation == "StartSyncExecution"): injection_function = update_stepfunction_input @@ -215,7 +210,6 @@ def patched_api_call_fallback(original_func, instance, args, kwargs, function_va endpoint_name=endpoint_name, operation=operation, service=schematize_service_name("{}.{}".format(pin.service, endpoint_name)), - context_started_callback=set_patched_api_call_span_tags, pin=pin, span_name=function_vars.get("trace_operation"), span_type=SpanTypes.HTTP, diff --git a/ddtrace/contrib/botocore/services/bedrock.py b/ddtrace/contrib/botocore/services/bedrock.py index 0e13fecbf2d..e0896833e1c 100644 --- a/ddtrace/contrib/botocore/services/bedrock.py +++ b/ddtrace/contrib/botocore/services/bedrock.py @@ -42,19 +42,16 @@ def read(self, amt=None): self._body.append(json.loads(body)) if self.__wrapped__.tell() == int(self.__wrapped__._content_length): formatted_response = _extract_text_and_response_reason(self._execution_ctx, self._body[0]) + model_provider = self._execution_ctx["model_provider"] + model_name = self._execution_ctx["model_name"] + should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name core.dispatch( "botocore.bedrock.process_response", - [ - self._execution_ctx, - formatted_response, - None, - self._body[0], - self._execution_ctx["model_provider"] == _COHERE, - ], + [self._execution_ctx, formatted_response, None, self._body[0], should_set_choice_ids], ) return body except Exception: - core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_context, sys.exc_info()]) + core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()]) raise def readlines(self): @@ -64,15 +61,12 @@ def readlines(self): for line in lines: self._body.append(json.loads(line)) formatted_response = _extract_text_and_response_reason(self._execution_ctx, self._body[0]) + model_provider = self._execution_ctx["model_provider"] + model_name = self._execution_ctx["model_name"] + should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name core.dispatch( "botocore.bedrock.process_response", - [ - self._execution_ctx, - formatted_response, - None, - self._body[0], - self._execution_ctx["model_provider"] == _COHERE, - ], + [self._execution_ctx, formatted_response, None, self._body[0], should_set_choice_ids], ) return lines except Exception: @@ -87,15 +81,14 @@ def __iter__(self): yield line metadata = _extract_streamed_response_metadata(self._execution_ctx, self._body) formatted_response = _extract_streamed_response(self._execution_ctx, self._body) + model_provider = self._execution_ctx["model_provider"] + model_name = self._execution_ctx["model_name"] + should_set_choice_ids = ( + model_provider == _COHERE and "is_finished" not in self._body[0] and "embed" not in model_name + ) core.dispatch( "botocore.bedrock.process_response", - [ - self._execution_ctx, - formatted_response, - metadata, - self._body, - self._execution_ctx["model_provider"] == _COHERE and "is_finished" not in self._body[0], - ], + [self._execution_ctx, formatted_response, metadata, self._body, should_set_choice_ids], ) except Exception: core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()]) @@ -107,6 +100,7 @@ def _extract_request_params(params: Dict[str, Any], provider: str) -> Dict[str, Extracts request parameters including prompt, temperature, top_p, max_tokens, and stop_sequences. """ request_body = json.loads(params.get("body")) + model_id = params.get("modelId") if provider == _AI21: return { "prompt": request_body.get("prompt"), @@ -115,6 +109,8 @@ def _extract_request_params(params: Dict[str, Any], provider: str) -> Dict[str, "max_tokens": request_body.get("maxTokens", ""), "stop_sequences": request_body.get("stopSequences", []), } + elif provider == _AMAZON and "embed" in model_id: + return {"prompt": request_body.get("inputText")} elif provider == _AMAZON: text_generation_config = request_body.get("textGenerationConfig", {}) return { @@ -135,6 +131,12 @@ def _extract_request_params(params: Dict[str, Any], provider: str) -> Dict[str, "max_tokens": request_body.get("max_tokens_to_sample", ""), "stop_sequences": request_body.get("stop_sequences", []), } + elif provider == _COHERE and "embed" in model_id: + return { + "prompt": request_body.get("texts"), + "input_type": request_body.get("input_type", ""), + "truncate": request_body.get("truncate", ""), + } elif provider == _COHERE: return { "prompt": request_body.get("prompt"), @@ -161,17 +163,22 @@ def _extract_request_params(params: Dict[str, Any], provider: str) -> Dict[str, def _extract_text_and_response_reason(ctx: core.ExecutionContext, body: Dict[str, Any]) -> Dict[str, List[str]]: text, finish_reason = "", "" + model_name = ctx["model_name"] provider = ctx["model_provider"] try: if provider == _AI21: text = body.get("completions")[0].get("data").get("text") finish_reason = body.get("completions")[0].get("finishReason") + elif provider == _AMAZON and "embed" in model_name: + text = [body.get("embedding", [])] elif provider == _AMAZON: text = body.get("results")[0].get("outputText") finish_reason = body.get("results")[0].get("completionReason") elif provider == _ANTHROPIC: text = body.get("completion", "") or body.get("content", "") finish_reason = body.get("stop_reason") + elif provider == _COHERE and "embed" in model_name: + text = body.get("embeddings", [[]]) elif provider == _COHERE: text = [generation["text"] for generation in body.get("generations")] finish_reason = [generation["finish_reason"] for generation in body.get("generations")] @@ -194,10 +201,13 @@ def _extract_text_and_response_reason(ctx: core.ExecutionContext, body: Dict[str def _extract_streamed_response(ctx: core.ExecutionContext, streamed_body: List[Dict[str, Any]]) -> Dict[str, List[str]]: text, finish_reason = "", "" + model_name = ctx["model_name"] provider = ctx["model_provider"] try: if provider == _AI21: - pass # note: ai21 does not support streamed responses + pass # DEV: ai21 does not support streamed responses + elif provider == _AMAZON and "embed" in model_name: + pass # DEV: amazon embed models do not support streamed responses elif provider == _AMAZON: text = "".join([chunk["outputText"] for chunk in streamed_body]) finish_reason = streamed_body[-1]["completionReason"] @@ -211,7 +221,9 @@ def _extract_streamed_response(ctx: core.ExecutionContext, streamed_body: List[D text += chunk["delta"].get("text", "") if "stop_reason" in chunk["delta"]: finish_reason = str(chunk["delta"]["stop_reason"]) - elif provider == _COHERE and streamed_body: + elif provider == _COHERE and "embed" in model_name: + pass # DEV: cohere embed models do not support streamed responses + elif provider == _COHERE: if "is_finished" in streamed_body[0]: # streamed response if "index" in streamed_body[0]: # n >= 2 num_generations = int(ctx.get_item("num_generations") or 0) @@ -230,8 +242,7 @@ def _extract_streamed_response(ctx: core.ExecutionContext, streamed_body: List[D text = "".join([chunk["generation"] for chunk in streamed_body]) finish_reason = streamed_body[-1]["stop_reason"] elif provider == _STABILITY: - # We do not yet support image modality models - pass + pass # DEV: we do not yet support image modality models except (IndexError, AttributeError): log.warning("Unable to extract text/finish_reason from response body. Defaulting to empty text/finish_reason.") @@ -306,7 +317,7 @@ def patched_bedrock_api_call(original_func, instance, args, kwargs, function_var span_name=function_vars.get("trace_operation"), service=schematize_service_name("{}.{}".format(pin.service, function_vars.get("endpoint_name"))), resource=function_vars.get("operation"), - span_type=SpanTypes.LLM, + span_type=SpanTypes.LLM if "embed" not in model_name else None, call_key="instrumented_bedrock_call", call_trace=True, bedrock_integration=function_vars.get("integration"), diff --git a/ddtrace/contrib/botocore/services/kinesis.py b/ddtrace/contrib/botocore/services/kinesis.py index 58d361f8b93..412f0b0c27f 100644 --- a/ddtrace/contrib/botocore/services/kinesis.py +++ b/ddtrace/contrib/botocore/services/kinesis.py @@ -19,7 +19,6 @@ from ....internal.schema import schematize_service_name from ..utils import extract_DD_context from ..utils import get_kinesis_data_object -from ..utils import set_patched_api_call_span_tags from ..utils import set_response_metadata_tags @@ -134,7 +133,6 @@ def patched_kinesis_api_call(original_func, instance, args, kwargs, function_var operation=operation, service=schematize_service_name("{}.{}".format(pin.service, endpoint_name)), call_trace=False, - context_started_callback=set_patched_api_call_span_tags, pin=pin, span_name=span_name, span_type=SpanTypes.HTTP, diff --git a/ddtrace/contrib/botocore/services/sqs.py b/ddtrace/contrib/botocore/services/sqs.py index 9561f7a93f0..37080c85d70 100644 --- a/ddtrace/contrib/botocore/services/sqs.py +++ b/ddtrace/contrib/botocore/services/sqs.py @@ -8,7 +8,6 @@ from ddtrace import config from ddtrace.contrib.botocore.utils import extract_DD_context -from ddtrace.contrib.botocore.utils import set_patched_api_call_span_tags from ddtrace.contrib.botocore.utils import set_response_metadata_tags from ddtrace.ext import SpanTypes from ddtrace.internal import core @@ -55,7 +54,8 @@ def add_dd_attributes_to_message( entry["MessageAttributes"]["_datadog"] = {"DataType": data_type, f"{data_type}Value": _encode_data(data_to_add)} -def update_messages(ctx, params: Any, endpoint_service: Optional[str] = None) -> None: +def update_messages(ctx, endpoint_service: Optional[str] = None) -> None: + params = ctx["params"] if "Entries" in params or "PublishBatchRequestEntries" in params: entries = params.get("Entries", params.get("PublishBatchRequestEntries", [])) if len(entries) == 0: @@ -141,7 +141,6 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): span_type=SpanTypes.HTTP, child_of=child_of if child_of is not None else pin.tracer.context_provider.active(), activate=True, - context_started_callback=set_patched_api_call_span_tags, instance=instance, args=args, params=params, @@ -154,7 +153,7 @@ def patched_sqs_api_call(original_func, instance, args, kwargs, function_vars): core.dispatch("botocore.patched_sqs_api_call.started", [ctx]) if should_update_messages: - update_messages(ctx, params, endpoint_service=endpoint_name) + update_messages(ctx, endpoint_service=endpoint_name) try: if not func_has_run: diff --git a/ddtrace/contrib/botocore/services/stepfunctions.py b/ddtrace/contrib/botocore/services/stepfunctions.py index 433b36c2238..d611f664a48 100644 --- a/ddtrace/contrib/botocore/services/stepfunctions.py +++ b/ddtrace/contrib/botocore/services/stepfunctions.py @@ -12,7 +12,6 @@ from ....internal.schema import SpanDirection from ....internal.schema import schematize_cloud_messaging_operation from ....internal.schema import schematize_service_name -from ..utils import set_patched_api_call_span_tags from ..utils import set_response_metadata_tags @@ -70,7 +69,6 @@ def patched_stepfunction_api_call(original_func, instance, args, kwargs: Dict, f params=params, endpoint_name=endpoint_name, operation=operation, - context_started_callback=set_patched_api_call_span_tags, pin=pin, ) as ctx, ctx.get_item(ctx["call_key"]): core.dispatch("botocore.patched_stepfunctions_api_call.started", [ctx]) diff --git a/ddtrace/contrib/botocore/utils.py b/ddtrace/contrib/botocore/utils.py index c35a3a5312b..ead47ace10c 100644 --- a/ddtrace/contrib/botocore/utils.py +++ b/ddtrace/contrib/botocore/utils.py @@ -3,35 +3,28 @@ """ import base64 import json -from typing import Any # noqa:F401 -from typing import Dict # noqa:F401 -from typing import Optional # noqa:F401 -from typing import Tuple # noqa:F401 +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple -from ddtrace import Span # noqa:F401 +from ddtrace import Span from ddtrace import config -from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY -from ddtrace.constants import SPAN_KIND -from ddtrace.constants import SPAN_MEASURED_KEY -from ddtrace.ext import SpanKind -from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.core import ExecutionContext -from ...ext import aws from ...ext import http from ...internal.logger import get_logger -from ...internal.utils.formats import deep_getattr from ...propagation.http import HTTPPropagator log = get_logger(__name__) -MAX_EVENTBRIDGE_DETAIL_SIZE = 1 << 18 # 256KB - +TWOFIFTYSIX_KB = 1 << 18 +MAX_EVENTBRIDGE_DETAIL_SIZE = TWOFIFTYSIX_KB LINE_BREAK = "\n" -def get_json_from_str(data_str): - # type: (str) -> Tuple[str, Optional[Dict[str, Any]]] +def get_json_from_str(data_str: str) -> Tuple[str, Optional[Dict[str, Any]]]: data_obj = json.loads(data_str) if data_str.endswith(LINE_BREAK): @@ -39,8 +32,7 @@ def get_json_from_str(data_str): return None, data_obj -def get_kinesis_data_object(data): - # type: (str) -> Tuple[str, Optional[Dict[str, Any]]] +def get_kinesis_data_object(data: str) -> Tuple[str, Optional[Dict[str, Any]]]: """ :data: the data from a kinesis stream The data from a kinesis stream comes as a string (could be json, base64 encoded, etc.) @@ -74,14 +66,12 @@ def get_kinesis_data_object(data): return None, None -def inject_trace_to_eventbridge_detail(ctx, params, span): - # type: (Any, Any, Span) -> None +def inject_trace_to_eventbridge_detail(ctx: ExecutionContext) -> None: """ - :params: contains the params for the current botocore action - :span: the span which provides the trace context to be propagated Inject trace headers into the EventBridge record if the record's Detail object contains a JSON string Max size per event is 256KB (https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-putevent-size.html) """ + params = ctx["params"] if "Entries" not in params: log.warning("Unable to inject context. The Event Bridge event had no Entries.") return @@ -96,6 +86,7 @@ def inject_trace_to_eventbridge_detail(ctx, params, span): continue detail["_datadog"] = {} + span = ctx[ctx["call_key"]] HTTPPropagator.inject(span.context, detail["_datadog"]) detail_json = json.dumps(detail) @@ -108,18 +99,10 @@ def inject_trace_to_eventbridge_detail(ctx, params, span): entry["Detail"] = detail_json -def modify_client_context(client_context_object, trace_headers): - if config.botocore["invoke_with_legacy_context"]: - trace_headers = {"_datadog": trace_headers} - - if "custom" in client_context_object: - client_context_object["custom"].update(trace_headers) - else: - client_context_object["custom"] = trace_headers - - -def inject_trace_to_client_context(ctx, params, span): +def inject_trace_to_client_context(ctx): trace_headers = {} + span = ctx[ctx["call_key"]] + params = ctx["params"] HTTPPropagator.inject(span.context, trace_headers) client_context_object = {} if "ClientContext" in params: @@ -138,40 +121,17 @@ def inject_trace_to_client_context(ctx, params, span): params["ClientContext"] = base64.b64encode(json_context).decode("utf-8") -def set_patched_api_call_span_tags(span, instance, args, params, endpoint_name, operation): - span.set_tag_str(COMPONENT, config.botocore.integration_name) - # set span.kind to the type of request being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - span.set_tag(SPAN_MEASURED_KEY) - - if args: - # DEV: join is the fastest way of concatenating strings that is compatible - # across Python versions (see - # https://stackoverflow.com/questions/1316887/what-is-the-most-efficient-string-concatenation-method-in-python) - span.resource = ".".join((endpoint_name, operation.lower())) - span.set_tag("aws_service", endpoint_name) - - if params and not config.botocore["tag_no_params"]: - aws._add_api_param_span_tags(span, endpoint_name, params) +def modify_client_context(client_context_object, trace_headers): + if config.botocore["invoke_with_legacy_context"]: + trace_headers = {"_datadog": trace_headers} + if "custom" in client_context_object: + client_context_object["custom"].update(trace_headers) else: - span.resource = endpoint_name - - region_name = deep_getattr(instance, "meta.region_name") - - span.set_tag_str("aws.agent", "botocore") - if operation is not None: - span.set_tag_str("aws.operation", operation) - if region_name is not None: - span.set_tag_str("aws.region", region_name) - span.set_tag_str("region", region_name) - - # set analytics sample rate - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.botocore.get_analytics_sample_rate()) + client_context_object["custom"] = trace_headers -def set_response_metadata_tags(span, result): - # type: (Span, Dict[str, Any]) -> None +def set_response_metadata_tags(span: Span, result: Dict[str, Any]) -> None: if not result or not result.get("ResponseMetadata"): return response_meta = result["ResponseMetadata"] diff --git a/ddtrace/contrib/flask_cache/utils.py b/ddtrace/contrib/flask_cache/utils.py index d7f33ec160c..d770b1d065b 100644 --- a/ddtrace/contrib/flask_cache/utils.py +++ b/ddtrace/contrib/flask_cache/utils.py @@ -1,7 +1,8 @@ # project +from ddtrace._trace.utils_redis import _extract_conn_tags as extract_redis_tags + from ...ext import net from ..pylibmc.addrs import parse_addresses -from ..trace_utils_redis import _extract_conn_tags as extract_redis_tags def _resource_from_cache_prefix(resource, cache): diff --git a/ddtrace/contrib/redis/asyncio_patch.py b/ddtrace/contrib/redis/asyncio_patch.py index f444fef7ddb..7bcc9653c74 100644 --- a/ddtrace/contrib/redis/asyncio_patch.py +++ b/ddtrace/contrib/redis/asyncio_patch.py @@ -1,11 +1,11 @@ from ddtrace import config +from ddtrace._trace.utils_redis import _trace_redis_cmd +from ddtrace._trace.utils_redis import _trace_redis_execute_async_cluster_pipeline +from ddtrace._trace.utils_redis import _trace_redis_execute_pipeline +from ddtrace.contrib.redis_utils import _run_redis_command_async from ...internal.utils.formats import stringify_cache_args from ...pin import Pin -from ..trace_utils_redis import _run_redis_command_async -from ..trace_utils_redis import _trace_redis_cmd -from ..trace_utils_redis import _trace_redis_execute_async_cluster_pipeline -from ..trace_utils_redis import _trace_redis_execute_pipeline # diff --git a/ddtrace/contrib/redis/patch.py b/ddtrace/contrib/redis/patch.py index c4bf0f42af6..f9784c4e29b 100644 --- a/ddtrace/contrib/redis/patch.py +++ b/ddtrace/contrib/redis/patch.py @@ -3,6 +3,11 @@ import redis from ddtrace import config +from ddtrace._trace.utils_redis import _trace_redis_cmd +from ddtrace._trace.utils_redis import _trace_redis_execute_pipeline +from ddtrace.contrib.redis_utils import ROW_RETURNING_COMMANDS +from ddtrace.contrib.redis_utils import determine_row_count +from ddtrace.ext import db from ddtrace.vendor import wrapt from ...internal.schema import schematize_service_name @@ -11,9 +16,6 @@ from ...internal.utils.formats import stringify_cache_args from ...pin import Pin from ..trace_utils import unwrap -from ..trace_utils_redis import _run_redis_command -from ..trace_utils_redis import _trace_redis_cmd -from ..trace_utils_redis import _trace_redis_execute_pipeline config._add( @@ -119,6 +121,20 @@ def unpatch(): unwrap(redis.asyncio.cluster.ClusterPipeline, "execute") +def _run_redis_command(span, func, args, kwargs): + parsed_command = stringify_cache_args(args) + redis_command = parsed_command.split(" ")[0] + try: + result = func(*args, **kwargs) + if redis_command in ROW_RETURNING_COMMANDS: + span.set_metric(db.ROWCOUNT, determine_row_count(redis_command=redis_command, result=result)) + return result + except Exception: + if redis_command in ROW_RETURNING_COMMANDS: + span.set_metric(db.ROWCOUNT, 0) + raise + + # # tracing functions # diff --git a/ddtrace/contrib/redis_utils.py b/ddtrace/contrib/redis_utils.py new file mode 100644 index 00000000000..fa605cdeaa8 --- /dev/null +++ b/ddtrace/contrib/redis_utils.py @@ -0,0 +1,82 @@ +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from ddtrace.ext import net +from ddtrace.ext import redis as redisx +from ddtrace.internal import core +from ddtrace.internal.utils.formats import stringify_cache_args + + +SINGLE_KEY_COMMANDS = [ + "GET", + "GETDEL", + "GETEX", + "GETRANGE", + "GETSET", + "LINDEX", + "LRANGE", + "RPOP", + "LPOP", + "HGET", + "HGETALL", + "HKEYS", + "HMGET", + "HRANDFIELD", + "HVALS", +] +MULTI_KEY_COMMANDS = ["MGET"] +ROW_RETURNING_COMMANDS = SINGLE_KEY_COMMANDS + MULTI_KEY_COMMANDS + + +def _extract_conn_tags(conn_kwargs): + """Transform redis conn info into dogtrace metas""" + try: + conn_tags = { + net.TARGET_HOST: conn_kwargs["host"], + net.TARGET_PORT: conn_kwargs["port"], + redisx.DB: conn_kwargs.get("db") or 0, + } + client_name = conn_kwargs.get("client_name") + if client_name: + conn_tags[redisx.CLIENT_NAME] = client_name + return conn_tags + except Exception: + return {} + + +def determine_row_count(redis_command: str, result: Optional[Union[List, Dict, str]]) -> int: + empty_results = [b"", [], {}, None] + # result can be an empty list / dict / string + if result not in empty_results: + if redis_command == "MGET": + # only include valid key results within count + result = [x for x in result if x not in empty_results] + return len(result) + elif redis_command == "HMGET": + # only include valid key results within count + result = [x for x in result if x not in empty_results] + return 1 if len(result) > 0 else 0 + else: + return 1 + else: + return 0 + + +async def _run_redis_command_async(span, func, args, kwargs): + parsed_command = stringify_cache_args(args) + redis_command = parsed_command.split(" ")[0] + rowcount = None + try: + result = await func(*args, **kwargs) + return result + except Exception: + rowcount = 0 + raise + finally: + if rowcount is None: + rowcount = determine_row_count(redis_command=redis_command, result=result) + if redis_command not in ROW_RETURNING_COMMANDS: + rowcount = None + core.dispatch("redis.async_command.post", [span, rowcount]) diff --git a/ddtrace/contrib/trace_utils_redis.py b/ddtrace/contrib/trace_utils_redis.py deleted file mode 100644 index 88bdc11639c..00000000000 --- a/ddtrace/contrib/trace_utils_redis.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Some utils used by the dogtrace redis integration -""" -from contextlib import contextmanager - -from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY -from ddtrace.constants import SPAN_KIND -from ddtrace.constants import SPAN_MEASURED_KEY -from ddtrace.contrib import trace_utils -from ddtrace.ext import SpanKind -from ddtrace.ext import SpanTypes -from ddtrace.ext import db -from ddtrace.ext import net -from ddtrace.ext import redis as redisx -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.schema import schematize_cache_operation -from ddtrace.internal.utils.formats import stringify_cache_args - - -format_command_args = stringify_cache_args - -SINGLE_KEY_COMMANDS = [ - "GET", - "GETDEL", - "GETEX", - "GETRANGE", - "GETSET", - "LINDEX", - "LRANGE", - "RPOP", - "LPOP", - "HGET", - "HGETALL", - "HKEYS", - "HMGET", - "HRANDFIELD", - "HVALS", -] -MULTI_KEY_COMMANDS = ["MGET"] -ROW_RETURNING_COMMANDS = SINGLE_KEY_COMMANDS + MULTI_KEY_COMMANDS - - -def _extract_conn_tags(conn_kwargs): - """Transform redis conn info into dogtrace metas""" - try: - conn_tags = { - net.TARGET_HOST: conn_kwargs["host"], - net.TARGET_PORT: conn_kwargs["port"], - redisx.DB: conn_kwargs.get("db") or 0, - } - client_name = conn_kwargs.get("client_name") - if client_name: - conn_tags[redisx.CLIENT_NAME] = client_name - return conn_tags - except Exception: - return {} - - -def determine_row_count(redis_command, span, result): - empty_results = [b"", [], {}, None] - # result can be an empty list / dict / string - if result not in empty_results: - if redis_command == "MGET": - # only include valid key results within count - result = [x for x in result if x not in empty_results] - span.set_metric(db.ROWCOUNT, len(result)) - elif redis_command == "HMGET": - # only include valid key results within count - result = [x for x in result if x not in empty_results] - span.set_metric(db.ROWCOUNT, 1 if len(result) > 0 else 0) - else: - span.set_metric(db.ROWCOUNT, 1) - else: - # set count equal to 0 if an empty result - span.set_metric(db.ROWCOUNT, 0) - - -def _run_redis_command(span, func, args, kwargs): - parsed_command = stringify_cache_args(args) - redis_command = parsed_command.split(" ")[0] - try: - result = func(*args, **kwargs) - if redis_command in ROW_RETURNING_COMMANDS: - determine_row_count(redis_command=redis_command, span=span, result=result) - return result - except Exception: - if redis_command in ROW_RETURNING_COMMANDS: - span.set_metric(db.ROWCOUNT, 0) - raise - - -@contextmanager -def _trace_redis_cmd(pin, config_integration, instance, args): - """Create a span for the execute command method and tag it""" - query = stringify_cache_args(args, cmd_max_len=config_integration.cmd_max_length) - with pin.tracer.trace( - schematize_cache_operation(redisx.CMD, cache_provider=redisx.APP), - service=trace_utils.ext_service(pin, config_integration), - span_type=SpanTypes.REDIS, - resource=query.split(" ")[0] if config_integration.resource_only_command else query, - ) as span: - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - span.set_tag_str(COMPONENT, config_integration.integration_name) - span.set_tag_str(db.SYSTEM, redisx.APP) - span.set_tag(SPAN_MEASURED_KEY) - span_name = schematize_cache_operation(redisx.RAWCMD, cache_provider=redisx.APP) - span.set_tag_str(span_name, query) - if pin.tags: - span.set_tags(pin.tags) - # some redis clients do not have a connection_pool attribute (ex. aioredis v1.3) - if hasattr(instance, "connection_pool"): - span.set_tags(_extract_conn_tags(instance.connection_pool.connection_kwargs)) - span.set_metric(redisx.ARGS_LEN, len(args)) - # set analytics sample rate if enabled - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config_integration.get_analytics_sample_rate()) - yield span - - -@contextmanager -def _trace_redis_execute_pipeline(pin, config_integration, cmds, instance, is_cluster=False): - """Create a span for the execute pipeline method and tag it""" - cmd_string = resource = "\n".join(cmds) - if config_integration.resource_only_command: - resource = "\n".join([cmd.split(" ")[0] for cmd in cmds]) - - with pin.tracer.trace( - schematize_cache_operation(redisx.CMD, cache_provider=redisx.APP), - resource=resource, - service=trace_utils.ext_service(pin, config_integration), - span_type=SpanTypes.REDIS, - ) as span: - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - span.set_tag_str(COMPONENT, config_integration.integration_name) - span.set_tag_str(db.SYSTEM, redisx.APP) - span.set_tag(SPAN_MEASURED_KEY) - span_name = schematize_cache_operation(redisx.RAWCMD, cache_provider=redisx.APP) - span.set_tag_str(span_name, cmd_string) - if not is_cluster: - span.set_tags(_extract_conn_tags(instance.connection_pool.connection_kwargs)) - span.set_metric(redisx.PIPELINE_LEN, len(instance.command_stack)) - # set analytics sample rate if enabled - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config_integration.get_analytics_sample_rate()) - # yield the span in case the caller wants to build on span - yield span - - -@contextmanager -def _trace_redis_execute_async_cluster_pipeline(pin, config_integration, cmds, instance): - """Create a span for the execute async cluster pipeline method and tag it""" - cmd_string = resource = "\n".join(cmds) - if config_integration.resource_only_command: - resource = "\n".join([cmd.split(" ")[0] for cmd in cmds]) - - with pin.tracer.trace( - schematize_cache_operation(redisx.CMD, cache_provider=redisx.APP), - resource=resource, - service=trace_utils.ext_service(pin, config_integration), - span_type=SpanTypes.REDIS, - ) as span: - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - span.set_tag_str(COMPONENT, config_integration.integration_name) - span.set_tag_str(db.SYSTEM, redisx.APP) - span.set_tag(SPAN_MEASURED_KEY) - span_name = schematize_cache_operation(redisx.RAWCMD, cache_provider=redisx.APP) - span.set_tag_str(span_name, cmd_string) - span.set_metric(redisx.PIPELINE_LEN, len(instance._command_stack)) - # set analytics sample rate if enabled - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config_integration.get_analytics_sample_rate()) - # yield the span in case the caller wants to build on span - yield span - - -async def _run_redis_command_async(span, func, args, kwargs): - parsed_command = stringify_cache_args(args) - redis_command = parsed_command.split(" ")[0] - try: - result = await func(*args, **kwargs) - if redis_command in ROW_RETURNING_COMMANDS: - determine_row_count(redis_command=redis_command, span=span, result=result) - return result - except Exception: - if redis_command in ROW_RETURNING_COMMANDS: - span.set_metric(db.ROWCOUNT, 0) - raise diff --git a/ddtrace/contrib/yaaredis/patch.py b/ddtrace/contrib/yaaredis/patch.py index 5166e0d6b82..a23b0d86cc2 100644 --- a/ddtrace/contrib/yaaredis/patch.py +++ b/ddtrace/contrib/yaaredis/patch.py @@ -3,6 +3,9 @@ import yaaredis from ddtrace import config +from ddtrace._trace.utils_redis import _trace_redis_cmd +from ddtrace._trace.utils_redis import _trace_redis_execute_pipeline +from ddtrace.contrib.redis_utils import _run_redis_command_async from ddtrace.vendor import wrapt from ...internal.schema import schematize_service_name @@ -11,9 +14,6 @@ from ...internal.utils.formats import stringify_cache_args from ...internal.utils.wrappers import unwrap from ...pin import Pin -from ..trace_utils_redis import _run_redis_command_async -from ..trace_utils_redis import _trace_redis_cmd -from ..trace_utils_redis import _trace_redis_execute_pipeline config._add( diff --git a/ddtrace/internal/constants.py b/ddtrace/internal/constants.py index 566bec75dad..50b8e1280e4 100644 --- a/ddtrace/internal/constants.py +++ b/ddtrace/internal/constants.py @@ -90,7 +90,9 @@ class _PRIORITY_CATEGORY: USER = "user" - RULE = "rule" + RULE_DEF = "rule_default" + RULE_CUSTOMER = "rule_customer" + RULE_DYNAMIC = "rule_dynamic" AUTO = "auto" DEFAULT = "default" @@ -99,7 +101,9 @@ class _PRIORITY_CATEGORY: # used to simplify code that selects sampling priority based on many factors _CATEGORY_TO_PRIORITIES = { _PRIORITY_CATEGORY.USER: (USER_KEEP, USER_REJECT), - _PRIORITY_CATEGORY.RULE: (USER_KEEP, USER_REJECT), + _PRIORITY_CATEGORY.RULE_DEF: (USER_KEEP, USER_REJECT), + _PRIORITY_CATEGORY.RULE_CUSTOMER: (USER_KEEP, USER_REJECT), + _PRIORITY_CATEGORY.RULE_DYNAMIC: (USER_KEEP, USER_REJECT), _PRIORITY_CATEGORY.AUTO: (AUTO_KEEP, AUTO_REJECT), _PRIORITY_CATEGORY.DEFAULT: (AUTO_KEEP, AUTO_REJECT), } diff --git a/ddtrace/internal/flare.py b/ddtrace/internal/flare.py index 9a11223b221..7cf850e7656 100644 --- a/ddtrace/internal/flare.py +++ b/ddtrace/internal/flare.py @@ -1,4 +1,5 @@ import binascii +import dataclasses import io import json import logging @@ -7,9 +8,7 @@ import pathlib import shutil import tarfile -from typing import Any from typing import Dict -from typing import List from typing import Optional from typing import Tuple @@ -19,7 +18,7 @@ from ddtrace.internal.utils.http import get_connection -TRACER_FLARE_DIRECTORY = pathlib.Path("tracer_flare") +TRACER_FLARE_DIRECTORY = "tracer_flare" TRACER_FLARE_TAR = pathlib.Path("tracer_flare.tar") TRACER_FLARE_ENDPOINT = "/tracer_flare/v1" TRACER_FLARE_FILE_HANDLER_NAME = "tracer_flare_file_handler" @@ -29,111 +28,99 @@ log = get_logger(__name__) +@dataclasses.dataclass +class FlareSendRequest: + case_id: str + hostname: str + email: str + source: str = "tracer_python" + + class Flare: - def __init__(self, timeout_sec: int = DEFAULT_TIMEOUT_SECONDS): - self.original_log_level = 0 # NOTSET - self.timeout = timeout_sec + def __init__(self, timeout_sec: int = DEFAULT_TIMEOUT_SECONDS, flare_dir: str = TRACER_FLARE_DIRECTORY): + self.original_log_level: int = logging.NOTSET + self.timeout: int = timeout_sec + self.flare_dir: pathlib.Path = pathlib.Path(flare_dir) self.file_handler: Optional[RotatingFileHandler] = None - def prepare(self, configs: List[dict]): + def prepare(self, log_level: str): """ Update configurations to start sending tracer logs to a file to be sent in a flare later. """ - if not os.path.exists(TRACER_FLARE_DIRECTORY): - try: - os.makedirs(TRACER_FLARE_DIRECTORY) - log.info("Tracer logs will now be sent to the %s directory", TRACER_FLARE_DIRECTORY) - except Exception as e: - log.error("Failed to create %s directory: %s", TRACER_FLARE_DIRECTORY, e) - return - for agent_config in configs: - # AGENT_CONFIG is currently being used for multiple purposes - # We only want to prepare for a tracer flare if the config name - # starts with 'flare-log-level' - if not agent_config.get("name", "").startswith("flare-log-level"): - return + try: + self.flare_dir.mkdir(exist_ok=True) + except Exception as e: + log.error("Failed to create %s directory: %s", self.flare_dir, e) + return + + flare_log_level_int = logging.getLevelName(log_level) + if type(flare_log_level_int) != int: + raise TypeError("Invalid log level provided: %s", log_level) - # Validate the flare log level - flare_log_level = agent_config.get("config", {}).get("log_level").upper() - flare_log_level_int = logging.getLevelName(flare_log_level) - if type(flare_log_level_int) != int: - raise TypeError("Invalid log level provided: %s", flare_log_level_int) - - ddlogger = get_logger("ddtrace") - pid = os.getpid() - flare_file_path = TRACER_FLARE_DIRECTORY / pathlib.Path(f"tracer_python_{pid}.log") - self.original_log_level = ddlogger.level - - # Set the logger level to the more verbose between original and flare - # We do this valid_original_level check because if the log level is NOTSET, the value is 0 - # which is the minimum value. In this case, we just want to use the flare level, but still - # retain the original state as NOTSET/0 - valid_original_level = 100 if self.original_log_level == 0 else self.original_log_level - logger_level = min(valid_original_level, flare_log_level_int) - ddlogger.setLevel(logger_level) - self.file_handler = _add_file_handler( - ddlogger, flare_file_path.__str__(), flare_log_level, TRACER_FLARE_FILE_HANDLER_NAME - ) - - # Create and add config file - self._generate_config_file(pid) - - def send(self, configs: List[Any]): + ddlogger = get_logger("ddtrace") + pid = os.getpid() + flare_file_path = self.flare_dir / f"tracer_python_{pid}.log" + self.original_log_level = ddlogger.level + + # Set the logger level to the more verbose between original and flare + # We do this valid_original_level check because if the log level is NOTSET, the value is 0 + # which is the minimum value. In this case, we just want to use the flare level, but still + # retain the original state as NOTSET/0 + valid_original_level = ( + logging.CRITICAL if self.original_log_level == logging.NOTSET else self.original_log_level + ) + logger_level = min(valid_original_level, flare_log_level_int) + ddlogger.setLevel(logger_level) + self.file_handler = _add_file_handler( + ddlogger, flare_file_path.__str__(), flare_log_level_int, TRACER_FLARE_FILE_HANDLER_NAME + ) + + # Create and add config file + self._generate_config_file(pid) + + def send(self, flare_send_req: FlareSendRequest): """ Revert tracer flare configurations back to original state before sending the flare. """ - for agent_task in configs: - # AGENT_TASK is currently being used for multiple purposes - # We only want to generate the tracer flare if the task_type is - # 'tracer_flare' - if type(agent_task) != dict or agent_task.get("task_type") != "tracer_flare": - continue - args = agent_task.get("args", {}) - - self.revert_configs() - - # We only want the flare to be sent once, even if there are - # multiple tracer instances - lock_path = TRACER_FLARE_DIRECTORY / TRACER_FLARE_LOCK - if not os.path.exists(lock_path): - try: - open(lock_path, "w").close() - except Exception as e: - log.error("Failed to create %s file", lock_path) - raise e - data = { - "case_id": args.get("case_id"), - "source": "tracer_python", - "hostname": args.get("hostname"), - "email": args.get("user_handle"), - } - try: - client = get_connection(config._trace_agent_url, timeout=self.timeout) - headers, body = self._generate_payload(data) - client.request("POST", TRACER_FLARE_ENDPOINT, body, headers) - response = client.getresponse() - if response.status == 200: - log.info("Successfully sent the flare") - else: - log.error( - "Upload failed with %s status code:(%s) %s", - response.status, - response.reason, - response.read().decode(), - ) - except Exception as e: - log.error("Failed to send tracer flare") - raise e - finally: - client.close() - # Clean up files regardless of success/failure - self.clean_up_files() - return + self.revert_configs() + + # We only want the flare to be sent once, even if there are + # multiple tracer instances + lock_path = self.flare_dir / TRACER_FLARE_LOCK + if not os.path.exists(lock_path): + try: + open(lock_path, "w").close() + except Exception as e: + log.error("Failed to create %s file", lock_path) + raise e + try: + client = get_connection(config._trace_agent_url, timeout=self.timeout) + headers, body = self._generate_payload(flare_send_req.__dict__) + client.request("POST", TRACER_FLARE_ENDPOINT, body, headers) + response = client.getresponse() + if response.status == 200: + log.info("Successfully sent the flare to Zendesk ticket %s", flare_send_req.case_id) + else: + log.error( + "Tracer flare upload to Zendesk ticket %s failed with %s status code:(%s) %s", + flare_send_req.case_id, + response.status, + response.reason, + response.read().decode(), + ) + except Exception as e: + log.error("Failed to send tracer flare to Zendesk ticket %s", flare_send_req.case_id) + raise e + finally: + client.close() + # Clean up files regardless of success/failure + self.clean_up_files() + return def _generate_config_file(self, pid: int): - config_file = TRACER_FLARE_DIRECTORY / pathlib.Path(f"tracer_config_{pid}.json") + config_file = self.flare_dir / f"tracer_config_{pid}.json" try: with open(config_file, "w") as f: tracer_configs = { @@ -162,8 +149,7 @@ def revert_configs(self): def _generate_payload(self, params: Dict[str, str]) -> Tuple[dict, bytes]: tar_stream = io.BytesIO() with tarfile.open(fileobj=tar_stream, mode="w") as tar: - for file_name in os.listdir(TRACER_FLARE_DIRECTORY): - flare_file_name = TRACER_FLARE_DIRECTORY / pathlib.Path(file_name) + for flare_file_name in self.flare_dir.iterdir(): tar.add(flare_file_name) tar_stream.seek(0) @@ -197,6 +183,6 @@ def _get_valid_logger_level(self, flare_log_level: int) -> int: def clean_up_files(self): try: - shutil.rmtree(TRACER_FLARE_DIRECTORY) + shutil.rmtree(self.flare_dir) except Exception as e: log.warning("Failed to clean up tracer flare files: %s", e) diff --git a/ddtrace/internal/packages.py b/ddtrace/internal/packages.py index ddbef347ef4..2d8f1c5fd1e 100644 --- a/ddtrace/internal/packages.py +++ b/ddtrace/internal/packages.py @@ -190,8 +190,8 @@ def _third_party_packages() -> set: return ( set(decompress(read_binary("ddtrace.internal", "third-party.tar.gz")).decode("utf-8").splitlines()) - | tp_config.excludes - ) - tp_config.includes + | tp_config.includes + ) - tp_config.excludes @cached() diff --git a/ddtrace/internal/remoteconfig/client.py b/ddtrace/internal/remoteconfig/client.py index 5f6234cbe51..c2768e57bc6 100644 --- a/ddtrace/internal/remoteconfig/client.py +++ b/ddtrace/internal/remoteconfig/client.py @@ -75,6 +75,7 @@ class Capabilities(enum.IntFlag): APM_TRACING_HTTP_HEADER_TAGS = 1 << 14 APM_TRACING_CUSTOM_TAGS = 1 << 15 APM_TRACING_ENABLED = 1 << 19 + APM_TRACING_SAMPLE_RULES = 1 << 29 class RemoteConfigError(Exception): @@ -382,6 +383,7 @@ def _build_payload(self, state): | Capabilities.APM_TRACING_HTTP_HEADER_TAGS | Capabilities.APM_TRACING_CUSTOM_TAGS | Capabilities.APM_TRACING_ENABLED + | Capabilities.APM_TRACING_SAMPLE_RULES ) return dict( client=dict( @@ -544,9 +546,10 @@ def _process_response(self, data): # type: (Mapping[str, Any]) -> None try: payload = self.converter.structure_attrs_fromdict(data, AgentPayload) - except Exception: + except Exception as e: log.debug("invalid agent payload received: %r", data, exc_info=True) - raise RemoteConfigError("invalid agent payload received") + msg = f"invalid agent payload received: {e}" + raise RemoteConfigError(msg) self._validate_config_exists_in_target_paths(payload.client_configs, payload.target_files) diff --git a/ddtrace/internal/sampling.py b/ddtrace/internal/sampling.py index 0d5aa1a2784..267c575e8a5 100644 --- a/ddtrace/internal/sampling.py +++ b/ddtrace/internal/sampling.py @@ -62,6 +62,16 @@ class SamplingMechanism(object): REMOTE_RATE_USER = 6 REMOTE_RATE_DATADOG = 7 SPAN_SAMPLING_RULE = 8 + REMOTE_USER_RULE = 11 + REMOTE_DYNAMIC_RULE = 12 + + +class PriorityCategory(object): + DEFAULT = "default" + AUTO = "auto" + RULE_DEFAULT = "rule_default" + RULE_CUSTOMER = "rule_customer" + RULE_DYNAMIC = "rule_dynamic" # Use regex to validate trace tag value @@ -278,11 +288,17 @@ def is_single_span_sampled(span): def _set_sampling_tags(span, sampled, sample_rate, priority_category): # type: (Span, bool, float, str) -> None mechanism = SamplingMechanism.TRACE_SAMPLING_RULE - if priority_category == "rule": + if priority_category == PriorityCategory.RULE_DEFAULT: + span.set_metric(SAMPLING_RULE_DECISION, sample_rate) + if priority_category == PriorityCategory.RULE_CUSTOMER: + span.set_metric(SAMPLING_RULE_DECISION, sample_rate) + mechanism = SamplingMechanism.REMOTE_USER_RULE + if priority_category == PriorityCategory.RULE_DYNAMIC: span.set_metric(SAMPLING_RULE_DECISION, sample_rate) - elif priority_category == "default": + mechanism = SamplingMechanism.REMOTE_DYNAMIC_RULE + elif priority_category == PriorityCategory.DEFAULT: mechanism = SamplingMechanism.DEFAULT - elif priority_category == "auto": + elif priority_category == PriorityCategory.AUTO: mechanism = SamplingMechanism.AGENT_RATE span.set_metric(SAMPLING_AGENT_DECISION, sample_rate) priorities = _CATEGORY_TO_PRIORITIES[priority_category] diff --git a/ddtrace/internal/telemetry/writer.py b/ddtrace/internal/telemetry/writer.py index 1ef8b965351..836daa5da74 100644 --- a/ddtrace/internal/telemetry/writer.py +++ b/ddtrace/internal/telemetry/writer.py @@ -83,7 +83,6 @@ from .constants import TELEMETRY_TRACE_PEER_SERVICE_MAPPING from .constants import TELEMETRY_TRACE_REMOVE_INTEGRATION_SERVICE_NAMES_ENABLED from .constants import TELEMETRY_TRACE_SAMPLING_LIMIT -from .constants import TELEMETRY_TRACE_SAMPLING_RULES from .constants import TELEMETRY_TRACE_SPAN_ATTRIBUTE_SCHEMA from .constants import TELEMETRY_TRACE_WRITER_BUFFER_SIZE_BYTES from .constants import TELEMETRY_TRACE_WRITER_INTERVAL_SECONDS @@ -386,6 +385,9 @@ def _telemetry_entry(self, cfg_name: str) -> Tuple[str, str, _ConfigSource]: elif cfg_name == "_trace_sample_rate": name = "trace_sample_rate" value = str(item.value()) + elif cfg_name == "_trace_sampling_rules": + name = "trace_sampling_rules" + value = str(item.value()) elif cfg_name == "logs_injection": name = "logs_injection_enabled" value = "true" if item.value() else "false" @@ -428,6 +430,7 @@ def _app_started_event(self, register_app_shutdown=True): self._telemetry_entry("_sca_enabled"), self._telemetry_entry("_dsm_enabled"), self._telemetry_entry("_trace_sample_rate"), + self._telemetry_entry("_trace_sampling_rules"), self._telemetry_entry("logs_injection"), self._telemetry_entry("trace_http_header_tags"), self._telemetry_entry("tags"), @@ -462,7 +465,6 @@ def _app_started_event(self, register_app_shutdown=True): (TELEMETRY_TRACE_SAMPLING_LIMIT, config._trace_rate_limit, "unknown"), (TELEMETRY_SPAN_SAMPLING_RULES, config._sampling_rules, "unknown"), (TELEMETRY_SPAN_SAMPLING_RULES_FILE, config._sampling_rules_file, "unknown"), - (TELEMETRY_TRACE_SAMPLING_RULES, config._trace_sampling_rules, "unknown"), (TELEMETRY_PRIORITY_SAMPLING, config._priority_sampling, "unknown"), (TELEMETRY_PARTIAL_FLUSH_ENABLED, config._partial_flush_enabled, "unknown"), (TELEMETRY_PARTIAL_FLUSH_MIN_SPANS, config._partial_flush_min_spans, "unknown"), @@ -714,10 +716,10 @@ def _generate_metrics_event(self, namespace_metrics): elif payload_type == TELEMETRY_TYPE_GENERATE_METRICS: self.add_event(payload, TELEMETRY_TYPE_GENERATE_METRICS) - def _generate_logs_event(self, payload): + def _generate_logs_event(self, logs): # type: (Set[Dict[str, str]]) -> None log.debug("%s request payload", TELEMETRY_TYPE_LOGS) - self.add_event(list(payload), TELEMETRY_TYPE_LOGS) + self.add_event({"logs": list(logs)}, TELEMETRY_TYPE_LOGS) def periodic(self, force_flush=False): namespace_metrics = self._namespace.flush() diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 7c2b73d0bd8..b7cf05d8beb 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -27,7 +27,8 @@ from ddtrace.llmobs._trace_processor import LLMObsTraceProcessor from ddtrace.llmobs._utils import _get_ml_app from ddtrace.llmobs._utils import _get_session_id -from ddtrace.llmobs._writer import LLMObsWriter +from ddtrace.llmobs._writer import LLMObsEvalMetricWriter +from ddtrace.llmobs._writer import LLMObsSpanWriter from ddtrace.llmobs.utils import Messages @@ -41,18 +42,25 @@ class LLMObs(Service): def __init__(self, tracer=None): super(LLMObs, self).__init__() self.tracer = tracer or ddtrace.tracer - self._llmobs_writer = LLMObsWriter( + self._llmobs_span_writer = LLMObsSpanWriter( site=config._dd_site, api_key=config._dd_api_key, interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)), timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 2.0)), ) - self._llmobs_writer.start() + self._llmobs_eval_metric_writer = LLMObsEvalMetricWriter( + site=config._dd_site, + api_key=config._dd_api_key, + interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)), + timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 2.0)), + ) + self._llmobs_span_writer.start() + self._llmobs_eval_metric_writer.start() def _start_service(self) -> None: tracer_filters = self.tracer._filters if not any(isinstance(tracer_filter, LLMObsTraceProcessor) for tracer_filter in tracer_filters): - tracer_filters += [LLMObsTraceProcessor(self._llmobs_writer)] + tracer_filters += [LLMObsTraceProcessor(self._llmobs_span_writer)] self.tracer.configure(settings={"FILTERS": tracer_filters}) def _stop_service(self) -> None: @@ -319,21 +327,21 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None): Will be mapped to span's `meta.{input,output}.messages` fields. """ if input_messages is not None: - if not isinstance(input_messages, Messages): - input_messages = Messages(input_messages) try: + if not isinstance(input_messages, Messages): + input_messages = Messages(input_messages) if input_messages.messages: span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages.messages)) except (TypeError, AttributeError): - log.warning("Failed to parse input messages.") + log.warning("Failed to parse input messages.", exc_info=True) if output_messages is not None: - if not isinstance(output_messages, Messages): - output_messages = Messages(output_messages) try: + if not isinstance(output_messages, Messages): + output_messages = Messages(output_messages) if output_messages.messages: span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages.messages)) except (TypeError, AttributeError): - log.warning("Failed to parse output messages.") + log.warning("Failed to parse output messages.", exc_info=True) @classmethod def _tag_text_io(cls, span, input_value=None, output_value=None): diff --git a/ddtrace/llmobs/_trace_processor.py b/ddtrace/llmobs/_trace_processor.py index 45728ffad3c..f95b2637be0 100644 --- a/ddtrace/llmobs/_trace_processor.py +++ b/ddtrace/llmobs/_trace_processor.py @@ -41,8 +41,8 @@ class LLMObsTraceProcessor(TraceProcessor): Processor that extracts LLM-type spans in a trace to submit as separate LLMObs span events to LLM Observability. """ - def __init__(self, llmobs_writer): - self._writer = llmobs_writer + def __init__(self, llmobs_span_writer): + self._span_writer = llmobs_span_writer self._no_apm_traces = asbool(os.getenv("DD_LLMOBS_NO_APM", False)) def process_trace(self, trace: List[Span]) -> Optional[List[Span]]: @@ -57,7 +57,7 @@ def submit_llmobs_span(self, span: Span) -> None: """Generate and submit an LLMObs span event to be sent to LLMObs.""" try: span_event = self._llmobs_span_event(span) - self._writer.enqueue(span_event) + self._span_writer.enqueue(span_event) except (KeyError, TypeError): log.error("Error generating LLMObs span event for span %s, likely due to malformed span", span) diff --git a/ddtrace/llmobs/_writer.py b/ddtrace/llmobs/_writer.py index 9518191d21e..8380f861f0c 100644 --- a/ddtrace/llmobs/_writer.py +++ b/ddtrace/llmobs/_writer.py @@ -3,6 +3,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Union # TypedDict was added to typing in python 3.8 @@ -21,7 +22,7 @@ logger = get_logger(__name__) -class LLMObsEvent(TypedDict): +class LLMObsSpanEvent(TypedDict): span_id: str trace_id: str parent_id: str @@ -37,40 +38,49 @@ class LLMObsEvent(TypedDict): metrics: Dict[str, Any] -class LLMObsWriter(PeriodicService): - """Writer to the Datadog LLMObs intake.""" +class LLMObsEvaluationMetricEvent(TypedDict, total=False): + span_id: str + trace_id: str + metric_type: str + label: str + categorical_value: str + numerical_value: float + score_value: float + + +class BaseLLMObsWriter(PeriodicService): + """Base writer class for submitting data to Datadog LLMObs endpoints.""" def __init__(self, site: str, api_key: str, interval: float, timeout: float) -> None: - super(LLMObsWriter, self).__init__(interval=interval) + super(BaseLLMObsWriter, self).__init__(interval=interval) self._lock = forksafe.RLock() - self._buffer = [] # type: List[LLMObsEvent] + self._buffer = [] # type: List[Union[LLMObsSpanEvent, LLMObsEvaluationMetricEvent]] self._buffer_limit = 1000 self._timeout = timeout # type: float self._api_key = api_key or "" # type: str - self._endpoint = "/api/v2/llmobs" # type: str + self._endpoint = "" # type: str self._site = site # type: str - self._intake = "llmobs-intake.%s" % self._site # type: str + self._intake = "" # type: str self._headers = {"DD-API-KEY": self._api_key, "Content-Type": "application/json"} + self._event_type = "" # type: str def start(self, *args, **kwargs): - super(LLMObsWriter, self).start() - logger.debug("started llmobs writer to %r", self._url) + super(BaseLLMObsWriter, self).start() + logger.debug("started %r to %r", (self.__class__.__name__, self._url)) atexit.register(self.on_shutdown) - def enqueue(self, event: LLMObsEvent) -> None: + def on_shutdown(self): + self.periodic() + + def _enqueue(self, event: Union[LLMObsSpanEvent, LLMObsEvaluationMetricEvent]) -> None: with self._lock: if len(self._buffer) >= self._buffer_limit: - logger.warning("LLMobs event buffer full (limit is %d), dropping record", self._buffer_limit) + logger.warning( + "%r event buffer full (limit is %d), dropping event", (self.__class__.__name__, self._buffer_limit) + ) return self._buffer.append(event) - def on_shutdown(self): - self.periodic() - - @property - def _url(self) -> str: - return "https://%s%s" % (self._intake, self._endpoint) - def periodic(self) -> None: with self._lock: if not self._buffer: @@ -78,11 +88,11 @@ def periodic(self) -> None: events = self._buffer self._buffer = [] - data = {"ml_obs": {"stage": "raw", "type": "span", "spans": events}} + data = self._data(events) try: enc_llm_events = json.dumps(data) except TypeError: - logger.error("failed to encode %d LLMObs events", len(events), exc_info=True) + logger.error("failed to encode %d LLMObs %s events", (len(events), self._event_type), exc_info=True) return conn = httplib.HTTPSConnection(self._intake, 443, timeout=self._timeout) try: @@ -90,15 +100,61 @@ def periodic(self) -> None: resp = get_connection_response(conn) if resp.status >= 300: logger.error( - "failed to send %d LLMObs events to %r, got response code %r, status: %r", - len(events), - self._url, - resp.status, - resp.read(), + "failed to send %d LLMObs %s events to %s, got response code %d, status: %s", + ( + len(events), + self._event_type, + self._url, + resp.status, + resp.read(), + ), ) else: - logger.debug("sent %d LLMObs events to %r", len(events), self._url) + logger.debug("sent %d LLMObs %s events to %s", (len(events), self._event_type, self._url)) except Exception: - logger.error("failed to send %d LLMObs events to %r", len(events), self._intake, exc_info=True) + logger.error( + "failed to send %d LLMObs %s events to %s", (len(events), self._event_type, self._intake), exc_info=True + ) finally: conn.close() + + @property + def _url(self) -> str: + return "https://%s%s" % (self._intake, self._endpoint) + + def _data(self, events: List[Any]) -> Dict[str, Any]: + raise NotImplementedError + + +class LLMObsSpanWriter(BaseLLMObsWriter): + """Writer to the Datadog LLMObs Span Event Endpoint.""" + + def __init__(self, site: str, api_key: str, interval: float, timeout: float) -> None: + super(LLMObsSpanWriter, self).__init__(site, api_key, interval, timeout) + self._event_type = "span" + self._buffer = [] + self._endpoint = "/api/v2/llmobs" # type: str + self._intake = "llmobs-intake.%s" % self._site # type: str + + def enqueue(self, event: LLMObsSpanEvent) -> None: + self._enqueue(event) + + def _data(self, events: List[LLMObsSpanEvent]) -> Dict[str, Any]: + return {"ml_obs": {"stage": "raw", "type": "span", "spans": events}} + + +class LLMObsEvalMetricWriter(BaseLLMObsWriter): + """Writer to the Datadog LLMObs Custom Eval Metrics Endpoint.""" + + def __init__(self, site: str, api_key: str, interval: float, timeout: float) -> None: + super(LLMObsEvalMetricWriter, self).__init__(site, api_key, interval, timeout) + self._event_type = "evaluation_metric" + self._buffer = [] + self._endpoint = "/api/unstable/llm-obs/v1/eval-metric" + self._intake = "api.%s" % self._site # type: str + + def enqueue(self, event: LLMObsEvaluationMetricEvent) -> None: + self._enqueue(event) + + def _data(self, events: List[LLMObsEvaluationMetricEvent]) -> Dict[str, Any]: + return {"data": {"type": "evaluation_metric", "attributes": {"metrics": events}}} diff --git a/ddtrace/llmobs/utils.py b/ddtrace/llmobs/utils.py index 9304b6a7aa5..997a26c0b85 100644 --- a/ddtrace/llmobs/utils.py +++ b/ddtrace/llmobs/utils.py @@ -23,20 +23,19 @@ def __init__(self, messages: Union[List[Dict[str, str]], Dict[str, str], str]): self.messages = [] if not isinstance(messages, list): messages = [messages] # type: ignore[list-item] - try: - for message in messages: - if isinstance(message, str): - self.messages.append(Message(content=message)) - continue - elif not isinstance(message, dict): - log.warning("messages must be a string, dictionary, or list of dictionaries.") - continue - if "role" not in message: - self.messages.append(Message(content=message.get("content", ""))) - continue - self.messages.append(Message(content=message.get("content", ""), role=message.get("role", ""))) - except (TypeError, ValueError, AttributeError): - log.warning( - "Cannot format provided messages. The messages argument must be a string, a dictionary, or a " - "list of dictionaries, or construct messages directly using the ``ddtrace.llmobs.utils.Message`` class." - ) + for message in messages: + if isinstance(message, str): + self.messages.append(Message(content=message)) + continue + elif not isinstance(message, dict): + raise TypeError("messages must be a string, dictionary, or list of dictionaries.") + content = message.get("content", "") + role = message.get("role") + if not isinstance(content, str): + raise TypeError("Message content must be a string.") + if not role: + self.messages.append(Message(content=content)) + continue + if not isinstance(role, str): + raise TypeError("Message role must be a string, and one of .") + self.messages.append(Message(content=content, role=role)) diff --git a/ddtrace/opentelemetry/__init__.py b/ddtrace/opentelemetry/__init__.py index 816400e3129..289b296c767 100644 --- a/ddtrace/opentelemetry/__init__.py +++ b/ddtrace/opentelemetry/__init__.py @@ -38,10 +38,10 @@ Datadog and OpenTelemetry APIs can be used interchangeably:: # Sample Usage - import opentelemetry + from opentelemetry import trace import ddtrace - oteltracer = opentelemetry.trace.get_tracer(__name__) + oteltracer = trace.get_tracer(__name__) with oteltracer.start_as_current_span("otel-span") as parent_span: parent_span.set_attribute("otel_key", "otel_val") diff --git a/ddtrace/sampler.py b/ddtrace/sampler.py index 69cc58c73d7..fe558c1f426 100644 --- a/ddtrace/sampler.py +++ b/ddtrace/sampler.py @@ -23,6 +23,8 @@ from .settings import _config as ddconfig +PROVENANCE_ORDER = ["customer", "dynamic", "default"] + try: from json.decoder import JSONDecodeError except ImportError: @@ -158,7 +160,7 @@ def _choose_priority_category(self, sampler): elif isinstance(sampler, _AgentRateSampler): return _PRIORITY_CATEGORY.AUTO else: - return _PRIORITY_CATEGORY.RULE + return _PRIORITY_CATEGORY.RULE_DEF def _make_sampling_decision(self, span): # type: (Span) -> Tuple[bool, BaseSampler] @@ -204,7 +206,7 @@ class DatadogSampler(RateByServiceSampler): per second. """ - __slots__ = ("limiter", "rules") + __slots__ = ("limiter", "rules", "default_sample_rate") NO_RATE_LIMIT = -1 # deprecate and remove the DEFAULT_RATE_LIMIT field from DatadogSampler @@ -228,7 +230,7 @@ def __init__( """ # Use default sample rate of 1.0 super(DatadogSampler, self).__init__() - + self.default_sample_rate = default_sample_rate if default_sample_rate is None: if ddconfig._get_source("_trace_sample_rate") != "default": default_sample_rate = float(ddconfig._trace_sample_rate) @@ -239,7 +241,7 @@ def __init__( if rules is None: env_sampling_rules = ddconfig._trace_sampling_rules if env_sampling_rules: - rules = self._parse_rules_from_env_variable(env_sampling_rules) + rules = self._parse_rules_from_str(env_sampling_rules) else: rules = [] self.rules = rules @@ -268,7 +270,8 @@ def __str__(self): __repr__ = __str__ - def _parse_rules_from_env_variable(self, rules): + @staticmethod + def _parse_rules_from_str(rules): # type: (str) -> List[SamplingRule] sampling_rules = [] try: @@ -283,13 +286,22 @@ def _parse_rules_from_env_variable(self, rules): name = rule.get("name", SamplingRule.NO_RULE) resource = rule.get("resource", SamplingRule.NO_RULE) tags = rule.get("tags", SamplingRule.NO_RULE) + provenance = rule.get("provenance", "default") try: sampling_rule = SamplingRule( - sample_rate=sample_rate, service=service, name=name, resource=resource, tags=tags + sample_rate=sample_rate, + service=service, + name=name, + resource=resource, + tags=tags, + provenance=provenance, ) except ValueError as e: raise ValueError("Error creating sampling rule {}: {}".format(json.dumps(rule), e)) sampling_rules.append(sampling_rule) + + # Sort the sampling_rules list using a lambda function as the key + sampling_rules = sorted(sampling_rules, key=lambda rule: PROVENANCE_ORDER.index(rule.provenance)) return sampling_rules def sample(self, span): @@ -320,7 +332,13 @@ def sample(self, span): def _choose_priority_category_with_rule(self, rule, sampler): # type: (Optional[SamplingRule], BaseSampler) -> str if rule: - return _PRIORITY_CATEGORY.RULE + provenance = rule.provenance + if provenance == "customer": + return _PRIORITY_CATEGORY.RULE_CUSTOMER + if provenance == "dynamic": + return _PRIORITY_CATEGORY.RULE_DYNAMIC + return _PRIORITY_CATEGORY.RULE_DEF + if self.limiter._has_been_configured: return _PRIORITY_CATEGORY.USER return super(DatadogSampler, self)._choose_priority_category(sampler) diff --git a/ddtrace/sampling_rule.py b/ddtrace/sampling_rule.py index aecf03de5ab..72ab1574277 100644 --- a/ddtrace/sampling_rule.py +++ b/ddtrace/sampling_rule.py @@ -34,6 +34,7 @@ def __init__( name=NO_RULE, # type: Any resource=NO_RULE, # type: Any tags=NO_RULE, # type: Any + provenance="default", # type: str ): # type: (...) -> None """ @@ -83,6 +84,7 @@ def __init__( self.service = self.choose_matcher(service) self.name = self.choose_matcher(name) self.resource = self.choose_matcher(resource) + self.provenance = provenance @property def sample_rate(self): @@ -236,13 +238,14 @@ def choose_matcher(self, prop): return GlobMatcher(prop) if prop != SamplingRule.NO_RULE else SamplingRule.NO_RULE def __repr__(self): - return "{}(sample_rate={!r}, service={!r}, name={!r}, resource={!r}, tags={!r})".format( + return "{}(sample_rate={!r}, service={!r}, name={!r}, resource={!r}, tags={!r}, provenance={!r})".format( self.__class__.__name__, self.sample_rate, self._no_rule_or_self(self.service), self._no_rule_or_self(self.name), self._no_rule_or_self(self.resource), self._no_rule_or_self(self.tags), + self.provenance, ) __str__ = __repr__ diff --git a/ddtrace/settings/asm.py b/ddtrace/settings/asm.py index fc15d95c13b..91fb9417807 100644 --- a/ddtrace/settings/asm.py +++ b/ddtrace/settings/asm.py @@ -92,8 +92,7 @@ class ASMConfig(Env): _deduplication_enabled = Env.var(bool, "_DD_APPSEC_DEDUPLICATION_ENABLED", default=True) # default will be set to True once the feature is GA. For now it's always False - # _ep_enabled = Env.var(bool, EXPLOIT_PREVENTION.EP_ENABLED, default=False) - _ep_enabled = False + _ep_enabled = Env.var(bool, EXPLOIT_PREVENTION.EP_ENABLED, default=False) _ep_stack_trace_enabled = Env.var(bool, EXPLOIT_PREVENTION.STACK_TRACE_ENABLED, default=True) # for max_stack_traces, 0 == unlimited _ep_max_stack_traces = Env.var( diff --git a/ddtrace/settings/config.py b/ddtrace/settings/config.py index c49abc83bae..fc0083d6222 100644 --- a/ddtrace/settings/config.py +++ b/ddtrace/settings/config.py @@ -1,4 +1,5 @@ from copy import deepcopy +import json import os import re import sys @@ -282,6 +283,11 @@ def _default_config(): default=1.0, envs=[("DD_TRACE_SAMPLE_RATE", float)], ), + "_trace_sampling_rules": _ConfigItem( + name="trace_sampling_rules", + default=lambda: "", + envs=[("DD_TRACE_SAMPLING_RULES", str)], + ), "logs_injection": _ConfigItem( name="logs_injection", default=False, @@ -384,7 +390,6 @@ def __init__(self): self._startup_logs_enabled = asbool(os.getenv("DD_TRACE_STARTUP_LOGS", False)) self._trace_rate_limit = int(os.getenv("DD_TRACE_RATE_LIMIT", default=DEFAULT_SAMPLING_RATE_LIMIT)) - self._trace_sampling_rules = os.getenv("DD_TRACE_SAMPLING_RULES") self._partial_flush_enabled = asbool(os.getenv("DD_TRACE_PARTIAL_FLUSH_ENABLED", default=True)) self._partial_flush_min_spans = int(os.getenv("DD_TRACE_PARTIAL_FLUSH_MIN_SPANS", default=300)) self._priority_sampling = asbool(os.getenv("DD_PRIORITY_SAMPLING", default=True)) @@ -562,7 +567,6 @@ def __init__(self): def __getattr__(self, name) -> Any: if name in self._config: return self._config[name].value() - if name not in self._integration_configs: self._integration_configs[name] = IntegrationConfig(self, name) @@ -753,6 +757,14 @@ def _handle_remoteconfig(self, data, test_tracer=None): if "tracing_sampling_rate" in lib_config: base_rc_config["_trace_sample_rate"] = lib_config["tracing_sampling_rate"] + if "tracing_sampling_rules" in lib_config: + trace_sampling_rules = lib_config["tracing_sampling_rules"] + if trace_sampling_rules: + # returns None if no rules + trace_sampling_rules = self.convert_rc_trace_sampling_rules(trace_sampling_rules) + if trace_sampling_rules: + base_rc_config["_trace_sampling_rules"] = trace_sampling_rules + if "log_injection_enabled" in lib_config: base_rc_config["logs_injection"] = lib_config["log_injection_enabled"] @@ -802,3 +814,60 @@ def enable_remote_configuration(self): remoteconfig_poller.register("APM_TRACING", remoteconfig_pubsub) remoteconfig_poller.register("AGENT_CONFIG", remoteconfig_pubsub) remoteconfig_poller.register("AGENT_TASK", remoteconfig_pubsub) + + def _remove_invalid_rules(self, rc_rules: List) -> List: + """Remove invalid sampling rules from the given list""" + # loop through list of dictionaries, if a dictionary doesn't have certain attributes, remove it + for rule in rc_rules: + if ( + ("service" not in rule and "name" not in rule and "resource" not in rule and "tags" not in rule) + or "sample_rate" not in rule + or "provenance" not in rule + ): + log.debug("Invalid sampling rule from remoteconfig found, rule will be removed: %s", rule) + rc_rules.remove(rule) + + return rc_rules + + def _tags_to_dict(self, tags: List[Dict]): + """ + Converts a list of tag dictionaries to a single dictionary. + """ + if isinstance(tags, list): + return {tag["key"]: tag["value_glob"] for tag in tags} + return tags + + def convert_rc_trace_sampling_rules(self, rc_rules: List[Dict[str, Any]]) -> Optional[str]: + """Example of an incoming rule: + [ + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [ + { + "key": "care_about", + "value_glob": "yes" + }, + { + "key": "region", + "value_glob": "us-*" + } + ] + } + ] + + Example of a converted rule: + '[{"sample_rate":1.0,"service":"my-service","resource":"*","name":"web.request","tags":{"care_about":"yes","region":"us-*"},provenance":"customer"}]' + """ + rc_rules = self._remove_invalid_rules(rc_rules) + for rule in rc_rules: + tags = rule.get("tags") + if tags: + rule["tags"] = self._tags_to_dict(tags) + if rc_rules: + return json.dumps(rc_rules) + else: + return None diff --git a/ddtrace/settings/third_party.py b/ddtrace/settings/third_party.py index 83a7c1ef567..3416be1d524 100644 --- a/ddtrace/settings/third_party.py +++ b/ddtrace/settings/third_party.py @@ -7,14 +7,14 @@ class ThirdPartyDetectionConfig(En): excludes = En.v( set, "excludes", - help="Additional packages to treat as third-party", + help="List of packages that should not be treated as third-party", help_type="List", default=set(), ) includes = En.v( set, "includes", - help="List of packages that should not be treated as third-party", + help="Additional packages to treat as third-party", help_type="List", default=set(), ) diff --git a/docker-compose.yml b/docker-compose.yml index f5a7060507e..dd8cc79c6cf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -138,7 +138,7 @@ services: - ENABLED_CHECKS=trace_content_length,trace_stall,meta_tracer_version_header,trace_count_header,trace_peer_service,trace_dd_service - SNAPSHOT_IGNORED_ATTRS=span_id,trace_id,parent_id,duration,start,metrics.system.pid,metrics.system.process_id,metrics.process_id,meta.runtime-id,meta._dd.p.tid,meta.pathway.hash,metrics._dd.tracer_kr,meta._dd.parent_id vertica: - image: sumitchawla/vertica + image: vertica/vertica-ce environment: - VP_TEST_USER=dbadmin - VP_TEST_PASSWORD=abc123 @@ -195,7 +195,7 @@ services: - DD_REMOTE_CONFIGURATION_ENABLED=true - DD_AGENT_PORT=8126 - DD_TRACE_AGENT_URL=http://testagent:8126 - - _DD_APPSEC_DEDUPLICATION_ENABLED=false + - _DD_APPSEC_DEDUPLICATION_ENABLED=false volumes: ddagent: diff --git a/releasenotes/notes/exploit_prevention_feature_LFI_SSRF-5be3b699341eadb1.yaml b/releasenotes/notes/exploit_prevention_feature_LFI_SSRF-5be3b699341eadb1.yaml new file mode 100644 index 00000000000..7b1d0648f84 --- /dev/null +++ b/releasenotes/notes/exploit_prevention_feature_LFI_SSRF-5be3b699341eadb1.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + ASM: This introduces Exploit Prevention for Application Security Management for + LFI (using file opening with standard CPython API) and SSRF (using either standard CPython API urllib or + the requests package available on pypi). + By default, the feature is disabled, but it can be enabled with `DD_APPSEC_RASP_ENABLED=true` in the environment. diff --git a/releasenotes/notes/feat-bedrock-embedding-d44ac603bdb83a7b.yaml b/releasenotes/notes/feat-bedrock-embedding-d44ac603bdb83a7b.yaml new file mode 100644 index 00000000000..db857b82469 --- /dev/null +++ b/releasenotes/notes/feat-bedrock-embedding-d44ac603bdb83a7b.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + botocore: This introduces tracing support for bedrock-runtime embedding operations. diff --git a/releasenotes/notes/fix-fstring-zeropadding-e8e463a4d8623040.yaml b/releasenotes/notes/fix-fstring-zeropadding-e8e463a4d8623040.yaml new file mode 100644 index 00000000000..99f4706de38 --- /dev/null +++ b/releasenotes/notes/fix-fstring-zeropadding-e8e463a4d8623040.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Code Security: This fix solves an issue with fstrings where formatting was not applied to int parameters diff --git a/scripts/cppcheck.sh b/scripts/cppcheck.sh index 920d1c060fd..d96809c0cf4 100755 --- a/scripts/cppcheck.sh +++ b/scripts/cppcheck.sh @@ -1,5 +1,5 @@ #!/bin/bash set -e -cppcheck --error-exitcode=1 --std=c++17 --language=c++ --force \ +cppcheck --inline-suppr --error-exitcode=1 --std=c++17 --language=c++ --force \ $(git ls-files '*.c' '*.cpp' '*.h' '*.hpp' '*.cc' '*.hh' | grep -E -v '^(ddtrace/(vendor|internal)|ddtrace/appsec/_iast/_taint_tracking/_vendor)/') diff --git a/tests/.suitespec.json b/tests/.suitespec.json index f20f33f57e6..7e6f1512ec4 100644 --- a/tests/.suitespec.json +++ b/tests/.suitespec.json @@ -139,7 +139,8 @@ "ddtrace/contrib/redis/*", "ddtrace/contrib/aredis/*", "ddtrace/contrib/yaaredis/*", - "ddtrace/contrib/trace_utils_redis.py", + "ddtrace/_trace/utils_redis.py", + "ddtrace/contrib/redis_utils.py", "ddtrace/ext/redis.py" ], "mongo": [ diff --git a/tests/appsec/contrib_appsec/utils.py b/tests/appsec/contrib_appsec/utils.py index dae02eb8f21..1a193b47a04 100644 --- a/tests/appsec/contrib_appsec/utils.py +++ b/tests/appsec/contrib_appsec/utils.py @@ -1186,8 +1186,11 @@ def test_stream_response( def test_exploit_prevention( self, interface, root_span, get_tag, asm_enabled, ep_enabled, endpoint, parameters, rule, top_functions ): + from unittest.mock import patch as mock_patch + from ddtrace.appsec._common_module_patches import patch_common_modules from ddtrace.appsec._common_module_patches import unpatch_common_modules + from ddtrace.appsec._metrics import DDWAF_VERSION from ddtrace.contrib.requests import patch as patch_requests from ddtrace.contrib.requests import unpatch as unpatch_requests from ddtrace.ext import http @@ -1196,7 +1199,7 @@ def test_exploit_prevention( patch_requests() with override_global_config(dict(_asm_enabled=asm_enabled, _ep_enabled=ep_enabled)), override_env( dict(DD_APPSEC_RULES=rules.RULES_EXPLOIT_PREVENTION) - ): + ), mock_patch("ddtrace.internal.telemetry.metrics_namespaces.MetricNamespace.add_metric") as mocked: patch_common_modules() self.update_tracer(interface) response = interface.client.get(f"/rasp/{endpoint}/?{parameters}") @@ -1212,6 +1215,20 @@ def test_exploit_prevention( assert any( function.endswith(top_function) for top_function in top_functions ), f"unknown top function {function}" + # assert mocked.call_args_list == [] + telemetry_calls = { + (c.__name__, f"{ns}.{nm}", t): v for (c, ns, nm, v, t), _ in mocked.call_args_list + } + assert ( + "CountMetric", + "appsec.rasp.rule.match", + (("rule_type", endpoint), ("waf_version", DDWAF_VERSION)), + ) in telemetry_calls + assert ( + "CountMetric", + "appsec.rasp.rule.eval", + (("rule_type", endpoint), ("waf_version", DDWAF_VERSION)), + ) in telemetry_calls else: assert get_triggers(root_span()) is None assert self.check_for_stack_trace(root_span) == [] diff --git a/tests/appsec/iast/aspects/test_aspect_helpers.py b/tests/appsec/iast/aspects/test_aspect_helpers.py index d261980a7b0..7e8a5a41230 100644 --- a/tests/appsec/iast/aspects/test_aspect_helpers.py +++ b/tests/appsec/iast/aspects/test_aspect_helpers.py @@ -1,3 +1,5 @@ +import os + import pytest from ddtrace.appsec._iast._taint_tracking import OriginType @@ -7,6 +9,7 @@ from ddtrace.appsec._iast._taint_tracking import common_replace from ddtrace.appsec._iast._taint_tracking import get_ranges from ddtrace.appsec._iast._taint_tracking import set_ranges +from ddtrace.appsec._iast._taint_tracking import set_ranges_on_splitted from ddtrace.appsec._iast._taint_tracking.aspects import _convert_escaped_text_to_tainted_text @@ -69,8 +72,8 @@ def test_common_replace_tainted_bytearray(): assert get_ranges(s2) == [_RANGE1, _RANGE2] -def _build_sample_range(start, end, name): # type: (int, int) -> TaintRange - return TaintRange(start, end, Source(name, "sample_value", OriginType.PARAMETER)) +def _build_sample_range(start, length, name): # type: (int, int) -> TaintRange + return TaintRange(start, length, Source(name, "sample_value", OriginType.PARAMETER)) def test_as_formatted_evidence(): # type: () -> None @@ -105,3 +108,339 @@ def test_as_formatted_evidence_convert_escaped_text_to_tainted_text(): # type: as_formatted_evidence(s, tag_mapping_function=TagMappingMode.Mapper) == ":+-<1750328947>abcde<1750328947>-+:fgh" ) assert _convert_escaped_text_to_tainted_text(":+-<1750328947>abcde<1750328947>-+:fgh", [ranges]) == "abcdefgh" + + +def test_set_ranges_on_splitted_str() -> None: + s = "abc|efgh" + range1 = _build_sample_range(0, 2, "first") + range2 = _build_sample_range(4, 2, "second") + set_ranges(s, (range1, range2)) + ranges = get_ranges(s) + assert ranges + + parts = s.split("|") + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [TaintRange(0, 2, Source("first", "sample_value", OriginType.PARAMETER))] + assert get_ranges(parts[1]) == [TaintRange(0, 2, Source("second", "sample_value", OriginType.PARAMETER))] + + +def test_set_ranges_on_splitted_rsplit() -> None: + s = "abc|efgh|jkl" + range1 = _build_sample_range(0, 2, s[0:2]) + range2 = _build_sample_range(4, 2, s[4:6]) + range3 = _build_sample_range(9, 3, s[9:12]) + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + parts = s.rsplit("|", 1) + assert parts == ["abc|efgh", "jkl"] + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [ + TaintRange(0, 2, Source("ab", "sample_value", OriginType.PARAMETER)), + TaintRange(4, 2, Source("ef", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [ + TaintRange(0, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplit(): + s = "abc/efgh/jkl" + range1 = _build_sample_range(0, 4, s[0:4]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + parts = list(os.path.split(s)) + assert parts == ["abc/efgh", "jkl"] + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [ + TaintRange(0, 4, Source("abc/", "sample_value", OriginType.PARAMETER)), + TaintRange(4, 4, Source("efgh", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [ + TaintRange(0, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitext(): + s = "abc/efgh/jkl.txt" + range1 = _build_sample_range(0, 3, s[0:2]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + range4 = _build_sample_range(13, 4, s[13:17]) + set_ranges(s, (range1, range2, range3, range4)) + ranges = get_ranges(s) + assert ranges + + parts = list(os.path.splitext(s)) + assert parts == ["abc/efgh/jkl", ".txt"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [ + TaintRange(0, 3, Source("abc", "sample_value", OriginType.PARAMETER)), + TaintRange(4, 4, Source("efgh", "sample_value", OriginType.PARAMETER)), + TaintRange(9, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [ + TaintRange(1, 4, Source("txt", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplit_with_empty_string(): + s = "abc/efgh/jkl/" + range1 = _build_sample_range(0, 2, s[0:2]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + parts = list(os.path.split(s)) + assert parts == ["abc/efgh/jkl", ""] + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [ + TaintRange(0, 2, Source("ab", "sample_value", OriginType.PARAMETER)), + TaintRange(4, 4, Source("efgh", "sample_value", OriginType.PARAMETER)), + TaintRange(9, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [] + + +def test_set_ranges_on_splitted_ospathbasename(): + s = "abc/efgh/jkl" + range1 = _build_sample_range(0, 2, s[0:2]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + # Basename aspect implementation works by adding the previous content in a list so + # we can use set_ranges_on_splitted to set the ranges on the last part (the real result) + parts = ["abc/efgh/", os.path.basename(s)] + assert parts == ["abc/efgh/", "jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[1]) == [ + TaintRange(0, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitdrive_windows(): + s = "C:/abc/efgh/jkl" + range1 = _build_sample_range(0, 2, s[0:2]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + range4 = _build_sample_range(12, 3, s[12:16]) + set_ranges(s, (range1, range2, range3, range4)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitdrive would do on Windows instead of calling it + parts = ["C:", "/abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [ + TaintRange(0, 2, Source("C:", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [ + TaintRange(2, 4, Source("bc/e", "sample_value", OriginType.PARAMETER)), + TaintRange(7, 3, Source("gh/", "sample_value", OriginType.PARAMETER)), + TaintRange(10, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitdrive_posix(): + s = "/abc/efgh/jkl" + range1 = _build_sample_range(0, 2, s[0:2]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitdrive would do on posix instead of calling it + parts = ["", "/abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [] + assert get_ranges(parts[1]) == ranges + + +def test_set_ranges_on_splitted_ospathsplitroot_windows_drive(): + s = "C:/abc/efgh/jkl" + range1 = _build_sample_range(0, 2, s[0:2]) + range2 = _build_sample_range(4, 4, s[4:8]) + range3 = _build_sample_range(9, 3, s[9:12]) + range4 = _build_sample_range(12, 3, s[12:16]) + set_ranges(s, (range1, range2, range3, range4)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitroot would do on Windows instead of calling it + parts = ["C:", "/", "abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [ + TaintRange(0, 2, Source("C:", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [] + assert get_ranges(parts[2]) == [ + TaintRange(1, 4, Source("bc/e", "sample_value", OriginType.PARAMETER)), + TaintRange(6, 3, Source("gh/", "sample_value", OriginType.PARAMETER)), + TaintRange(9, 3, Source("jkl", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitroot_windows_share(): + s = "//server/share/abc/efgh/jkl" + range1 = _build_sample_range(0, 2, "//") + range2 = _build_sample_range(2, 6, "server") + range3 = _build_sample_range(9, 5, "share") + range4 = _build_sample_range(14, 1, "/") + range5 = _build_sample_range(15, 3, "abc") + range6 = _build_sample_range(19, 4, "efgh") + range7 = _build_sample_range(23, 4, "/jkl") + set_ranges(s, (range1, range2, range3, range4, range5, range6, range7)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitroot would do on Windows instead of calling it; the implementation + # removed the second element + parts = ["//server/share", "/", "abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [ + TaintRange(0, 2, Source("//", "sample_value", OriginType.PARAMETER)), + TaintRange(2, 6, Source("server", "sample_value", OriginType.PARAMETER)), + TaintRange(9, 5, Source("share", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[1]) == [ + TaintRange(0, 1, Source("/", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[2]) == [ + TaintRange(0, 3, Source("abc", "sample_value", OriginType.PARAMETER)), + TaintRange(4, 4, Source("efgh", "sample_value", OriginType.PARAMETER)), + TaintRange(8, 4, Source("/jkl", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitroot_posix_normal_path(): + s = "/abc/efgh/jkl" + range1 = _build_sample_range(0, 4, "/abc") + range2 = _build_sample_range(3, 5, "c/efg") + range3 = _build_sample_range(7, 5, "gh/jk") + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitroot would do on posix instead of calling it + parts = ["", "/", "abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [] + assert get_ranges(parts[1]) == [ + TaintRange(0, 1, Source("/abc", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[2]) == [ + TaintRange(0, 3, Source("abc", "sample_value", OriginType.PARAMETER)), + TaintRange(2, 5, Source("c/efg", "sample_value", OriginType.PARAMETER)), + TaintRange(6, 5, Source("gh/jk", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitroot_posix_startwithtwoslashes_path(): + s = "//abc/efgh/jkl" + range1 = _build_sample_range(0, 2, "//") + range2 = _build_sample_range(2, 3, "abc") + range3 = _build_sample_range(5, 4, "/efg") + range4 = _build_sample_range(9, 4, "h/jk") + set_ranges(s, (range1, range2, range3, range4)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitroot would do on posix starting with double slash instead of calling it + parts = ["", "//", "abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [] + assert get_ranges(parts[1]) == [ + TaintRange(0, 2, Source("//", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[2]) == [ + TaintRange(0, 3, Source("abc", "sample_value", OriginType.PARAMETER)), + TaintRange(3, 4, Source("/efg", "sample_value", OriginType.PARAMETER)), + TaintRange(7, 4, Source("h/jk", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_ospathsplitroot_posix_startwiththreeslashes_path(): + s = "///abc/efgh/jkl" + range1 = _build_sample_range(0, 3, "///") + range2 = _build_sample_range(3, 3, "abc") + range3 = _build_sample_range(6, 4, "/efg") + range4 = _build_sample_range(10, 4, "h/jk") + set_ranges(s, (range1, range2, range3, range4)) + ranges = get_ranges(s) + assert ranges + + # We emulate what os.path.splitroot would do on posix starting with triple slash instead of calling it + parts = ["", "/", "//abc/efgh/jkl"] + assert set_ranges_on_splitted(s, ranges, parts, include_separator=True) + assert get_ranges(parts[0]) == [] + assert get_ranges(parts[1]) == [ + TaintRange(0, 1, Source("/", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[2]) == [ + TaintRange(0, 2, Source("///", "sample_value", OriginType.PARAMETER)), + TaintRange(2, 3, Source("abc", "sample_value", OriginType.PARAMETER)), + TaintRange(5, 4, Source("/efg", "sample_value", OriginType.PARAMETER)), + TaintRange(9, 4, Source("h/jk", "sample_value", OriginType.PARAMETER)), + ] + + +def test_set_ranges_on_splitted_bytes() -> None: + s = b"abc|efgh|ijkl" + range1 = _build_sample_range(0, 2, "first") # ab -> 0, 2 + range2 = _build_sample_range(5, 1, "second") # f -> 1, 1 + range3 = _build_sample_range(11, 2, "third") # jkl -> 1, 3 + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + parts = s.split(b"|") + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [TaintRange(0, 2, Source("first", "sample_value", OriginType.PARAMETER))] + assert get_ranges(parts[1]) == [TaintRange(1, 1, Source("second", "sample_value", OriginType.PARAMETER))] + assert get_ranges(parts[2]) == [TaintRange(2, 2, Source("third", "sample_value", OriginType.PARAMETER))] + + +def test_set_ranges_on_splitted_bytearray() -> None: + s = bytearray(b"abc|efgh|ijkl") + range1 = _build_sample_range(0, 2, "ab") + range2 = _build_sample_range(5, 1, "f") + range3 = _build_sample_range(5, 6, "fgh|ij") + + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + + parts = s.split(b"|") + assert set_ranges_on_splitted(s, ranges, parts) + assert get_ranges(parts[0]) == [TaintRange(0, 2, Source("ab", "sample_value", OriginType.PARAMETER))] + assert get_ranges(parts[1]) == [ + TaintRange(1, 1, Source("f", "sample_value", OriginType.PARAMETER)), + TaintRange(1, 4, Source("second", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(parts[2]) == [TaintRange(0, 2, Source("third", "sample_value", OriginType.PARAMETER))] + + +def test_set_ranges_on_splitted_wrong_args(): + s = "12345" + range1 = _build_sample_range(1, 3, "234") + set_ranges(s, (range1,)) + ranges = get_ranges(s) + + assert not set_ranges_on_splitted(s, [], ["123", 45]) + assert not set_ranges_on_splitted("", ranges, ["123", 45]) + assert not set_ranges_on_splitted(s, ranges, []) + parts = ["123", 45] + set_ranges_on_splitted(s, ranges, parts) + ranges = get_ranges(parts[0]) + assert ranges == [TaintRange(1, 3, Source("123", "sample_value", OriginType.PARAMETER))] diff --git a/tests/appsec/iast/aspects/test_ospath_aspects.py b/tests/appsec/iast/aspects/test_ospath_aspects.py new file mode 100644 index 00000000000..87d60ad6a4d --- /dev/null +++ b/tests/appsec/iast/aspects/test_ospath_aspects.py @@ -0,0 +1,714 @@ +import os +import sys + +import pytest + +from ddtrace.appsec._iast._taint_tracking import OriginType +from ddtrace.appsec._iast._taint_tracking import Source +from ddtrace.appsec._iast._taint_tracking import TaintRange +from ddtrace.appsec._iast._taint_tracking import _aspect_ospathbasename +from ddtrace.appsec._iast._taint_tracking import _aspect_ospathdirname +from ddtrace.appsec._iast._taint_tracking import _aspect_ospathjoin +from ddtrace.appsec._iast._taint_tracking import _aspect_ospathnormcase +from ddtrace.appsec._iast._taint_tracking import _aspect_ospathsplit +from ddtrace.appsec._iast._taint_tracking import _aspect_ospathsplitext + + +if sys.version_info >= (3, 12) or os.name == "nt": + from ddtrace.appsec._iast._taint_tracking import _aspect_ospathsplitdrive +if sys.version_info >= (3, 12): + from ddtrace.appsec._iast._taint_tracking import _aspect_ospathsplitroot +from ddtrace.appsec._iast._taint_tracking import get_tainted_ranges +from ddtrace.appsec._iast._taint_tracking import taint_pyobject + + +def test_ospathjoin_first_arg_nottainted_noslash(): + tainted_foo = taint_pyobject( + pyobject="foo", + source_name="test_ospath", + source_value="foo", + source_origin=OriginType.PARAMETER, + ) + + tainted_bar = taint_pyobject( + pyobject="bar", + source_name="test_ospath", + source_value="bar", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathjoin("root", tainted_foo, "nottainted", tainted_bar, "alsonottainted") + assert res == "root/foo/nottainted/bar/alsonottainted" + assert get_tainted_ranges(res) == [ + TaintRange(5, 3, Source("test_ospath", "foo", OriginType.PARAMETER)), + TaintRange(20, 3, Source("test_ospath", "bar", OriginType.PARAMETER)), + ] + + +def test_ospathjoin_later_arg_tainted_with_slash_then_ignore_previous(): + ignored_tainted_foo = taint_pyobject( + pyobject="foo", + source_name="test_ospath", + source_value="foo", + source_origin=OriginType.PARAMETER, + ) + + tainted_slashbar = taint_pyobject( + pyobject="/bar", + source_name="test_ospath", + source_value="/bar", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathjoin("ignored", ignored_tainted_foo, "ignored_nottainted", tainted_slashbar, "alsonottainted") + assert res == "/bar/alsonottainted" + assert get_tainted_ranges(res) == [ + TaintRange(0, 4, Source("test_ospath", "/bar", OriginType.PARAMETER)), + ] + + +def test_ospathjoin_first_arg_tainted_no_slash(): + tainted_foo = taint_pyobject( + pyobject="foo", + source_name="test_ospath", + source_value="foo", + source_origin=OriginType.PARAMETER, + ) + + tainted_bar = taint_pyobject( + pyobject="bar", + source_name="test_ospath", + source_value="bar", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathjoin(tainted_foo, "nottainted", tainted_bar, "alsonottainted") + assert res == "foo/nottainted/bar/alsonottainted" + assert get_tainted_ranges(res) == [ + TaintRange(0, 3, Source("test_ospath", "foo", OriginType.PARAMETER)), + TaintRange(15, 3, Source("test_ospath", "bar", OriginType.PARAMETER)), + ] + + +def test_ospathjoin_first_arg_tainted_with_slash(): + tainted_slashfoo = taint_pyobject( + pyobject="/foo", + source_name="test_ospath", + source_value="/foo", + source_origin=OriginType.PARAMETER, + ) + + tainted_bar = taint_pyobject( + pyobject="bar", + source_name="test_ospath", + source_value="bar", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathjoin(tainted_slashfoo, "nottainted", tainted_bar, "alsonottainted") + assert res == "/foo/nottainted/bar/alsonottainted" + assert get_tainted_ranges(res) == [ + TaintRange(0, 4, Source("test_ospath", "/foo", OriginType.PARAMETER)), + TaintRange(16, 3, Source("test_ospath", "bar", OriginType.PARAMETER)), + ] + + +def test_ospathjoin_single_arg_nottainted(): + res = _aspect_ospathjoin("nottainted") + assert res == "nottainted" + assert not get_tainted_ranges(res) + + res = _aspect_ospathjoin("/nottainted") + assert res == "/nottainted" + assert not get_tainted_ranges(res) + + +def test_ospathjoin_single_arg_tainted(): + tainted_foo = taint_pyobject( + pyobject="foo", + source_name="test_ospath", + source_value="foo", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathjoin(tainted_foo) + assert res == "foo" + assert get_tainted_ranges(res) == [TaintRange(0, 3, Source("test_ospath", "/foo", OriginType.PARAMETER))] + + tainted_slashfoo = taint_pyobject( + pyobject="/foo", + source_name="test_ospath", + source_value="/foo", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathjoin(tainted_slashfoo) + assert res == "/foo" + assert get_tainted_ranges(res) == [TaintRange(0, 4, Source("test_ospath", "/foo", OriginType.PARAMETER))] + + +def test_ospathjoin_last_slash_nottainted(): + tainted_foo = taint_pyobject( + pyobject="foo", + source_name="test_ospath", + source_value="foo", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathjoin("root", tainted_foo, "/nottainted") + assert res == "/nottainted" + assert not get_tainted_ranges(res) + + +def test_ospathjoin_last_slash_tainted(): + tainted_foo = taint_pyobject( + pyobject="foo", + source_name="test_ospath", + source_value="foo", + source_origin=OriginType.PARAMETER, + ) + + tainted_slashbar = taint_pyobject( + pyobject="/bar", + source_name="test_ospath", + source_value="/bar", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathjoin("root", tainted_foo, "nottainted", tainted_slashbar) + assert res == "/bar" + assert get_tainted_ranges(res) == [TaintRange(0, 4, Source("test_ospath", "/bar", OriginType.PARAMETER))] + + +def test_ospathjoin_wrong_arg(): + with pytest.raises(TypeError): + _ = _aspect_ospathjoin("root", 42, "foobar") + + +def test_ospathjoin_bytes_nottainted(): + res = _aspect_ospathjoin(b"nottainted", b"alsonottainted") + assert res == b"nottainted/alsonottainted" + + +def test_ospathjoin_bytes_tainted(): + tainted_foo = taint_pyobject( + pyobject=b"foo", + source_name="test_ospath", + source_value=b"foo", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathjoin(tainted_foo, b"nottainted") + assert res == b"foo/nottainted" + assert get_tainted_ranges(res) == [TaintRange(0, 3, Source("test_ospath", b"foo", OriginType.PARAMETER))] + + tainted_slashfoo = taint_pyobject( + pyobject=b"/foo", + source_name="test_ospath", + source_value=b"/foo", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathjoin(tainted_slashfoo, b"nottainted") + assert res == b"/foo/nottainted" + assert get_tainted_ranges(res) == [TaintRange(0, 4, Source("test_ospath", b"/foo", OriginType.PARAMETER))] + + res = _aspect_ospathjoin(b"nottainted_ignore", b"alsoignored", tainted_slashfoo) + assert res == b"/foo" + assert get_tainted_ranges(res) == [TaintRange(0, 4, Source("test_ospath", b"/foo", OriginType.PARAMETER))] + + +def test_ospathjoin_empty(): + res = _aspect_ospathjoin("") + assert res == "" + + +def test_ospathjoin_noparams(): + with pytest.raises(TypeError): + _ = _aspect_ospathjoin() + + +def test_ospathbasename_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz", + source_name="test_ospath", + source_value="/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathbasename(tainted_foobarbaz) + assert res == "baz" + assert get_tainted_ranges(res) == [TaintRange(0, 3, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathbasename_tainted_empty(): + tainted_empty = taint_pyobject( + pyobject="", + source_name="test_ospath", + source_value="", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathbasename(tainted_empty) + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathbasename_nottainted(): + res = _aspect_ospathbasename("/foo/bar/baz") + assert res == "baz" + assert not get_tainted_ranges(res) + + +def test_ospathbasename_wrong_arg(): + with pytest.raises(TypeError): + _ = _aspect_ospathbasename(42) + + +def test_ospathbasename_bytes_tainted(): + tainted_foobarbaz = taint_pyobject( + pyobject=b"/foo/bar/baz", + source_name="test_ospath", + source_value=b"/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathbasename(tainted_foobarbaz) + assert res == b"baz" + assert get_tainted_ranges(res) == [TaintRange(0, 3, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathbasename_bytes_nottainted(): + res = _aspect_ospathbasename(b"/foo/bar/baz") + assert res == b"baz" + assert not get_tainted_ranges(res) + + +def test_ospathbasename_single_slash_tainted(): + tainted_slash = taint_pyobject( + pyobject="/", + source_name="test_ospath", + source_value="/", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathbasename(tainted_slash) + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathbasename_nottainted_empty(): + res = _aspect_ospathbasename("") + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathnormcase_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz", + source_name="test_ospath", + source_value="/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathnormcase(tainted_foobarbaz) + assert res == "/foo/bar/baz" + assert get_tainted_ranges(res) == [TaintRange(0, 12, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathnormcase_tainted_empty(): + tainted_empty = taint_pyobject( + pyobject="", + source_name="test_ospath", + source_value="", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathnormcase(tainted_empty) + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathnormcase_nottainted(): + res = _aspect_ospathnormcase("/foo/bar/baz") + assert res == "/foo/bar/baz" + assert not get_tainted_ranges(res) + + +def test_ospathnormcase_wrong_arg(): + with pytest.raises(TypeError): + _ = _aspect_ospathnormcase(42) + + +def test_ospathnormcase_bytes_tainted(): + tainted_foobarbaz = taint_pyobject( + pyobject=b"/foo/bar/baz", + source_name="test_ospath", + source_value=b"/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathnormcase(tainted_foobarbaz) + assert res == b"/foo/bar/baz" + assert get_tainted_ranges(res) == [TaintRange(0, 12, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathnormcase_bytes_nottainted(): + res = _aspect_ospathnormcase(b"/foo/bar/baz") + assert res == b"/foo/bar/baz" + assert not get_tainted_ranges(res) + + +def test_ospathnormcase_single_slash_tainted(): + tainted_slash = taint_pyobject( + pyobject="/", + source_name="test_ospath", + source_value="/", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathnormcase(tainted_slash) + assert res == "/" + assert get_tainted_ranges(res) == [TaintRange(0, 1, Source("test_ospath", "/", OriginType.PARAMETER))] + + +def test_ospathnormcase_nottainted_empty(): + res = _aspect_ospathnormcase("") + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathdirname_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz", + source_name="test_ospath", + source_value="/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathdirname(tainted_foobarbaz) + assert res == "/foo/bar" + assert get_tainted_ranges(res) == [TaintRange(0, 8, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathdirname_tainted_empty(): + tainted_empty = taint_pyobject( + pyobject="", + source_name="test_ospath", + source_value="", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathdirname(tainted_empty) + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathdirname_nottainted(): + res = _aspect_ospathdirname("/foo/bar/baz") + assert res == "/foo/bar" + assert not get_tainted_ranges(res) + + +def test_ospathdirname_wrong_arg(): + with pytest.raises(TypeError): + _ = _aspect_ospathdirname(42) + + +def test_ospathdirname_bytes_tainted(): + tainted_foobarbaz = taint_pyobject( + pyobject=b"/foo/bar/baz", + source_name="test_ospath", + source_value=b"/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathdirname(tainted_foobarbaz) + assert res == b"/foo/bar" + assert get_tainted_ranges(res) == [TaintRange(0, 8, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathdirname_bytes_nottainted(): + res = _aspect_ospathdirname(b"/foo/bar/baz") + assert res == b"/foo/bar" + assert not get_tainted_ranges(res) + + +def test_ospathdirname_single_slash_tainted(): + tainted_slash = taint_pyobject( + pyobject="/", + source_name="test_ospath", + source_value="/", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathdirname(tainted_slash) + assert res == "/" + assert get_tainted_ranges(res) == [TaintRange(0, 1, Source("test_ospath", "/", OriginType.PARAMETER))] + + +def test_ospathdirname_nottainted_empty(): + res = _aspect_ospathdirname("") + assert res == "" + assert not get_tainted_ranges(res) + + +def test_ospathsplit_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz", + source_name="test_ospath", + source_value="/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplit(tainted_foobarbaz) + assert res == ("/foo/bar", "baz") + assert get_tainted_ranges(res[0]) == [TaintRange(0, 8, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER))] + assert get_tainted_ranges(res[1]) == [TaintRange(0, 3, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER))] + + +def test_ospathsplit_tainted_empty(): + tainted_empty = taint_pyobject( + pyobject="", + source_name="test_ospath", + source_value="", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplit(tainted_empty) + assert res == ("", "") + assert not get_tainted_ranges(res[0]) + assert not get_tainted_ranges(res[1]) + + +def test_ospathsplit_nottainted(): + res = _aspect_ospathsplit("/foo/bar/baz") + assert res == ("/foo/bar", "baz") + assert not get_tainted_ranges(res[0]) + assert not get_tainted_ranges(res[1]) + + +def test_ospathsplit_wrong_arg(): + with pytest.raises(TypeError): + _ = _aspect_ospathsplit(42) + + +def test_ospathsplit_bytes_tainted(): + tainted_foobarbaz = taint_pyobject( + pyobject=b"/foo/bar/baz", + source_name="test_ospath", + source_value=b"/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplit(tainted_foobarbaz) + assert res == (b"/foo/bar", b"baz") + assert get_tainted_ranges(res[0]) == [ + TaintRange(0, 8, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(res[1]) == [ + TaintRange(0, 3, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER)) + ] + + +def test_ospathsplit_bytes_nottainted(): + res = _aspect_ospathsplit(b"/foo/bar/baz") + assert res == (b"/foo/bar", b"baz") + assert not get_tainted_ranges(res[0]) + assert not get_tainted_ranges(res[1]) + + +def test_ospathsplit_single_slash_tainted(): + tainted_slash = taint_pyobject( + pyobject="/", + source_name="test_ospath", + source_value="/", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathsplit(tainted_slash) + assert res == ("/", "") + assert get_tainted_ranges(res[0]) == [TaintRange(0, 1, Source("test_ospath", "/", OriginType.PARAMETER))] + assert not get_tainted_ranges(res[1]) + + +def test_ospathsplit_nottainted_empty(): + res = _aspect_ospathsplit("") + assert res == ("", "") + assert not get_tainted_ranges(res[0]) + assert not get_tainted_ranges(res[1]) + + +def test_ospathsplitext_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz.jpg", + source_name="test_ospath", + source_value="/foo/bar/baz.jpg", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitext(tainted_foobarbaz) + assert res == ("/foo/bar/baz", ".jpg") + assert get_tainted_ranges(res[0]) == [ + TaintRange(0, 12, Source("test_ospath", "/foo/bar/baz.jpg", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(res[1]) == [ + TaintRange(0, 4, Source("test_ospath", "/foo/bar/baz.jpg", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz", + source_name="test_ospath", + source_value="/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitroot(tainted_foobarbaz) + assert res == ("", "/", "foo/bar/baz") + assert not get_tainted_ranges(res[0]) + assert get_tainted_ranges(res[1]) == [TaintRange(0, 1, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER))] + assert get_tainted_ranges(res[2]) == [ + TaintRange(0, 11, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_tainted_doble_initial_slash(): + tainted_foobarbaz = taint_pyobject( + pyobject="//foo/bar/baz", + source_name="test_ospath", + source_value="//foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitroot(tainted_foobarbaz) + assert res == ("", "//", "foo/bar/baz") + assert not get_tainted_ranges(res[0]) + assert get_tainted_ranges(res[1]) == [ + TaintRange(0, 2, Source("test_ospath", "//foo/bar/baz", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(res[2]) == [ + TaintRange(0, 11, Source("test_ospath", "//foo/bar/baz", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_tainted_triple_initial_slash(): + tainted_foobarbaz = taint_pyobject( + pyobject="///foo/bar/baz", + source_name="test_ospath", + source_value="///foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitroot(tainted_foobarbaz) + assert res == ("", "/", "//foo/bar/baz") + assert not get_tainted_ranges(res[0]) + assert get_tainted_ranges(res[1]) == [ + TaintRange(0, 1, Source("test_ospath", "///foo/bar/baz", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(res[2]) == [ + TaintRange(0, 13, Source("test_ospath", "///foo/bar/baz", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_tainted_empty(): + tainted_empty = taint_pyobject( + pyobject="", + source_name="test_ospath", + source_value="", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitroot(tainted_empty) + assert res == ("", "", "") + for i in res: + assert not get_tainted_ranges(i) + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_nottainted(): + res = _aspect_ospathsplitroot("/foo/bar/baz") + assert res == ("", "/", "foo/bar/baz") + for i in res: + assert not get_tainted_ranges(i) + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_wrong_arg(): + with pytest.raises(TypeError): + _ = _aspect_ospathsplitroot(42) + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_bytes_tainted(): + tainted_foobarbaz = taint_pyobject( + pyobject=b"/foo/bar/baz", + source_name="test_ospath", + source_value=b"/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitroot(tainted_foobarbaz) + assert res == (b"", b"/", b"foo/bar/baz") + assert not get_tainted_ranges(res[0]) + assert get_tainted_ranges(res[1]) == [ + TaintRange(0, 1, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(res[2]) == [ + TaintRange(0, 11, Source("test_ospath", b"/foo/bar/baz", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def test_ospathsplitroot_bytes_nottainted(): + res = _aspect_ospathsplitroot(b"/foo/bar/baz") + assert res == (b"", b"/", b"foo/bar/baz") + for i in res: + assert not get_tainted_ranges(i) + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires Python 3.12") +def tets_ospathsplitroot_single_slash_tainted(): + tainted_slash = taint_pyobject( + pyobject="/", + source_name="test_ospath", + source_value="/", + source_origin=OriginType.PARAMETER, + ) + res = _aspect_ospathsplitroot(tainted_slash) + assert res == ("", "/", "") + assert not get_tainted_ranges(res[0]) + assert get_tainted_ranges(res[1]) == [TaintRange(0, 1, Source("test_ospath", "/", OriginType.PARAMETER))] + assert not get_tainted_ranges(res[2]) + + +@pytest.mark.skipif(sys.version_info < (3, 12) and os.name != "nt", reason="Requires Python 3.12 or Windows") +def test_ospathsplitdrive_tainted_normal(): + tainted_foobarbaz = taint_pyobject( + pyobject="/foo/bar/baz", + source_name="test_ospath", + source_value="/foo/bar/baz", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitdrive(tainted_foobarbaz) + assert res == ("", "/foo/bar/baz") + assert not get_tainted_ranges(res[0]) + assert get_tainted_ranges(res[1]) == [ + TaintRange(0, 12, Source("test_ospath", "/foo/bar/baz", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12) and os.name != "nt", reason="Requires Python 3.12 or Windows") +def test_ospathsplitdrive_tainted_empty(): + tainted_empty = taint_pyobject( + pyobject="", + source_name="test_ospath", + source_value="", + source_origin=OriginType.PARAMETER, + ) + + res = _aspect_ospathsplitdrive(tainted_empty) + assert res == ("", "") + for i in res: + assert not get_tainted_ranges(i) + + +# TODO: add tests for ospathsplitdrive with different drive letters that must run +# under Windows since they're noop under posix diff --git a/tests/appsec/iast/aspects/test_ospath_aspects_fixtures.py b/tests/appsec/iast/aspects/test_ospath_aspects_fixtures.py new file mode 100644 index 00000000000..8a10c996b2c --- /dev/null +++ b/tests/appsec/iast/aspects/test_ospath_aspects_fixtures.py @@ -0,0 +1,113 @@ +import os +import sys + +import pytest + +from ddtrace.appsec._iast._taint_tracking import OriginType +from ddtrace.appsec._iast._taint_tracking import Source +from ddtrace.appsec._iast._taint_tracking import TaintRange +from ddtrace.appsec._iast._taint_tracking import get_tainted_ranges +from ddtrace.appsec._iast._taint_tracking import taint_pyobject +from tests.appsec.iast.aspects.conftest import _iast_patched_module + + +mod = _iast_patched_module("tests.appsec.iast.fixtures.aspects.module_functions") + + +def test_ospathjoin_tainted(): + string_input = taint_pyobject( + pyobject="foo", source_name="first_element", source_value="foo", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_join(string_input, "bar") + assert result == "foo/bar" + assert get_tainted_ranges(result) == [TaintRange(0, 3, Source("first_element", "foo", OriginType.PARAMETER))] + + +def test_ospathnormcase_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar", source_name="first_element", source_value="/foo/bar", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_normcase(string_input) + assert result == "/foo/bar" + assert get_tainted_ranges(result) == [TaintRange(0, 8, Source("first_element", "/foo/bar", OriginType.PARAMETER))] + + +def test_ospathbasename_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar", source_name="first_element", source_value="/foo/bar", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_basename(string_input) + assert result == "bar" + assert get_tainted_ranges(result) == [TaintRange(0, 3, Source("first_element", "/foo/bar", OriginType.PARAMETER))] + + +def test_ospathdirname_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar", source_name="first_element", source_value="/foo/bar", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_dirname(string_input) + assert result == "/foo" + assert get_tainted_ranges(result) == [TaintRange(0, 4, Source("first_element", "/foo/bar", OriginType.PARAMETER))] + + +def test_ospathsplit_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar", source_name="first_element", source_value="/foo/bar", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_split(string_input) + assert result == ("/foo", "bar") + assert get_tainted_ranges(result[0]) == [ + TaintRange(0, 4, Source("first_element", "/foo/bar", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(result[1]) == [ + TaintRange(0, 3, Source("first_element", "/foo/bar", OriginType.PARAMETER)) + ] + + +def test_ospathsplitext_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar.txt", + source_name="first_element", + source_value="/foo/bar.txt", + source_origin=OriginType.PARAMETER, + ) + result = mod.do_os_path_splitext(string_input) + assert result == ("/foo/bar", ".txt") + assert get_tainted_ranges(result[0]) == [ + TaintRange(0, 8, Source("first_element", "/foo/bar.txt", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(result[1]) == [ + TaintRange(0, 4, Source("first_element", "/foo/bar.txt", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") +def test_ospathsplitroot_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar", source_name="first_element", source_value="/foo/bar", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_splitroot(string_input) + assert result == ("", "/", "foo/bar") + assert not get_tainted_ranges(result[0]) + assert get_tainted_ranges(result[1]) == [ + TaintRange(0, 1, Source("first_element", "/foo/bar", OriginType.PARAMETER)) + ] + assert get_tainted_ranges(result[2]) == [ + TaintRange(0, 7, Source("first_element", "/foo/bar", OriginType.PARAMETER)) + ] + + +@pytest.mark.skipif(sys.version_info < (3, 12) and os.name != "nt", reason="Required Python 3.12 or Windows") +def test_ospathsplitdrive_tainted(): + string_input = taint_pyobject( + pyobject="/foo/bar", source_name="first_element", source_value="/foo/bar", source_origin=OriginType.PARAMETER + ) + result = mod.do_os_path_splitdrive(string_input) + assert result == ("", "/foo/bar") + assert not get_tainted_ranges(result[0]) + assert get_tainted_ranges(result[1]) == [ + TaintRange(0, 8, Source("first_element", "/foo/bar", OriginType.PARAMETER)) + ] + + +# TODO: add tests for os.path.splitdrive and os.path.normcase under Windows diff --git a/tests/appsec/iast/aspects/test_split_aspect.py b/tests/appsec/iast/aspects/test_split_aspect.py new file mode 100644 index 00000000000..f7c9ec197e0 --- /dev/null +++ b/tests/appsec/iast/aspects/test_split_aspect.py @@ -0,0 +1,145 @@ +from ddtrace.appsec._iast._taint_tracking import TaintRange +from ddtrace.appsec._iast._taint_tracking import _aspect_rsplit +from ddtrace.appsec._iast._taint_tracking import _aspect_split +from ddtrace.appsec._iast._taint_tracking import _aspect_splitlines +from ddtrace.appsec._iast._taint_tracking._native.taint_tracking import OriginType +from ddtrace.appsec._iast._taint_tracking._native.taint_tracking import Source +from ddtrace.appsec._iast._taint_tracking._native.taint_tracking import get_ranges +from ddtrace.appsec._iast._taint_tracking._native.taint_tracking import set_ranges +from tests.appsec.iast.aspects.test_aspect_helpers import _build_sample_range + + +# These tests are simple ones testing the calls and replacements since most of the +# actual testing is in test_aspect_helpers' test for set_ranges_on_splitted which these +# functions call internally. +def test_aspect_split_simple(): + s = "abc def" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, " def") + set_ranges(s, (range1, range2)) + ranges = get_ranges(s) + assert ranges + res = _aspect_split(s) + assert res == ["abc", "def"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(" def", "sample_value", OriginType.PARAMETER))] + + +def test_aspect_rsplit_simple(): + s = "abc def" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, " def") + set_ranges(s, (range1, range2)) + ranges = get_ranges(s) + assert ranges + res = _aspect_rsplit(s) + assert res == ["abc", "def"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(" def", "sample_value", OriginType.PARAMETER))] + + +def test_aspect_split_with_separator(): + s = "abc:def" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, ":def") + set_ranges(s, (range1, range2)) + ranges = get_ranges(s) + assert ranges + res = _aspect_split(s, ":") + assert res == ["abc", "def"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(":def", "sample_value", OriginType.PARAMETER))] + + +def test_aspect_rsplit_with_separator(): + s = "abc:def" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, ":def") + set_ranges(s, (range1, range2)) + ranges = get_ranges(s) + assert ranges + res = _aspect_rsplit(s, ":") + assert res == ["abc", "def"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(":def", "sample_value", OriginType.PARAMETER))] + + +def test_aspect_split_with_maxsplit(): + s = "abc def ghi" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, " def") + range3 = _build_sample_range(7, 4, " ghi") + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + res = _aspect_split(s, maxsplit=1) + assert res == ["abc", "def ghi"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [ + TaintRange(0, 3, Source(" def", "sample_value", OriginType.PARAMETER)), + TaintRange(3, 4, Source(" ghi", "sample_value", OriginType.PARAMETER)), + ] + + res = _aspect_split(s, maxsplit=2) + assert res == ["abc", "def", "ghi"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(" def", "sample_value", OriginType.PARAMETER))] + assert get_ranges(res[2]) == [TaintRange(0, 3, Source(" ghi", "sample_value", OriginType.PARAMETER))] + + res = _aspect_split(s, maxsplit=0) + assert res == ["abc def ghi"] + assert get_ranges(res[0]) == [range1, range2, range3] + + +def test_aspect_rsplit_with_maxsplit(): + s = "abc def ghi" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, " def") + range3 = _build_sample_range(7, 4, " ghi") + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + res = _aspect_rsplit(s, maxsplit=1) + assert res == ["abc def", "ghi"] + assert get_ranges(res[0]) == [ + range1, + TaintRange(3, 4, Source(" def", "sample_value", OriginType.PARAMETER)), + ] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(" ghi", "sample_value", OriginType.PARAMETER))] + res = _aspect_rsplit(s, maxsplit=2) + assert res == ["abc", "def", "ghi"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(" def", "sample_value", OriginType.PARAMETER))] + assert get_ranges(res[2]) == [TaintRange(0, 3, Source(" ghi", "sample_value", OriginType.PARAMETER))] + + res = _aspect_rsplit(s, maxsplit=0) + assert res == ["abc def ghi"] + assert get_ranges(res[0]) == [range1, range2, range3] + + +def test_aspect_splitlines_simple(): + s = "abc\ndef" + range1 = _build_sample_range(0, 3, "abc") + range2 = _build_sample_range(3, 4, " def") + set_ranges(s, (range1, range2)) + ranges = get_ranges(s) + assert ranges + res = _aspect_splitlines(s) + assert res == ["abc", "def"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 3, Source(" def", "sample_value", OriginType.PARAMETER))] + + +def test_aspect_splitlines_keepend_true(): + s = "abc\ndef\nhij\n" + range1 = _build_sample_range(0, 4, "abc\n") + range2 = _build_sample_range(4, 4, "def\n") + range3 = _build_sample_range(8, 4, "hij\n") + set_ranges(s, (range1, range2, range3)) + ranges = get_ranges(s) + assert ranges + res = _aspect_splitlines(s, True) + assert res == ["abc\n", "def\n", "hij\n"] + assert get_ranges(res[0]) == [range1] + assert get_ranges(res[1]) == [TaintRange(0, 4, Source("def\n", "sample_value", OriginType.PARAMETER))] + assert get_ranges(res[2]) == [TaintRange(0, 4, Source("hij\n", "sample_value", OriginType.PARAMETER))] diff --git a/tests/appsec/iast/aspects/test_str_aspect.py b/tests/appsec/iast/aspects/test_str_aspect.py index 5140bb1fe4e..5666fb97baf 100644 --- a/tests/appsec/iast/aspects/test_str_aspect.py +++ b/tests/appsec/iast/aspects/test_str_aspect.py @@ -4,7 +4,10 @@ from ddtrace.appsec._iast import oce from ddtrace.appsec._iast._taint_tracking import OriginType +from ddtrace.appsec._iast._taint_tracking import Source +from ddtrace.appsec._iast._taint_tracking import TaintRange from ddtrace.appsec._iast._taint_tracking import as_formatted_evidence +from ddtrace.appsec._iast._taint_tracking import get_tainted_ranges from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted from ddtrace.appsec._iast._taint_tracking import taint_pyobject import ddtrace.appsec._iast._taint_tracking.aspects as ddtrace_aspects @@ -371,6 +374,145 @@ def test_repr_aspect_tainting(obj, expected_result, formatted_result): _iast_error_metric.assert_not_called() +def test_split_tainted_noargs(): + result = mod.do_split_no_args("abc def ghi") + assert result == ["abc", "def", "ghi"] + for substr in result: + assert not get_tainted_ranges(substr) + + s_tainted = taint_pyobject( + pyobject="abc def ghi", + source_name="test_split_tainted", + source_value="abc def ghi", + source_origin=OriginType.PARAMETER, + ) + assert get_tainted_ranges(s_tainted) + + result2 = mod.do_split_no_args(s_tainted) + assert result2 == ["abc", "def", "ghi"] + for substr in result2: + assert get_tainted_ranges(substr) == [ + TaintRange(0, 3, Source("test_split_tainted", "abc", OriginType.PARAMETER)), + ] + + +@pytest.mark.parametrize( + "s, call, _args, should_be_tainted, result_list, result_tainted_list", + [ + ("abc def", mod.do_split_no_args, [], True, ["abc", "def"], [(0, 3), (0, 3)]), + (b"abc def", mod.do_split_no_args, [], True, [b"abc", b"def"], [(0, 3), (0, 3)]), + ( + bytearray(b"abc def"), + mod.do_split_no_args, + [], + True, + [bytearray(b"abc"), bytearray(b"def")], + [(0, 3), (0, 3)], + ), + ("abc def", mod.do_rsplit_no_args, [], True, ["abc", "def"], [(0, 3), (0, 3)]), + (b"abc def", mod.do_rsplit_no_args, [], True, [b"abc", b"def"], [(0, 3), (0, 3)]), + ( + bytearray(b"abc def"), + mod.do_rsplit_no_args, + [], + True, + [bytearray(b"abc"), bytearray(b"def")], + [(0, 3), (0, 3)], + ), + ("abc def", mod.do_split_no_args, [], False, ["abc", "def"], []), + ("abc def", mod.do_rsplit_no_args, [], False, ["abc", "def"], []), + (b"abc def", mod.do_rsplit_no_args, [], False, [b"abc", b"def"], []), + ("abc def hij", mod.do_split_no_args, [], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + (b"abc def hij", mod.do_split_no_args, [], True, [b"abc", b"def", b"hij"], [(0, 3), (0, 3), (0, 3)]), + ("abc def hij", mod.do_rsplit_no_args, [], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + (b"abc def hij", mod.do_rsplit_no_args, [], True, [b"abc", b"def", b"hij"], [(0, 3), (0, 3), (0, 3)]), + ("abc def hij", mod.do_split_no_args, [], False, ["abc", "def", "hij"], []), + (b"abc def hij", mod.do_split_no_args, [], False, [b"abc", b"def", b"hij"], []), + ("abc def hij", mod.do_rsplit_no_args, [], False, ["abc", "def", "hij"], []), + (b"abc def hij", mod.do_rsplit_no_args, [], False, [b"abc", b"def", b"hij"], []), + ( + bytearray(b"abc def hij"), + mod.do_rsplit_no_args, + [], + False, + [bytearray(b"abc"), bytearray(b"def"), bytearray(b"hij")], + [], + ), + ("abc def hij", mod.do_split_maxsplit, [1], True, ["abc", "def hij"], [(0, 3), (0, 7)]), + ("abc def hij", mod.do_rsplit_maxsplit, [1], True, ["abc def", "hij"], [(0, 7), (0, 3)]), + ("abc def hij", mod.do_split_maxsplit, [1], False, ["abc", "def hij"], []), + ("abc def hij", mod.do_rsplit_maxsplit, [1], False, ["abc def", "hij"], []), + ("abc def hij", mod.do_split_maxsplit, [2], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + ("abc def hij", mod.do_rsplit_maxsplit, [2], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + ("abc def hij", mod.do_split_maxsplit, [2], False, ["abc", "def", "hij"], []), + ("abc def hij", mod.do_rsplit_maxsplit, [2], False, ["abc", "def", "hij"], []), + ("abc|def|hij", mod.do_split_separator, ["|"], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + ("abc|def|hij", mod.do_rsplit_separator, ["|"], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + ("abc|def|hij", mod.do_split_separator, ["|"], False, ["abc", "def", "hij"], []), + ("abc|def|hij", mod.do_rsplit_separator, ["|"], False, ["abc", "def", "hij"], []), + ("abc|def hij", mod.do_split_separator, ["|"], True, ["abc", "def hij"], [(0, 3), (0, 7)]), + ("abc|def hij", mod.do_rsplit_separator, ["|"], True, ["abc", "def hij"], [(0, 3), (0, 7)]), + ("abc|def hij", mod.do_split_separator, ["|"], False, ["abc", "def hij"], []), + ("abc|def hij", mod.do_rsplit_separator, ["|"], False, ["abc", "def hij"], []), + ("abc|def|hij", mod.do_split_separator_and_maxsplit, ["|", 1], True, ["abc", "def|hij"], [(0, 3), (0, 7)]), + ("abc|def|hij", mod.do_rsplit_separator_and_maxsplit, ["|", 1], True, ["abc|def", "hij"], [(0, 7), (0, 3)]), + ("abc|def|hij", mod.do_split_separator_and_maxsplit, ["|", 1], False, ["abc", "def|hij"], []), + ("abc|def|hij", mod.do_rsplit_separator_and_maxsplit, ["|", 1], False, ["abc|def", "hij"], []), + ("abc\ndef\nhij", mod.do_splitlines_no_arg, [], True, ["abc", "def", "hij"], [(0, 3), (0, 3), (0, 3)]), + (b"abc\ndef\nhij", mod.do_splitlines_no_arg, [], True, [b"abc", b"def", b"hij"], [(0, 3), (0, 3), (0, 3)]), + ( + bytearray(b"abc\ndef\nhij"), + mod.do_splitlines_no_arg, + [], + True, + [bytearray(b"abc"), bytearray(b"def"), bytearray(b"hij")], + [(0, 3), (0, 3), (0, 3)], + ), + ( + "abc\ndef\nhij\n", + mod.do_splitlines_keepends, + [True], + True, + ["abc\n", "def\n", "hij\n"], + [(0, 4), (0, 4), (0, 4)], + ), + ( + b"abc\ndef\nhij\n", + mod.do_splitlines_keepends, + [True], + True, + [b"abc\n", b"def\n", b"hij\n"], + [(0, 4), (0, 4), (0, 4)], + ), + ( + bytearray(b"abc\ndef\nhij\n"), + mod.do_splitlines_keepends, + [True], + True, + [bytearray(b"abc\n"), bytearray(b"def\n"), bytearray(b"hij\n")], + [(0, 4), (0, 4), (0, 4)], + ), + ], +) +def test_split_aspect_tainting(s, call, _args, should_be_tainted, result_list, result_tainted_list): + _test_name = "test_split_aspect_tainting" + if should_be_tainted: + obj = taint_pyobject( + s, source_name="test_split_aspect_tainting", source_value=s, source_origin=OriginType.PARAMETER + ) + else: + obj = s + + result = call(obj, *_args) + assert result == result_list + for idx, result_range in enumerate(result_tainted_list): + result_item = result[idx] + assert is_pyobject_tainted(result_item) == should_be_tainted + if should_be_tainted: + _range = get_tainted_ranges(result_item)[0] + assert _range == TaintRange(result_range[0], result_range[1], Source(_test_name, obj, OriginType.PARAMETER)) + + class TestOperatorsReplacement(BaseReplacement): def test_aspect_ljust_str_tainted(self): # type: () -> None diff --git a/tests/appsec/iast/aspects/test_str_py3.py b/tests/appsec/iast/aspects/test_str_py3.py index 5437c631984..c93c35844cb 100644 --- a/tests/appsec/iast/aspects/test_str_py3.py +++ b/tests/appsec/iast/aspects/test_str_py3.py @@ -52,6 +52,11 @@ def test_string_fstring_with_format_tainted(self): result = mod_py3.do_repr_fstring_with_format(string_input) # pylint: disable=no-member assert as_formatted_evidence(result) == "':+-foo-+:' " + def test_int_fstring_zero_padding_tainted(self): + int_input = 5 + result = mod_py3.do_zero_padding_fstring(int_input) # pylint: disable=no-member + assert result == "00005" + def test_string_fstring_repr_str_twice_tainted(self): # type: () -> None string_input = "foo" diff --git a/tests/appsec/iast/conftest.py b/tests/appsec/iast/conftest.py index e8e9bbebb9c..cc304eb56b7 100644 --- a/tests/appsec/iast/conftest.py +++ b/tests/appsec/iast/conftest.py @@ -11,6 +11,8 @@ from ddtrace.appsec._iast.taint_sinks._base import VulnerabilityBase from ddtrace.appsec._iast.taint_sinks.command_injection import patch as cmdi_patch from ddtrace.appsec._iast.taint_sinks.command_injection import unpatch as cmdi_unpatch +from ddtrace.appsec._iast.taint_sinks.header_injection import patch as header_injection_patch +from ddtrace.appsec._iast.taint_sinks.header_injection import unpatch as header_injection_unpatch from ddtrace.appsec._iast.taint_sinks.path_traversal import patch as path_traversal_patch from ddtrace.appsec._iast.taint_sinks.weak_cipher import patch as weak_cipher_patch from ddtrace.appsec._iast.taint_sinks.weak_cipher import unpatch_iast as weak_cipher_unpatch @@ -62,6 +64,7 @@ def iast_span(tracer, env, request_sampling="100", deduplication=False): psycopg_patch() sqlalchemy_patch() cmdi_patch() + header_injection_patch() langchain_patch() iast_span_processor.on_span_start(span) yield span @@ -73,6 +76,7 @@ def iast_span(tracer, env, request_sampling="100", deduplication=False): psycopg_unpatch() sqlalchemy_unpatch() cmdi_unpatch() + header_injection_unpatch() langchain_unpatch() diff --git a/tests/appsec/iast/fixtures/aspects/module_functions.py b/tests/appsec/iast/fixtures/aspects/module_functions.py new file mode 100644 index 00000000000..a50bd059d27 --- /dev/null +++ b/tests/appsec/iast/fixtures/aspects/module_functions.py @@ -0,0 +1,33 @@ +import os.path + + +def do_os_path_join(a, b): + return os.path.join(a, b) + + +def do_os_path_normcase(a): + return os.path.normcase(a) + + +def do_os_path_basename(a): + return os.path.basename(a) + + +def do_os_path_dirname(a): + return os.path.dirname(a) + + +def do_os_path_splitdrive(a): + return os.path.splitdrive(a) + + +def do_os_path_splitroot(a): + return os.path.splitroot(a) + + +def do_os_path_split(a): + return os.path.split(a) + + +def do_os_path_splitext(a): + return os.path.splitext(a) diff --git a/tests/appsec/iast/fixtures/aspects/str_methods.py b/tests/appsec/iast/fixtures/aspects/str_methods.py index eff192d7108..dd748b732c4 100644 --- a/tests/appsec/iast/fixtures/aspects/str_methods.py +++ b/tests/appsec/iast/fixtures/aspects/str_methods.py @@ -473,11 +473,13 @@ def django_check(all_issues, display_num_errors=False): if visible_issue_count: footer += "\n" footer += "System check identified %s (%s silenced)." % ( - "no issues" - if visible_issue_count == 0 - else "1 issue" - if visible_issue_count == 1 - else "%s issues" % visible_issue_count, + ( + "no issues" + if visible_issue_count == 0 + else "1 issue" + if visible_issue_count == 1 + else "%s issues" % visible_issue_count + ), len(all_issues) - visible_issue_count, ) @@ -513,11 +515,13 @@ def django_check_simple_formatted_ifs(f): def django_check_simple_formatted_multiple_ifs(f): visible_issue_count = 1 f += "System check identified %s (%s silenced)." % ( - "no issues" - if visible_issue_count == 0 - else "1 issue" - if visible_issue_count == 1 - else "%s issues" % visible_issue_count, + ( + "no issues" + if visible_issue_count == 0 + else "1 issue" + if visible_issue_count == 1 + else "%s issues" % visible_issue_count + ), 5 - visible_issue_count, ) return f @@ -1000,15 +1004,38 @@ def do_rsplit_no_args(s): # type: (str) -> List[str] return s.rsplit() -def do_split(s, sep, maxsplit=-1): # type: (str, str, int) -> List[str] - return s.split(sep, maxsplit) +def do_split_maxsplit(s, maxsplit=-1): # type: (str, int) -> List[str] + return s.split(maxsplit=maxsplit) + + +def do_rsplit_maxsplit(s, maxsplit=-1): # type: (str, int) -> List[str] + return s.rsplit(maxsplit=maxsplit) -# foosep is just needed so it has the signature expected by _test_somesplit_impl -def do_splitlines(s, foosep): # type: (str, str) -> List[str] +def do_split_separator(s, separator): # type: (str, str) -> List[str] + return s.split(separator) + + +def do_rsplit_separator(s, separator): # type: (str, str) -> List[str] + return s.rsplit(separator) + + +def do_split_separator_and_maxsplit(s, separator, maxsplit): # type: (str, str, int) -> List[str] + return s.split(separator, maxsplit) + + +def do_rsplit_separator_and_maxsplit(s, separator, maxsplit): # type: (str, str, int) -> List[str] + return s.rsplit(separator, maxsplit) + + +def do_splitlines_no_arg(s): # type: (str) -> List[str] return s.splitlines() +def do_splitlines_keepends(s, keepends): # type: (str, bool) -> List[str] + return s.splitlines(keepends=keepends) + + def do_partition(s, sep): # type: (str, str) -> Tuple[str, str, str] return s.partition(sep) diff --git a/tests/appsec/iast/fixtures/aspects/str_methods_py3.py b/tests/appsec/iast/fixtures/aspects/str_methods_py3.py index 864e868a762..9698afed88c 100644 --- a/tests/appsec/iast/fixtures/aspects/str_methods_py3.py +++ b/tests/appsec/iast/fixtures/aspects/str_methods_py3.py @@ -9,6 +9,10 @@ from typing import Tuple # noqa:F401 +def do_zero_padding_fstring(a): # type: (int) -> str + return f"{a:05d}" + + def do_fmt_value(a): # type: (str) -> str return f"{a:<8s}bar" diff --git a/tests/appsec/iast/taint_sinks/redaction_fixtures/evidence-redaction-suite.json b/tests/appsec/iast/taint_sinks/redaction_fixtures/evidence-redaction-suite.json index 89fc975a262..0719edb550a 100644 --- a/tests/appsec/iast/taint_sinks/redaction_fixtures/evidence-redaction-suite.json +++ b/tests/appsec/iast/taint_sinks/redaction_fixtures/evidence-redaction-suite.json @@ -2733,6 +2733,1289 @@ } ] } + }, + { + "type": "VULNERABILITIES", + "description": "Consecutive ranges - at the beginning", + "input": [ + { + "type": "UNVALIDATED_REDIRECT", + "evidence": { + "value": "https://user:password@datadoghq.com:443/api/v1/test/123/?param1=pone¶m2=ptwo#fragment1=fone&fragment2=ftwo", + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "protocol", + "parameterValue": "http" + } + }, + { + "start": 4, + "end": 5, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "secure", + "parameterValue": "s" + } + }, + { + "start": 22, + "end": 35, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "host", + "parameterValue": "datadoghq.com" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "protocol", + "value": "http" + }, + { + "origin": "http.request.parameter", + "name": "secure", + "value": "s" + }, + { + "origin": "http.request.parameter", + "name": "host", + "value": "datadoghq.com" + } + ], + "vulnerabilities": [ + { + "type": "UNVALIDATED_REDIRECT", + "evidence": { + "valueParts": [ + { + "source": 0, + "value": "http" + }, + { + "source": 1, + "value": "s" + }, + { + "value": "://" + }, + { + "redacted": true + }, + { + "value": "@" + }, + { + "source": 2, + "value": "datadoghq.com" + }, + { + "value": ":443/api/v1/test/123/?param1=" + }, + { + "redacted": true + }, + { + "value": "¶m2=" + }, + { + "redacted": true + }, + { + "value": "#fragment1=" + }, + { + "redacted": true + }, + { + "value": "&fragment2=" + }, + { + "redacted": true + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction ", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS vulnerability but can be extended to future ones", + "ranges": [ + { + "start": 123, + "end": 126, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "type", + "parameterValue": "XSS" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "type", + "value": "XSS" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "redacted": true + }, + { + "source": 0, + "value": "XSS" + }, + { + "redacted": true + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction - with redactable source ", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS vulnerability but can be extended to future ones", + "ranges": [ + { + "start": 123, + "end": 126, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "XSS" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "password", + "redacted": true, + "pattern": "abc" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "redacted": true + }, + { + "source": 0, + "redacted": true, + "pattern": "abc" + }, + { + "redacted": true + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction - with null source ", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS vulnerability but can be extended to future ones", + "ranges": [ + { + "start": 123, + "end": 126, + "iinfo": { + "type": "http.request.body" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.body" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "redacted": true + }, + { + "source": 0, + "value": "XSS" + }, + { + "redacted": true + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction - multiple ranges", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS vulnerability but can be extended to future ones", + "ranges": [ + { + "start": 16, + "end": 26, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "text", + "parameterValue": "super long" + } + }, + { + "start": 123, + "end": 126, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "type", + "parameterValue": "XSS" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "text", + "value": "super long" + }, + { + "origin": "http.request.parameter", + "name": "type", + "value": "XSS" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "redacted": true + }, + { + "source": 0, + "value": "super long" + }, + { + "redacted": true + }, + { + "source": 1, + "value": "XSS" + }, + { + "redacted": true + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction - first range at the beginning ", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS vulnerability but can be extended to future ones", + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "text", + "parameterValue": "this" + } + }, + { + "start": 123, + "end": 126, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "type", + "parameterValue": "XSS" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "text", + "value": "this" + }, + { + "origin": "http.request.parameter", + "name": "type", + "value": "XSS" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "source": 0, + "value": "this" + }, + { + "redacted": true + }, + { + "source": 1, + "value": "XSS" + }, + { + "redacted": true + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction - last range at the end ", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS", + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "text", + "parameterValue": "this" + } + }, + { + "start": 123, + "end": 126, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "type", + "parameterValue": "XSS" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "text", + "value": "this" + }, + { + "origin": "http.request.parameter", + "name": "type", + "value": "XSS" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "source": 0, + "value": "this" + }, + { + "redacted": true + }, + { + "source": 1, + "value": "XSS" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Tainted range based redaction - whole text ", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS", + "ranges": [ + { + "start": 0, + "end": 126, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "text", + "parameterValue": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "text", + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "source": 0, + "value": "this could be a super long text, so we need to reduce it before send it to the backend. This redaction strategy applies to XSS" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Mongodb json query with sensitive source", + "input": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "value": { + "password": "1234" + }, + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "1234" + } + } + ], + "rangesToApply": { + "password": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "1234" + } + } + ] + } + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "password", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "{\n \"password\": \"" + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + }, + { + "value": "\"\n}" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Mongodb json query with non sensitive source", + "input": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "value": { + "username": "user" + }, + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "user" + } + } + ], + "rangesToApply": { + "username": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "user" + } + } + ] + } + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "username", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "{\n \"username\": \"" + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + }, + { + "value": "\"\n}" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Mongodb json query with partial non sensitive source", + "input": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "value": { + "username": "user" + }, + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "PREFIX_user" + } + } + ], + "rangesToApply": { + "username": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "PREFIX_user" + } + } + ] + } + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "username", + "redacted": true, + "pattern": "abcdefghijk" + } + ], + "vulnerabilities": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "{\n \"username\": \"" + }, + { + "source": 0, + "redacted": true, + "pattern": "hijk" + }, + { + "value": "\"\n}" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Mongodb json query with non sensitive source and other fields", + "input": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "value": { + "username": "user", + "secret": "SECRET_VALUE" + }, + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "user" + } + } + ], + "rangesToApply": { + "username": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "user" + } + } + ] + } + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "username", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "{\n \"username\": \"" + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + }, + { + "value": "\",\n \"secret\": \"" + }, + { + "redacted": true + }, + { + "value": "\"\n}" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Mongodb json query with sensitive value in a key", + "input": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "value": { + "username": "user", + "token_usage": { + "bearer zss8dR9QP81A": 10 + } + }, + "ranges": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "user" + } + } + ], + "rangesToApply": { + "username": [ + { + "start": 0, + "end": 4, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "username", + "parameterValue": "user" + } + } + ] + } + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "username", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "NOSQL_MONGODB_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "{\n \"username\": \"" + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + }, + { + "value": "\",\n \"token_usage\": {\n \"" + }, + { + "redacted": true + }, + { + "value": "\": " + }, + { + "redacted": true + }, + { + "value": "\n }\n}" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "No redacted that needs to be truncated - whole text", + "input": [ + { + "type": "XSS", + "evidence": { + "value": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Sed ut perspiciatis unde omnis iste natus error sit voluptatem ac", + "ranges": [ + { + "start": 0, + "end": 510, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "text", + "parameterValue": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Sed ut perspiciatis unde omnis iste natus error sit voluptatem ac" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "text", + "truncated": "right", + "value": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure do" + } + ], + "vulnerabilities": [ + { + "type": "XSS", + "evidence": { + "valueParts": [ + { + "source": 0, + "truncated": "right", + "value": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure do" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection without sensitive data", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "custom: text", + "ranges": [ + { + "start": 8, + "end": 12, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "param", + "parameterValue": "text" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "param", + "value": "text" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "custom: " + }, + { + "source": 0, + "value": "text" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection with only sensitive data from tainted", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "custom: pass", + "ranges": [ + { + "start": 8, + "end": 12, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "pass" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "password", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "custom: " + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection with partial sensitive data from tainted", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "custom: this is pass", + "ranges": [ + { + "start": 16, + "end": 20, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "pass" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "password", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "custom: this is " + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection with sensitive data from header name", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "password: text", + "ranges": [ + { + "start": 10, + "end": 14, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "param", + "parameterValue": "text" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "param", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "password: " + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection with sensitive data from header value", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "custom: bearer 1234123", + "ranges": [ + { + "start": 15, + "end": 22, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "param", + "parameterValue": "1234123" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "param", + "redacted": true, + "pattern": "abcdefg" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "custom: " + }, + { + "redacted": true + }, + { + "source": 0, + "redacted": true, + "pattern": "abcdefg" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection with sensitive data from header and tainted", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "password: this is pass", + "ranges": [ + { + "start": 18, + "end": 22, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "pass" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "password", + "redacted": true, + "pattern": "abcd" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "password: " + }, + { + "redacted": true + }, + { + "source": 0, + "redacted": true, + "pattern": "abcd" + } + ] + } + } + ] + } + }, + { + "type": "VULNERABILITIES", + "description": "Header injection with sensitive data from header and tainted (source does not match)", + "input": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "value": "password: this is key word", + "ranges": [ + { + "start": 18, + "end": 26, + "iinfo": { + "type": "http.request.parameter", + "parameterName": "password", + "parameterValue": "key%20word" + } + } + ] + } + } + ], + "expected": { + "sources": [ + { + "origin": "http.request.parameter", + "name": "password", + "redacted": true, + "pattern": "abcdefghij" + } + ], + "vulnerabilities": [ + { + "type": "HEADER_INJECTION", + "evidence": { + "valueParts": [ + { + "value": "password: " + }, + { + "redacted": true + }, + { + "source": 0, + "redacted": true, + "pattern": "********" + } + ] + } + } + ] + } } ] } diff --git a/tests/appsec/iast/taint_sinks/test_command_injection.py b/tests/appsec/iast/taint_sinks/test_command_injection.py index 394a1a5ef4d..0100756dd41 100644 --- a/tests/appsec/iast/taint_sinks/test_command_injection.py +++ b/tests/appsec/iast/taint_sinks/test_command_injection.py @@ -40,12 +40,11 @@ def setup(): def test_ossystem(tracer, iast_span_defaults): with override_global_config(dict(_iast_enabled=True)): patch() - _BAD_DIR = "forbidden_dir/" + _BAD_DIR = "mytest/folder/" _BAD_DIR = taint_pyobject( pyobject=_BAD_DIR, source_name="test_ossystem", source_value=_BAD_DIR, - source_origin=OriginType.PARAMETER, ) assert is_pyobject_tainted(_BAD_DIR) with tracer.trace("ossystem_test"): @@ -54,26 +53,26 @@ def test_ossystem(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report - - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [ + data = span_report.build_and_scrub_value_parts() + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [ {"value": "dir "}, {"redacted": True}, {"pattern": "abcdefghijklmn", "redacted": True, "source": 0}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_ossystem" - assert source.origin == OriginType.PARAMETER - assert source.value is None + assert "value" not in vulnerability["evidence"].keys() + assert vulnerability["evidence"].get("pattern") is None + assert vulnerability["evidence"].get("redacted") is None + assert source["name"] == "test_ossystem" + assert source["origin"] == OriginType.PARAMETER + assert "value" not in source.keys() line, hash_value = get_line_and_hash("test_ossystem", VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value def test_communicate(tracer, iast_span_defaults): @@ -94,26 +93,27 @@ def test_communicate(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [ + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [ {"value": "dir "}, {"redacted": True}, {"pattern": "abcdefghijklmn", "redacted": True, "source": 0}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_communicate" - assert source.origin == OriginType.PARAMETER - assert source.value is None + assert "value" not in vulnerability["evidence"].keys() + assert "pattern" not in vulnerability["evidence"].keys() + assert "redacted" not in vulnerability["evidence"].keys() + assert source["name"] == "test_communicate" + assert source["origin"] == OriginType.PARAMETER + assert "value" not in source.keys() line, hash_value = get_line_and_hash("test_communicate", VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value def test_run(tracer, iast_span_defaults): @@ -132,26 +132,27 @@ def test_run(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [ + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [ {"value": "dir "}, {"redacted": True}, {"pattern": "abcdefghijklmn", "redacted": True, "source": 0}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_run" - assert source.origin == OriginType.PARAMETER - assert source.value is None + assert "value" not in vulnerability["evidence"].keys() + assert "pattern" not in vulnerability["evidence"].keys() + assert "redacted" not in vulnerability["evidence"].keys() + assert source["name"] == "test_run" + assert source["origin"] == OriginType.PARAMETER + assert "value" not in source.keys() line, hash_value = get_line_and_hash("test_run", VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value def test_popen_wait(tracer, iast_span_defaults): @@ -171,26 +172,27 @@ def test_popen_wait(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [ + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [ {"value": "dir "}, {"redacted": True}, {"pattern": "abcdefghijklmn", "redacted": True, "source": 0}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_popen_wait" - assert source.origin == OriginType.PARAMETER - assert source.value is None + assert "value" not in vulnerability["evidence"].keys() + assert "pattern" not in vulnerability["evidence"].keys() + assert "redacted" not in vulnerability["evidence"].keys() + assert source["name"] == "test_popen_wait" + assert source["origin"] == OriginType.PARAMETER + assert "value" not in source.keys() line, hash_value = get_line_and_hash("test_popen_wait", VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value def test_popen_wait_shell_true(tracer, iast_span_defaults): @@ -210,26 +212,27 @@ def test_popen_wait_shell_true(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [ + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [ {"value": "dir "}, {"redacted": True}, {"pattern": "abcdefghijklmn", "redacted": True, "source": 0}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_popen_wait_shell_true" - assert source.origin == OriginType.PARAMETER - assert source.value is None + assert "value" not in vulnerability["evidence"].keys() + assert "pattern" not in vulnerability["evidence"].keys() + assert "redacted" not in vulnerability["evidence"].keys() + assert source["name"] == "test_popen_wait_shell_true" + assert source["origin"] == OriginType.PARAMETER + assert "value" not in source.keys() line, hash_value = get_line_and_hash("test_popen_wait_shell_true", VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value @pytest.mark.skipif(sys.platform != "linux", reason="Only for Linux") @@ -275,22 +278,23 @@ def test_osspawn_variants(tracer, iast_span_defaults, function, mode, arguments, span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report - - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [{"value": "/bin/ls -l "}, {"source": 0, "value": _BAD_DIR}] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_osspawn_variants" - assert source.origin == OriginType.PARAMETER - assert source.value == _BAD_DIR + data = span_report.build_and_scrub_value_parts() + + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [{"value": "/bin/ls -l "}, {"source": 0, "value": _BAD_DIR}] + assert "value" not in vulnerability["evidence"].keys() + assert "pattern" not in vulnerability["evidence"].keys() + assert "redacted" not in vulnerability["evidence"].keys() + assert source["name"] == "test_osspawn_variants" + assert source["origin"] == OriginType.PARAMETER + assert source["value"] == _BAD_DIR line, hash_value = get_line_and_hash(tag, VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value @pytest.mark.skipif(sys.platform != "linux", reason="Only for Linux") @@ -315,8 +319,9 @@ def test_multiple_cmdi(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - assert len(list(span_report.vulnerabilities)) == 2 + assert len(list(data["vulnerabilities"])) == 2 @pytest.mark.skipif(sys.platform != "linux", reason="Only for Linux") @@ -334,8 +339,9 @@ def test_string_cmdi(tracer, iast_span_defaults): span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - assert len(list(span_report.vulnerabilities)) == 1 + assert len(list(data["vulnerabilities"])) == 1 @pytest.mark.parametrize("num_vuln_expected", [1, 0, 0]) @@ -360,5 +366,5 @@ def test_cmdi_deduplication(num_vuln_expected, tracer, iast_span_deduplication_e assert span_report is None else: assert span_report - - assert len(span_report.vulnerabilities) == num_vuln_expected + data = span_report.build_and_scrub_value_parts() + assert len(data["vulnerabilities"]) == num_vuln_expected diff --git a/tests/appsec/iast/taint_sinks/test_command_injection_redacted.py b/tests/appsec/iast/taint_sinks/test_command_injection_redacted.py index b76e47a5805..4cb6a962c7d 100644 --- a/tests/appsec/iast/taint_sinks/test_command_injection_redacted.py +++ b/tests/appsec/iast/taint_sinks/test_command_injection_redacted.py @@ -2,12 +2,14 @@ import pytest from ddtrace.appsec._constants import IAST +from ddtrace.appsec._iast._taint_tracking import origin_to_str from ddtrace.appsec._iast._taint_tracking import str_to_origin +from ddtrace.appsec._iast._taint_tracking import taint_pyobject +from ddtrace.appsec._iast._taint_tracking.aspects import add_aspect from ddtrace.appsec._iast.constants import VULN_CMDI from ddtrace.appsec._iast.reporter import Evidence from ddtrace.appsec._iast.reporter import IastSpanReporter from ddtrace.appsec._iast.reporter import Location -from ddtrace.appsec._iast.reporter import Source from ddtrace.appsec._iast.reporter import Vulnerability from ddtrace.appsec._iast.taint_sinks.command_injection import CommandInjection from ddtrace.internal import core @@ -36,10 +38,14 @@ def test_cmdi_redaction_suite(evidence_input, sources_expected, vulnerabilities_ span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report - vulnerability = list(span_report.vulnerabilities)[0] + span_report.build_and_scrub_value_parts() + result = span_report._to_dict() + vulnerability = list(result["vulnerabilities"])[0] + source = list(result["sources"])[0] + source["origin"] = origin_to_str(source["origin"]) - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == vulnerabilities_expected["evidence"]["valueParts"] + assert vulnerability["type"] == VULN_CMDI + assert source == sources_expected @pytest.mark.parametrize( @@ -72,24 +78,52 @@ def test_cmdi_redaction_suite(evidence_input, sources_expected, vulnerabilities_ "/mytest/../folder/file.txt", ], ) -def test_cmdi_redact_rel_paths(file_path): - ev = Evidence( - valueParts=[ - {"value": "sudo "}, - {"value": "ls "}, - {"value": file_path, "source": 0}, +def test_cmdi_redact_rel_paths_and_sudo(file_path): + file_path = taint_pyobject(pyobject=file_path, source_name="test_ossystem", source_value=file_path) + ev = Evidence(value=add_aspect("sudo ", add_aspect("ls ", file_path))) + loc = Location(path="foobar.py", line=35, spanId=123) + v = Vulnerability(type=VULN_CMDI, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() + + assert result["vulnerabilities"] + + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ + {"value": "sudo ls "}, + {"redacted": True, "pattern": ANY, "source": 0}, ] - ) + + +@pytest.mark.parametrize( + "file_path", + [ + "2 > /mytest/folder/", + "2 > mytest/folder/", + "-p mytest/folder", + "--path=../mytest/folder/", + "--path=../mytest/folder/", + "--options ../mytest/folder", + "-a /mytest/folder/", + "-b /mytest/folder/", + "-c /mytest/folder", + ], +) +def test_cmdi_redact_sudo_command_with_options(file_path): + file_path = taint_pyobject(pyobject=file_path, source_name="test_ossystem", source_value=file_path) + ev = Evidence(value=add_aspect("sudo ", add_aspect("ls ", file_path))) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) - s = Source(origin="file", name="SomeName", value=file_path) - report = IastSpanReporter([s], {v}) + v = Vulnerability(type=VULN_CMDI, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() - redacted_report = CommandInjection._redact_report(report) - for v in redacted_report.vulnerabilities: - assert v.evidence.valueParts == [ - {"value": "sudo "}, - {"value": "ls "}, + assert result["vulnerabilities"] + + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ + {"value": "sudo ls "}, {"redacted": True, "pattern": ANY, "source": 0}, ] @@ -108,24 +142,69 @@ def test_cmdi_redact_rel_paths(file_path): "-c /mytest/folder", ], ) -def test_cmdi_redact_options(file_path): - ev = Evidence( - valueParts=[ - {"value": "sudo "}, +def test_cmdi_redact_command_with_options(file_path): + file_path = taint_pyobject(pyobject=file_path, source_name="test_ossystem", source_value=file_path) + ev = Evidence(value=add_aspect("ls ", file_path)) + loc = Location(path="foobar.py", line=35, spanId=123) + v = Vulnerability(type=VULN_CMDI, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() + + assert result["vulnerabilities"] + + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ {"value": "ls "}, - {"value": file_path, "source": 0}, + {"redacted": True, "pattern": ANY, "source": 0}, ] - ) + + +@pytest.mark.parametrize( + "file_path", + [ + "/mytest/folder/", + "mytest/folder/", + "mytest/folder", + "../mytest/folder/", + "../mytest/folder/", + "../mytest/folder", + "/mytest/folder/", + "/mytest/folder/", + "/mytest/folder", + "/mytest/../folder/", + "mytest/../folder/", + "mytest/../folder", + "../mytest/../folder/", + "../mytest/../folder/", + "../mytest/../folder", + "/mytest/../folder/", + "/mytest/../folder/", + "/mytest/../folder", + "/mytest/folder/file.txt", + "mytest/folder/file.txt", + "../mytest/folder/file.txt", + "/mytest/folder/file.txt", + "mytest/../folder/file.txt", + "../mytest/../folder/file.txt", + "/mytest/../folder/file.txt", + ], +) +def test_cmdi_redact_rel_paths(file_path): + file_path = taint_pyobject(pyobject=file_path, source_name="test_ossystem", source_value=file_path) + ev = Evidence(value=add_aspect("dir -l ", file_path)) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) - s = Source(origin="file", name="SomeName", value=file_path) - report = IastSpanReporter([s], {v}) + v = Vulnerability(type=VULN_CMDI, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() - redacted_report = CommandInjection._redact_report(report) - for v in redacted_report.vulnerabilities: - assert v.evidence.valueParts == [ - {"value": "sudo "}, - {"value": "ls "}, + assert result["vulnerabilities"] + + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ + {"value": "dir "}, + {"redacted": True}, {"redacted": True, "pattern": ANY, "source": 0}, ] @@ -145,23 +224,19 @@ def test_cmdi_redact_options(file_path): ], ) def test_cmdi_redact_source_command(file_path): - ev = Evidence( - valueParts=[ - {"value": "sudo "}, - {"value": "ls ", "source": 0}, - {"value": file_path}, - ] - ) - loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) - s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") - report = IastSpanReporter([s], {v}) + Ls_cmd = taint_pyobject(pyobject="ls ", source_name="test_ossystem", source_value="ls ") - redacted_report = CommandInjection._redact_report(report) - for v in redacted_report.vulnerabilities: - assert v.evidence.valueParts == [ + ev = Evidence(value=add_aspect("sudo ", add_aspect(Ls_cmd, file_path))) + loc = Location(path="foobar.py", line=35, spanId=123) + v = Vulnerability(type=VULN_CMDI, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() + + assert result["vulnerabilities"] + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ {"value": "sudo "}, {"value": "ls ", "source": 0}, - {"value": " "}, {"redacted": True}, ] diff --git a/tests/appsec/iast/taint_sinks/test_header_injection_redacted.py b/tests/appsec/iast/taint_sinks/test_header_injection_redacted.py new file mode 100644 index 00000000000..db9272e1625 --- /dev/null +++ b/tests/appsec/iast/taint_sinks/test_header_injection_redacted.py @@ -0,0 +1,117 @@ +import pytest + +from ddtrace.appsec._constants import IAST +from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted +from ddtrace.appsec._iast._taint_tracking import origin_to_str +from ddtrace.appsec._iast._taint_tracking import str_to_origin +from ddtrace.appsec._iast.constants import VULN_HEADER_INJECTION +from ddtrace.appsec._iast.reporter import Evidence +from ddtrace.appsec._iast.reporter import IastSpanReporter +from ddtrace.appsec._iast.reporter import Location +from ddtrace.appsec._iast.reporter import Source +from ddtrace.appsec._iast.reporter import Vulnerability +from ddtrace.appsec._iast.taint_sinks.header_injection import HeaderInjection +from ddtrace.internal import core +from tests.appsec.iast.taint_sinks.test_taint_sinks_utils import _taint_pyobject_multiranges +from tests.appsec.iast.taint_sinks.test_taint_sinks_utils import get_parametrize + + +@pytest.mark.parametrize( + "header_name, header_value", + [ + ("test", "aaaaaaaaaaaaaa"), + ("test2", "9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b"), + ], +) +def test_header_injection_redact_excluded(header_name, header_value): + ev = Evidence( + valueParts=[ + {"value": header_name + ": "}, + {"value": header_value, "source": 0}, + ] + ) + loc = Location(path="foobar.py", line=35, spanId=123) + v = Vulnerability(type=VULN_HEADER_INJECTION, evidence=ev, location=loc) + s = Source(origin="SomeOrigin", name="SomeName", value=header_value) + report = IastSpanReporter([s], {v}) + report.add_ranges_to_evidence_and_extract_sources(v) + redacted_report = HeaderInjection._redact_report(report) + for v in redacted_report.vulnerabilities: + assert v.evidence.valueParts == [{"value": header_name + ": "}, {"source": 0, "value": header_value}] + + +@pytest.mark.parametrize( + "header_name, header_value, value_part", + [ + ( + "WWW-Authenticate", + 'Basic realm="api"', + [{"value": "WWW-Authenticate: "}, {"source": 0, "value": 'Basic realm="api"'}], + ), + ( + "Authorization", + "Token 9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b", + [ + {"value": "Authorization: "}, + { + "pattern": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRST", + "redacted": True, + "source": 0, + }, + ], + ), + ], +) +def test_common_django_header_injection_redact(header_name, header_value, value_part): + ev = Evidence( + valueParts=[ + {"value": header_name + ": "}, + {"value": header_value, "source": 0}, + ] + ) + loc = Location(path="foobar.py", line=35, spanId=123) + v = Vulnerability(type=VULN_HEADER_INJECTION, evidence=ev, location=loc) + s = Source(origin="SomeOrigin", name="SomeName", value=header_value) + report = IastSpanReporter([s], {v}) + report.add_ranges_to_evidence_and_extract_sources(v) + redacted_report = HeaderInjection._redact_report(report) + for v in redacted_report.vulnerabilities: + assert v.evidence.valueParts == value_part + + +@pytest.mark.parametrize( + "evidence_input, sources_expected, vulnerabilities_expected", + list(get_parametrize(VULN_HEADER_INJECTION)), +) +def test_header_injection_redaction_suite( + evidence_input, sources_expected, vulnerabilities_expected, iast_span_defaults +): + tainted_object = _taint_pyobject_multiranges( + evidence_input["value"], + [ + ( + input_ranges["iinfo"]["parameterName"], + input_ranges["iinfo"]["parameterValue"], + str_to_origin(input_ranges["iinfo"]["type"]), + input_ranges["start"], + input_ranges["end"] - input_ranges["start"], + ) + for input_ranges in evidence_input["ranges"] + ], + ) + + assert is_pyobject_tainted(tainted_object) + + HeaderInjection.report(tainted_object) + + span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) + assert span_report + + span_report.build_and_scrub_value_parts() + result = span_report._to_dict() + vulnerability = list(result["vulnerabilities"])[0] + source = list(result["sources"])[0] + source["origin"] = origin_to_str(source["origin"]) + + assert vulnerability["type"] == VULN_HEADER_INJECTION + assert source == sources_expected diff --git a/tests/appsec/iast/taint_sinks/test_insecure_cookie.py b/tests/appsec/iast/taint_sinks/test_insecure_cookie.py index 2a45778a89c..9d2784b3c49 100644 --- a/tests/appsec/iast/taint_sinks/test_insecure_cookie.py +++ b/tests/appsec/iast/taint_sinks/test_insecure_cookie.py @@ -1,7 +1,9 @@ +import json + +import attr import pytest from ddtrace.appsec._constants import IAST -from ddtrace.appsec._iast._utils import _iast_report_to_str from ddtrace.appsec._iast.constants import VULN_INSECURE_COOKIE from ddtrace.appsec._iast.constants import VULN_NO_HTTPONLY_COOKIE from ddtrace.appsec._iast.constants import VULN_NO_SAMESITE_COOKIE @@ -9,6 +11,20 @@ from ddtrace.internal import core +def _iast_report_to_str(data): + from ddtrace.appsec._iast._taint_tracking import OriginType + from ddtrace.appsec._iast._taint_tracking import origin_to_str + + class OriginTypeEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, OriginType): + # if the obj is uuid, we simply return the value of uuid + return origin_to_str(obj) + return json.JSONEncoder.default(self, obj) + + return json.dumps(attr.asdict(data, filter=lambda attr, x: x is not None), cls=OriginTypeEncoder) + + def test_insecure_cookies(iast_span_defaults): cookies = {"foo": "bar"} asm_check_cookies(cookies) diff --git a/tests/appsec/iast/taint_sinks/test_path_traversal.py b/tests/appsec/iast/taint_sinks/test_path_traversal.py index 6a8083908ba..0dda76950e7 100644 --- a/tests/appsec/iast/taint_sinks/test_path_traversal.py +++ b/tests/appsec/iast/taint_sinks/test_path_traversal.py @@ -33,17 +33,20 @@ def test_path_traversal_open(iast_span_defaults): ) mod.pt_open(tainted_string) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert len(span_report.vulnerabilities) == 1 - assert vulnerability.type == VULN_PATH_TRAVERSAL - assert source.name == "path" - assert source.origin == OriginType.PATH - assert source.value == file_path - assert vulnerability.evidence.valueParts == [{"source": 0, "value": file_path}] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None + assert span_report + data = span_report.build_and_scrub_value_parts() + + assert len(data["vulnerabilities"]) == 1 + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_PATH_TRAVERSAL + assert source["name"] == "path" + assert source["origin"] == OriginType.PATH + assert source["value"] == file_path + assert vulnerability["evidence"]["valueParts"] == [{"source": 0, "value": file_path}] + assert "value" not in vulnerability["evidence"].keys() + assert vulnerability["evidence"].get("pattern") is None + assert vulnerability["evidence"].get("redacted") is None @pytest.mark.parametrize( @@ -82,19 +85,22 @@ def test_path_traversal(module, function, iast_span_defaults): getattr(mod, "path_{}_{}".format(module, function))(tainted_string) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) + assert span_report + data = span_report.build_and_scrub_value_parts() + line, hash_value = get_line_and_hash( "path_{}_{}".format(module, function), VULN_PATH_TRAVERSAL, filename=FIXTURES_PATH ) - vulnerability = list(span_report.vulnerabilities)[0] - assert len(span_report.vulnerabilities) == 1 - assert vulnerability.type == VULN_PATH_TRAVERSAL - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value - assert vulnerability.evidence.valueParts == [{"source": 0, "value": file_path}] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None + vulnerability = data["vulnerabilities"][0] + assert len(data["vulnerabilities"]) == 1 + assert vulnerability["type"] == VULN_PATH_TRAVERSAL + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value + assert vulnerability["evidence"]["valueParts"] == [{"source": 0, "value": file_path}] + assert "value" not in vulnerability["evidence"].keys() + assert vulnerability["evidence"].get("pattern") is None + assert vulnerability["evidence"].get("redacted") is None @pytest.mark.parametrize("num_vuln_expected", [1, 0, 0]) diff --git a/tests/appsec/iast/taint_sinks/test_path_traversal_redacted.py b/tests/appsec/iast/taint_sinks/test_path_traversal_redacted.py index 9ecc9cc5e14..ccd88c0ce11 100644 --- a/tests/appsec/iast/taint_sinks/test_path_traversal_redacted.py +++ b/tests/appsec/iast/taint_sinks/test_path_traversal_redacted.py @@ -2,6 +2,7 @@ import pytest +from ddtrace.appsec._iast.constants import VULN_PATH_TRAVERSAL from ddtrace.appsec._iast.reporter import Evidence from ddtrace.appsec._iast.reporter import IastSpanReporter from ddtrace.appsec._iast.reporter import Location @@ -34,7 +35,7 @@ def test_path_traversal_redact_exclude(file_path): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_PATH_TRAVERSAL, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -80,7 +81,7 @@ def test_path_traversal_redact_rel_paths(file_path): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_PATH_TRAVERSAL, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -97,7 +98,7 @@ def test_path_traversal_redact_abs_paths(): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_PATH_TRAVERSAL, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) diff --git a/tests/appsec/iast/taint_sinks/test_sql_injection.py b/tests/appsec/iast/taint_sinks/test_sql_injection.py index 62252cc7808..54efea82ffe 100644 --- a/tests/appsec/iast/taint_sinks/test_sql_injection.py +++ b/tests/appsec/iast/taint_sinks/test_sql_injection.py @@ -53,8 +53,6 @@ def test_sql_injection(fixture_path, fixture_module, iast_span_defaults): {"value": "students", "source": 0}, ] assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None assert source.name == "test_ossystem" assert source.origin == OriginType.PARAMETER assert source.value == "students" diff --git a/tests/appsec/iast/taint_sinks/test_sql_injection_redacted.py b/tests/appsec/iast/taint_sinks/test_sql_injection_redacted.py index 038d12d3ceb..4122b53d402 100644 --- a/tests/appsec/iast/taint_sinks/test_sql_injection_redacted.py +++ b/tests/appsec/iast/taint_sinks/test_sql_injection_redacted.py @@ -1,9 +1,6 @@ -import copy - import pytest from ddtrace.appsec._constants import IAST -from ddtrace.appsec._iast import oce from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted from ddtrace.appsec._iast._taint_tracking import str_to_origin from ddtrace.appsec._iast.constants import VULN_SQL_INJECTION @@ -12,13 +9,10 @@ from ddtrace.appsec._iast.reporter import Location from ddtrace.appsec._iast.reporter import Source from ddtrace.appsec._iast.reporter import Vulnerability -from ddtrace.appsec._iast.taint_sinks._base import VulnerabilityBase from ddtrace.appsec._iast.taint_sinks.sql_injection import SqlInjection from ddtrace.internal import core -from ddtrace.internal.utils.cache import LFUCache from tests.appsec.iast.taint_sinks.test_taint_sinks_utils import _taint_pyobject_multiranges from tests.appsec.iast.taint_sinks.test_taint_sinks_utils import get_parametrize -from tests.utils import override_env from tests.utils import override_global_config @@ -91,7 +85,7 @@ def test_redacted_report_no_match(): ev = Evidence(value="SomeEvidenceValue") orig_ev = ev.value loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -103,46 +97,37 @@ def test_redacted_report_no_match(): def test_redacted_report_source_name_match(): ev = Evidence(value="'SomeEvidenceValue'") - len_ev = len(ev.value) - 2 loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="secret", value="SomeValue") report = IastSpanReporter([s], {v}) redacted_report = SqlInjection._redact_report(report) for v in redacted_report.vulnerabilities: - assert v.evidence.redacted - assert v.evidence.pattern == "'%s'" % ("*" * len_ev) assert not v.evidence.value def test_redacted_report_source_value_match(): ev = Evidence(value="'SomeEvidenceValue'") - len_ev = len(ev.value) - 2 loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="somepassword") report = IastSpanReporter([s], {v}) redacted_report = SqlInjection._redact_report(report) for v in redacted_report.vulnerabilities: - assert v.evidence.redacted - assert v.evidence.pattern == "'%s'" % ("*" * len_ev) assert not v.evidence.value def test_redacted_report_evidence_value_match_also_redacts_source_value(): ev = Evidence(value="'SomeSecretPassword'") - len_ev = len(ev.value) - 2 loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeSecretPassword") report = IastSpanReporter([s], {v}) redacted_report = SqlInjection._redact_report(report) for v in redacted_report.vulnerabilities: - assert v.evidence.redacted - assert v.evidence.pattern == "'%s'" % ("*" * len_ev) assert not v.evidence.value for s in redacted_report.sources: assert s.redacted @@ -159,7 +144,7 @@ def test_redacted_report_valueparts(): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -183,7 +168,7 @@ def test_redacted_report_valueparts_username_not_tainted(): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -211,7 +196,7 @@ def test_redacted_report_valueparts_username_tainted(): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -237,7 +222,7 @@ def test_regression_ci_failure(): ] ) loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) + v = Vulnerability(type=VULN_SQL_INJECTION, evidence=ev, location=loc) s = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") report = IastSpanReporter([s], {v}) @@ -250,122 +235,3 @@ def test_regression_ci_failure(): {"redacted": True}, {"value": "'"}, ] - - -def test_scrub_cache(tracer): - valueParts1 = [ - {"value": "SELECT * FROM users WHERE password = '"}, - {"value": "1234", "source": 0}, - {"value": ":{SHA1}'"}, - ] - # valueParts will be modified to be scrubbed, thus these copies - valueParts1_copy1 = copy.deepcopy(valueParts1) - valueParts1_copy2 = copy.deepcopy(valueParts1) - valueParts1_copy3 = copy.deepcopy(valueParts1) - valueParts2 = [ - {"value": "SELECT * FROM users WHERE password = '"}, - {"value": "123456", "source": 0}, - {"value": ":{SHA1}'"}, - ] - - s1 = Source(origin="SomeOrigin", name="SomeName", value="SomeValue") - s2 = Source(origin="SomeOtherOrigin", name="SomeName", value="SomeValue") - - env = {"DD_IAST_REQUEST_SAMPLING": "100", "DD_IAST_ENABLED": "true"} - with override_env(env): - oce.reconfigure() - with tracer.trace("test1") as span: - oce.acquire_request(span) - VulnerabilityBase._redacted_report_cache = LFUCache() - SqlInjection.report(evidence_value=valueParts1, sources=[s1]) - span_report1 = core.get_item(IAST.CONTEXT_KEY, span=span) - assert span_report1, "no report: check that get_info_frame is not skipping this frame" - assert list(span_report1.vulnerabilities)[0].evidence == Evidence( - value=None, - pattern=None, - valueParts=[ - {"value": "SELECT * FROM users WHERE password = '"}, - {"redacted": True}, - {"value": ":{SHA1}'"}, - ], - ) - assert len(VulnerabilityBase._redacted_report_cache) == 1 - oce.release_request() - - # Should be the same report object - with tracer.trace("test2") as span: - oce.acquire_request(span) - SqlInjection.report(evidence_value=valueParts1_copy1, sources=[s1]) - span_report2 = core.get_item(IAST.CONTEXT_KEY, span=span) - assert list(span_report2.vulnerabilities)[0].evidence == Evidence( - value=None, - pattern=None, - valueParts=[ - {"value": "SELECT * FROM users WHERE password = '"}, - {"redacted": True}, - {"value": ":{SHA1}'"}, - ], - ) - assert id(span_report1) == id(span_report2) - assert span_report1 is span_report2 - assert len(VulnerabilityBase._redacted_report_cache) == 1 - oce.release_request() - - # Different report, other valueParts - with tracer.trace("test3") as span: - oce.acquire_request(span) - SqlInjection.report(evidence_value=valueParts2, sources=[s1]) - span_report3 = core.get_item(IAST.CONTEXT_KEY, span=span) - assert list(span_report3.vulnerabilities)[0].evidence == Evidence( - value=None, - pattern=None, - valueParts=[ - {"value": "SELECT * FROM users WHERE password = '"}, - {"redacted": True}, - {"value": ":{SHA1}'"}, - ], - ) - assert id(span_report1) != id(span_report3) - assert span_report1 is not span_report3 - assert len(VulnerabilityBase._redacted_report_cache) == 2 - oce.release_request() - - # Different report, other source - with tracer.trace("test4") as span: - oce.acquire_request(span) - SqlInjection.report(evidence_value=valueParts1_copy2, sources=[s2]) - span_report4 = core.get_item(IAST.CONTEXT_KEY, span=span) - assert list(span_report4.vulnerabilities)[0].evidence == Evidence( - value=None, - pattern=None, - valueParts=[ - {"value": "SELECT * FROM users WHERE password = '"}, - {"redacted": True}, - {"value": ":{SHA1}'"}, - ], - ) - assert id(span_report1) != id(span_report4) - assert span_report1 is not span_report4 - assert len(VulnerabilityBase._redacted_report_cache) == 3 - oce.release_request() - - # Same as previous so cache should not increase - with tracer.trace("test4") as span: - oce.acquire_request(span) - SqlInjection.report(evidence_value=valueParts1_copy3, sources=[s2]) - span_report5 = core.get_item(IAST.CONTEXT_KEY, span=span) - assert list(span_report5.vulnerabilities)[0].evidence == Evidence( - value=None, - pattern=None, - valueParts=[ - {"value": "SELECT * FROM users WHERE password = '"}, - {"redacted": True}, - {"value": ":{SHA1}'"}, - ], - ) - assert id(span_report1) != id(span_report5) - assert span_report1 is not span_report5 - assert id(span_report4) == id(span_report5) - assert span_report4 is span_report5 - assert len(VulnerabilityBase._redacted_report_cache) == 3 - oce.release_request() diff --git a/tests/appsec/iast/taint_sinks/test_ssrf.py b/tests/appsec/iast/taint_sinks/test_ssrf.py index 25e133830ec..49053f0b07b 100644 --- a/tests/appsec/iast/taint_sinks/test_ssrf.py +++ b/tests/appsec/iast/taint_sinks/test_ssrf.py @@ -39,25 +39,26 @@ def test_ssrf(tracer, iast_span_defaults): pass span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report + data = span_report.build_and_scrub_value_parts() - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_SSRF - assert vulnerability.evidence.valueParts == [ + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_SSRF + assert vulnerability["evidence"]["valueParts"] == [ {"value": "http://localhost/"}, {"source": 0, "value": tainted_path}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_ssrf" - assert source.origin == OriginType.PARAMETER - assert source.value == tainted_path + assert "value" not in vulnerability["evidence"].keys() + assert vulnerability["evidence"].get("pattern") is None + assert vulnerability["evidence"].get("redacted") is None + assert source["name"] == "test_ssrf" + assert source["origin"] == OriginType.PARAMETER + assert source["value"] == tainted_path line, hash_value = get_line_and_hash("test_ssrf", VULN_SSRF, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value @pytest.mark.parametrize("num_vuln_expected", [1, 0, 0]) diff --git a/tests/appsec/iast/taint_sinks/test_ssrf_redacted.py b/tests/appsec/iast/taint_sinks/test_ssrf_redacted.py index ca43fcb5112..aa329cb551e 100644 --- a/tests/appsec/iast/taint_sinks/test_ssrf_redacted.py +++ b/tests/appsec/iast/taint_sinks/test_ssrf_redacted.py @@ -3,12 +3,14 @@ import pytest from ddtrace.appsec._constants import IAST +from ddtrace.appsec._iast._taint_tracking import origin_to_str from ddtrace.appsec._iast._taint_tracking import str_to_origin +from ddtrace.appsec._iast._taint_tracking import taint_pyobject +from ddtrace.appsec._iast._taint_tracking.aspects import add_aspect from ddtrace.appsec._iast.constants import VULN_SSRF from ddtrace.appsec._iast.reporter import Evidence from ddtrace.appsec._iast.reporter import IastSpanReporter from ddtrace.appsec._iast.reporter import Location -from ddtrace.appsec._iast.reporter import Source from ddtrace.appsec._iast.reporter import Vulnerability from ddtrace.appsec._iast.taint_sinks.ssrf import SSRF from ddtrace.internal import core @@ -45,58 +47,72 @@ def test_ssrf_redaction_suite(evidence_input, sources_expected, vulnerabilities_ span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report - vulnerability = list(span_report.vulnerabilities)[0] + span_report.build_and_scrub_value_parts() + result = span_report._to_dict() + vulnerability = list(result["vulnerabilities"])[0] + source = list(result["sources"])[0] + source["origin"] = origin_to_str(source["origin"]) - assert vulnerability.type == VULN_SSRF - assert vulnerability.evidence.valueParts == vulnerabilities_expected["evidence"]["valueParts"] + assert vulnerability["type"] == VULN_SSRF + assert source == sources_expected -def test_cmdi_redact_param(): +def test_ssrf_redact_param(): + password_taint_range = taint_pyobject(pyobject="test1234", source_name="password", source_value="test1234") + ev = Evidence( - valueParts=[ - {"value": "https://www.domain1.com/?id="}, - {"value": "test1234", "source": 0}, - {"value": "¶m2=value2¶m3=value3¶m3=value3"}, - ] + value=add_aspect( + "https://www.domain1.com/?id=", + add_aspect(password_taint_range, "¶m2=value2¶m3=value3¶m3=value3"), + ) ) + loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) - s = Source(origin="http.request.parameter.name", name="password", value="test1234") - report = IastSpanReporter([s], {v}) - - redacted_report = SSRF._redact_report(report) - for v in redacted_report.vulnerabilities: - assert v.evidence.valueParts == [ - {"value": "https://www.domain1.com/?id="}, + v = Vulnerability(type=VULN_SSRF, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() + + assert result["vulnerabilities"] + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ + {"value": "https://www.domain1.com/"}, + {"redacted": True}, {"pattern": "abcdefgh", "redacted": True, "source": 0}, - {"value": "¶m2=value2¶m3=value3¶m3=value3"}, + {"redacted": True}, + {"redacted": True}, + {"redacted": True}, ] def test_cmdi_redact_user_password(): + user_taint_range = taint_pyobject(pyobject="root", source_name="username", source_value="root") + password_taint_range = taint_pyobject( + pyobject="superpasswordsecure", source_name="password", source_value="superpasswordsecure" + ) + ev = Evidence( - valueParts=[ - {"value": "https://"}, - {"value": "root", "source": 0}, - {"value": ":"}, - {"value": "superpasswordsecure", "source": 1}, - {"value": "@domain1.com/?id="}, - {"value": "¶m2=value2¶m3=value3¶m3=value3"}, - ] + value=add_aspect( + "https://", + add_aspect( + add_aspect(add_aspect(user_taint_range, ":"), password_taint_range), + "@domain1.com/?id=¶m2=value2¶m3=value3¶m3=value3", + ), + ) ) + loc = Location(path="foobar.py", line=35, spanId=123) - v = Vulnerability(type="VulnerabilityType", evidence=ev, location=loc) - s1 = Source(origin="http.request.parameter.name", name="username", value="root") - s2 = Source(origin="http.request.parameter.name", name="password", value="superpasswordsecure") - report = IastSpanReporter([s1, s2], {v}) - - redacted_report = SSRF._redact_report(report) - for v in redacted_report.vulnerabilities: - assert v.evidence.valueParts == [ + v = Vulnerability(type=VULN_SSRF, evidence=ev, location=loc) + report = IastSpanReporter(vulnerabilities={v}) + report.add_ranges_to_evidence_and_extract_sources(v) + result = report.build_and_scrub_value_parts() + + assert result["vulnerabilities"] + for v in result["vulnerabilities"]: + assert v["evidence"]["valueParts"] == [ {"value": "https://"}, {"pattern": "abcd", "redacted": True, "source": 0}, {"value": ":"}, - {"source": 1, "value": "superpasswordsecure"}, - {"value": "@domain1.com/?id="}, - {"value": "¶m2=value2¶m3=value3¶m3=value3"}, + {"pattern": "abcdefghijklmnopqrs", "redacted": True, "source": 1}, + {"value": "@domain1.com/?id=¶m2=value2¶m3=value3¶m3=value3"}, ] diff --git a/tests/appsec/iast/taint_sinks/test_weak_randomness.py b/tests/appsec/iast/taint_sinks/test_weak_randomness.py index 602834accb2..f8aa0ab1a71 100644 --- a/tests/appsec/iast/taint_sinks/test_weak_randomness.py +++ b/tests/appsec/iast/taint_sinks/test_weak_randomness.py @@ -39,8 +39,6 @@ def test_weak_randomness(random_func, iast_span_defaults): assert vulnerability.hash == hash_value assert vulnerability.evidence.value == "Random.{}".format(random_func) assert vulnerability.evidence.valueParts is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None @pytest.mark.skipif(WEEK_RANDOMNESS_PY_VERSION, reason="Some random methods exists on 3.9 or higher") @@ -73,8 +71,6 @@ def test_weak_randomness_module(random_func, iast_span_defaults): assert vulnerability.hash == hash_value assert vulnerability.evidence.value == "Random.{}".format(random_func) assert vulnerability.evidence.valueParts is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None @pytest.mark.skipif(WEEK_RANDOMNESS_PY_VERSION, reason="Some random methods exists on 3.9 or higher") diff --git a/tests/appsec/iast/test_iast_propagation_path.py b/tests/appsec/iast/test_iast_propagation_path.py index 5456daf540d..9637b692501 100644 --- a/tests/appsec/iast/test_iast_propagation_path.py +++ b/tests/appsec/iast/test_iast_propagation_path.py @@ -13,18 +13,18 @@ FIXTURES_PATH = "tests/appsec/iast/fixtures/propagation_path.py" -def _assert_vulnerability(span_report, value_parts, file_line_label): - vulnerability = list(span_report.vulnerabilities)[0] - assert vulnerability.type == VULN_PATH_TRAVERSAL - assert vulnerability.evidence.valueParts == value_parts - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None +def _assert_vulnerability(data, value_parts, file_line_label): + vulnerability = data["vulnerabilities"][0] + assert vulnerability["type"] == VULN_PATH_TRAVERSAL + assert vulnerability["evidence"]["valueParts"] == value_parts + assert "value" not in vulnerability["evidence"].keys() + assert "pattern" not in vulnerability["evidence"].keys() + assert "redacted" not in vulnerability["evidence"].keys() line, hash_value = get_line_and_hash(file_line_label, VULN_PATH_TRAVERSAL, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value def test_propagation_no_path(iast_span_defaults): @@ -55,19 +55,22 @@ def test_propagation_path_1_origin_1_propagation(origin1, iast_span_defaults): mod.propagation_path_1_source_1_prop(tainted_string) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) - source = span_report.sources[0] + span_report.build_and_scrub_value_parts() + data = span_report._to_dict() + sources = data["sources"] source_value_encoded = str(origin1, encoding="utf-8") if type(origin1) is not str else origin1 - assert source.name == "path" - assert source.origin == OriginType.PATH - assert source.value == source_value_encoded + assert len(sources) == 1 + assert sources[0]["name"] == "path" + assert sources[0]["origin"] == OriginType.PATH + assert sources[0]["value"] == source_value_encoded value_parts = [ {"value": ANY}, {"source": 0, "value": source_value_encoded}, {"value": ".txt"}, ] - _assert_vulnerability(span_report, value_parts, "propagation_path_1_source_1_prop") + _assert_vulnerability(data, value_parts, "propagation_path_1_source_1_prop") @pytest.mark.parametrize( @@ -87,12 +90,15 @@ def test_propagation_path_1_origins_2_propagations(origin1, iast_span_defaults): mod.propagation_path_1_source_2_prop(tainted_string_1) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) + span_report.build_and_scrub_value_parts() + data = span_report._to_dict() + sources = data["sources"] value_encoded = str(origin1, encoding="utf-8") if type(origin1) is not str else origin1 - sources = span_report.sources + assert len(sources) == 1 - assert sources[0].name == "path1" - assert sources[0].origin == OriginType.PATH - assert sources[0].value == value_encoded + assert sources[0]["name"] == "path1" + assert sources[0]["origin"] == OriginType.PATH + assert sources[0]["value"] == value_encoded value_parts = [ {"value": ANY}, @@ -100,14 +106,14 @@ def test_propagation_path_1_origins_2_propagations(origin1, iast_span_defaults): {"source": 0, "value": value_encoded}, {"value": ".txt"}, ] - _assert_vulnerability(span_report, value_parts, "propagation_path_1_source_2_prop") + _assert_vulnerability(data, value_parts, "propagation_path_1_source_2_prop") @pytest.mark.parametrize( "origin1, origin2", [ ("taintsource1", "taintsource2"), - ("taintsource", "taintsource"), + # ("taintsource", "taintsource"), TODO: invalid source pos ("1", "1"), (b"taintsource1", "taintsource2"), (b"taintsource1", b"taintsource2"), @@ -130,35 +136,37 @@ def test_propagation_path_2_origins_2_propagations(origin1, origin2, iast_span_d mod.propagation_path_2_source_2_prop(tainted_string_1, tainted_string_2) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) + span_report.build_and_scrub_value_parts() + data = span_report._to_dict() + sources = data["sources"] - sources = span_report.sources assert len(sources) == 2 source1_value_encoded = str(origin1, encoding="utf-8") if type(origin1) is not str else origin1 - assert sources[0].name == "path1" - assert sources[0].origin == OriginType.PATH - assert sources[0].value == source1_value_encoded + assert sources[0]["name"] == "path1" + assert sources[0]["origin"] == OriginType.PATH + assert sources[0]["value"] == source1_value_encoded source2_value_encoded = str(origin2, encoding="utf-8") if type(origin2) is not str else origin2 - assert sources[1].name == "path2" - assert sources[1].origin == OriginType.PARAMETER - assert sources[1].value == source2_value_encoded - + assert sources[1]["name"] == "path2" + assert sources[1]["origin"] == OriginType.PARAMETER + assert sources[1]["value"] == source2_value_encoded value_parts = [ {"value": ANY}, {"source": 0, "value": source1_value_encoded}, {"source": 1, "value": source2_value_encoded}, {"value": ".txt"}, ] - _assert_vulnerability(span_report, value_parts, "propagation_path_2_source_2_prop") + _assert_vulnerability(data, value_parts, "propagation_path_2_source_2_prop") @pytest.mark.parametrize( "origin1, origin2", [ ("taintsource1", "taintsource2"), - ("taintsource", "taintsource"), + # ("taintsource", "taintsource"), TODO: invalid source pos ("1", "1"), (b"taintsource1", "taintsource2"), + # (b"taintsource", "taintsource"), TODO: invalid source pos (b"taintsource1", b"taintsource2"), ("taintsource1", b"taintsource2"), (bytearray(b"taintsource1"), "taintsource2"), @@ -179,18 +187,20 @@ def test_propagation_path_2_origins_3_propagation(origin1, origin2, iast_span_de mod.propagation_path_3_prop(tainted_string_1, tainted_string_2) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) + span_report.build_and_scrub_value_parts() + data = span_report._to_dict() + sources = data["sources"] - sources = span_report.sources assert len(sources) == 2 source1_value_encoded = str(origin1, encoding="utf-8") if type(origin1) is not str else origin1 - assert sources[0].name == "path1" - assert sources[0].origin == OriginType.PATH - assert sources[0].value == source1_value_encoded + assert sources[0]["name"] == "path1" + assert sources[0]["origin"] == OriginType.PATH + assert sources[0]["value"] == source1_value_encoded source2_value_encoded = str(origin2, encoding="utf-8") if type(origin2) is not str else origin2 - assert sources[1].name == "path2" - assert sources[1].origin == OriginType.PARAMETER - assert sources[1].value == source2_value_encoded + assert sources[1]["name"] == "path2" + assert sources[1]["origin"] == OriginType.PARAMETER + assert sources[1]["value"] == source2_value_encoded value_parts = [ {"value": ANY}, @@ -204,7 +214,7 @@ def test_propagation_path_2_origins_3_propagation(origin1, origin2, iast_span_de {"source": 1, "value": source2_value_encoded}, {"value": ".txt"}, ] - _assert_vulnerability(span_report, value_parts, "propagation_path_3_prop") + _assert_vulnerability(data, value_parts, "propagation_path_3_prop") @pytest.mark.parametrize( @@ -233,13 +243,14 @@ def test_propagation_path_2_origins_5_propagation(origin1, origin2, iast_span_de mod.propagation_path_5_prop(tainted_string_1, tainted_string_2) span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) - - sources = span_report.sources + span_report.build_and_scrub_value_parts() + data = span_report._to_dict() + sources = data["sources"] assert len(sources) == 1 source1_value_encoded = str(origin1, encoding="utf-8") if type(origin1) is not str else origin1 - assert sources[0].name == "path1" - assert sources[0].origin == OriginType.PATH - assert sources[0].value == source1_value_encoded + assert sources[0]["name"] == "path1" + assert sources[0]["origin"] == OriginType.PATH + assert sources[0]["value"] == source1_value_encoded value_parts = [{"value": ANY}, {"source": 0, "value": "aint"}, {"value": ".txt"}] - _assert_vulnerability(span_report, value_parts, "propagation_path_5_prop") + _assert_vulnerability(data, value_parts, "propagation_path_5_prop") diff --git a/tests/appsec/integrations/test_langchain.py b/tests/appsec/integrations/test_langchain.py index d1e86e6ab68..325bfe670d5 100644 --- a/tests/appsec/integrations/test_langchain.py +++ b/tests/appsec/integrations/test_langchain.py @@ -33,21 +33,23 @@ def test_openai_llm_appsec_iast_cmdi(iast_span_defaults): # noqa: F811 span_report = core.get_item(IAST.CONTEXT_KEY, span=iast_span_defaults) assert span_report - - vulnerability = list(span_report.vulnerabilities)[0] - source = span_report.sources[0] - assert vulnerability.type == VULN_CMDI - assert vulnerability.evidence.valueParts == [ - {"value": "echo Hello World", "source": 0}, + data = span_report.build_and_scrub_value_parts() + vulnerability = data["vulnerabilities"][0] + source = data["sources"][0] + assert vulnerability["type"] == VULN_CMDI + assert vulnerability["evidence"]["valueParts"] == [ + {"source": 0, "value": "echo "}, + {"pattern": "", "redacted": True, "source": 0}, + {"source": 0, "value": "Hello World"}, ] - assert vulnerability.evidence.value is None - assert vulnerability.evidence.pattern is None - assert vulnerability.evidence.redacted is None - assert source.name == "test_openai_llm_appsec_iast_cmdi" - assert source.origin == OriginType.PARAMETER - assert source.value == string_to_taint + assert "value" not in vulnerability["evidence"].keys() + assert vulnerability["evidence"].get("pattern") is None + assert vulnerability["evidence"].get("redacted") is None + assert source["name"] == "test_openai_llm_appsec_iast_cmdi" + assert source["origin"] == OriginType.PARAMETER + assert "value" not in source.keys() line, hash_value = get_line_and_hash("test_openai_llm_appsec_iast_cmdi", VULN_CMDI, filename=FIXTURES_PATH) - assert vulnerability.location.path == FIXTURES_PATH - assert vulnerability.location.line == line - assert vulnerability.hash == hash_value + assert vulnerability["location"]["path"] == FIXTURES_PATH + assert vulnerability["location"]["line"] == line + assert vulnerability["hash"] == hash_value diff --git a/tests/commands/test_runner.py b/tests/commands/test_runner.py index 9a8fa33cbf0..a1439d6accd 100644 --- a/tests/commands/test_runner.py +++ b/tests/commands/test_runner.py @@ -334,49 +334,36 @@ def test_info_no_configs(): ) p.wait() stdout = p.stdout.read() - assert (stdout) == ( - b"""\x1b[94m\x1b[1mTracer Configurations:\x1b[0m - Tracer enabled: True - Application Security enabled: False - Remote Configuration enabled: False - IAST enabled (experimental): False - Debug logging: False - Writing traces to: http://localhost:8126 - Agent error: Agent not reachable at http://localhost:8126. """ - + b"""Exception raised: [Errno 99] Cannot assign requested address\n""" - b""" App Analytics enabled(deprecated): False - Log injection enabled: False - Health metrics enabled: False - Priority sampling enabled: True - Partial flushing enabled: True - Partial flush minimum number of spans: 300 - WAF timeout: 5.0 msecs - \x1b[92m\x1b[1mTagging:\x1b[0m - DD Service: None - DD Env: None - DD Version: None - Global Tags: None - Tracer Tags: None - -\x1b[96m\x1b[1mSummary\x1b[0m""" - b"""\n\n\x1b[91mERROR: It looks like you have an agent error: 'Agent not reachable at http://localhost:8126.""" - b""" Exception raised: [Errno 99] Cannot assign requested address'\n""" - b""" If you're experiencing a connection error, please """ - b"""make sure you've followed the setup for your particular environment so that the tracer and Datadog """ - b"""agent are configured properly to connect, and that the Datadog agent is running:""" - b""" https://ddtrace.readthedocs.io/en/stable/troubleshooting.html""" - b"""#failed-to-send-traces-connectionrefusederror""" - b"""\nIf your issue is not a connection error then please reach out to support for further assistance:""" - b""" https://docs.datadoghq.com/help/\x1b[0m""" - b"""\n\n\x1b[93mWARNING SERVICE NOT SET: It is recommended that a service tag be set for all traced """ - b"""applications. For more information please see""" - b""" https://ddtrace.readthedocs.io/en/stable/troubleshooting.html\x1b[0m""" - b"""\n\n\x1b[93mWARNING ENV NOT SET: It is recommended that an env tag be set for all traced applications. """ - b"""For more information please see https://ddtrace.readthedocs.io/en/stable/troubleshooting.html\x1b[0m""" - b"""\n\n\x1b[93mWARNING VERSION NOT SET: """ - b"""It is recommended that a version tag be set for all traced applications. """ - b"""For more information please see https://ddtrace.readthedocs.io/en/stable/troubleshooting.html\x1b[0m\n""" - ) + # checks most of the output but some pieces are removed due to the dynamic nature of the output + expected_strings = [ + b"\x1b[1mTracer Configurations:\x1b[0m", + b"Tracer enabled: True", + b"Application Security enabled: False", + b"Remote Configuration enabled: False", + b"Debug logging: False", + b"App Analytics enabled(deprecated): False", + b"Log injection enabled: False", + b"Health metrics enabled: False", + b"Priority sampling enabled: True", + b"Partial flushing enabled: True", + b"Partial flush minimum number of spans: 300", + b"WAF timeout: 5.0 msecs", + b"Tagging:", + b"DD Service: None", + b"DD Env: None", + b"DD Version: None", + b"Global Tags: None", + b"Tracer Tags: None", + b"Summary", + b"WARNING SERVICE NOT SET: It is recommended that a service tag be set for all traced applications.", + b"For more information please see https://ddtrace.readthedocs.io/en/stable/troubleshooting.html\x1b[0m", + b"WARNING ENV NOT SET: It is recommended that an env tag be set for all traced applications. For more", + b"information please see https://ddtrace.readthedocs.io/en/stable/troubleshooting.html\x1b[0m", + b"WARNING VERSION NOT SET: It is recommended that a version tag be set for all traced applications.", + b"For more information please see https://ddtrace.readthedocs.io/en/stable/troubleshooting.html\x1b[0m", + ] + for expected in expected_strings: + assert expected in stdout, f"Expected string not found in output: {expected.decode()}" assert p.returncode == 0 @@ -405,43 +392,31 @@ def test_info_w_configs(): p.wait() stdout = p.stdout.read() - assert ( - (stdout) - == b"""\x1b[94m\x1b[1mTracer Configurations:\x1b[0m - Tracer enabled: True - Application Security enabled: True - Remote Configuration enabled: True - IAST enabled (experimental): True - Debug logging: True - Writing traces to: http://168.212.226.204:8126 - Agent error: Agent not reachable at http://168.212.226.204:8126. Exception raised: timed out - App Analytics enabled(deprecated): False - Log injection enabled: True - Health metrics enabled: False - Priority sampling enabled: True - Partial flushing enabled: True - Partial flush minimum number of spans: 1000 - WAF timeout: 5.0 msecs - \x1b[92m\x1b[1mTagging:\x1b[0m - DD Service: tester - DD Env: dev - DD Version: 0.45 - Global Tags: None - Tracer Tags: None - -\x1b[96m\x1b[1mSummary\x1b[0m""" - b"""\n\n\x1b[91mERROR: It looks like you have an agent error: """ - b"""'Agent not reachable at http://168.212.226.204:8126. """ - b"""Exception raised: timed out'\n If you're experiencing a connection error, """ - b"""please make sure you've followed the """ - b"""setup for your particular environment so that the tracer and """ - b"""Datadog agent are configured properly to connect,""" - b""" and that the Datadog agent is running:""" - b""" https://ddtrace.readthedocs.io/en/stable/troubleshooting.html#failed-to-send-traces-""" - b"""connectionrefusederror\n""" - b"""If your issue is not a connection error then please reach out to support for further assistance: """ - b"""https://docs.datadoghq.com/help/\x1b[0m\n""" - ) + # checks most of the output but some pieces are removed due to the dynamic nature of the output + expected_strings = [ + b"1mTracer Configurations:\x1b[0m", + b"Tracer enabled: True", + b"Remote Configuration enabled: True", + b"IAST enabled (experimental)", + b"Debug logging: True", + b"App Analytics enabled(deprecated): False", + b"Log injection enabled: True", + b"Health metrics enabled: False", + b"Priority sampling enabled: True", + b"Partial flushing enabled: True", + b"Partial flush minimum number of spans: 1000", + b"WAF timeout: 5.0 msecs", + b"Tagging:", + b"DD Service: tester", + b"DD Env: dev", + b"DD Version: 0.45", + b"Global Tags: None", + b"Tracer Tags: None", + b"m\x1b[1mSummary\x1b[0m", + ] + + for expected in expected_strings: + assert expected in stdout, f"Expected string not found in output: {expected.decode()}" assert p.returncode == 0 diff --git a/tests/contrib/botocore/bedrock_cassettes/amazon_embedding.yaml b/tests/contrib/botocore/bedrock_cassettes/amazon_embedding.yaml new file mode 100644 index 00000000000..b6cdf04f298 --- /dev/null +++ b/tests/contrib/botocore/bedrock_cassettes/amazon_embedding.yaml @@ -0,0 +1,44 @@ +interactions: +- request: + body: '{"inputText": "Hello World!"}' + headers: + Content-Length: + - '29' + User-Agent: + - !!binary | + Qm90bzMvMS4zNC40OSBtZC9Cb3RvY29yZSMxLjM0LjQ5IHVhLzIuMCBvcy9tYWNvcyMyMy40LjAg + bWQvYXJjaCNhcm02NCBsYW5nL3B5dGhvbiMzLjEwLjUgbWQvcHlpbXBsI0NQeXRob24gY2ZnL3Jl + dHJ5LW1vZGUjbGVnYWN5IEJvdG9jb3JlLzEuMzQuNDk= + X-Amz-Date: + - !!binary | + MjAyNDA0MjNUMjA1NzAzWg== + amz-sdk-invocation-id: + - !!binary | + ZTUyMjJhZGQtNGI3My00YjM4LThhZmEtZTkxNmI1NmJkZTky + amz-sdk-request: + - !!binary | + YXR0ZW1wdD0x + method: POST + uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke + response: + body: + string: '{"embedding":[0.45703125,0.30078125,0.41210938,0.41015625,0.74609375,0.3828125,-0.018798828,-0.0010681152,0.515625,-0.25195312,-0.20898438,-0.15332031,-0.22460938,0.13671875,-0.34375,0.037109375,-0.03491211,-0.29882812,-0.875,0.21289062,-0.39648438,0.26171875,-0.49609375,-0.15527344,-0.014465332,0.25,-0.1796875,-0.46875,0.82421875,0.1796875,-0.34960938,1.2265625,-0.079589844,-0.3046875,-0.93359375,0.3828125,-0.031982422,-0.74609375,-0.49023438,0.546875,0.080078125,0.328125,-0.1484375,-0.30078125,0.54296875,0.58203125,-0.1875,-0.17773438,0.59375,1.0859375,-0.20703125,0.99609375,-0.25195312,0.6875,-0.10839844,-0.11035156,0.66015625,-0.12695312,-0.0079956055,-0.37695312,-0.038085938,-0.34570312,-0.62890625,0.94921875,0.12597656,0.125,-0.12109375,0.3515625,-0.17285156,0.6328125,0.41210938,0.48828125,0.19140625,0.6640625,-0.67578125,0.23730469,0.80859375,-0.16796875,0.43164062,0.27148438,0.42578125,0.06298828,-0.3984375,-0.27148438,-0.64453125,-0.04638672,0.19726562,-0.095214844,1.3542175E-4,-0.98828125,-0.048828125,-0.25,1.0078125,-0.087402344,-0.49609375,0.359375,-0.9765625,-0.5078125,0.06933594,-0.24511719,0.9921875,0.1875,0.08984375,0.36132812,0.14355469,0.52734375,-0.08251953,0.78125,0.09863281,0.49414062,0.24316406,-0.6796875,0.875,0.051757812,0.42382812,0.11425781,0.69140625,-1.0546875,-0.19824219,-0.3984375,0.18359375,-0.14648438,0.6015625,0.51171875,0.2734375,0.2265625,0.19824219,-0.640625,0.14648438,0.072265625,0.48828125,0.23730469,1.3125,-0.1484375,-0.25195312,-0.15527344,-0.37695312,0.26757812,-0.55078125,-0.98828125,-0.03112793,-0.26367188,0.19238281,0.375,0.0146484375,0.037109375,-0.53515625,0.26953125,0.16796875,-0.44921875,0.4453125,0.21679688,0.26367188,-0.48046875,0.734375,0.3671875,0.3359375,0.3203125,-0.171875,-0.0703125,0.20800781,0.62109375,-1.1328125,0.025390625,-0.52734375,-0.41601562,-0.6796875,0.3671875,-0.21679688,-0.54296875,1.1875,0.21679688,-0.34179688,0.12060547,-0.24804688,0.057373047,0.27734375,0.20800781,0.016235352,0.51953125,0.515625,-0.2421875,0.24023438,-0.890625,-0.27539062,-0.095703125,-0.58203125,-0.7578125,0.055908203,0.95703125,-0.11328125,0.2421875,0.20800781,-0.024414062,-0.328125,0.45507812,-0.13769531,0.34375,-0.16601562,-0.26757812,-0.15039062,0.546875,0.38671875,0.009399414,-0.11767578,-0.21191406,-0.071777344,-0.040283203,-0.42773438,0.06347656,-0.111328125,-0.14648438,-0.32226562,-1.171875,0.6484375,-0.5,0.10107422,-0.20507812,0.16796875,-0.022949219,0.30273438,-0.96484375,0.30664062,0.45898438,-0.31445312,0.31054688,0.53125,0.37304688,-0.08154297,0.6015625,-0.6171875,0.296875,-0.63671875,0.48242188,0.087402344,0.30859375,-0.6484375,-0.10498047,-0.42578125,0.62890625,0.33007812,-0.42382812,-0.24121094,-0.17480469,0.10498047,0.3984375,0.28710938,0.27929688,-0.064453125,-0.17480469,-0.15039062,0.546875,-0.39257812,-0.2734375,-0.32421875,-0.025268555,-0.51953125,-0.38867188,-0.3359375,-0.09033203,0.24511719,0.056152344,0.40039062,-0.51953125,0.203125,-0.65625,0.13671875,0.44140625,-0.064453125,-1.1171875,1.3256073E-4,0.11230469,0.41601562,-0.18066406,-0.19238281,0.02722168,-1.25,0.76171875,0.19238281,-0.4375,0.359375,-0.20507812,0.84765625,-0.41601562,-0.25390625,-1.609375,-0.28125,0.3984375,0.05053711,0.080078125,-0.609375,-0.107910156,-0.64453125,0.21484375,0.8046875,0.30273438,0.43554688,0.17578125,-0.37109375,0.12890625,-0.18066406,0.25976562,0.6015625,0.2734375,0.24609375,0.15136719,0.18945312,0.08935547,-0.35546875,0.3671875,-0.39257812,-0.095214844,0.47070312,-0.22753906,0.84375,-0.14160156,-0.103515625,0.18847656,-0.2734375,-0.1953125,-0.56640625,-0.734375,-0.79296875,0.05053711,0.103515625,-0.2890625,-0.21191406,-0.6484375,0.122558594,-0.42382812,-0.042236328,0.3125,-0.014343262,0.25195312,0.08251953,-0.80859375,-0.46875,0.92578125,0.53515625,-0.86328125,-0.39257812,-0.18261719,0.90234375,-0.484375,-0.02746582,-0.31835938,-0.083984375,0.24414062,-0.5625,-0.032958984,0.4140625,-0.26953125,-0.51171875,0.13085938,0.84375,-0.8984375,-0.13867188,-0.44726562,-0.29101562,-0.09423828,0.49023438,-0.2890625,0.07470703,-0.65625,-0.453125,-0.09765625,-0.50390625,-0.74609375,0.049072266,0.6171875,-0.515625,1.7265625,0.15039062,-0.28710938,-0.0010757446,0.30273438,-0.20410156,0.63671875,-0.45703125,-0.6796875,0.9765625,-1.0078125,-0.24511719,-0.17382812,-0.49023438,0.76953125,-0.359375,-1.1015625,-1.3125,-0.3984375,0.15039062,0.2265625,0.14648438,-0.03491211,0.09814453,0.08886719,0.83984375,-0.33203125,-0.34960938,-0.20898438,-0.20410156,0.072753906,0.42382812,-1.453125,0.65234375,-0.625,-1.0,-0.14257812,0.2578125,-0.1953125,-0.030639648,0.4375,-0.115234375,0.26757812,0.36523438,-0.36523438,0.26367188,0.0067443848,-0.578125,-1.140625,-0.39257812,0.47070312,0.5859375,0.029785156,1.03125,-0.52734375,0.6640625,-0.09472656,-0.69921875,0.5078125,-0.2890625,-0.27734375,0.66796875,0.578125,0.453125,0.3046875,-0.2890625,0.34179688,-0.41992188,-0.22460938,0.7421875,0.024414062,-0.22265625,-0.265625,-0.35546875,-0.05859375,-0.7890625,-0.14257812,-0.39257812,0.3515625,0.07519531,-0.29296875,-0.43554688,-0.71875,0.15625,0.084472656,-0.07373047,0.38476562,0.4609375,-0.13964844,0.17773438,-0.3828125,-0.06738281,-0.50390625,-0.375,-0.56640625,1.0546875,0.54296875,-0.80078125,-0.11425781,-0.7578125,0.13671875,0.30273438,0.09765625,-0.34765625,-0.0061950684,-0.05883789,0.61328125,-0.3359375,0.21972656,0.53515625,0.4140625,0.37695312,-1.2265625,0.203125,-0.578125,-0.28125,-0.17871094,-0.546875,-0.703125,-0.07080078,1.7578125,-0.6171875,-0.81640625,-0.34765625,0.1328125,-2.6512146E-4,-0.19140625,-0.016723633,0.703125,0.22851562,-0.029541016,-0.7265625,0.038085938,-0.44921875,0.29101562,-1.1640625,0.40234375,-0.3671875,-0.09863281,1.125,0.16992188,-0.82421875,0.48828125,0.110839844,0.65625,0.51953125,0.20800781,-0.625,0.22949219,0.057861328,0.16210938,0.57421875,0.265625,-0.8203125,1.1796875,-0.6875,-0.32421875,-1.2421875,-0.14550781,-0.0013809204,-0.43945312,-0.35742188,0.07861328,0.8984375,-0.040283203,-0.07421875,0.43554688,-0.984375,-0.38085938,-0.14746094,-0.5078125,-1.1015625,0.94921875,-0.5234375,0.063964844,-0.38867188,-0.26171875,-0.48242188,0.075683594,0.43554688,-0.36914062,0.21972656,0.018188477,0.053466797,0.92578125,-0.62109375,0.44335938,-0.5859375,-0.59375,0.25585938,-0.20898438,-0.31445312,-0.48632812,0.14355469,-0.19238281,-0.15234375,0.41015625,0.07324219,0.025756836,-0.375,-0.19140625,0.9921875,0.27734375,0.33398438,0.93359375,0.5078125,-0.515625,0.58203125,0.43164062,0.42382812,-0.29296875,0.5625,-0.51171875,-0.6328125,0.5625,0.1015625,0.18359375,0.26757812,-0.765625,-0.29101562,-0.77734375,0.546875,-0.79296875,0.19921875,0.65234375,0.23730469,-1.25,-0.73828125,-0.35351562,0.28125,0.15039062,0.49414062,-0.2578125,0.15429688,0.051513672,-0.55078125,0.119628906,-0.44726562,0.7421875,-0.7109375,0.6328125,0.671875,0.16894531,-1.09375,-0.087890625,0.20703125,0.44140625,-0.3046875,-0.734375,-0.09423828,-0.42578125,-0.79296875,-0.41601562,0.40039062,0.06933594,0.009765625,-1.0546875,0.115722656,-1.2421875,-0.734375,0.890625,0.09277344,0.8125,0.125,1.1875,-0.19824219,0.40234375,0.36328125,0.06933594,-0.515625,0.484375,0.45117188,0.16113281,0.99609375,-0.42578125,0.24804688,1.3671875,-0.4453125,0.16308594,-0.30859375,0.30078125,0.44726562,-0.15136719,-0.036376953,0.068847656,0.33398438,0.19824219,-0.83984375,-0.24511719,-0.35546875,0.54296875,-0.03125,0.35742188,0.9140625,0.4375,0.14355469,-0.48242188,0.52734375,0.20019531,0.31445312,-0.98046875,0.76171875,-0.0022735596,-0.22851562,-0.30859375,0.100097656,-0.30078125,-0.87109375,-0.8125,0.31640625,-0.33789062,-0.53515625,-0.16992188,-0.296875,1.3359375,0.54296875,-0.24902344,-0.095703125,0.71484375,-0.026977539,0.72265625,0.4375,0.66796875,0.21777344,0.18652344,0.009216309,0.11279297,-0.21777344,1.0859375,-0.65625,0.6328125,-0.008911133,-0.58203125,-0.44726562,0.8125,-0.703125,-0.15917969,0.051513672,0.49609375,0.35351562,0.24023438,0.20996094,-0.5859375,0.26171875,0.6484375,-0.57421875,0.024169922,0.41992188,0.55078125,0.037109375,-0.6171875,-0.01977539,0.58203125,0.88671875,0.29296875,-0.1875,0.3984375,0.48046875,-0.045898438,0.515625,1.2890625,0.41992188,0.3203125,-0.13867188,-0.37109375,-0.47460938,-0.26171875,-0.5546875,-0.49023438,-0.6796875,-0.04736328,-0.25976562,-0.94921875,0.1484375,-0.02319336,-0.7578125,0.7578125,0.08105469,0.59765625,-0.29882812,0.84375,0.26953125,0.09472656,-0.66015625,-0.99609375,0.7265625,0.5703125,-0.09716797,0.55078125,-0.11230469,0.46875,0.546875,0.21582031,-0.19140625,-0.29492188,-0.20507812,-0.1328125,0.3671875,0.28320312,0.072265625,0.091796875,-0.07373047,0.4296875,-0.106933594,-0.033447266,0.22851562,0.84765625,0.3203125,-0.007598877,-0.5078125,0.06689453,0.34179688,0.484375,-1.515625,1.0625,0.6171875,0.08935547,-0.6484375,0.55078125,0.3828125,-0.22363281,0.765625,-0.28320312,0.09765625,0.76953125,-0.45703125,0.05493164,0.42578125,0.18066406,0.12695312,-0.038085938,0.44335938,-0.15722656,0.24414062,0.032958984,0.6484375,0.31640625,0.35351562,0.048339844,0.20410156,-0.38085938,-0.51953125,-0.21191406,-0.48046875,-0.7890625,-0.21972656,0.09375,0.42773438,0.51171875,-0.30859375,1.5546875,-0.17382812,-0.7109375,0.061767578,0.0036621094,0.15234375,-0.076660156,-0.16894531,-0.33007812,-0.05908203,0.6796875,0.13769531,-0.37304688,-0.21972656,-0.014343262,-0.47070312,-0.6484375,1.21875,-0.12451172,-0.046875,0.107910156,-0.37109375,0.057861328,0.51171875,0.640625,0.14648438,0.37695312,0.16601562,-0.24707031,1.4296875,-0.57421875,-0.39257812,-0.4921875,0.45507812,-0.12792969,-0.09033203,-0.31054688,0.10253906,-0.42773438,0.14160156,-0.11376953,-0.73828125,0.5546875,-0.16796875,0.36132812,0.24609375,-0.8359375,-0.6484375,0.2578125,0.1328125,-0.21191406,-0.23046875,0.33203125,0.23632812,0.59375,0.26367188,-0.08984375,1.0078125,-0.060791016,0.58203125,-0.6015625,0.44921875,0.2109375,-0.08300781,1.2578125,0.4765625,0.072753906,-0.03930664,0.24121094,0.41992188,0.6875,0.46679688,0.41210938,0.08984375,0.59375,0.03173828,-0.6875,-0.08642578,0.69140625,-0.59765625,-0.10888672,0.19238281,0.053222656,0.118652344,-0.13085938,-0.15917969,-0.055419922,-0.23828125,0.25195312,0.057861328,-0.19238281,0.23925781,0.75390625,-0.05810547,0.828125,0.87890625,-0.65234375,-0.55859375,-1.0859375,-0.1328125,-0.00793457,0.013916016,0.19042969,-0.10107422,0.34765625,-0.12695312,-0.14941406,0.375,-0.5078125,-0.22167969,0.4609375,-0.18066406,-0.18359375,-0.51171875,0.40234375,0.6015625,0.29296875,0.453125,0.115722656,-1.265625,4.386902E-5,0.59375,-0.44335938,0.26367188,-0.34960938,-0.8359375,-0.33203125,-0.039794922,0.58203125,0.3203125,0.39648438,-0.43554688,0.0013046265,-0.07373047,-0.7578125,0.31640625,-0.22070312,-0.004272461,-0.60546875,-0.7890625,-0.07861328,-0.69140625,-0.32421875,-0.2734375,0.38476562,-1.25,0.010620117,-1.1953125,0.15234375,-0.66015625,0.265625,-0.08496094,0.33007812,-0.23828125,0.060546875,-0.039794922,-0.17773438,0.8359375,0.34765625,-0.73046875,0.37890625,-0.23632812,-0.45703125,-0.015136719,-0.73828125,-0.076660156,0.11230469,0.45117188,-0.8125,-0.27148438,0.22851562,-0.10498047,-0.03930664,0.59375,-0.62109375,-0.6796875,0.74609375,0.50390625,-0.76171875,-0.70703125,-0.29882812,0.049072266,-0.060546875,-0.43554688,-0.578125,0.20410156,0.1640625,0.040039062,-0.62109375,-0.10839844,-0.19824219,-0.30859375,0.30859375,0.0042419434,0.24511719,0.19238281,0.5859375,0.6328125,0.3359375,-0.88671875,-0.5546875,0.40234375,-0.022460938,0.16113281,-0.04272461,0.81640625,0.98828125,-0.27734375,-0.19921875,-0.55078125,-0.05053711,0.01171875,-0.69921875,-0.15917969,0.43164062,0.22167969,-0.43945312,-0.44921875,0.21484375,-0.5703125,-0.24707031,0.17578125,0.008483887,0.14453125,-0.36328125,-0.118652344,0.28710938,0.0010604858,0.4453125,0.24609375,-0.83203125,-0.33007812,-0.12402344,-0.30273438,-0.51953125,-0.18066406,-0.17871094,-0.5,-0.34375,0.072753906,0.25390625,0.37304688,-0.12109375,0.35546875,-0.4140625,-0.16308594,-0.23828125,0.52734375,-0.3984375,-0.17578125,-0.17871094,-0.20117188,0.33984375,-0.71875,-0.35546875,-0.14648438,0.020996094,0.11621094,0.90234375,0.21386719,-0.31054688,-0.49023438,-0.61328125,-0.12402344,-0.07421875,0.12207031,0.04321289,-0.3515625,0.06201172,-0.07763672,0.21875,-0.7578125,0.55859375,-0.12597656,0.73046875,0.44335938,1.328125,-0.3515625,-0.0035858154,-0.34765625,0.42382812,-0.17773438,-0.90625,0.83203125,0.16992188,0.1328125,0.64453125,-0.09277344,-0.06640625,-0.21875,0.48046875,0.3359375,0.36328125,1.0703125,0.97265625,-0.14550781,-1.0234375,-0.38671875,0.16308594,0.3828125,-0.47070312,-0.84375,-0.28125,-0.50390625,0.23828125,1.0390625,0.14746094,-0.34570312,0.29882812,-0.37109375,-0.01977539,-0.65234375,0.4453125,0.21875,0.24121094,0.4609375,0.44726562,-0.40039062,-1.28125,-0.81640625,0.546875,1.1640625,-0.22753906,-0.296875,1.1484375,-0.640625,0.1640625,0.5,-0.29101562,0.03930664,-0.32421875,0.12695312,-0.49414062,0.052734375,-0.3671875,0.2890625,-0.08105469,0.640625,-0.546875,0.11816406,1.109375,0.23730469,-0.12890625,-1.4375,-0.3515625,0.80078125,-0.25390625,-0.079589844,-1.0,0.62890625,-0.39453125,-0.72265625,0.34765625,-0.875,-0.46875,0.48242188,0.32617188,-0.060302734,-0.41210938,-0.18457031,0.09472656,-0.8515625,0.83984375,0.3671875,-0.072265625,0.875,0.55859375,-0.33203125,-0.25,0.35742188,0.31445312,0.04248047,0.8125,0.33203125,-0.072753906,-0.18554688,-0.59765625,0.07128906,-0.27148438,-0.25976562,0.08105469,1.0625,-1.421875,-0.09423828,-0.46484375,-0.29492188,-0.65234375,1.3046875,0.3359375,-0.091796875,-0.17578125,0.26171875,0.546875,-0.061279297,-0.15234375,0.65234375,0.2578125,-1.375,0.609375,-0.38867188,-0.265625,-0.859375,0.19628906,-0.3984375,0.41015625,0.1484375,0.0115356445,-0.44726562,0.5078125,0.54296875,-0.30078125,0.2734375,0.12695312,0.2109375,0.984375,0.060546875,-0.36132812,-1.53125,0.625,-0.71875,-0.41210938,0.99609375,0.061523438,-0.19042969,-0.14648438,-0.3515625,0.051757812,0.2578125,-0.70703125,-0.022705078,-0.035888672,0.3359375,-0.08984375,-0.4609375,-0.038330078,-1.5078125,0.3125,0.3828125,0.42578125,0.22070312,1.0859375,0.703125,0.38085938,0.84765625,-0.40820312,-0.26757812,0.50390625,-0.53125,-0.12158203,-0.32226562,0.234375,-0.68359375,-0.29882812,0.11230469,0.3203125,-0.29882812,0.5,-0.609375,-0.3671875,-0.052001953,0.30859375,-0.24316406,0.5234375,-0.41601562,-0.17578125,0.734375,0.26367188,0.30078125,0.084472656,0.9140625,-0.98828125,-0.70703125,-0.044189453,-0.25976562,-0.7265625,0.9140625,0.017211914,0.6171875,-0.6171875,-0.73828125,0.12451172,-0.13769531,0.30273438,-0.62890625,-0.921875,0.16308594,0.07470703,-0.5,0.59375,0.16015625,0.31445312,-0.11279297,1.875,-0.4140625,0.7421875,0.17773438,0.2421875,-0.23828125,0.421875,0.2265625,-0.84765625,0.2421875,0.005706787,-0.18847656,0.21679688,0.39453125,0.39257812,-0.703125,-0.55078125,-0.74609375,-0.13769531,0.055419922,0.20214844,-0.026367188,-0.59765625,1.4140625,-0.32421875,-0.14550781,0.026977539,0.31054688,-0.0703125,1.4140625,0.46679688,0.65625,-0.17089844,-0.07373047,0.41210938,0.028076172,0.3671875,0.041992188,-0.3515625,-0.1640625,-0.8203125,-0.029785156,-0.03515625,-0.140625,-0.12207031,-0.43945312,0.44140625,-0.0013046265,-0.24804688,-0.041748047,0.13378906,0.81640625,-0.14550781,-0.12792969,-0.15527344,0.6015625,-0.17578125,0.66015625,-0.3984375,0.5234375,-0.21679688,0.14648438,0.2890625,-1.0546875,0.09814453,-0.016967773,-0.013000488,-0.20800781,0.82421875,0.3359375,-0.6953125,0.30273438,0.10205078,-0.828125,-0.29882812,0.42773438,-0.55859375,-1.5625,-0.46289062,-0.25585938,0.68359375,0.66796875,0.27539062,-0.7421875,-0.140625,0.055419922,0.012023926,-0.11328125,0.3671875,-0.37890625,0.75390625,-0.60546875,0.734375,0.041503906,0.83984375,-0.640625,-0.671875,-0.27929688,-0.24316406,0.4921875,-0.6640625,-0.16210938,-0.29296875,0.4140625,-0.29101562,0.40429688,0.296875,0.875,0.43359375,-0.13964844,0.28515625,-0.359375,0.20996094,-0.23144531,-0.54296875,-0.083984375,-0.28125,0.01574707,-0.18945312,-0.65625,0.05810547,-0.10205078,-0.4921875,0.94921875,-0.78515625,0.122558594,-0.14550781,-0.39453125,0.046875,0.16796875,0.71875,0.66796875,-0.53515625,0.0033416748,0.45117188,0.004211426,0.09667969,0.12792969,-0.7578125,0.28125,-0.94921875,0.36914062,-0.049316406,-0.07080078,-0.40820312,0.38671875,0.5859375,0.609375,-0.84765625,0.90625,-0.051513672,0.734375,-0.5859375,0.71484375,1.015625,-0.061523438,0.005432129,-0.28710938,0.3671875,-0.46484375,0.59375,0.87109375,0.59375,0.67578125,-0.6171875,-0.11621094,0.061035156,-0.26367188,0.625,-0.22363281,-0.36523438,0.41601562,-0.030151367,-0.59375,-0.27929688,-0.92578125,-0.28125,0.20703125,-0.2890625,0.4296875,1.359375,0.7890625,-0.49804688,0.20703125,-0.13085938,-0.0043945312,-0.73046875,0.37890625,0.17773438,-0.11816406,-0.37890625,0.017944336,0.76171875,0.8125,0.35742188,-0.19628906,-0.044433594,0.5078125,-0.65234375,0.04663086,-0.99609375,-0.42382812,-0.65625,-0.12988281,1.03125,0.30664062,-0.30078125,-1.71875,0.35546875,-0.01940918,0.09814453,0.59765625,-0.21386719,-0.87109375,1.6171875,-0.49023438,-0.10107422,-0.12597656,0.18554688,-0.42382812,0.69140625,-0.64453125,0.114746094],"inputTextTokenCount":3}' + headers: + Connection: + - keep-alive + Content-Length: + - '17006' + Content-Type: + - application/json + Date: + - Tue, 23 Apr 2024 20:57:03 GMT + X-Amzn-Bedrock-Input-Token-Count: + - '3' + X-Amzn-Bedrock-Invocation-Latency: + - '311' + x-amzn-RequestId: + - 1fd884e0-c9e8-44fa-b736-d31e2f607d54 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/contrib/botocore/bedrock_cassettes/cohere_embedding.yaml b/tests/contrib/botocore/bedrock_cassettes/cohere_embedding.yaml new file mode 100644 index 00000000000..3c7ca4b192b --- /dev/null +++ b/tests/contrib/botocore/bedrock_cassettes/cohere_embedding.yaml @@ -0,0 +1,45 @@ +interactions: +- request: + body: '{"texts": ["Hello World!", "Goodbye cruel world!"], "input_type": "search_document"}' + headers: + Content-Length: + - '84' + User-Agent: + - !!binary | + Qm90bzMvMS4zNC40OSBtZC9Cb3RvY29yZSMxLjM0LjQ5IHVhLzIuMCBvcy9tYWNvcyMyMy40LjAg + bWQvYXJjaCNhcm02NCBsYW5nL3B5dGhvbiMzLjEwLjUgbWQvcHlpbXBsI0NQeXRob24gY2ZnL3Jl + dHJ5LW1vZGUjbGVnYWN5IEJvdG9jb3JlLzEuMzQuNDk= + X-Amz-Date: + - !!binary | + MjAyNDA0MjNUMjEwMDExWg== + amz-sdk-invocation-id: + - !!binary | + MGIwYzc5OTUtOWE5MC00MmM2LWIxODYtNGViYTJkNmQ2MWFm + amz-sdk-request: + - !!binary | + YXR0ZW1wdD0x + method: POST + uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/cohere.embed-english-v3/invoke + response: + body: + string: '{"embeddings":[[-0.041503906,-0.026153564,-0.070373535,-0.057525635,-0.026245117,-0.019714355,-0.029449463,0.060333252,-0.012016296,0.031219482,-0.03036499,-0.0023498535,0.026153564,-0.036193848,0.050231934,-0.0152282715,0.050445557,-0.002483368,0.014816284,-0.0055618286,-0.016464233,0.03363037,-0.0049209595,0.01776123,0.04031372,-0.018188477,-0.0035190582,-0.017364502,-0.035064697,-0.050842285,0.008964539,0.05026245,0.021530151,0.026947021,-0.0010185242,0.06341553,-0.015213013,0.022537231,0.020935059,-0.015823364,0.023162842,-0.01763916,0.018157959,-0.01637268,-0.038330078,0.010925293,-0.00762558,-0.004917145,0.0049591064,-0.048583984,0.017349243,-0.022094727,0.023269653,0.009536743,-0.006526947,0.012512207,-0.04284668,-0.022705078,0.032348633,0.015808105,-0.006450653,-0.0058784485,0.015792847,8.506775E-4,0.027633667,-0.030853271,-0.007522583,0.007587433,0.024002075,0.029251099,9.994507E-4,0.0010881424,-0.018234253,-3.2234192E-4,0.02130127,0.0051574707,-0.008659363,0.029281616,0.016479492,0.026229858,-0.01486969,0.009094238,0.0032100677,-0.06359863,0.031402588,-0.016723633,-0.0017356873,0.009475708,-0.0035057068,-0.08514404,-0.013404846,0.057525635,-0.035003662,0.0022697449,-0.047698975,-0.031585693,0.0098724365,0.04864502,-0.010551453,0.0234375,-0.060943604,-0.002948761,-0.014450073,-0.03857422,0.008773804,-0.050231934,0.007457733,-0.0051956177,0.04171753,0.045318604,-0.057403564,0.07757568,0.01473999,0.021224976,-0.005619049,-0.009849548,-9.765625E-4,0.0029563904,-0.024383545,0.02015686,0.013069153,0.02268982,-0.0025596619,0.017486572,-0.004875183,0.016708374,-0.028945923,-0.016921997,0.06939697,-0.054870605,-0.002248764,0.048034668,-0.09857178,-0.06604004,-0.038635254,0.037719727,0.027420044,-0.049316406,-0.0031814575,-0.010917664,0.037384033,-0.0021858215,8.3065033E-4,0.012664795,-0.034484863,0.038360596,0.018753052,-0.004699707,-0.042236328,-0.045013428,-0.024154663,-0.046142578,-0.021118164,0.03189087,0.043823242,0.021652222,0.004501343,0.008804321,0.059448242,0.05593872,-0.009033203,0.004043579,0.0496521,-0.058929443,-0.028503418,-0.040405273,-0.037597656,0.012016296,-0.00579834,0.027130127,-0.029067993,0.007537842,-0.02279663,-0.011871338,-2.1362305E-4,0.0073127747,0.010383606,-0.011375427,-0.018081665,0.07141113,0.07513428,-0.0018005371,-0.08093262,0.08026123,-0.032592773,-0.01386261,0.0050735474,-0.011375427,0.037841797,0.0065727234,0.019454956,0.0056152344,-0.061706543,0.007293701,-0.015930176,-0.0032348633,0.019363403,0.0068130493,-5.0354004E-4,0.015563965,-0.0129776,0.022033691,0.016967773,-0.0075645447,0.055541992,-0.0033626556,0.015853882,0.012435913,0.042510986,0.009346008,-0.022140503,-0.025527954,0.06121826,-0.13354492,-0.06323242,-0.013664246,-0.044189453,0.01272583,-0.010894775,0.021469116,0.032318115,-0.018997192,0.01184082,0.008255005,0.013465881,0.032196045,-0.037139893,0.009536743,-0.044799805,0.02670288,-0.031951904,0.01675415,-0.05947876,-0.033966064,-0.06222534,-0.040740967,-0.07141113,0.0025959015,0.022979736,0.046569824,0.016799927,-0.019744873,-0.010566711,0.0047683716,0.0335083,0.028427124,0.025131226,0.015792847,-0.00390625,0.036224365,0.0061302185,0.009002686,0.022232056,-0.0063438416,-0.012245178,0.068481445,-0.012573242,0.0043678284,-0.008300781,0.009353638,-0.00541687,0.031433105,0.014595032,0.015434265,-0.036621094,-0.012878418,-0.034210205,-0.003583908,-0.018997192,0.016815186,0.027069092,-0.0491333,-0.016494751,-0.021621704,0.05960083,-0.015655518,-0.067993164,0.019622803,0.009819031,-0.0119018555,-0.0050697327,-0.019897461,0.018951416,0.009895325,0.018875122,0.02218628,0.04348755,0.026657104,0.0044441223,0.03387451,0.06536865,0.02784729,0.016342163,0.037902832,0.0010757446,0.01020813,0.042755127,0.05908203,-0.019073486,0.010620117,-0.056121826,-0.011314392,0.008270264,-0.024139404,-0.0034065247,-0.024673462,0.025421143,0.0115356445,-6.2942505E-4,0.021362305,-0.028305054,0.0440979,0.0099487305,0.035247803,-0.048706055,-0.05999756,-0.024368286,-0.019622803,-0.11859131,-0.065979004,-0.0309906,-0.0073547363,-0.004371643,-0.018005371,0.0030021667,-0.0044517517,-0.039794922,0.014312744,-0.01689148,-0.018035889,0.0051193237,-0.013595581,-0.027404785,0.00705719,0.018661499,-0.008934021,0.02520752,0.010749817,-0.0038280487,-0.006034851,0.040649414,-0.027175903,0.003419876,0.021362305,0.021713257,-0.040283203,0.006652832,-0.013458252,0.03189087,0.011428833,0.001083374,-0.024414062,0.017150879,-0.021194458,0.03591919,-0.06011963,-0.043670654,-0.038085938,-0.021865845,-0.0635376,0.05203247,0.021896362,0.017303467,-0.040405273,-0.060699463,-0.05432129,0.030578613,-0.0039138794,-0.059814453,-0.010299683,-0.029174805,-0.04147339,0.040893555,-0.026000977,-0.062438965,0.024993896,-0.003967285,-0.039031982,-0.021026611,-0.010467529,0.024658203,0.035461426,-0.01676941,0.018493652,0.0077171326,0.026565552,0.001209259,0.012168884,0.03778076,0.011230469,0.002632141,0.030273438,-0.04119873,0.0015602112,0.04272461,0.004989624,-0.053588867,0.03366089,-0.11859131,0.023147583,-0.024993896,-0.01600647,-0.02017212,0.0025901794,-0.07891846,-0.009033203,-0.056243896,-0.024612427,0.02368164,0.022979736,0.031433105,-0.002412796,0.039001465,-0.010360718,0.012123108,-0.021972656,-0.087890625,0.016479492,-0.012496948,0.017303467,-0.024017334,0.02381897,-0.054107666,-0.004699707,-0.021087646,0.045166016,0.0033435822,0.025024414,0.0132369995,-0.030975342,0.011276245,-0.039611816,-0.018295288,-0.03652954,-0.017196655,0.01676941,-0.023040771,0.028411865,-0.007724762,-0.016998291,0.016647339,0.014465332,0.020843506,0.019088745,-0.010063171,-0.004512787,0.022521973,0.03378296,-0.053863525,0.007675171,-0.02532959,0.026260376,-0.02947998,0.025756836,-0.016616821,-0.009803772,0.01727295,0.04827881,0.003780365,-0.010955811,0.0055274963,-0.026367188,-0.042541504,0.0072784424,-0.028869629,-0.0054779053,0.07788086,-0.020355225,-0.03677368,0.01449585,-0.03250122,0.018081665,0.022628784,0.018157959,-0.015434265,0.002368927,-0.025024414,0.008857727,-0.030059814,-1.9836426E-4,-0.022064209,-0.012069702,-0.033416748,-0.04763794,-0.016052246,-0.0027198792,-0.05557251,-0.00970459,0.006515503,0.028823853,0.021026611,-0.026611328,0.031555176,-0.0061798096,-0.031585693,-0.0053863525,-0.03161621,0.032287598,0.0149002075,-0.060150146,-0.015594482,-0.016799927,0.005432129,-0.013442993,0.024856567,-0.018920898,-0.010696411,0.023834229,0.020568848,0.055023193,0.004219055,0.002937317,0.013938904,0.026153564,-0.019546509,0.03955078,-0.0061187744,-0.007659912,0.0039749146,0.0011529922,0.01864624,0.032440186,-0.006843567,0.037841797,0.03111267,-0.006515503,-0.041625977,0.01751709,-0.0115356445,-0.08465576,0.033294678,0.006000519,-0.006477356,0.024169922,-0.016906738,0.085998535,-0.012420654,-0.0035476685,0.009208679,-0.008361816,0.012489319,0.014389038,-0.031280518,-0.02760315,0.022705078,0.014923096,0.029663086,0.090026855,-0.032684326,0.08679199,-2.670288E-4,-0.012367249,0.031311035,0.022583008,-0.017028809,-0.05883789,-0.042266846,0.006515503,-0.018676758,0.020126343,-0.009780884,0.038513184,-0.055236816,0.023651123,-0.031036377,0.016189575,0.006752014,-0.016311646,-0.009292603,-0.017150879,-0.008041382,-0.03225708,-0.045532227,0.0012378693,-0.026062012,0.043945312,-0.038391113,0.0027770996,0.018676758,0.087768555,-0.0026512146,0.024551392,-0.03677368,0.048034668,0.009979248,0.045684814,0.0017814636,0.024734497,-8.9645386E-4,0.02748108,-0.021606445,-0.020217896,-0.008232117,0.0037136078,0.027496338,0.008773804,0.022094727,-0.048614502,0.019058228,0.0023002625,-0.04171753,-0.018066406,0.011566162,-0.034057617,0.058013916,-0.012496948,-0.025360107,-0.022750854,0.023345947,0.007041931,0.028762817,-0.011993408,-0.015838623,0.05618286,-0.009490967,-0.054473877,-0.004760742,0.01953125,0.010566711,-0.013214111,0.02558899,-0.06555176,0.028442383,-6.1416626E-4,0.0019302368,-0.032043457,-3.0517578E-4,0.028930664,-0.0037651062,-0.019561768,0.028549194,-0.019195557,-0.016418457,0.0062332153,0.025909424,0.026733398,-0.02798462,0.012001038,-0.05291748,-0.023284912,0.00422287,0.013687134,0.041870117,-0.025726318,-0.020370483,-0.044006348,-3.9196014E-4,0.00944519,-0.023834229,-0.015098572,-0.023223877,0.008781433,0.0076789856,0.020141602,-0.014175415,0.016662598,-0.005973816,0.033813477,-0.013748169,-0.033111572,-0.016845703,0.0051345825,-0.010635376,-0.02268982,0.019210815,0.0124053955,0.015052795,0.034118652,-0.01600647,0.040374756,-0.007949829,-0.0030612946,0.021774292,-0.007896423,-0.03164673,-0.01576233,0.043884277,-0.017059326,0.0039596558,0.007537842,-0.019470215,0.023986816,0.011787415,0.020629883,-0.045074463,0.022628784,0.01914978,0.011131287,-0.016403198,-0.039276123,-0.07208252,0.023010254,-0.03894043,0.010787964,0.04019165,0.0017576218,-0.043823242,-0.034454346,0.030059814,0.02558899,1.5258789E-5,-0.03302002,0.03741455,-0.040374756,0.014251709,0.0046806335,-0.011749268,0.0289917,-0.025726318,0.006515503,-2.9087067E-4,0.043670654,-0.05911255,-0.03189087,0.038726807,-0.027862549,-0.06311035,-0.007610321,-0.02104187,-0.02180481,-0.029296875,-0.068725586,-0.016113281,-0.012924194,0.017684937,-0.020828247,-0.026885986,-0.0058670044,0.008880615,0.0056419373,0.016693115,0.0473938,-0.011367798,-0.0010662079,-0.0013999939,0.02822876,0.014808655,0.010635376,0.006538391,0.0030784607,-0.05682373,0.0035820007,0.019012451,0.02571106,0.021362305,0.04168701,0.02029419,0.040039062,-0.017074585,0.0127334595,0.019332886,0.006351471,0.05267334,0.0029335022,0.014518738,-0.040405273,-0.038635254,0.034179688,0.07299805,-0.027801514,-0.050476074,-0.030014038,-0.00617218,0.06488037,0.0038414001,0.064208984,0.034210205,0.02494812,-0.012954712,0.026641846,0.0597229,0.01146698,0.0014743805,-0.027877808,-0.04699707,0.037597656,0.014572144,-0.012710571,0.018417358,0.02508545,6.599426E-4,0.003255844,-0.043884277,-0.021469116,-0.0284729,-0.037109375,0.044311523,0.043640137,0.018676758,0.1005249,-0.022979736,0.02911377,-0.0015258789,0.05899048,0.042175293,0.016601562,0.012954712,-0.0038909912,0.017425537,-0.03274536,-0.019714355,0.011199951,-0.014831543,0.0069389343,-0.006549835,-0.07409668,0.027420044,0.0491333,-0.0038833618,0.023590088,-9.317398E-4,-0.027160645,8.049011E-4,0.015716553,0.008773804,-0.003025055,-0.00642395,-0.0012283325,0.010566711,-0.05407715,-0.011138916,0.03326416,0.03125,-0.051696777,-0.016860962,0.028656006,0.017044067,-0.021911621,-0.012763977,-0.01890564,0.039794922,-0.013145447,-4.6682358E-4,0.020568848,-0.011108398,0.0021705627,0.03765869,-0.039855957,0.049591064,0.0110321045,-0.005542755,-0.0113220215,0.0050315857,-0.003232956,-0.079589844,0.018722534,0.034423828,-2.7942657E-4,-0.013671875,0.05960083,0.05230713,0.057281494,0.029251099,0.019073486,-0.007331848,0.018981934,0.0074005127,-0.030715942,-0.04446411,-0.013702393,-0.02027893,-2.0217896E-4,0.017913818,0.02960205,0.006713867,0.0044059753,-0.041015625,-0.011566162,-0.0054397583,-0.034362793,0.0073280334,0.02130127,0.012771606,-0.06100464,-0.03945923,-0.014793396,-0.009017944,-0.017608643,-0.037139893,0.058624268,0.0135650635,0.015274048,-0.013259888,-0.041229248,-0.02255249,0.030029297,-0.028579712,0.036224365,0.011756897,0.0043754578,0.029129028,-0.040374756,-0.021484375,0.014190674,-0.077819824,0.002494812,-0.017791748,0.019805908,-0.02432251,0.046691895,-0.041290283,-0.028915405,-0.0020446777,0.009262085,-0.032440186,-0.00554657,0.014709473,0.012992859,-0.024871826,-0.048858643,0.026321411,0.005897522,0.024353027,0.064697266,-0.048950195,0.017547607,-0.009010315,0.014549255,0.040802002,-0.025970459,-0.023788452,0.004211426,-0.02810669,-0.030014038,-0.011566162,0.025314331,-0.05480957,-0.02720642,-0.006198883,-0.01209259,-0.019378662,-0.047668457,-0.09552002,-0.014328003,-0.014564514,-0.046417236,-0.005859375,-0.006511688,-0.014915466,-0.008666992,-0.016555786,-0.016479492,0.0070610046,0.04147339,-0.04446411,0.030334473,0.01423645,-0.01802063,-0.019104004,0.0045928955,0.0038833618,-0.013938904,-0.0061950684,-0.0023040771,0.012863159,0.042114258,-0.052612305,-0.03289795,-0.019195557,0.029571533,-2.5558472E-4,0.028396606,0.057556152,0.03375244,-0.06903076,-0.008117676,-0.04675293,-0.00806427,0.026321411,-0.004749298,-0.030944824,-0.006416321,-6.5612793E-4,-0.02166748,0.009925842,-0.0069274902,0.034576416,-0.010894775,-0.011184692,0.03353882,-0.0657959,0.03781128,-0.012008667,0.020965576,-0.024719238,0.067871094,-0.033721924,0.012908936,0.005168915,0.018966675,0.0158844,0.0044174194,-0.030136108,-0.022460938,0.064331055,0.028320312,0.01259613,0.004337311,0.047424316,0.0025100708,-0.053009033,0.024597168,0.05508423,0.028564453,-0.042633057,-0.0047836304,0.049438477,0.0046958923,0.006164551,-0.060394287,-0.039398193,0.055236816,-0.050323486,-0.028961182,-0.02078247,-0.044555664,-0.008033752,0.0053710938,-0.020370483,-0.061553955,0.016067505,0.054779053,-0.012863159,0.021575928],[0.007972717,0.0024280548,-0.023376465,-0.036071777,1.9264221E-4,0.014801025,0.0071029663,-0.004360199,-0.018234253,0.023132324,-0.042877197,-0.013389587,0.045318604,-0.03543091,0.042907715,0.0048332214,0.025680542,0.026672363,0.0035133362,-0.0020771027,-0.021209717,0.008041382,-0.030914307,-0.04434204,0.004173279,0.021499634,-0.04208374,0.006576538,-0.073913574,-0.08654785,-0.013748169,0.092285156,0.01713562,0.020599365,0.010017395,0.011482239,-0.036590576,0.029083252,0.016189575,-0.020095825,0.021408081,-0.0087890625,-0.06109619,-0.024230957,-0.08343506,-0.012451172,0.022827148,0.026016235,0.0073013306,-0.028549194,0.040283203,0.034576416,0.051727295,0.016906738,-0.020889282,0.0022678375,-0.012893677,0.005531311,0.033996582,0.0022392273,-0.018875122,-0.013549805,0.024108887,-0.032440186,0.0031280518,0.013534546,-0.007888794,0.013366699,0.058166504,0.01725769,-0.04083252,-0.0011711121,0.013961792,-0.013442993,0.009841919,0.064086914,-0.026931763,0.051086426,0.040740967,-0.006477356,-0.013435364,6.008148E-4,-0.008781433,-0.009712219,-0.011795044,-0.010375977,-0.006969452,0.0029678345,-0.012237549,-0.089416504,-0.015083313,0.06713867,-0.037017822,-0.019180298,-0.056549072,-0.03930664,0.024093628,0.043273926,-0.061157227,-0.031799316,-0.04284668,0.009384155,-0.047912598,-0.01083374,-0.023757935,0.002407074,-0.010772705,-0.017974854,0.011367798,0.05911255,-0.0184021,-0.023208618,-0.010986328,0.03668213,0.023208618,-0.018875122,-0.05630493,0.03845215,-0.047332764,0.04147339,0.00995636,0.019439697,-0.05505371,0.055358887,0.026550293,0.05307007,-0.0039138794,0.026153564,0.08666992,0.012229919,-0.008392334,0.018981934,-0.11810303,-0.089538574,-0.052490234,0.016082764,0.045654297,-0.039855957,-0.010025024,-0.0012817383,0.037322998,-0.002067566,-0.029464722,-0.014595032,-0.0017738342,-0.009475708,0.043182373,0.014801025,-0.014251709,0.009094238,-0.04940796,0.027893066,0.023086548,0.034423828,0.01461792,0.027130127,0.01033783,-0.013534546,0.013656616,0.039520264,-0.01096344,0.011108398,0.06921387,-0.026016235,-0.0030212402,-0.011543274,-0.020599365,0.005302429,0.0023651123,0.04220581,-0.06793213,-0.02532959,-0.010414124,0.012817383,-0.007850647,0.04498291,0.015396118,-0.017288208,0.006729126,0.020706177,0.030914307,-0.012954712,-0.017532349,0.047180176,-0.021606445,-0.021575928,-0.0060043335,-0.03277588,0.045318604,-0.014854431,-0.024551392,0.03704834,0.0087890625,0.035461426,-0.024475098,-0.05505371,-0.009529114,0.0014896393,-0.021194458,-0.026733398,0.011680603,0.001285553,0.016189575,-0.03942871,-0.0076446533,-0.012550354,-4.4250488E-4,0.036956787,0.04034424,0.047607422,-0.0044670105,-0.02168274,-0.00894165,0.017456055,0.0041160583,-0.01399231,-0.017654419,-0.014175415,0.009124756,-0.0069351196,0.06341553,-0.021240234,0.026138306,-0.01828003,0.041656494,-0.019226074,-0.009681702,-0.044403076,0.036834717,-0.011131287,0.01234436,0.03427124,-0.042663574,-0.035949707,-0.051086426,-0.00504303,-0.020950317,0.04232788,0.007270813,0.011054993,0.0015964508,-0.011894226,-0.054473877,-0.0569458,-0.008010864,-0.022842407,0.010177612,0.0026245117,0.0390625,0.018478394,0.008834839,-0.025054932,0.03857422,0.020507812,0.029785156,0.061828613,-0.0026779175,-0.0012540817,0.0345459,-0.024261475,0.005680084,0.034820557,-0.0026245117,0.014022827,-0.026641846,-0.028533936,-0.028656006,0.016448975,-0.0034885406,-0.008125305,0.028930664,-0.032958984,-0.02003479,0.009506226,0.036102295,-0.0121536255,-0.049987793,0.0025253296,0.0019054413,-0.0066566467,0.014137268,-0.0054473877,0.04724121,0.020126343,0.018295288,0.03466797,0.048614502,0.040527344,-0.004722595,-0.012260437,0.028564453,0.074523926,0.024230957,0.02923584,0.07922363,0.03677368,0.023513794,0.045166016,-0.008644104,0.014854431,-0.035186768,0.009254456,0.008491516,-0.027999878,-0.0016345978,0.078308105,0.05126953,-0.013305664,0.02217102,0.054870605,-0.038726807,-0.0019245148,0.023071289,0.007724762,0.0057525635,-0.04473877,0.020996094,0.027786255,0.013893127,0.034942627,0.037353516,8.029938E-4,-0.049926758,-0.006713867,-0.003396988,-0.034362793,-0.008590698,-0.023620605,0.0023536682,0.0060691833,0.05783081,0.05517578,0.012481689,0.011428833,0.028656006,-0.011505127,0.018600464,0.015838623,0.022521973,-0.007949829,0.033996582,-0.03086853,-0.023452759,0.018722534,0.034057617,-0.061828613,0.010726929,-0.0018997192,0.07434082,0.033843994,-0.009117126,-0.02470398,-0.01927185,0.020629883,0.03692627,0.009597778,0.0023822784,-0.02330017,-0.0017147064,-0.0680542,-0.035583496,0.016708374,-0.03112793,0.046813965,0.011497498,-0.05529785,0.021759033,0.0368042,-0.024169922,0.050811768,-0.013748169,0.014923096,-0.030960083,-0.019454956,-0.09020996,-0.005874634,-0.007095337,-0.066833496,0.008651733,-0.011703491,0.012435913,0.026931763,0.007865906,0.048858643,0.03152466,-0.03173828,-0.06021118,0.028167725,0.011543274,0.04925537,-0.01600647,-0.013244629,0.03173828,-0.0284729,-0.02508545,-0.026519775,0.039398193,0.037017822,0.0059890747,-0.0020275116,-0.022521973,0.037109375,-0.03074646,0.047698975,-0.07800293,0.0068626404,-0.04623413,-0.04949951,0.016906738,0.022033691,0.003753662,0.057556152,0.019714355,0.008262634,-0.020553589,-0.023864746,-0.029388428,0.026611328,-0.041259766,-0.016082764,-0.0013427734,0.010002136,-0.023834229,-0.022735596,-0.012268066,0.016159058,-0.0038642883,0.021331787,-0.011642456,0.02168274,-0.0029525757,-0.015327454,-0.020736694,-0.039886475,-0.026275635,-0.024337769,-0.01689148,0.04156494,0.03012085,-0.018127441,-0.05038452,0.006095886,-0.018676758,0.020996094,-0.03744507,-0.008956909,0.0032100677,-0.03253174,-0.036071777,-9.975433E-4,-0.06378174,0.030380249,-0.037200928,0.021743774,-0.011383057,-0.052886963,9.2697144E-4,0.030670166,0.0012054443,0.05090332,0.06149292,0.002105713,-0.029953003,0.011146545,-0.033477783,-0.053955078,0.03414917,-0.034851074,-0.014160156,-0.021148682,-0.081604004,0.09564209,0.021209717,0.023651123,-0.017868042,-0.00970459,-5.41687E-4,0.04840088,-0.031921387,-0.0087890625,-0.026672363,0.010871887,0.042236328,-0.022125244,-0.02558899,0.016677856,-0.016403198,0.03643799,0.010864258,-0.06719971,-0.034332275,-0.047790527,0.026855469,-0.06137085,-0.022064209,-0.014144897,8.239746E-4,-0.028900146,-0.004585266,-0.030685425,-0.07611084,9.288788E-4,0.06689453,-0.052490234,0.028945923,-0.01979065,-0.0413208,-0.013008118,0.019638062,0.02859497,-0.022583008,-0.00623703,0.07684326,0.026626587,-0.012588501,-0.04864502,0.04788208,0.01739502,-0.026901245,0.017318726,0.037963867,-0.018356323,0.05996704,0.024429321,-0.04168701,0.023391724,0.007698059,-0.001660347,0.015686035,-0.063964844,0.03765869,-1.9073486E-4,-0.0033950806,0.05557251,-0.021255493,0.056854248,-0.024719238,0.01939392,0.023376465,-0.057617188,-0.005973816,-0.0037002563,-0.03100586,-0.009712219,0.0039711,-0.07922363,-0.024841309,0.01411438,0.006313324,0.026657104,0.01586914,-0.012710571,2.784729E-4,0.0019798279,0.014564514,-0.02418518,-0.022872925,0.0040130615,-0.02281189,-0.014892578,-0.008270264,0.004299164,-0.036315918,0.014770508,0.01424408,-0.0031585693,0.038757324,-0.018463135,0.008255005,-0.027832031,-6.465912E-4,-0.026473999,0.039489746,0.008277893,-0.017913818,-0.01398468,-0.03111267,-0.010543823,0.0076828003,0.01876831,-0.009513855,0.03012085,0.011604309,-0.022521973,0.028060913,-0.016113281,0.0046844482,0.011230469,0.0063476562,-0.0057754517,-0.013763428,-0.030090332,0.017089844,-0.01651001,-0.0063667297,-0.03933716,0.0637207,-0.016921997,-0.0158844,0.002439499,-0.02407837,-0.0015716553,0.024505615,-0.038726807,0.005142212,0.015449524,-0.013900757,-0.03967285,0.018096924,-0.03515625,0.05734253,-0.04901123,-0.03491211,0.09124756,-0.05026245,-0.03451538,0.0635376,0.009437561,-0.04852295,-0.05722046,0.010345459,-0.05090332,0.022003174,0.009017944,0.011566162,-0.030517578,0.0602417,-0.007347107,-0.022735596,0.016921997,-0.011161804,0.014839172,-0.03250122,-0.06149292,-0.002532959,0.041015625,0.010803223,0.0020713806,0.0178833,0.0132751465,-0.007801056,0.019348145,-0.015289307,-0.052490234,0.014862061,-0.028259277,-0.054260254,0.0017242432,-0.028213501,0.0031280518,-0.01234436,0.008598328,-0.058898926,0.04055786,0.042816162,0.0061836243,-0.026123047,0.0552063,0.008476257,0.011627197,0.011108398,0.048065186,-0.01725769,-0.006969452,0.030639648,0.004463196,0.056274414,0.024169922,0.010620117,0.03552246,-0.0013484955,9.95636E-4,0.0050811768,0.006210327,0.025482178,-0.05279541,-0.0034561157,-0.0037002563,-0.036102295,-0.01914978,-0.0011482239,0.023284912,0.040374756,-0.034240723,-0.016860962,-0.00995636,0.055999756,0.036621094,-0.016296387,0.0046653748,-0.029693604,0.019180298,-0.049743652,0.012390137,0.028015137,-0.039276123,-0.0256958,0.0385437,0.060760498,0.00390625,-0.0019607544,0.003868103,0.020980835,0.010070801,0.022232056,-0.0042304993,0.028671265,-0.040130615,0.01525116,0.018936157,-0.0074005127,0.010147095,-0.036621094,-0.0059661865,0.027008057,-0.016708374,-0.014480591,0.03237915,0.004550934,-0.04425049,-0.036743164,-0.050689697,0.018554688,-0.025360107,0.03213501,-0.014831543,-0.012130737,-0.005302429,0.016433716,0.015174866,0.055908203,0.05117798,-0.019699097,0.038085938,0.021850586,-0.006465912,0.014511108,-0.015052795,-0.021865845,0.010940552,-0.019500732,0.0027427673,0.0066719055,0.010673523,-0.014343262,0.017669678,-0.011878967,0.009338379,0.012748718,0.002670288,0.0335083,0.0057907104,0.009147644,-0.01424408,0.0033569336,-0.011749268,0.04144287,0.004371643,0.0015029907,-0.040130615,-0.06530762,-0.0262146,-0.02822876,0.028549194,0.02305603,0.006778717,-0.06677246,-0.013961792,0.03753662,0.0101623535,0.012748718,0.017822266,-8.8500977E-4,-0.028793335,0.005077362,0.02053833,0.025619507,-0.047729492,-0.006877899,0.01524353,0.0012569427,0.03338623,-0.06512451,-0.009796143,-0.008056641,0.02192688,0.045410156,-0.024307251,0.032073975,0.03591919,-0.0016384125,-0.018585205,-0.012321472,0.036315918,0.013786316,0.02784729,0.028213501,-8.010864E-5,-0.039031982,-0.0121154785,-9.3460083E-4,-0.028839111,-0.030258179,-0.0034751892,-0.0256958,-0.048858643,-0.038726807,0.038482666,-0.0013504028,-0.0065307617,0.0030670166,-0.025177002,0.0070266724,0.05355835,0.019454956,0.03036499,-0.008895874,-0.019561768,-0.022079468,-0.041168213,-0.01802063,0.04168701,0.024475098,-0.022094727,-0.031951904,0.0024299622,0.02243042,0.00504303,-0.018615723,0.0022087097,0.040374756,-0.0096206665,-0.017303467,-0.013702393,-0.009414673,0.05609131,0.032348633,-0.04107666,0.048736572,0.015945435,-0.010169983,0.018463135,-0.0069084167,0.03225708,-0.11608887,0.049713135,0.06750488,-0.01751709,0.014785767,0.04345703,0.05065918,0.060516357,-0.007091522,0.034362793,0.042022705,0.0871582,-0.003118515,-0.042053223,-0.046813965,-0.0115356445,-0.004142761,0.01838684,0.012565613,-0.027648926,0.04522705,0.0065727234,0.0031471252,-0.012519836,-0.022888184,-0.015838623,0.019607544,-0.0026855469,0.016677856,0.010879517,-0.03302002,0.0011777878,0.021942139,-0.021713257,0.022003174,-0.012924194,-0.006134033,-0.049713135,0.024017334,0.020263672,0.0038452148,-0.07635498,0.028244019,-0.018447876,0.020706177,-0.006877899,-0.014923096,-0.014389038,-0.022216797,-0.008102417,-0.047088623,-0.0115356445,-0.03161621,0.02166748,-0.016647339,-0.010101318,0.002243042,-0.03048706,-0.008117676,-0.013755798,0.02331543,-0.05029297,-0.023971558,0.007507324,-0.055664062,-0.070129395,0.09082031,-0.026443481,-0.043548584,-0.0020980835,-0.008628845,0.02394104,-0.09869385,-0.02519226,-0.062683105,-0.021820068,-0.0134887695,-0.032958984,0.019744873,-6.275177E-4,-0.023071289,-0.006454468,0.012886047,-0.034973145,0.04812622,-0.042541504,-0.027511597,-0.025024414,-5.4979324E-4,-0.0028800964,0.02633667,-0.019241333,0.005504608,0.0045318604,0.011688232,-0.043670654,0.011413574,-0.04638672,-0.02810669,0.006465912,-0.023132324,0.040771484,-0.006450653,0.012374878,0.030761719,0.048034668,0.03817749,-0.006034851,-0.039855957,0.014457703,0.008560181,0.045715332,-0.037841797,-0.0209198,-0.0128479,0.056396484,-0.00856781,0.047943115,0.038330078,-0.0057907104,-0.0284729,-0.037017822,-0.020324707,0.0121536255,-0.037384033,-0.02835083,0.018722534,-8.659363E-4,0.015899658,-0.07098389,0.0069503784,0.026901245,0.018997192,0.00724411,-0.04309082,0.011207581,-0.0048713684,0.043914795,0.0012407303,0.007255554,0.009460449,0.034118652,-0.029403687,0.020263672,-0.015472412,-7.543564E-4,0.027282715,-0.00995636,0.0066108704,-0.0104904175,0.0077667236,0.031951904,0.009170532,0.01802063,-0.0039978027,-0.004688263,-0.037139893,-0.018554688,0.031402588,-0.04385376,-0.036743164,0.032836914,0.02658081,-0.005706787,-0.057678223,0.026901245,-0.023086548,0.018951416,-0.050720215,0.06384277,-0.0031585693,-0.0041236877,5.760193E-4,0.028533936,0.036346436,-0.0057868958,0.0049591064,-0.0024299622,0.0065078735,-0.0051994324]],"id":"0e9cb5ab-1fef-46eb-8e2c-773f0f60f39d","response_type":"embeddings_floats","texts":["hello + world!","goodbye cruel world!"]}' + headers: + Connection: + - keep-alive + Content-Length: + - '25552' + Content-Type: + - application/json + Date: + - Tue, 23 Apr 2024 21:00:11 GMT + X-Amzn-Bedrock-Input-Token-Count: + - '7' + X-Amzn-Bedrock-Invocation-Latency: + - '271' + x-amzn-RequestId: + - 0e9cb5ab-1fef-46eb-8e2c-773f0f60f39d + status: + code: 200 + message: OK +version: 1 diff --git a/tests/contrib/botocore/test_bedrock.py b/tests/contrib/botocore/test_bedrock.py index 47ec875020b..a94699813ed 100644 --- a/tests/contrib/botocore/test_bedrock.py +++ b/tests/contrib/botocore/test_bedrock.py @@ -131,7 +131,7 @@ def aws_credentials(): @pytest.fixture -def boto3(aws_credentials, mock_llmobs_writer, ddtrace_global_config, ddtrace_config_botocore): +def boto3(aws_credentials, mock_llmobs_span_writer, ddtrace_global_config, ddtrace_config_botocore): global_config = default_global_config() global_config.update(ddtrace_global_config) with override_global_config(global_config): @@ -156,12 +156,12 @@ def bedrock_client(boto3, request_vcr): @pytest.fixture -def mock_llmobs_writer(): - patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsWriter") +def mock_llmobs_span_writer(): + patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsSpanWriter") try: - LLMObsWriterMock = patcher.start() + LLMObsSpanWriterMock = patcher.start() m = mock.MagicMock() - LLMObsWriterMock.return_value = m + LLMObsSpanWriterMock.return_value = m yield m finally: patcher.stop() @@ -433,6 +433,24 @@ def test_readlines_error(bedrock_client, request_vcr): response.get("body").readlines() +@pytest.mark.snapshot +def test_amazon_embedding(bedrock_client, request_vcr): + body = json.dumps({"inputText": "Hello World!"}) + model = "amazon.titan-embed-text-v1" + with request_vcr.use_cassette("amazon_embedding.yaml"): + response = bedrock_client.invoke_model(body=body, modelId=model) + json.loads(response.get("body").read()) + + +@pytest.mark.snapshot +def test_cohere_embedding(bedrock_client, request_vcr): + body = json.dumps({"texts": ["Hello World!", "Goodbye cruel world!"], "input_type": "search_document"}) + model = "cohere.embed-english-v3" + with request_vcr.use_cassette("cohere_embedding.yaml"): + response = bedrock_client.invoke_model(body=body, modelId=model) + json.loads(response.get("body").read()) + + @pytest.mark.parametrize( "ddtrace_global_config", [dict(_llmobs_enabled=True, _llmobs_sample_rate=1.0, _llmobs_ml_app="")] ) @@ -463,7 +481,7 @@ def expected_llmobs_span_event(span, n_output, message=False): ) @classmethod - def _test_llmobs_invoke(cls, provider, bedrock_client, mock_llmobs_writer, cassette_name=None, n_output=1): + def _test_llmobs_invoke(cls, provider, bedrock_client, mock_llmobs_span_writer, cassette_name=None, n_output=1): mock_tracer = DummyTracer(writer=DummyWriter(trace_flush_enabled=False)) pin = Pin.get_from(bedrock_client) pin.override(bedrock_client, tracer=mock_tracer) @@ -491,13 +509,15 @@ def _test_llmobs_invoke(cls, provider, bedrock_client, mock_llmobs_writer, casse json.loads(response.get("body").read()) span = mock_tracer.pop_traces()[0][0] - assert mock_llmobs_writer.enqueue.call_count == 1 - mock_llmobs_writer.enqueue.assert_called_with( + assert mock_llmobs_span_writer.enqueue.call_count == 1 + mock_llmobs_span_writer.enqueue.assert_called_with( cls.expected_llmobs_span_event(span, n_output, message="message" in provider) ) @classmethod - def _test_llmobs_invoke_stream(cls, provider, bedrock_client, mock_llmobs_writer, cassette_name=None, n_output=1): + def _test_llmobs_invoke_stream( + cls, provider, bedrock_client, mock_llmobs_span_writer, cassette_name=None, n_output=1 + ): mock_tracer = DummyTracer(writer=DummyWriter(trace_flush_enabled=False)) pin = Pin.get_from(bedrock_client) pin.override(bedrock_client, tracer=mock_tracer) @@ -526,70 +546,76 @@ def _test_llmobs_invoke_stream(cls, provider, bedrock_client, mock_llmobs_writer pass span = mock_tracer.pop_traces()[0][0] - assert mock_llmobs_writer.enqueue.call_count == 1 - mock_llmobs_writer.enqueue.assert_called_with( + assert mock_llmobs_span_writer.enqueue.call_count == 1 + mock_llmobs_span_writer.enqueue.assert_called_with( cls.expected_llmobs_span_event(span, n_output, message="message" in provider) ) - def test_llmobs_ai21_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke("ai21", bedrock_client, mock_llmobs_writer) + def test_llmobs_ai21_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke("ai21", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_amazon_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke("amazon", bedrock_client, mock_llmobs_writer) + def test_llmobs_amazon_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke("amazon", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_anthropic_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke("anthropic", bedrock_client, mock_llmobs_writer) + def test_llmobs_anthropic_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke("anthropic", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_anthropic_message(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke("anthropic_message", bedrock_client, mock_llmobs_writer) + def test_llmobs_anthropic_message(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke("anthropic_message", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_cohere_single_output_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): + def test_llmobs_cohere_single_output_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): self._test_llmobs_invoke( - "cohere", bedrock_client, mock_llmobs_writer, cassette_name="cohere_invoke_single_output.yaml" + "cohere", bedrock_client, mock_llmobs_span_writer, cassette_name="cohere_invoke_single_output.yaml" ) - def test_llmobs_cohere_multi_output_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): + def test_llmobs_cohere_multi_output_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): self._test_llmobs_invoke( "cohere", bedrock_client, - mock_llmobs_writer, + mock_llmobs_span_writer, cassette_name="cohere_invoke_multi_output.yaml", n_output=2, ) - def test_llmobs_meta_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke("meta", bedrock_client, mock_llmobs_writer) + def test_llmobs_meta_invoke(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke("meta", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_amazon_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke_stream("amazon", bedrock_client, mock_llmobs_writer) + def test_llmobs_amazon_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke_stream("amazon", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_anthropic_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke_stream("anthropic", bedrock_client, mock_llmobs_writer) + def test_llmobs_anthropic_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke_stream("anthropic", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_anthropic_message_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke_stream("anthropic_message", bedrock_client, mock_llmobs_writer) + def test_llmobs_anthropic_message_invoke_stream( + self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer + ): + self._test_llmobs_invoke_stream("anthropic_message", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_cohere_single_output_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): + def test_llmobs_cohere_single_output_invoke_stream( + self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer + ): self._test_llmobs_invoke_stream( "cohere", bedrock_client, - mock_llmobs_writer, + mock_llmobs_span_writer, cassette_name="cohere_invoke_stream_single_output.yaml", ) - def test_llmobs_cohere_multi_output_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): + def test_llmobs_cohere_multi_output_invoke_stream( + self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer + ): self._test_llmobs_invoke_stream( "cohere", bedrock_client, - mock_llmobs_writer, + mock_llmobs_span_writer, cassette_name="cohere_invoke_stream_multi_output.yaml", n_output=2, ) - def test_llmobs_meta_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer): - self._test_llmobs_invoke_stream("meta", bedrock_client, mock_llmobs_writer) + def test_llmobs_meta_invoke_stream(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer): + self._test_llmobs_invoke_stream("meta", bedrock_client, mock_llmobs_span_writer) - def test_llmobs_error(self, ddtrace_global_config, bedrock_client, mock_llmobs_writer, request_vcr): + def test_llmobs_error(self, ddtrace_global_config, bedrock_client, mock_llmobs_span_writer, request_vcr): import botocore mock_tracer = DummyTracer(writer=DummyWriter(trace_flush_enabled=False)) @@ -626,5 +652,5 @@ def test_llmobs_error(self, ddtrace_global_config, bedrock_client, mock_llmobs_w ), ] - assert mock_llmobs_writer.enqueue.call_count == 1 - mock_llmobs_writer.assert_has_calls(expected_llmobs_writer_calls) + assert mock_llmobs_span_writer.enqueue.call_count == 1 + mock_llmobs_span_writer.assert_has_calls(expected_llmobs_writer_calls) diff --git a/tests/contrib/django/django_app/appsec_urls.py b/tests/contrib/django/django_app/appsec_urls.py index 9a90d6fd790..c076b238a76 100644 --- a/tests/contrib/django/django_app/appsec_urls.py +++ b/tests/contrib/django/django_app/appsec_urls.py @@ -1,4 +1,5 @@ import hashlib +import os from typing import TYPE_CHECKING # noqa:F401 import django @@ -209,6 +210,23 @@ def sqli_http_request_body(request): return HttpResponse(value, status=200) +def command_injection(request): + value = decode_aspect(bytes.decode, 1, request.body) + # label iast_command_injection + os.system(add_aspect("dir -l ", value)) + + return HttpResponse("OK", status=200) + + +def header_injection(request): + value = decode_aspect(bytes.decode, 1, request.body) + + response = HttpResponse("OK", status=200) + # label iast_header_injection + response.headers["Header-Injection"] = value + return response + + def validate_querydict(request): qd = request.GET res = qd.getlist("x") @@ -224,6 +242,8 @@ def validate_querydict(request): handler("body/$", body_view, name="body_view"), handler("weak-hash/$", weak_hash_view, name="weak_hash"), handler("block/$", block_callable_view, name="block"), + handler("command-injection/$", command_injection, name="command_injection"), + handler("header-injection/$", header_injection, name="header_injection"), handler("taint-checking-enabled/$", taint_checking_enabled_view, name="taint_checking_enabled_view"), handler("taint-checking-disabled/$", taint_checking_disabled_view, name="taint_checking_disabled_view"), handler("sqli_http_request_parameter/$", sqli_http_request_parameter, name="sqli_http_request_parameter"), diff --git a/tests/contrib/django/test_django_appsec_iast.py b/tests/contrib/django/test_django_appsec_iast.py index e483f57ca5e..7298e06cd22 100644 --- a/tests/contrib/django/test_django_appsec_iast.py +++ b/tests/contrib/django/test_django_appsec_iast.py @@ -8,6 +8,8 @@ from ddtrace.appsec._iast import oce from ddtrace.appsec._iast._patch_modules import patch_iast from ddtrace.appsec._iast._utils import _is_python_version_supported as python_supported_by_iast +from ddtrace.appsec._iast.constants import VULN_CMDI +from ddtrace.appsec._iast.constants import VULN_HEADER_INJECTION from ddtrace.appsec._iast.constants import VULN_SQL_INJECTION from ddtrace.internal.compat import urlencode from ddtrace.settings.asm import config as asm_config @@ -491,13 +493,12 @@ def test_django_tainted_user_agent_iast_enabled_sqli_http_body(client, test_span payload=payload, content_type=content_type, ) - vuln_type = "SQL_INJECTION" loaded = json.loads(root_span.get_tag(IAST.JSON)) - line, hash_value = get_line_and_hash("iast_enabled_sqli_http_body", vuln_type, filename=TEST_FILE) + line, hash_value = get_line_and_hash("iast_enabled_sqli_http_body", VULN_SQL_INJECTION, filename=TEST_FILE) assert loaded["sources"] == [{"origin": "http.request.body", "name": "body", "value": "master"}] - assert loaded["vulnerabilities"][0]["type"] == "SQL_INJECTION" + assert loaded["vulnerabilities"][0]["type"] == VULN_SQL_INJECTION assert loaded["vulnerabilities"][0]["hash"] == hash_value assert loaded["vulnerabilities"][0]["evidence"] == { "valueParts": [ @@ -548,9 +549,70 @@ def test_querydict_django_with_iast(client, test_spans, tracer): ) assert root_span.get_tag(IAST.JSON) is None - assert response.status_code == 200 assert ( response.content == b"x=['1', '3'], all=[('x', ['1', '3']), ('y', ['2'])]," b" keys=['x', 'y'], urlencode=x=1&x=3&y=2" ) + + +@pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") +def test_django_command_injection(client, test_spans, tracer): + with override_global_config(dict(_iast_enabled=True, _deduplication_enabled=False)), override_env( + dict(DD_IAST_ENABLED="True") + ): + oce.reconfigure() + patch_iast({"command_injection": True}) + root_span, _ = _aux_appsec_get_root_span( + client, + test_spans, + tracer, + url="/appsec/command-injection/", + payload="master", + content_type="application/json", + ) + + loaded = json.loads(root_span.get_tag(IAST.JSON)) + + line, hash_value = get_line_and_hash("iast_command_injection", VULN_CMDI, filename=TEST_FILE) + + assert loaded["sources"] == [ + {"name": "body", "origin": "http.request.body", "pattern": "abcdef", "redacted": True} + ] + assert loaded["vulnerabilities"][0]["type"] == VULN_CMDI + assert loaded["vulnerabilities"][0]["hash"] == hash_value + assert loaded["vulnerabilities"][0]["evidence"] == { + "valueParts": [{"value": "dir "}, {"redacted": True}, {"pattern": "abcdef", "redacted": True, "source": 0}] + } + assert loaded["vulnerabilities"][0]["location"]["line"] == line + assert loaded["vulnerabilities"][0]["location"]["path"] == TEST_FILE + + +@pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") +def test_django_header_injection(client, test_spans, tracer): + with override_global_config(dict(_iast_enabled=True, _deduplication_enabled=False)), override_env( + dict(DD_IAST_ENABLED="True") + ): + oce.reconfigure() + patch_iast({"header_injection": True}) + root_span, _ = _aux_appsec_get_root_span( + client, + test_spans, + tracer, + url="/appsec/header-injection/", + payload="master", + content_type="application/json", + ) + + loaded = json.loads(root_span.get_tag(IAST.JSON)) + + line, hash_value = get_line_and_hash("iast_header_injection", VULN_HEADER_INJECTION, filename=TEST_FILE) + + assert loaded["sources"] == [{"origin": "http.request.body", "name": "body", "value": "master"}] + assert loaded["vulnerabilities"][0]["type"] == VULN_HEADER_INJECTION + assert loaded["vulnerabilities"][0]["hash"] == hash_value + assert loaded["vulnerabilities"][0]["evidence"] == { + "valueParts": [{"value": "Header-Injection: "}, {"source": 0, "value": "master"}] + } + assert loaded["vulnerabilities"][0]["location"]["line"] == line + assert loaded["vulnerabilities"][0]["location"]["path"] == TEST_FILE diff --git a/tests/contrib/flask/test_flask_appsec_iast.py b/tests/contrib/flask/test_flask_appsec_iast.py index 8b48aaad61d..d3b7f603ab0 100644 --- a/tests/contrib/flask/test_flask_appsec_iast.py +++ b/tests/contrib/flask/test_flask_appsec_iast.py @@ -7,8 +7,10 @@ from ddtrace.appsec._constants import IAST from ddtrace.appsec._iast import oce from ddtrace.appsec._iast._utils import _is_python_version_supported as python_supported_by_iast +from ddtrace.appsec._iast.constants import VULN_HEADER_INJECTION from ddtrace.appsec._iast.constants import VULN_SQL_INJECTION -from ddtrace.contrib.sqlite3.patch import patch +from ddtrace.appsec._iast.taint_sinks.header_injection import patch as patch_header_injection +from ddtrace.contrib.sqlite3.patch import patch as patch_sqlite_sqli from tests.appsec.iast.iast_utils import get_line_and_hash from tests.contrib.flask import BaseFlaskTestCase from tests.utils import override_env @@ -47,7 +49,8 @@ def setUp(self): ) ), override_env(IAST_ENV): super(FlaskAppSecIASTEnabledTestCase, self).setUp() - patch() + patch_sqlite_sqli() + patch_header_injection() oce.reconfigure() self.tracer._iast_enabled = True @@ -57,7 +60,7 @@ def setUp(self): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_http_request_path_parameter(self): @self.app.route("/sqli//", methods=["GET", "POST"]) - def test_sqli(param_str): + def sqli_1(param_str): import sqlite3 from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted @@ -108,7 +111,7 @@ def test_sqli(param_str): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_enabled_http_request_header_getitem(self): @self.app.route("/sqli//", methods=["GET", "POST"]) - def test_sqli(param_str): + def sqli_2(param_str): import sqlite3 from flask import request @@ -164,7 +167,7 @@ def test_sqli(param_str): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_enabled_http_request_header_name_keys(self): @self.app.route("/sqli//", methods=["GET", "POST"]) - def test_sqli(param_str): + def sqli_3(param_str): import sqlite3 from flask import request @@ -218,7 +221,7 @@ def test_sqli(param_str): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_enabled_http_request_header_values(self): @self.app.route("/sqli//", methods=["GET", "POST"]) - def test_sqli(param_str): + def sqli_4(param_str): import sqlite3 from flask import request @@ -270,7 +273,7 @@ def test_sqli(param_str): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_simple_iast_path_header_and_querystring_tainted(self): @self.app.route("/sqli///", methods=["GET", "POST"]) - def test_sqli(param_str, param_int): + def sqli_5(param_str, param_int): from flask import request from ddtrace.appsec._iast._taint_tracking import OriginType @@ -326,7 +329,7 @@ def test_sqli(param_str, param_int): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_simple_iast_path_header_and_querystring_tainted_request_sampling_0(self): @self.app.route("/sqli//", methods=["GET", "POST"]) - def test_sqli(param_str): + def sqli_6(param_str): from flask import request from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted @@ -357,7 +360,7 @@ def test_sqli(param_str): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_enabled_http_request_cookies_value(self): @self.app.route("/sqli/cookies/", methods=["GET", "POST"]) - def test_sqli(): + def sqli_7(): import sqlite3 from flask import request @@ -423,7 +426,7 @@ def test_sqli(): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_enabled_http_request_cookies_name(self): @self.app.route("/sqli/cookies/", methods=["GET", "POST"]) - def test_sqli(): + def sqli_8(): import sqlite3 from flask import request @@ -487,7 +490,7 @@ def test_sqli(): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_http_request_parameter(self): @self.app.route("/sqli/parameter/", methods=["GET"]) - def test_sqli(): + def sqli_9(): import sqlite3 from ddtrace.appsec._iast._taint_tracking.aspects import add_aspect @@ -535,7 +538,7 @@ def test_sqli(): @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") def test_flask_full_sqli_iast_enabled_http_request_header_values_scrubbed(self): @self.app.route("/sqli//", methods=["GET", "POST"]) - def test_sqli(param_str): + def sqli_10(param_str): import sqlite3 from flask import request @@ -586,6 +589,54 @@ def test_sqli(param_str): assert vulnerability["location"]["path"] == TEST_FILE_PATH assert vulnerability["hash"] == hash_value + @pytest.mark.skipif(not python_supported_by_iast(), reason="Python version not supported by IAST") + def test_flask_header_injection(self): + @self.app.route("/header_injection/", methods=["GET", "POST"]) + def header_injection(): + from flask import Response + from flask import request + + from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted + + tainted_string = request.form.get("name") + assert is_pyobject_tainted(tainted_string) + resp = Response("OK") + resp.headers["Vary"] = tainted_string + + # label test_flask_header_injection_label + resp.headers["Header-Injection"] = tainted_string + return resp + + with override_global_config( + dict( + _iast_enabled=True, + _asm_enabled=True, + ) + ): + resp = self.client.post("/header_injection/", data={"name": "test"}) + assert resp.status_code == 200 + + root_span = self.pop_spans()[0] + assert root_span.get_metric(IAST.ENABLED) == 1.0 + + loaded = json.loads(root_span.get_tag(IAST.JSON)) + assert loaded["sources"] == [{"origin": "http.request.parameter", "name": "name", "value": "test"}] + + line, hash_value = get_line_and_hash( + "test_flask_header_injection_label", + VULN_HEADER_INJECTION, + filename=TEST_FILE_PATH, + ) + vulnerability = loaded["vulnerabilities"][0] + assert vulnerability["type"] == VULN_HEADER_INJECTION + assert vulnerability["evidence"] == { + "valueParts": [{"value": "Header-Injection: "}, {"source": 0, "value": "test"}] + } + # TODO: vulnerability path is flaky, it points to "tests/contrib/flask/__init__.py" + # assert vulnerability["location"]["path"] == TEST_FILE_PATH + # assert vulnerability["location"]["line"] == line + # assert vulnerability["hash"] == hash_value + class FlaskAppSecIASTDisabledTestCase(BaseFlaskTestCase): @pytest.fixture(autouse=True) diff --git a/tests/contrib/langchain/conftest.py b/tests/contrib/langchain/conftest.py index 391a09505f9..d0df10ec929 100644 --- a/tests/contrib/langchain/conftest.py +++ b/tests/contrib/langchain/conftest.py @@ -77,12 +77,12 @@ def mock_tracer(langchain, mock_logs, mock_metrics): @pytest.fixture -def mock_llmobs_writer(): - patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsWriter") +def mock_llmobs_span_writer(): + patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsSpanWriter") try: - LLMObsWriterMock = patcher.start() + LLMObsSpanWriterMock = patcher.start() m = mock.MagicMock() - LLMObsWriterMock.return_value = m + LLMObsSpanWriterMock.return_value = m yield m finally: patcher.stop() diff --git a/tests/contrib/langchain/test_langchain.py b/tests/contrib/langchain/test_langchain.py index f9d702b33a1..466966f939b 100644 --- a/tests/contrib/langchain/test_langchain.py +++ b/tests/contrib/langchain/test_langchain.py @@ -1341,7 +1341,7 @@ def _test_llmobs_llm_invoke( provider, generate_trace, request_vcr, - mock_llmobs_writer, + mock_llmobs_span_writer, mock_tracer, cassette_name, input_role=None, @@ -1369,15 +1369,15 @@ def _test_llmobs_llm_invoke( ), ] - assert mock_llmobs_writer.enqueue.call_count == 1 - mock_llmobs_writer.assert_has_calls(expected_llmons_writer_calls) + assert mock_llmobs_span_writer.enqueue.call_count == 1 + mock_llmobs_span_writer.assert_has_calls(expected_llmons_writer_calls) @classmethod def _test_llmobs_chain_invoke( cls, generate_trace, request_vcr, - mock_llmobs_writer, + mock_llmobs_span_writer, mock_tracer, cassette_name, expected_spans_data=[("llm", {"provider": "openai", "input_role": None, "output_role": None})], @@ -1396,48 +1396,48 @@ def _test_llmobs_chain_invoke( expected_llmobs_writer_calls = cls._expected_llmobs_chain_calls( trace=trace, expected_spans_data=expected_spans_data ) - assert mock_llmobs_writer.enqueue.call_count == len(expected_spans_data) - mock_llmobs_writer.assert_has_calls(expected_llmobs_writer_calls) + assert mock_llmobs_span_writer.enqueue.call_count == len(expected_spans_data) + mock_llmobs_span_writer.assert_has_calls(expected_llmobs_writer_calls) - def test_llmobs_openai_llm(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_openai_llm(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain.llms.OpenAI() self._test_llmobs_llm_invoke( generate_trace=llm, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_completion_sync.yaml", different_py39_cassette=True, provider="openai", ) - def test_llmobs_cohere_llm(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_cohere_llm(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain.llms.Cohere(model="cohere.command-light-text-v14") self._test_llmobs_llm_invoke( generate_trace=llm, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="cohere_completion_sync.yaml", provider="cohere", ) - def test_llmobs_ai21_llm(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_ai21_llm(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain.llms.AI21() self._test_llmobs_llm_invoke( generate_trace=llm, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="ai21_completion_sync.yaml", provider="ai21", different_py39_cassette=True, ) - def test_llmobs_huggingfacehub_llm(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_huggingfacehub_llm(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain.llms.HuggingFaceHub( repo_id="google/flan-t5-xxl", model_kwargs={"temperature": 0.0, "max_tokens": 256}, @@ -1447,19 +1447,19 @@ def test_llmobs_huggingfacehub_llm(self, langchain, mock_llmobs_writer, mock_tra self._test_llmobs_llm_invoke( generate_trace=llm, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="huggingfacehub_completion_sync.yaml", provider="huggingface_hub", ) - def test_llmobs_openai_chat_model(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_openai_chat_model(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): chat = langchain.chat_models.ChatOpenAI(temperature=0, max_tokens=256) self._test_llmobs_llm_invoke( generate_trace=lambda prompt: chat([langchain.schema.HumanMessage(content=prompt)]), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_chat_completion_sync_call.yaml", provider="openai", @@ -1468,13 +1468,13 @@ def test_llmobs_openai_chat_model(self, langchain, mock_llmobs_writer, mock_trac different_py39_cassette=True, ) - def test_llmobs_openai_chat_model_custom_role(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_openai_chat_model_custom_role(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): chat = langchain.chat_models.ChatOpenAI(temperature=0, max_tokens=256) self._test_llmobs_llm_invoke( generate_trace=lambda prompt: chat([langchain.schema.ChatMessage(content=prompt, role="custom")]), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_chat_completion_sync_call.yaml", provider="openai", @@ -1483,13 +1483,13 @@ def test_llmobs_openai_chat_model_custom_role(self, langchain, mock_llmobs_write different_py39_cassette=True, ) - def test_llmobs_chain(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_chain(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): chain = langchain.chains.LLMMathChain(llm=langchain.llms.OpenAI(temperature=0, max_tokens=256)) self._test_llmobs_chain_invoke( generate_trace=lambda prompt: chain.run("what is two raised to the fifty-fourth power?"), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_math_chain_sync.yaml", expected_spans_data=[ @@ -1529,7 +1529,7 @@ def test_llmobs_chain(self, langchain, mock_llmobs_writer, mock_tracer, request_ ) @pytest.mark.skipif(sys.version_info < (3, 10, 0), reason="Requires unnecessary cassette file for Python 3.9") - def test_llmobs_chain_nested(self, langchain, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_chain_nested(self, langchain, mock_llmobs_span_writer, mock_tracer, request_vcr): template = """Paraphrase this text: {input_text} @@ -1567,7 +1567,7 @@ def test_llmobs_chain_nested(self, langchain, mock_llmobs_writer, mock_tracer, r self._test_llmobs_chain_invoke( generate_trace=lambda prompt: sequential_chain.run({"input_text": input_text}), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_sequential_paraphrase_and_rhyme_sync.yaml", expected_spans_data=[ diff --git a/tests/contrib/langchain/test_langchain_community.py b/tests/contrib/langchain/test_langchain_community.py index c207c1f761e..2932b725aad 100644 --- a/tests/contrib/langchain/test_langchain_community.py +++ b/tests/contrib/langchain/test_langchain_community.py @@ -1329,7 +1329,7 @@ def _test_llmobs_llm_invoke( provider, generate_trace, request_vcr, - mock_llmobs_writer, + mock_llmobs_span_writer, mock_tracer, cassette_name, input_role=None, @@ -1354,15 +1354,15 @@ def _test_llmobs_llm_invoke( ), ] - assert mock_llmobs_writer.enqueue.call_count == 1 - mock_llmobs_writer.assert_has_calls(expected_llmons_writer_calls) + assert mock_llmobs_span_writer.enqueue.call_count == 1 + mock_llmobs_span_writer.assert_has_calls(expected_llmons_writer_calls) @classmethod def _test_llmobs_chain_invoke( cls, generate_trace, request_vcr, - mock_llmobs_writer, + mock_llmobs_span_writer, mock_tracer, cassette_name, expected_spans_data=[("llm", {"provider": "openai", "input_role": None, "output_role": None})], @@ -1378,54 +1378,54 @@ def _test_llmobs_chain_invoke( expected_llmobs_writer_calls = cls._expected_llmobs_chain_calls( trace=trace, expected_spans_data=expected_spans_data ) - assert mock_llmobs_writer.enqueue.call_count == len(expected_spans_data) - mock_llmobs_writer.assert_has_calls(expected_llmobs_writer_calls) + assert mock_llmobs_span_writer.enqueue.call_count == len(expected_spans_data) + mock_llmobs_span_writer.assert_has_calls(expected_llmobs_writer_calls) @flaky(1735812000) - def test_llmobs_openai_llm(self, langchain_openai, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_openai_llm(self, langchain_openai, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain_openai.OpenAI() self._test_llmobs_llm_invoke( generate_trace=llm.invoke, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_completion_sync.yaml", provider="openai", ) - def test_llmobs_cohere_llm(self, langchain_community, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_cohere_llm(self, langchain_community, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain_community.llms.Cohere(model="cohere.command-light-text-v14") self._test_llmobs_llm_invoke( generate_trace=llm.invoke, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="cohere_completion_sync.yaml", provider="cohere", ) - def test_llmobs_ai21_llm(self, langchain_community, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_ai21_llm(self, langchain_community, mock_llmobs_span_writer, mock_tracer, request_vcr): llm = langchain_community.llms.AI21() self._test_llmobs_llm_invoke( generate_trace=llm.invoke, request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="ai21_completion_sync.yaml", provider="ai21", ) @flaky(1735812000) - def test_llmobs_openai_chat_model(self, langchain_openai, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_openai_chat_model(self, langchain_openai, mock_llmobs_span_writer, mock_tracer, request_vcr): chat = langchain_openai.ChatOpenAI(temperature=0, max_tokens=256) self._test_llmobs_llm_invoke( generate_trace=lambda prompt: chat.invoke([langchain.schema.HumanMessage(content=prompt)]), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_chat_completion_sync_call.yaml", provider="openai", @@ -1434,13 +1434,15 @@ def test_llmobs_openai_chat_model(self, langchain_openai, mock_llmobs_writer, mo ) @flaky(1735812000) - def test_llmobs_openai_chat_model_custom_role(self, langchain_openai, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_openai_chat_model_custom_role( + self, langchain_openai, mock_llmobs_span_writer, mock_tracer, request_vcr + ): chat = langchain_openai.ChatOpenAI(temperature=0, max_tokens=256) self._test_llmobs_llm_invoke( generate_trace=lambda prompt: chat.invoke([langchain.schema.ChatMessage(content=prompt, role="custom")]), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="openai_chat_completion_sync_call.yaml", provider="openai", @@ -1449,7 +1451,7 @@ def test_llmobs_openai_chat_model_custom_role(self, langchain_openai, mock_llmob ) @flaky(1735812000) - def test_llmobs_chain(self, langchain_core, langchain_openai, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_chain(self, langchain_core, langchain_openai, mock_llmobs_span_writer, mock_tracer, request_vcr): prompt = langchain_core.prompts.ChatPromptTemplate.from_messages( [("system", "You are world class technical documentation writer."), ("user", "{input}")] ) @@ -1471,7 +1473,7 @@ def test_llmobs_chain(self, langchain_core, langchain_openai, mock_llmobs_writer self._test_llmobs_chain_invoke( generate_trace=lambda prompt: chain.invoke({"input": prompt}), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="lcel_openai_chain_call.yaml", expected_spans_data=[ @@ -1486,7 +1488,9 @@ def test_llmobs_chain(self, langchain_core, langchain_openai, mock_llmobs_writer ], ) - def test_llmobs_chain_nested(self, langchain_core, langchain_openai, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_chain_nested( + self, langchain_core, langchain_openai, mock_llmobs_span_writer, mock_tracer, request_vcr + ): prompt1 = langchain_core.prompts.ChatPromptTemplate.from_template("what is the city {person} is from?") prompt2 = langchain_core.prompts.ChatPromptTemplate.from_template( "what country is the city {city} in? respond in {language}" @@ -1504,7 +1508,7 @@ def test_llmobs_chain_nested(self, langchain_core, langchain_openai, mock_llmobs {"person": "Spongebob Squarepants", "language": "Spanish"} ), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="lcel_openai_chain_nested.yaml", expected_spans_data=[ @@ -1528,7 +1532,9 @@ def test_llmobs_chain_nested(self, langchain_core, langchain_openai, mock_llmobs ) @pytest.mark.skipif(sys.version_info >= (3, 11, 0), reason="Python <3.11 required") - def test_llmobs_chain_batch(self, langchain_core, langchain_openai, mock_llmobs_writer, mock_tracer, request_vcr): + def test_llmobs_chain_batch( + self, langchain_core, langchain_openai, mock_llmobs_span_writer, mock_tracer, request_vcr + ): prompt = langchain_core.prompts.ChatPromptTemplate.from_template("Tell me a short joke about {topic}") output_parser = langchain_core.output_parsers.StrOutputParser() model = langchain_openai.ChatOpenAI() @@ -1537,7 +1543,7 @@ def test_llmobs_chain_batch(self, langchain_core, langchain_openai, mock_llmobs_ self._test_llmobs_chain_invoke( generate_trace=lambda inputs: chain.batch(["chickens", "pigs"]), request_vcr=request_vcr, - mock_llmobs_writer=mock_llmobs_writer, + mock_llmobs_span_writer=mock_llmobs_span_writer, mock_tracer=mock_tracer, cassette_name="lcel_openai_chain_batch.yaml", expected_spans_data=[ diff --git a/tests/contrib/openai/conftest.py b/tests/contrib/openai/conftest.py index 9ddffa0c74b..542c1568236 100644 --- a/tests/contrib/openai/conftest.py +++ b/tests/contrib/openai/conftest.py @@ -130,11 +130,11 @@ def mock_logs(scope="session"): @pytest.fixture def mock_llmobs_writer(scope="session"): - patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsWriter") + patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsSpanWriter") try: - LLMObsWriterMock = patcher.start() + LLMObsSpanWriterMock = patcher.start() m = mock.MagicMock() - LLMObsWriterMock.return_value = m + LLMObsSpanWriterMock.return_value = m yield m finally: patcher.stop() diff --git a/tests/integration/test_debug.py b/tests/integration/test_debug.py index 0d486355c26..cf5520dcb7c 100644 --- a/tests/integration/test_debug.py +++ b/tests/integration/test_debug.py @@ -336,7 +336,8 @@ def test_startup_logs_sampling_rules(): f = debug.collect(tracer) assert f.get("sampler_rules") == [ - "SamplingRule(sample_rate=1.0, service='NO_RULE', name='NO_RULE', resource='NO_RULE', tags='NO_RULE')" + "SamplingRule(sample_rate=1.0, service='NO_RULE', name='NO_RULE', resource='NO_RULE'," + " tags='NO_RULE', provenance='default')" ] sampler = ddtrace.sampler.DatadogSampler( @@ -346,7 +347,8 @@ def test_startup_logs_sampling_rules(): f = debug.collect(tracer) assert f.get("sampler_rules") == [ - "SamplingRule(sample_rate=1.0, service='xyz', name='abc', resource='NO_RULE', tags='NO_RULE')" + "SamplingRule(sample_rate=1.0, service='xyz', name='abc', resource='NO_RULE'," + " tags='NO_RULE', provenance='default')" ] diff --git a/tests/integration/test_integration_civisibility.py b/tests/integration/test_integration_civisibility.py index da406b75e4b..d1287452d5b 100644 --- a/tests/integration/test_integration_civisibility.py +++ b/tests/integration/test_integration_civisibility.py @@ -91,7 +91,7 @@ def test_civisibility_intake_payloads(): span.finish() conn = t._writer._conn t.shutdown() - assert conn.request.call_count == 2 + assert 2 <= conn.request.call_count <= 3 assert conn.request.call_args_list[0].args[1] == "api/v2/citestcycle" assert ( b"svc-no-cov" in conn.request.call_args_list[0].args[2] diff --git a/tests/internal/remoteconfig/test_remoteconfig.py b/tests/internal/remoteconfig/test_remoteconfig.py index e1870ef2867..feb83b775d6 100644 --- a/tests/internal/remoteconfig/test_remoteconfig.py +++ b/tests/internal/remoteconfig/test_remoteconfig.py @@ -10,6 +10,7 @@ from mock.mock import ANY import pytest +from ddtrace import config from ddtrace.internal.remoteconfig._connectors import PublisherSubscriberConnector from ddtrace.internal.remoteconfig._publishers import RemoteConfigPublisherMergeDicts from ddtrace.internal.remoteconfig._pubsub import PubSub @@ -20,6 +21,8 @@ from ddtrace.internal.remoteconfig.worker import RemoteConfigPoller from ddtrace.internal.remoteconfig.worker import remoteconfig_poller from ddtrace.internal.service import ServiceStatus +from ddtrace.sampler import DatadogSampler +from ddtrace.sampling_rule import SamplingRule from tests.internal.test_utils_version import _assert_and_get_version_agent_format from tests.utils import override_global_config @@ -313,7 +316,7 @@ def _reload_features(self, features, test_tracer=None): mock_send_request.assert_called() sleep(0.5) assert callback.features == {} - assert rc._client._last_error == "invalid agent payload received" + assert rc._client._last_error.startswith("invalid agent payload received") class Callback: features = {} @@ -351,7 +354,7 @@ def _reload_features(self, features, test_tracer=None): mock_send_request.assert_called() sleep(0.5) assert callback.features == {} - assert rc._client._last_error == "invalid agent payload received" + assert rc._client._last_error.startswith("invalid agent payload received") mock_send_request.return_value = get_mock_encoded_msg(b'{"asm":{"enabled":true}}') rc._online() @@ -428,3 +431,161 @@ def test_rc_default_products_registered(): assert bool(remoteconfig_poller._client._products.get("APM_TRACING")) == rc_enabled assert bool(remoteconfig_poller._client._products.get("AGENT_CONFIG")) == rc_enabled assert bool(remoteconfig_poller._client._products.get("AGENT_TASK")) == rc_enabled + + +@pytest.mark.parametrize( + "rc_rules,expected_config_rules,expected_sampling_rules", + [ + ( + [ # Test with all fields + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "dynamic", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"service": "my-service", "name": "web.request", "resource": "*", "provenance": "dynamic",' + ' "sample_rate": 1.0, "tags": {"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name="web.request", + resource="*", + tags={"care_about": "yes", "region": "us-*"}, + provenance="dynamic", + ) + ], + ), + ( # Test with no service + [ + { + "name": "web.request", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"name": "web.request", "resource": "*", "provenance": "customer", "sample_rate": 1.0, "tags": ' + '{"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service=SamplingRule.NO_RULE, + name="web.request", + resource="*", + tags={"care_about": "yes", "region": "us-*"}, + provenance="customer", + ) + ], + ), + ( + # Test with no tags + [ + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + } + ], + '[{"service": "my-service", "name": "web.request", "resource": "*", "provenance": ' + '"customer", "sample_rate": 1.0}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name="web.request", + resource="*", + tags=SamplingRule.NO_RULE, + provenance="customer", + ) + ], + ), + ( + # Test with no resource + [ + { + "service": "my-service", + "name": "web.request", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"service": "my-service", "name": "web.request", "provenance": "customer", "sample_rate": 1.0, "tags":' + ' {"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name="web.request", + resource=SamplingRule.NO_RULE, + tags={"care_about": "yes", "region": "us-*"}, + provenance="customer", + ) + ], + ), + ( + # Test with no name + [ + { + "service": "my-service", + "resource": "*", + "provenance": "customer", + "sample_rate": 1.0, + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + '[{"service": "my-service", "resource": "*", "provenance": "customer", "sample_rate": 1.0, "tags":' + ' {"care_about": "yes", "region": "us-*"}}]', + [ + SamplingRule( + sample_rate=1.0, + service="my-service", + name=SamplingRule.NO_RULE, + resource="*", + tags={"care_about": "yes", "region": "us-*"}, + provenance="customer", + ) + ], + ), + ( + # Test with no sample rate + [ + { + "service": "my-service", + "name": "web.request", + "resource": "*", + "provenance": "customer", + "tags": [{"key": "care_about", "value_glob": "yes"}, {"key": "region", "value_glob": "us-*"}], + } + ], + None, + None, + ), + ( + # Test with no service, name, resource, tags + [ + { + "provenance": "customer", + "sample_rate": 1.0, + } + ], + None, + None, + ), + ], +) +def test_trace_sampling_rules_conversion(rc_rules, expected_config_rules, expected_sampling_rules): + trace_sampling_rules = config.convert_rc_trace_sampling_rules(rc_rules) + + assert trace_sampling_rules == expected_config_rules + if trace_sampling_rules is not None: + parsed_rules = DatadogSampler._parse_rules_from_str(trace_sampling_rules) + assert parsed_rules == expected_sampling_rules diff --git a/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py b/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py index ad6a9e4436c..760fa4a2e7b 100644 --- a/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py +++ b/tests/internal/remoteconfig/test_remoteconfig_client_e2e.py @@ -18,7 +18,7 @@ def _expected_payload( rc_client, - capabilities="CPAA", # this was gathered by running the test and observing the payload + capabilities="IAjwAA==", # this was gathered by running the test and observing the payload has_errors=False, targets_version=0, backend_client_state=None, diff --git a/tests/internal/test_packages.py b/tests/internal/test_packages.py index 9189830a12f..763504159a2 100644 --- a/tests/internal/test_packages.py +++ b/tests/internal/test_packages.py @@ -86,8 +86,8 @@ def test_third_party_packages(): @pytest.mark.subprocess( env={ - "DD_THIRD_PARTY_DETECTION_EXCLUDES": "myfancypackage,myotherfancypackage", - "DD_THIRD_PARTY_DETECTION_INCLUDES": "requests", + "DD_THIRD_PARTY_DETECTION_INCLUDES": "myfancypackage,myotherfancypackage", + "DD_THIRD_PARTY_DETECTION_EXCLUDES": "requests", } ) def test_third_party_packages_excludes_includes(): diff --git a/tests/internal/test_settings.py b/tests/internal/test_settings.py index faea554f489..4bae14bfef9 100644 --- a/tests/internal/test_settings.py +++ b/tests/internal/test_settings.py @@ -256,6 +256,216 @@ def test_remoteconfig_sampling_rate_user(run_python_code_in_subprocess): assert status == 0, err.decode("utf-8") +def test_remoteconfig_sampling_rules(run_python_code_in_subprocess): + env = os.environ.copy() + env.update({"DD_TRACE_SAMPLING_RULES": '[{"sample_rate":0.1, "name":"test"}]'}) + + out, err, status, _ = run_python_code_in_subprocess( + """ +from ddtrace import config, tracer +from ddtrace.sampler import DatadogSampler +from tests.internal.test_settings import _base_rc_config, _deleted_rc_config + +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.1 +assert span.get_tag("_dd.p.dm") == "-3" + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "test", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.2, + } + ]})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.2 +assert span.get_tag("_dd.p.dm") == "-11" + +config._handle_remoteconfig(_base_rc_config({})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.1 + +custom_sampler = DatadogSampler(DatadogSampler._parse_rules_from_str('[{"sample_rate":0.3, "name":"test"}]')) +tracer.configure(sampler=custom_sampler) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "test", + "resource": "*", + "provenance": "dynamic", + "sample_rate": 0.4, + } + ]})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.4 +assert span.get_tag("_dd.p.dm") == "-12" + +config._handle_remoteconfig(_base_rc_config({})) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "ok", + "name": "test", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.4, + } + ]})) +with tracer.trace(service="ok", name="test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.4 +assert span.get_tag("_dd.p.dm") == "-11" + +config._handle_remoteconfig(_deleted_rc_config()) +with tracer.trace("test") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + + """, + env=env, + ) + assert status == 0, err.decode("utf-8") + + +def test_remoteconfig_sample_rate_and_rules(run_python_code_in_subprocess): + """There is complex logic regarding the interaction between setting new + sample rates and rules with remote config. + """ + env = os.environ.copy() + env.update({"DD_TRACE_SAMPLING_RULES": '[{"sample_rate":0.9, "name":"rules"}]'}) + env.update({"DD_TRACE_SAMPLE_RATE": "0.8"}) + + out, err, status, _ = run_python_code_in_subprocess( + """ +from ddtrace import config, tracer +from ddtrace.sampler import DatadogSampler +from tests.internal.test_settings import _base_rc_config, _deleted_rc_config + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "rules", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.7, + } + ]})) + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.7 +assert span.get_tag("_dd.p.dm") == "-11" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rate": 0.2})) + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.2 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rate": 0.3})) + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.3 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({})) + +with tracer.trace("rules") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.9 +assert span.get_tag("_dd.p.dm") == "-3" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + +config._handle_remoteconfig(_base_rc_config({"tracing_sampling_rules":[ + { + "service": "*", + "name": "rules_dynamic", + "resource": "*", + "provenance": "dynamic", + "sample_rate": 0.1, + }, + { + "service": "*", + "name": "rules_customer", + "resource": "*", + "provenance": "customer", + "sample_rate": 0.6, + } + ]})) + +with tracer.trace("rules_dynamic") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.1 +assert span.get_tag("_dd.p.dm") == "-12" + +with tracer.trace("rules_customer") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.6 +assert span.get_tag("_dd.p.dm") == "-11" + +with tracer.trace("sample_rate") as span: + pass +assert span.get_metric("_dd.rule_psr") == 0.8 +assert span.get_tag("_dd.p.dm") == "-3" + + """, + env=env, + ) + assert status == 0, err.decode("utf-8") + + def test_remoteconfig_custom_tags(run_python_code_in_subprocess): env = os.environ.copy() env.update({"DD_TAGS": "team:apm"}) diff --git a/tests/internal/test_tracer_flare.py b/tests/internal/test_tracer_flare.py index 35f38674e67..7051190e17d 100644 --- a/tests/internal/test_tracer_flare.py +++ b/tests/internal/test_tracer_flare.py @@ -2,13 +2,16 @@ from logging import Logger import multiprocessing import os +import pathlib from typing import Optional import unittest from unittest import mock +import uuid from ddtrace.internal.flare import TRACER_FLARE_DIRECTORY from ddtrace.internal.flare import TRACER_FLARE_FILE_HANDLER_NAME from ddtrace.internal.flare import Flare +from ddtrace.internal.flare import FlareSendRequest from ddtrace.internal.logger import get_logger @@ -16,25 +19,17 @@ class TracerFlareTests(unittest.TestCase): - mock_agent_config = [{"name": "flare-log-level", "config": {"log_level": "DEBUG"}}] - mock_agent_task = [ - False, - { - "args": { - "case_id": "1111111", - "hostname": "myhostname", - "user_handle": "user.name@datadoghq.com", - }, - "task_type": "tracer_flare", - "uuid": "d53fc8a4-8820-47a2-aa7d-d565582feb81", - }, - ] + mock_flare_send_request = FlareSendRequest( + case_id="1111111", hostname="myhostname", email="user.name@datadoghq.com" + ) def setUp(self): - self.flare = Flare() + self.flare_uuid = uuid.uuid4() + self.flare_dir = f"{TRACER_FLARE_DIRECTORY}-{self.flare_uuid}" + self.flare = Flare(flare_dir=pathlib.Path(self.flare_dir)) self.pid = os.getpid() - self.flare_file_path = f"{TRACER_FLARE_DIRECTORY}/tracer_python_{self.pid}.log" - self.config_file_path = f"{TRACER_FLARE_DIRECTORY}/tracer_config_{self.pid}.json" + self.flare_file_path = f"{self.flare_dir}/tracer_python_{self.pid}.log" + self.config_file_path = f"{self.flare_dir}/tracer_config_{self.pid}.json" def tearDown(self): self.confirm_cleanup() @@ -53,7 +48,7 @@ def test_single_process_success(self): """ ddlogger = get_logger("ddtrace") - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") file_handler = self._get_handler() valid_logger_level = self.flare._get_valid_logger_level(DEBUG_LEVEL_INT) @@ -66,7 +61,7 @@ def test_single_process_success(self): # Sends request to testagent # This just validates the request params - self.flare.send(self.mock_agent_task) + self.flare.send(self.mock_flare_send_request) def test_single_process_partial_failure(self): """ @@ -79,7 +74,7 @@ def test_single_process_partial_failure(self): # Mock the partial failure with mock.patch("json.dump") as mock_json: mock_json.side_effect = Exception("file issue happened") - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") file_handler = self._get_handler() assert file_handler is not None @@ -89,7 +84,7 @@ def test_single_process_partial_failure(self): assert os.path.exists(self.flare_file_path) assert not os.path.exists(self.config_file_path) - self.flare.send(self.mock_agent_task) + self.flare.send(self.mock_flare_send_request) def test_multiple_process_success(self): """ @@ -99,10 +94,10 @@ def test_multiple_process_success(self): num_processes = 3 def handle_agent_config(): - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") def handle_agent_task(): - self.flare.send(self.mock_agent_task) + self.flare.send(self.mock_flare_send_request) # Create multiple processes for _ in range(num_processes): @@ -114,7 +109,7 @@ def handle_agent_task(): # Assert that each process wrote its file successfully # We double the process number because each will generate a log file and a config file - assert len(processes) * 2 == len(os.listdir(TRACER_FLARE_DIRECTORY)) + assert len(processes) * 2 == len(os.listdir(self.flare_dir)) for _ in range(num_processes): p = multiprocessing.Process(target=handle_agent_task) @@ -130,19 +125,19 @@ def test_multiple_process_partial_failure(self): """ processes = [] - def do_tracer_flare(agent_config, agent_task): - self.flare.prepare(agent_config) + def do_tracer_flare(prep_request, send_request): + self.flare.prepare(prep_request) # Assert that only one process wrote its file successfully # We check for 2 files because it will generate a log file and a config file - assert 2 == len(os.listdir(TRACER_FLARE_DIRECTORY)) - self.flare.send(agent_task) + assert 2 == len(os.listdir(self.flare_dir)) + self.flare.send(send_request) # Create successful process - p = multiprocessing.Process(target=do_tracer_flare, args=(self.mock_agent_config, self.mock_agent_task)) + p = multiprocessing.Process(target=do_tracer_flare, args=("DEBUG", self.mock_flare_send_request)) processes.append(p) p.start() # Create failing process - p = multiprocessing.Process(target=do_tracer_flare, args=(None, self.mock_agent_task)) + p = multiprocessing.Process(target=do_tracer_flare, args=(None, self.mock_flare_send_request)) processes.append(p) p.start() for p in processes: @@ -154,7 +149,7 @@ def test_no_app_logs(self): file, just the tracer logs """ app_logger = Logger(name="my-app", level=DEBUG_LEVEL_INT) - self.flare.prepare(self.mock_agent_config) + self.flare.prepare("DEBUG") app_log_line = "this is an app log" app_logger.debug(app_log_line) @@ -169,5 +164,5 @@ def test_no_app_logs(self): self.flare.revert_configs() def confirm_cleanup(self): - assert not os.path.exists(TRACER_FLARE_DIRECTORY), f"The directory {TRACER_FLARE_DIRECTORY} still exists" + assert not self.flare.flare_dir.exists(), f"The directory {self.flare.flare_dir} still exists" assert self._get_handler() is None, "File handler was not removed" diff --git a/tests/llmobs/conftest.py b/tests/llmobs/conftest.py index d450a8e9057..a722c863c38 100644 --- a/tests/llmobs/conftest.py +++ b/tests/llmobs/conftest.py @@ -28,15 +28,21 @@ def pytest_configure(config): @pytest.fixture -def mock_llmobs_writer(): - patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsWriter") - LLMObsWriterMock = patcher.start() +def mock_llmobs_span_writer(): + patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsSpanWriter") + LLMObsSpanWriterMock = patcher.start() m = mock.MagicMock() - LLMObsWriterMock.return_value = m + LLMObsSpanWriterMock.return_value = m yield m patcher.stop() +@pytest.fixture +def mock_writer_logs(): + with mock.patch("ddtrace.llmobs._writer.logger") as m: + yield m + + @pytest.fixture def ddtrace_global_config(): config = {} @@ -48,7 +54,7 @@ def default_global_config(): @pytest.fixture -def LLMObs(mock_llmobs_writer, ddtrace_global_config): +def LLMObs(mock_llmobs_span_writer, ddtrace_global_config): global_config = default_global_config() global_config.update(ddtrace_global_config) with override_global_config(global_config): diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.send_score_metric.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.send_score_metric.yaml new file mode 100644 index 00000000000..4b45cb499b4 --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.send_score_metric.yaml @@ -0,0 +1,38 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678902", "trace_id": "98765432102", "metric_type": "score", "label": "sentiment", + "score_value": 0.9}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"3a7f7352-d7b0-485b-a352-46a66517da67","type":"evaluation_metric","attributes":{"metrics":[{"id":"b2660a92-f811-4bef-a45e-0eff068cea87","trace_id":"98765432102","span_id":"12345678902","timestamp":1714076672148,"metric_type":"score","label":"sentiment","score_value":0.9}]}}}' + headers: + Connection: + - keep-alive + Content-Length: + - '289' + Content-Type: + - application/vnd.api+json + Date: + - Thu, 25 Apr 2024 20:24:32 GMT + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_categorical_metric.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_categorical_metric.yaml new file mode 100644 index 00000000000..9d262616534 --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_categorical_metric.yaml @@ -0,0 +1,36 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678901", "trace_id": "98765432101", "metric_type": "categorical", "label": + "toxicity", "categorical_value": "high"}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"2e899670-dc6f-4e93-a8fa-cb43a38e8576","type":"evaluation_metric","attributes":{"metrics":[{"id":"a87b1729-8ed3-42df-9dc7-6db9d5ab7426","trace_id":"98765432101","span_id":"12345678901","timestamp":1714067595986,"metric_type":"categorical","label":"toxicity","categorical_value":"high"}]}}}' + headers: + content-length: + - '303' + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + content-type: + - application/vnd.api+json + date: + - Thu, 25 Apr 2024 17:53:16 GMT + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_metric_bad_api_key.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_metric_bad_api_key.yaml new file mode 100644 index 00000000000..34ab13d376b --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_metric_bad_api_key.yaml @@ -0,0 +1,32 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678901", "trace_id": "98765432101", "metric_type": "categorical", "label": + "toxicity", "categorical_value": "high"}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"status":"error","code":403,"errors":["Forbidden"],"statuspage":"http://status.datadoghq.com","twitter":"http://twitter.com/datadogops","email":"support@datadoghq.com"}' + headers: + Connection: + - keep-alive + Content-Length: + - '169' + Content-Type: + - application/json + Date: + - Wed, 24 Apr 2024 23:02:22 GMT + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-content-type-options: + - nosniff + status: + code: 403 + message: Forbidden +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_multiple_events.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_multiple_events.yaml new file mode 100644 index 00000000000..9225b826117 --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_multiple_events.yaml @@ -0,0 +1,39 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678902", "trace_id": "98765432102", "metric_type": "score", "label": "sentiment", + "score_value": 0.9}, {"span_id": "12345678903", "trace_id": "98765432103", "metric_type": + "numerical", "label": "token_count", "numerical_value": 35}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"ffe1e6ef-bfa4-4a7f-ae8e-2d184fc43fa5","type":"evaluation_metric","attributes":{"metrics":[{"id":"934b9d78-5734-4c8d-b7e2-9753a4b72fbc","trace_id":"98765432102","span_id":"12345678902","timestamp":1714076444671,"metric_type":"score","label":"sentiment","score_value":0.9},{"id":"44d9e3f1-4107-4fb9-a4e4-caaedca998c6","trace_id":"98765432103","span_id":"12345678903","timestamp":1714076444671,"metric_type":"numerical","label":"token_count","numerical_value":35}]}}}' + headers: + Connection: + - keep-alive + Content-Length: + - '479' + Content-Type: + - application/vnd.api+json + Date: + - Thu, 25 Apr 2024 20:20:44 GMT + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_numerical_metric.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_numerical_metric.yaml new file mode 100644 index 00000000000..ffdfcfc3f08 --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_numerical_metric.yaml @@ -0,0 +1,38 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678903", "trace_id": "98765432103", "metric_type": "numerical", "label": + "token_count", "numerical_value": 35}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"1d36673c-642d-402e-9861-e74edda54304","type":"evaluation_metric","attributes":{"metrics":[{"id":"1f4829cf-a2b0-4d84-a1c5-760cd693e76b","trace_id":"98765432103","span_id":"12345678903","timestamp":1714068096282,"metric_type":"numerical","label":"token_count","numerical_value":35}]}}}' + headers: + Connection: + - keep-alive + Content-Length: + - '298' + Content-Type: + - application/vnd.api+json + Date: + - Thu, 25 Apr 2024 18:01:36 GMT + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_score_metric.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_score_metric.yaml new file mode 100644 index 00000000000..0f9eaebe667 --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_score_metric.yaml @@ -0,0 +1,36 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678902", "trace_id": "98765432102", "metric_type": "score", "label": "sentiment", + "score_value": 0.9}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"33220b48-b299-4456-abe5-942ad9347f26","type":"evaluation_metric","attributes":{"metrics":[{"id":"9197af4e-e0a1-4bb3-b121-a5dbf80b6fa1","trace_id":"98765432102","span_id":"12345678902","timestamp":1714067719961,"metric_type":"score","label":"sentiment","score_value":0.9}]}}}' + headers: + content-length: + - '289' + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + content-type: + - application/vnd.api+json + date: + - Thu, 25 Apr 2024 17:55:19 GMT + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_timed_events.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_timed_events.yaml new file mode 100644 index 00000000000..60d10371e8e --- /dev/null +++ b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_eval_metric_writer.test_send_timed_events.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678902", "trace_id": "98765432102", "metric_type": "score", "label": "sentiment", + "score_value": 0.9}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"eb80fe76-bb98-4680-8759-2410aaf7d22b","type":"evaluation_metric","attributes":{"metrics":[{"id":"3b1d6a64-d3be-4abc-9ff0-72411cda4360","trace_id":"98765432102","span_id":"12345678902","timestamp":1714076534306,"metric_type":"score","label":"sentiment","score_value":0.9}]}}}' + headers: + Connection: + - keep-alive + Content-Length: + - '289' + Content-Type: + - application/vnd.api+json + Date: + - Thu, 25 Apr 2024 20:22:14 GMT + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +- request: + body: '{"data": {"type": "evaluation_metric", "attributes": {"metrics": [{"span_id": + "12345678903", "trace_id": "98765432103", "metric_type": "numerical", "label": + "token_count", "numerical_value": 35}]}}}' + headers: + Content-Type: + - application/json + DD-API-KEY: + - XXXXXX + method: POST + uri: https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric + response: + body: + string: '{"data":{"id":"ec0f3141-02ce-47cd-bab2-50b6f426bb95","type":"evaluation_metric","attributes":{"metrics":[{"id":"c9fff0a7-9859-4a53-b5f1-4155a30f765c","trace_id":"98765432103","span_id":"12345678903","timestamp":1714076535329,"metric_type":"numerical","label":"token_count","numerical_value":35}]}}}' + headers: + Connection: + - keep-alive + Content-Length: + - '298' + Content-Type: + - application/vnd.api+json + Date: + - Thu, 25 Apr 2024 20:22:15 GMT + content-security-policy: + - frame-ancestors 'self'; report-uri https://logs.browser-intake-datadoghq.com/api/v2/logs?dd-api-key=pub293163a918901030b79492fe1ab424cf&dd-evp-origin=content-security-policy&ddsource=csp-report&ddtags=site%3Adatad0g.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_chat_completion_event.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_chat_completion_event.yaml similarity index 100% rename from tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_chat_completion_event.yaml rename to tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_chat_completion_event.yaml diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_completion_bad_api_key.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_completion_bad_api_key.yaml similarity index 100% rename from tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_completion_bad_api_key.yaml rename to tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_completion_bad_api_key.yaml diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_completion_event.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_completion_event.yaml similarity index 100% rename from tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_completion_event.yaml rename to tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_completion_event.yaml diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_multiple_events.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_multiple_events.yaml similarity index 100% rename from tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_multiple_events.yaml rename to tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_multiple_events.yaml diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_timed_events.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_timed_events.yaml similarity index 100% rename from tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_timed_events.yaml rename to tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_span_writer.test_send_timed_events.yaml diff --git a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_on_exit.yaml b/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_on_exit.yaml deleted file mode 100644 index 92721c580c8..00000000000 --- a/tests/llmobs/llmobs_cassettes/tests.llmobs.test_llmobs_writer.test_send_on_exit.yaml +++ /dev/null @@ -1,42 +0,0 @@ -interactions: -- request: - body: '{"ml_obs": {"stage": "raw", "type": "span", "spans": [{"span_id": - "12345678901", "trace_id": "98765432101", "parent_id": "", "session_id": "98765432101", - "name": "completion_span", "tags": ["version:", "env:", "service:", "source:integration"], - "start_ns": 1707763310981223236, "duration": 12345678900, "error": 0, "meta": - {"span.kind": "llm", "model_name": "ada", "model_provider": "openai", "input": {"messages": - [{"content": "who broke enigma?"}], "parameters": {"temperature": 0, "max_tokens": - 256}}, "output": {"messages": [{"content": "\n\nThe Enigma code was broken by - a team of codebreakers at Bletchley Park, led by mathematician Alan Turing."}]}}, - "metrics": {"prompt_tokens": 64, "completion_tokens": 128, "total_tokens": 192}}]}}' - headers: - Content-Type: - - application/json - DD-API-KEY: - - XXXXXX - method: POST - uri: https://llmobs-intake.datad0g.com/api/v2/llmobs - response: - body: - string: '{}' - headers: - Connection: - - keep-alive - Content-Length: - - '2' - Content-Type: - - application/json - Date: - - Mon, 12 Feb 2024 20:49:24 GMT - accept-encoding: - - identity,gzip,x-gzip,deflate,x-deflate,zstd - cross-origin-resource-policy: - - cross-origin - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - x-content-type-options: - - nosniff - status: - code: 202 - message: Accepted -version: 1 diff --git a/tests/llmobs/test_llmobs_decorators.py b/tests/llmobs/test_llmobs_decorators.py index 7fc59ec0f9e..f106c9db51b 100644 --- a/tests/llmobs/test_llmobs_decorators.py +++ b/tests/llmobs/test_llmobs_decorators.py @@ -39,21 +39,21 @@ def f(): mock_logs.reset_mock() -def test_llm_decorator(LLMObs, mock_llmobs_writer): +def test_llm_decorator(LLMObs, mock_llmobs_span_writer): @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, "llm", model_name="test_model", model_provider="test_provider", session_id="test_session_id" ) ) -def test_llm_decorator_no_model_name_raises_error(LLMObs, mock_llmobs_writer): +def test_llm_decorator_no_model_name_raises_error(LLMObs, mock_llmobs_span_writer): with pytest.raises(TypeError): @llm(model_provider="test_provider", name="test_function", session_id="test_session_id") @@ -61,107 +61,107 @@ def f(): pass -def test_llm_decorator_default_kwargs(LLMObs, mock_llmobs_writer): +def test_llm_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): @llm(model_name="test_model") def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "llm", model_name="test_model", model_provider="custom") ) -def test_task_decorator(LLMObs, mock_llmobs_writer): +def test_task_decorator(LLMObs, mock_llmobs_span_writer): @task(name="test_function", session_id="test_session_id") def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "task", session_id="test_session_id") ) -def test_task_decorator_default_kwargs(LLMObs, mock_llmobs_writer): +def test_task_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): @task() def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) -def test_tool_decorator(LLMObs, mock_llmobs_writer): +def test_tool_decorator(LLMObs, mock_llmobs_span_writer): @tool(name="test_function", session_id="test_session_id") def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "tool", session_id="test_session_id") ) -def test_tool_decorator_default_kwargs(LLMObs, mock_llmobs_writer): +def test_tool_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): @tool() def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) -def test_workflow_decorator(LLMObs, mock_llmobs_writer): +def test_workflow_decorator(LLMObs, mock_llmobs_span_writer): @workflow(name="test_function", session_id="test_session_id") def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "workflow", session_id="test_session_id") ) -def test_workflow_decorator_default_kwargs(LLMObs, mock_llmobs_writer): +def test_workflow_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): @workflow() def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) -def test_agent_decorator(LLMObs, mock_llmobs_writer): +def test_agent_decorator(LLMObs, mock_llmobs_span_writer): @agent(name="test_function", session_id="test_session_id") def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "agent", session_id="test_session_id") ) -def test_agent_decorator_default_kwargs(LLMObs, mock_llmobs_writer): +def test_agent_decorator_default_kwargs(LLMObs, mock_llmobs_span_writer): @agent() def f(): pass f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) -def test_llm_decorator_with_error(LLMObs, mock_llmobs_writer): +def test_llm_decorator_with_error(LLMObs, mock_llmobs_span_writer): @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") def f(): raise ValueError("test_error") @@ -169,7 +169,7 @@ def f(): with pytest.raises(ValueError): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, "llm", @@ -183,7 +183,7 @@ def f(): ) -def test_non_llm_decorators_with_error(LLMObs, mock_llmobs_writer): +def test_non_llm_decorators_with_error(LLMObs, mock_llmobs_span_writer): for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool), ("agent", agent)]: @decorator(name="test_function", session_id="test_session_id") @@ -193,7 +193,7 @@ def f(): with pytest.raises(ValueError): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event( span, decorator_name, @@ -205,7 +205,7 @@ def f(): ) -def test_llm_annotate(LLMObs, mock_llmobs_writer): +def test_llm_annotate(LLMObs, mock_llmobs_span_writer): @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") def f(): LLMObs.annotate( @@ -218,7 +218,7 @@ def f(): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, "llm", @@ -234,7 +234,7 @@ def f(): ) -def test_llm_annotate_raw_string_io(LLMObs, mock_llmobs_writer): +def test_llm_annotate_raw_string_io(LLMObs, mock_llmobs_span_writer): @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") def f(): LLMObs.annotate( @@ -247,7 +247,7 @@ def f(): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, "llm", @@ -263,7 +263,7 @@ def f(): ) -def test_non_llm_decorators_no_args(LLMObs, mock_llmobs_writer): +def test_non_llm_decorators_no_args(LLMObs, mock_llmobs_span_writer): """Test that using the decorators without any arguments, i.e. @tool, works the same as @tool(...).""" for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool)]: @@ -273,10 +273,10 @@ def f(): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, decorator_name)) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, decorator_name)) -def test_agent_decorator_no_args(LLMObs, mock_llmobs_writer): +def test_agent_decorator_no_args(LLMObs, mock_llmobs_span_writer): """Test that using agent decorator without any arguments, i.e. @agent, works the same as @agent(...).""" @agent @@ -285,10 +285,10 @@ def f(): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) -def test_ml_app_override(LLMObs, mock_llmobs_writer): +def test_ml_app_override(LLMObs, mock_llmobs_span_writer): """Test that setting ml_app kwarg on the LLMObs decorators will override the DD_LLMOBS_APP_NAME value.""" for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool)]: @@ -298,7 +298,7 @@ def f(): f() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, decorator_name, tags={"ml_app": "test_ml_app"}) ) @@ -308,7 +308,7 @@ def g(): g() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, "llm", model_name="test_model", model_provider="custom", tags={"ml_app": "test_ml_app"} ) @@ -320,6 +320,6 @@ def h(): h() span = LLMObs._instance.tracer.pop()[0] - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_ml_app"}) ) diff --git a/tests/llmobs/test_llmobs_eval_metric_writer.py b/tests/llmobs/test_llmobs_eval_metric_writer.py new file mode 100644 index 00000000000..2f9368c8bd8 --- /dev/null +++ b/tests/llmobs/test_llmobs_eval_metric_writer.py @@ -0,0 +1,171 @@ +import os +import time + +import mock +import pytest + +from ddtrace.llmobs._writer import LLMObsEvalMetricWriter + + +INTAKE_ENDPOINT = "https://api.datad0g.com/api/unstable/llm-obs/v1/eval-metric" +DD_SITE = "datad0g.com" +dd_api_key = os.getenv("DD_API_KEY", default="") + + +def _categorical_metric_event(): + return { + "span_id": "12345678901", + "trace_id": "98765432101", + "metric_type": "categorical", + "label": "toxicity", + "categorical_value": "high", + } + + +def _score_metric_event(): + return { + "span_id": "12345678902", + "trace_id": "98765432102", + "metric_type": "score", + "label": "sentiment", + "score_value": 0.9, + } + + +def _numerical_metric_event(): + return { + "span_id": "12345678903", + "trace_id": "98765432103", + "metric_type": "numerical", + "label": "token_count", + "numerical_value": 35, + } + + +def test_writer_start(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=1000, timeout=1) + llmobs_eval_metric_writer.start() + mock_writer_logs.debug.assert_has_calls( + [mock.call("started %r to %r", ("LLMObsEvalMetricWriter", INTAKE_ENDPOINT))] + ) + + +def test_buffer_limit(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datadoghq.com", api_key="asdf", interval=1000, timeout=1) + for _ in range(1001): + llmobs_eval_metric_writer.enqueue({}) + mock_writer_logs.warning.assert_called_with( + "%r event buffer full (limit is %d), dropping event", ("LLMObsEvalMetricWriter", 1000) + ) + + +@pytest.mark.vcr_logs +def test_send_metric_bad_api_key(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter( + site="datad0g.com", api_key="", interval=1000, timeout=1 + ) + llmobs_eval_metric_writer.start() + llmobs_eval_metric_writer.enqueue(_categorical_metric_event()) + llmobs_eval_metric_writer.periodic() + mock_writer_logs.error.assert_called_with( + "failed to send %d LLMObs %s events to %s, got response code %d, status: %s", + ( + 1, + "evaluation_metric", + INTAKE_ENDPOINT, + 403, + b'{"status":"error","code":403,"errors":["Forbidden"],"statuspage":"http://status.datadoghq.com","twitter":"http://twitter.com/datadogops","email":"support@datadoghq.com"}', # noqa + ), + ) + + +@pytest.mark.vcr_logs +def test_send_categorical_metric(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=1000, timeout=1) + llmobs_eval_metric_writer.start() + llmobs_eval_metric_writer.enqueue(_categorical_metric_event()) + llmobs_eval_metric_writer.periodic() + mock_writer_logs.debug.assert_has_calls( + [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + ) + + +@pytest.mark.vcr_logs +def test_send_numerical_metric(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=1000, timeout=1) + llmobs_eval_metric_writer.start() + llmobs_eval_metric_writer.enqueue(_numerical_metric_event()) + llmobs_eval_metric_writer.periodic() + mock_writer_logs.debug.assert_has_calls( + [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + ) + + +@pytest.mark.vcr_logs +def test_send_score_metric(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=1000, timeout=1) + llmobs_eval_metric_writer.start() + llmobs_eval_metric_writer.enqueue(_score_metric_event()) + llmobs_eval_metric_writer.periodic() + mock_writer_logs.debug.assert_has_calls( + [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + ) + + +@pytest.mark.vcr_logs +def test_send_timed_events(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=0.01, timeout=1) + llmobs_eval_metric_writer.start() + mock_writer_logs.reset_mock() + + llmobs_eval_metric_writer.enqueue(_score_metric_event()) + time.sleep(0.1) + mock_writer_logs.debug.assert_has_calls( + [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + ) + mock_writer_logs.reset_mock() + llmobs_eval_metric_writer.enqueue(_numerical_metric_event()) + time.sleep(0.1) + mock_writer_logs.debug.assert_has_calls( + [mock.call("sent %d LLMObs %s events to %s", (1, "evaluation_metric", INTAKE_ENDPOINT))] + ) + + +@pytest.mark.vcr_logs +def test_send_multiple_events(mock_writer_logs): + llmobs_eval_metric_writer = LLMObsEvalMetricWriter(site="datad0g.com", api_key=dd_api_key, interval=0.01, timeout=1) + llmobs_eval_metric_writer.start() + mock_writer_logs.reset_mock() + + llmobs_eval_metric_writer.enqueue(_score_metric_event()) + llmobs_eval_metric_writer.enqueue(_numerical_metric_event()) + time.sleep(0.1) + mock_writer_logs.debug.assert_has_calls( + [mock.call("sent %d LLMObs %s events to %s", (2, "evaluation_metric", INTAKE_ENDPOINT))] + ) + + +def test_send_on_exit(mock_writer_logs, run_python_code_in_subprocess): + out, err, status, pid = run_python_code_in_subprocess( + """ +import atexit +import os +import time + +from ddtrace.llmobs._writer import LLMObsEvalMetricWriter +from tests.llmobs.test_llmobs_eval_metric_writer import _score_metric_event +from tests.llmobs._utils import logs_vcr + +ctx = logs_vcr.use_cassette("tests.llmobs.test_llmobs_eval_metric_writer.send_score_metric.yaml") +ctx.__enter__() +atexit.register(lambda: ctx.__exit__()) +llmobs_eval_metric_writer = LLMObsEvalMetricWriter( +site="datad0g.com", api_key=os.getenv("DD_API_KEY"), interval=0.01, timeout=1 +) +llmobs_eval_metric_writer.start() +llmobs_eval_metric_writer.enqueue(_score_metric_event()) +""", + ) + assert status == 0, err + assert out == b"" + assert err == b"" diff --git a/tests/llmobs/test_llmobs_service.py b/tests/llmobs/test_llmobs_service.py index e91a6cd64ef..88941a275e2 100644 --- a/tests/llmobs/test_llmobs_service.py +++ b/tests/llmobs/test_llmobs_service.py @@ -132,16 +132,16 @@ def test_llmobs_start_span_with_session_id(LLMObs): assert span.get_tag(SESSION_ID) == "test_session_id" -def test_llmobs_session_id_becomes_top_level_field(LLMObs, mock_llmobs_writer): +def test_llmobs_session_id_becomes_top_level_field(LLMObs, mock_llmobs_span_writer): session_id = "test_session_id" with LLMObs.task(session_id=session_id) as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "task", session_id=session_id) ) -def test_llmobs_llm_span(LLMObs, mock_llmobs_writer): +def test_llmobs_llm_span(LLMObs, mock_llmobs_span_writer): with LLMObs.llm(model_name="test_model", name="test_llm_call", model_provider="test_provider") as span: assert span.name == "test_llm_call" assert span.resource == "llm" @@ -151,7 +151,7 @@ def test_llmobs_llm_span(LLMObs, mock_llmobs_writer): assert span.get_tag(MODEL_PROVIDER) == "test_provider" assert span.get_tag(SESSION_ID) == "{:x}".format(span.trace_id) - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "llm", model_name="test_model", model_provider="test_provider") ) @@ -177,40 +177,40 @@ def test_llmobs_default_model_provider_set_to_custom(LLMObs): assert span.get_tag(MODEL_PROVIDER) == "custom" -def test_llmobs_tool_span(LLMObs, mock_llmobs_writer): +def test_llmobs_tool_span(LLMObs, mock_llmobs_span_writer): with LLMObs.tool(name="test_tool") as span: assert span.name == "test_tool" assert span.resource == "tool" assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "tool" - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) -def test_llmobs_task_span(LLMObs, mock_llmobs_writer): +def test_llmobs_task_span(LLMObs, mock_llmobs_span_writer): with LLMObs.task(name="test_task") as span: assert span.name == "test_task" assert span.resource == "task" assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "task" - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) -def test_llmobs_workflow_span(LLMObs, mock_llmobs_writer): +def test_llmobs_workflow_span(LLMObs, mock_llmobs_span_writer): with LLMObs.workflow(name="test_workflow") as span: assert span.name == "test_workflow" assert span.resource == "workflow" assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "workflow" - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) -def test_llmobs_agent_span(LLMObs, mock_llmobs_writer): +def test_llmobs_agent_span(LLMObs, mock_llmobs_span_writer): with LLMObs.agent(name="test_agent") as span: assert span.name == "test_agent" assert span.resource == "agent" assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "agent" - mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) + mock_llmobs_span_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) def test_llmobs_annotate_while_disabled_logs_warning(LLMObs, mock_logs): @@ -339,7 +339,16 @@ def test_llmobs_annotate_input_llm_message_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, input_data=[{"content": Unserializable()}]) assert llm_span.get_tag(INPUT_MESSAGES) is None - mock_logs.warning.assert_called_once_with("Failed to parse input messages.") + mock_logs.warning.assert_called_once_with("Failed to parse input messages.", exc_info=True) + + +def test_llmobs_annotate_incorrect_message_content_type_raises_warning(LLMObs, mock_logs): + with LLMObs.llm(model_name="test_model") as llm_span: + LLMObs.annotate(span=llm_span, input_data={"role": "user", "content": {"nested": "yes"}}) + mock_logs.warning.assert_called_once_with("Failed to parse input messages.", exc_info=True) + mock_logs.reset_mock() + LLMObs.annotate(span=llm_span, output_data={"role": "user", "content": {"nested": "yes"}}) + mock_logs.warning.assert_called_once_with("Failed to parse output messages.", exc_info=True) def test_llmobs_annotate_output_string(LLMObs): @@ -394,7 +403,7 @@ def test_llmobs_annotate_output_llm_message_wrong_type(LLMObs, mock_logs): with LLMObs.llm(model_name="test_model") as llm_span: LLMObs.annotate(span=llm_span, output_data=[{"content": Unserializable()}]) assert llm_span.get_tag(OUTPUT_MESSAGES) is None - mock_logs.warning.assert_called_once_with("Failed to parse output messages.") + mock_logs.warning.assert_called_once_with("Failed to parse output messages.", exc_info=True) def test_llmobs_annotate_metrics(LLMObs): @@ -417,11 +426,11 @@ def test_llmobs_annotate_metrics_wrong_type(LLMObs, mock_logs): ) -def test_llmobs_span_error_sets_error(LLMObs, mock_llmobs_writer): +def test_llmobs_span_error_sets_error(LLMObs, mock_llmobs_span_writer): with pytest.raises(ValueError): with LLMObs.llm(model_name="test_model", model_provider="test_model_provider") as span: raise ValueError("test error message") - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, model_name="test_model", @@ -437,10 +446,10 @@ def test_llmobs_span_error_sets_error(LLMObs, mock_llmobs_writer): "ddtrace_global_config", [dict(version="1.2.3", env="test_env", service="test_service", _llmobs_ml_app="test_app_name")], ) -def test_llmobs_tags(ddtrace_global_config, LLMObs, mock_llmobs_writer, monkeypatch): +def test_llmobs_tags(ddtrace_global_config, LLMObs, mock_llmobs_span_writer, monkeypatch): with LLMObs.task(name="test_task") as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event( span, "task", @@ -449,22 +458,22 @@ def test_llmobs_tags(ddtrace_global_config, LLMObs, mock_llmobs_writer, monkeypa ) -def test_llmobs_ml_app_override(LLMObs, mock_llmobs_writer): +def test_llmobs_ml_app_override(LLMObs, mock_llmobs_span_writer): with LLMObs.task(name="test_task", ml_app="test_app") as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "task", tags={"ml_app": "test_app"}) ) with LLMObs.tool(name="test_tool", ml_app="test_app") as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "tool", tags={"ml_app": "test_app"}) ) with LLMObs.llm(model_name="model_name", name="test_llm", ml_app="test_app") as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event( span, "llm", model_name="model_name", model_provider="custom", tags={"ml_app": "test_app"} ) @@ -472,12 +481,12 @@ def test_llmobs_ml_app_override(LLMObs, mock_llmobs_writer): with LLMObs.workflow(name="test_workflow", ml_app="test_app") as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_non_llm_span_event(span, "workflow", tags={"ml_app": "test_app"}) ) with LLMObs.agent(name="test_agent", ml_app="test_app") as span: pass - mock_llmobs_writer.enqueue.assert_called_with( + mock_llmobs_span_writer.enqueue.assert_called_with( _expected_llmobs_llm_span_event(span, "agent", tags={"ml_app": "test_app"}) ) diff --git a/tests/llmobs/test_llmobs_span_writer.py b/tests/llmobs/test_llmobs_span_writer.py new file mode 100644 index 00000000000..7032acad45f --- /dev/null +++ b/tests/llmobs/test_llmobs_span_writer.py @@ -0,0 +1,184 @@ +import os +import time + +import mock +import pytest + +from ddtrace.llmobs._writer import LLMObsSpanWriter + + +INTAKE_ENDPOINT = "https://llmobs-intake.datad0g.com/api/v2/llmobs" +DD_SITE = "datad0g.com" +dd_api_key = os.getenv("DD_API_KEY", default="") + + +def _completion_event(): + return { + "kind": "llm", + "span_id": "12345678901", + "trace_id": "98765432101", + "parent_id": "", + "session_id": "98765432101", + "name": "completion_span", + "tags": ["version:", "env:", "service:", "source:integration"], + "start_ns": 1707763310981223236, + "duration": 12345678900, + "error": 0, + "meta": { + "span.kind": "llm", + "model_name": "ada", + "model_provider": "openai", + "input": { + "messages": [{"content": "who broke enigma?"}], + "parameters": {"temperature": 0, "max_tokens": 256}, + }, + "output": { + "messages": [ + { + "content": "\n\nThe Enigma code was broken by a team of codebreakers at Bletchley Park, led by mathematician Alan Turing." # noqa: E501 + } + ] + }, + }, + "metrics": {"prompt_tokens": 64, "completion_tokens": 128, "total_tokens": 192}, + } + + +def _chat_completion_event(): + return { + "span_id": "12345678902", + "trace_id": "98765432102", + "parent_id": "", + "session_id": "98765432102", + "name": "chat_completion_span", + "tags": ["version:", "env:", "service:", "source:integration"], + "start_ns": 1707763310981223936, + "duration": 12345678900, + "error": 0, + "meta": { + "span.kind": "llm", + "model_name": "gpt-3.5-turbo", + "model_provider": "openai", + "input": { + "messages": [ + { + "role": "system", + "content": "You are an evil dark lord looking for his one ring to rule them all", + }, + {"role": "user", "content": "I am a hobbit looking to go to Mordor"}, + ], + "parameters": {"temperature": 0.9, "max_tokens": 256}, + }, + "output": { + "messages": [ + { + "content": "Ah, a bold and foolish hobbit seeking to challenge my dominion in Mordor. Very well, little creature, I shall play along. But know that I am always watching, and your quest will not go unnoticed", # noqa: E501 + "role": "assistant", + }, + ] + }, + }, + "metrics": {"prompt_tokens": 64, "completion_tokens": 128, "total_tokens": 192}, + } + + +def test_writer_start(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key="asdf", interval=1000, timeout=1) + llmobs_span_writer.start() + mock_writer_logs.debug.assert_has_calls([mock.call("started %r to %r", ("LLMObsSpanWriter", INTAKE_ENDPOINT))]) + + +def test_buffer_limit(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key="asdf", interval=1000, timeout=1) + for _ in range(1001): + llmobs_span_writer.enqueue({}) + mock_writer_logs.warning.assert_called_with( + "%r event buffer full (limit is %d), dropping event", ("LLMObsSpanWriter", 1000) + ) + + +@pytest.mark.vcr_logs +def test_send_completion_event(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key=dd_api_key, interval=1, timeout=1) + llmobs_span_writer.start() + llmobs_span_writer.enqueue(_completion_event()) + llmobs_span_writer.periodic() + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + + +@pytest.mark.vcr_logs +def test_send_chat_completion_event(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key=dd_api_key, interval=1, timeout=1) + llmobs_span_writer.start() + llmobs_span_writer.enqueue(_chat_completion_event()) + llmobs_span_writer.periodic() + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + + +@pytest.mark.vcr_logs +def test_send_completion_bad_api_key(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key="", interval=1, timeout=1) + llmobs_span_writer.start() + llmobs_span_writer.enqueue(_completion_event()) + llmobs_span_writer.periodic() + mock_writer_logs.error.assert_called_with( + "failed to send %d LLMObs %s events to %s, got response code %d, status: %s", + ( + 1, + "span", + INTAKE_ENDPOINT, + 403, + b'{"errors":[{"status":"403","title":"Forbidden","detail":"API key is invalid"}]}', + ), + ) + + +@pytest.mark.vcr_logs +def test_send_timed_events(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key=dd_api_key, interval=0.01, timeout=1) + llmobs_span_writer.start() + mock_writer_logs.reset_mock() + + llmobs_span_writer.enqueue(_completion_event()) + time.sleep(0.1) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + mock_writer_logs.reset_mock() + llmobs_span_writer.enqueue(_chat_completion_event()) + time.sleep(0.1) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (1, "span", INTAKE_ENDPOINT))]) + + +@pytest.mark.vcr_logs +def test_send_multiple_events(mock_writer_logs): + llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key=dd_api_key, interval=0.01, timeout=1) + llmobs_span_writer.start() + mock_writer_logs.reset_mock() + + llmobs_span_writer.enqueue(_completion_event()) + llmobs_span_writer.enqueue(_chat_completion_event()) + time.sleep(0.1) + mock_writer_logs.debug.assert_has_calls([mock.call("sent %d LLMObs %s events to %s", (2, "span", INTAKE_ENDPOINT))]) + + +def test_send_on_exit(mock_writer_logs, run_python_code_in_subprocess): + out, err, status, pid = run_python_code_in_subprocess( + """ +import atexit +import os +import time + +from ddtrace.llmobs._writer import LLMObsSpanWriter +from tests.llmobs.test_llmobs_span_writer import _completion_event +from tests.llmobs._utils import logs_vcr + +ctx = logs_vcr.use_cassette("tests.llmobs.test_llmobs_span_writer.test_send_completion_event.yaml") +ctx.__enter__() +atexit.register(lambda: ctx.__exit__()) +llmobs_span_writer = LLMObsSpanWriter(site="datad0g.com", api_key=os.getenv("DD_API_KEY"), interval=0.01, timeout=1) +llmobs_span_writer.start() +llmobs_span_writer.enqueue(_completion_event()) +""", + ) + assert status == 0, err + assert out == b"" + assert err == b"" diff --git a/tests/llmobs/test_llmobs_trace_processor.py b/tests/llmobs/test_llmobs_trace_processor.py index 4d15fb056a8..85d9f0541c2 100644 --- a/tests/llmobs/test_llmobs_trace_processor.py +++ b/tests/llmobs/test_llmobs_trace_processor.py @@ -31,7 +31,7 @@ def mock_logs(): def test_processor_returns_all_traces_by_default(monkeypatch): """Test that the LLMObsTraceProcessor returns all traces by default.""" - trace_filter = LLMObsTraceProcessor(llmobs_writer=mock.MagicMock()) + trace_filter = LLMObsTraceProcessor(llmobs_span_writer=mock.MagicMock()) root_llm_span = Span(name="span1", span_type=SpanTypes.LLM) root_llm_span.set_tag_str(SPAN_KIND, "llm") trace1 = [root_llm_span] @@ -41,7 +41,7 @@ def test_processor_returns_all_traces_by_default(monkeypatch): def test_processor_returns_all_traces_if_no_apm_env_var_is_false(monkeypatch): """Test that the LLMObsTraceProcessor returns all traces if DD_LLMOBS_NO_APM is not set to true.""" monkeypatch.setenv("DD_LLMOBS_NO_APM", "0") - trace_filter = LLMObsTraceProcessor(llmobs_writer=mock.MagicMock()) + trace_filter = LLMObsTraceProcessor(llmobs_span_writer=mock.MagicMock()) root_llm_span = Span(name="span1", span_type=SpanTypes.LLM) root_llm_span.set_tag_str(SPAN_KIND, "llm") trace1 = [root_llm_span] @@ -51,7 +51,7 @@ def test_processor_returns_all_traces_if_no_apm_env_var_is_false(monkeypatch): def test_processor_returns_none_if_no_apm_env_var_is_true(monkeypatch): """Test that the LLMObsTraceProcessor returns None if DD_LLMOBS_NO_APM is set to true.""" monkeypatch.setenv("DD_LLMOBS_NO_APM", "1") - trace_filter = LLMObsTraceProcessor(llmobs_writer=mock.MagicMock()) + trace_filter = LLMObsTraceProcessor(llmobs_span_writer=mock.MagicMock()) root_llm_span = Span(name="span1", span_type=SpanTypes.LLM) root_llm_span.set_tag_str(SPAN_KIND, "llm") trace1 = [root_llm_span] @@ -60,21 +60,21 @@ def test_processor_returns_none_if_no_apm_env_var_is_true(monkeypatch): def test_processor_creates_llmobs_span_event(): with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): - mock_llmobs_writer = mock.MagicMock() - trace_filter = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + mock_llmobs_span_writer = mock.MagicMock() + trace_filter = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) root_llm_span = Span(name="root", span_type=SpanTypes.LLM) root_llm_span.set_tag_str(SPAN_KIND, "llm") trace = [root_llm_span] trace_filter.process_trace(trace) - assert mock_llmobs_writer.enqueue.call_count == 1 - mock_llmobs_writer.assert_has_calls([mock.call.enqueue(_expected_llmobs_llm_span_event(root_llm_span, "llm"))]) + assert mock_llmobs_span_writer.enqueue.call_count == 1 + mock_llmobs_span_writer.assert_has_calls([mock.call.enqueue(_expected_llmobs_llm_span_event(root_llm_span, "llm"))]) def test_processor_only_creates_llmobs_span_event(): """Test that the LLMObsTraceProcessor only creates LLMObs span events for LLM span types.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() - trace_filter = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + mock_llmobs_span_writer = mock.MagicMock() + trace_filter = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as root_span: root_span.set_tag_str(SPAN_KIND, "llm") @@ -85,8 +85,8 @@ def test_processor_only_creates_llmobs_span_event(): expected_grandchild_llmobs_span = _expected_llmobs_llm_span_event(grandchild_span, "llm") expected_grandchild_llmobs_span["parent_id"] = str(root_span.span_id) trace_filter.process_trace(trace) - assert mock_llmobs_writer.enqueue.call_count == 2 - mock_llmobs_writer.assert_has_calls( + assert mock_llmobs_span_writer.enqueue.call_count == 2 + mock_llmobs_span_writer.assert_has_calls( [ mock.call.enqueue(_expected_llmobs_llm_span_event(root_span, "llm")), mock.call.enqueue(expected_grandchild_llmobs_span), @@ -220,28 +220,28 @@ def test_ml_app_propagates_ignore_non_llmobs_spans(): def test_malformed_span_logs_error_instead_of_raising(mock_logs): """Test that a trying to create a span event from a malformed span will log an error instead of crashing.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: # span does not have SPAN_KIND tag pass - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) tp.process_trace([llm_span]) mock_logs.error.assert_called_once_with( "Error generating LLMObs span event for span %s, likely due to malformed span", llm_span ) - mock_llmobs_writer.enqueue.assert_not_called() + mock_llmobs_span_writer.enqueue.assert_not_called() def test_model_and_provider_are_set(): """Test that model and provider are set on the span event if they are present on the LLM-kind span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(MODEL_NAME, "model_name") llm_span.set_tag(MODEL_PROVIDER, "model_provider") - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) span_event = tp._llmobs_span_event(llm_span) assert span_event["meta"]["model_name"] == "model_name" assert span_event["meta"]["model_provider"] == "model_provider" @@ -250,12 +250,12 @@ def test_model_and_provider_are_set(): def test_model_provider_defaults_to_custom(): """Test that model provider defaults to "custom" if not provided.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(MODEL_NAME, "model_name") - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) span_event = tp._llmobs_span_event(llm_span) assert span_event["meta"]["model_name"] == "model_name" assert span_event["meta"]["model_provider"] == "custom" @@ -264,12 +264,12 @@ def test_model_provider_defaults_to_custom(): def test_model_not_set_if_not_llm_kind_span(): """Test that model name and provider not set if non-LLM span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_workflow_span", span_type=SpanTypes.LLM) as span: span.set_tag(SPAN_KIND, "workflow") span.set_tag(MODEL_NAME, "model_name") - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) span_event = tp._llmobs_span_event(span) assert "model_name" not in span_event["meta"] assert "model_provider" not in span_event["meta"] @@ -278,97 +278,97 @@ def test_model_not_set_if_not_llm_kind_span(): def test_input_messages_are_set(): """Test that input messages are set on the span event if they are present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(INPUT_MESSAGES, '[{"content": "message", "role": "user"}]') - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["meta"]["input"]["messages"] == [{"content": "message", "role": "user"}] def test_input_value_is_set(): """Test that input value is set on the span event if they are present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(INPUT_VALUE, "value") - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["meta"]["input"]["value"] == "value" def test_input_parameters_are_set(): """Test that input parameters are set on the span event if they are present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(INPUT_PARAMETERS, '{"key": "value"}') - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["meta"]["input"]["parameters"] == {"key": "value"} def test_output_messages_are_set(): """Test that output messages are set on the span event if they are present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(OUTPUT_MESSAGES, '[{"content": "message", "role": "user"}]') - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["meta"]["output"]["messages"] == [{"content": "message", "role": "user"}] def test_output_value_is_set(): """Test that output value is set on the span event if they are present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(OUTPUT_VALUE, "value") - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["meta"]["output"]["value"] == "value" def test_metadata_is_set(): """Test that metadata is set on the span event if it is present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(METADATA, '{"key": "value"}') - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["meta"]["metadata"] == {"key": "value"} def test_metrics_are_set(): """Test that metadata is set on the span event if it is present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") llm_span.set_tag(METRICS, '{"tokens": 100}') - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) assert tp._llmobs_span_event(llm_span)["metrics"] == {"tokens": 100} def test_error_is_set(): """Test that error is set on the span event if it is present on the span.""" dummy_tracer = DummyTracer() - mock_llmobs_writer = mock.MagicMock() + mock_llmobs_span_writer = mock.MagicMock() with override_global_config(dict(_llmobs_ml_app="unnamed-ml-app")): with pytest.raises(ValueError): with dummy_tracer.trace("root_llm_span", span_type=SpanTypes.LLM) as llm_span: llm_span.set_tag(SPAN_KIND, "llm") raise ValueError("error") - tp = LLMObsTraceProcessor(llmobs_writer=mock_llmobs_writer) + tp = LLMObsTraceProcessor(llmobs_span_writer=mock_llmobs_span_writer) span_event = tp._llmobs_span_event(llm_span) assert span_event["meta"]["error.message"] == "error" assert "ValueError" in span_event["meta"]["error.type"] diff --git a/tests/llmobs/test_llmobs_writer.py b/tests/llmobs/test_llmobs_writer.py deleted file mode 100644 index 32248cc1c4b..00000000000 --- a/tests/llmobs/test_llmobs_writer.py +++ /dev/null @@ -1,183 +0,0 @@ -import os -import time - -import mock -import pytest - -from ddtrace.llmobs._writer import LLMObsWriter - - -INTAKE_ENDPOINT = "https://llmobs-intake.datad0g.com/api/v2/llmobs" -DD_SITE = "datad0g.com" -dd_api_key = os.getenv("DD_API_KEY", default="") - - -@pytest.fixture -def mock_logs(): - with mock.patch("ddtrace.llmobs._writer.logger") as m: - yield m - - -def _completion_event(): - return { - "kind": "llm", - "span_id": "12345678901", - "trace_id": "98765432101", - "parent_id": "", - "session_id": "98765432101", - "name": "completion_span", - "tags": ["version:", "env:", "service:", "source:integration"], - "start_ns": 1707763310981223236, - "duration": 12345678900, - "error": 0, - "meta": { - "span.kind": "llm", - "model_name": "ada", - "model_provider": "openai", - "input": { - "messages": [{"content": "who broke enigma?"}], - "parameters": {"temperature": 0, "max_tokens": 256}, - }, - "output": { - "messages": [ - { - "content": "\n\nThe Enigma code was broken by a team of codebreakers at Bletchley Park, led by mathematician Alan Turing." # noqa: E501 - } - ] - }, - }, - "metrics": {"prompt_tokens": 64, "completion_tokens": 128, "total_tokens": 192}, - } - - -def _chat_completion_event(): - return { - "span_id": "12345678902", - "trace_id": "98765432102", - "parent_id": "", - "session_id": "98765432102", - "name": "chat_completion_span", - "tags": ["version:", "env:", "service:", "source:integration"], - "start_ns": 1707763310981223936, - "duration": 12345678900, - "error": 0, - "meta": { - "span.kind": "llm", - "model_name": "gpt-3.5-turbo", - "model_provider": "openai", - "input": { - "messages": [ - { - "role": "system", - "content": "You are an evil dark lord looking for his one ring to rule them all", - }, - {"role": "user", "content": "I am a hobbit looking to go to Mordor"}, - ], - "parameters": {"temperature": 0.9, "max_tokens": 256}, - }, - "output": { - "messages": [ - { - "content": "Ah, a bold and foolish hobbit seeking to challenge my dominion in Mordor. Very well, little creature, I shall play along. But know that I am always watching, and your quest will not go unnoticed", # noqa: E501 - "role": "assistant", - }, - ] - }, - }, - "metrics": {"prompt_tokens": 64, "completion_tokens": 128, "total_tokens": 192}, - } - - -def test_buffer_limit(mock_logs): - llmobs_writer = LLMObsWriter(site="datadoghq.com", api_key="asdf", interval=1000, timeout=1) - for _ in range(1001): - llmobs_writer.enqueue({}) - mock_logs.warning.assert_called_with("LLMobs event buffer full (limit is %d), dropping record", 1000) - - -@pytest.mark.vcr_logs -def test_send_completion_event(mock_logs): - llmobs_writer = LLMObsWriter(site="datad0g.com", api_key=dd_api_key, interval=1, timeout=1) - llmobs_writer.start() - mock_logs.debug.assert_has_calls([mock.call("started llmobs writer to %r", INTAKE_ENDPOINT)]) - llmobs_writer.enqueue(_completion_event()) - mock_logs.reset_mock() - llmobs_writer.periodic() - mock_logs.debug.assert_has_calls([mock.call("sent %d LLMObs events to %r", 1, INTAKE_ENDPOINT)]) - - -@pytest.mark.vcr_logs -def test_send_chat_completion_event(mock_logs): - llmobs_writer = LLMObsWriter(site="datad0g.com", api_key=dd_api_key, interval=1, timeout=1) - llmobs_writer.start() - mock_logs.debug.assert_has_calls([mock.call("started llmobs writer to %r", INTAKE_ENDPOINT)]) - llmobs_writer.enqueue(_chat_completion_event()) - mock_logs.reset_mock() - llmobs_writer.periodic() - mock_logs.debug.assert_has_calls([mock.call("sent %d LLMObs events to %r", 1, INTAKE_ENDPOINT)]) - - -@pytest.mark.vcr_logs -def test_send_completion_bad_api_key(mock_logs): - llmobs_writer = LLMObsWriter(site="datad0g.com", api_key="", interval=1, timeout=1) - llmobs_writer.start() - llmobs_writer.enqueue(_completion_event()) - llmobs_writer.periodic() - mock_logs.error.assert_called_with( - "failed to send %d LLMObs events to %r, got response code %r, status: %r", - 1, - INTAKE_ENDPOINT, - 403, - b'{"errors":[{"status":"403","title":"Forbidden","detail":"API key is invalid"}]}', - ) - - -@pytest.mark.vcr_logs -def test_send_timed_events(mock_logs): - llmobs_writer = LLMObsWriter(site="datad0g.com", api_key=dd_api_key, interval=0.01, timeout=1) - llmobs_writer.start() - mock_logs.reset_mock() - - llmobs_writer.enqueue(_completion_event()) - time.sleep(0.1) - mock_logs.debug.assert_has_calls([mock.call("sent %d LLMObs events to %r", 1, INTAKE_ENDPOINT)]) - mock_logs.reset_mock() - llmobs_writer.enqueue(_chat_completion_event()) - time.sleep(0.1) - mock_logs.debug.assert_has_calls([mock.call("sent %d LLMObs events to %r", 1, INTAKE_ENDPOINT)]) - - -@pytest.mark.vcr_logs -def test_send_multiple_events(mock_logs): - llmobs_writer = LLMObsWriter(site="datad0g.com", api_key=dd_api_key, interval=0.01, timeout=1) - llmobs_writer.start() - mock_logs.reset_mock() - - llmobs_writer.enqueue(_completion_event()) - llmobs_writer.enqueue(_chat_completion_event()) - time.sleep(0.1) - mock_logs.debug.assert_has_calls([mock.call("sent %d LLMObs events to %r", 2, INTAKE_ENDPOINT)]) - - -def test_send_on_exit(mock_logs, run_python_code_in_subprocess): - out, err, status, pid = run_python_code_in_subprocess( - """ -import atexit -import os -import time - -from ddtrace.llmobs._writer import LLMObsWriter -from tests.llmobs.test_llmobs_writer import _completion_event -from tests.llmobs._utils import logs_vcr - -ctx = logs_vcr.use_cassette("tests.llmobs.test_llmobs_writer.test_send_on_exit.yaml") -ctx.__enter__() -atexit.register(lambda: ctx.__exit__()) -llmobs_writer = LLMObsWriter(site="datad0g.com", api_key=os.getenv("DD_API_KEY"), interval=0.01, timeout=1) -llmobs_writer.start() -llmobs_writer.enqueue(_completion_event()) -""", - ) - assert status == 0, err - assert out == b"" - assert err == b"" diff --git a/tests/llmobs/test_utils.py b/tests/llmobs/test_utils.py new file mode 100644 index 00000000000..26241b90b07 --- /dev/null +++ b/tests/llmobs/test_utils.py @@ -0,0 +1,57 @@ +import pytest + +from ddtrace.llmobs.utils import Messages + + +class Unserializable: + pass + + +def test_messages_with_string(): + messages = Messages("hello") + assert messages.messages == [{"content": "hello"}] + + +def test_messages_with_dict(): + messages = Messages({"content": "hello", "role": "user"}) + assert messages.messages == [{"content": "hello", "role": "user"}] + + +def test_messages_with_list_of_dicts(): + messages = Messages([{"content": "hello", "role": "user"}, {"content": "world", "role": "system"}]) + assert messages.messages == [{"content": "hello", "role": "user"}, {"content": "world", "role": "system"}] + + +def test_messages_with_incorrect_type(): + with pytest.raises(TypeError): + Messages(123) + with pytest.raises(TypeError): + Messages(Unserializable()) + with pytest.raises(TypeError): + Messages(None) + + +def test_messages_with_non_string_content(): + with pytest.raises(TypeError): + Messages([{"content": 123}]) + with pytest.raises(TypeError): + Messages([{"content": Unserializable()}]) + with pytest.raises(TypeError): + Messages([{"content": None}]) + with pytest.raises(TypeError): + Messages({"content": {"key": "value"}}) + + +def test_messages_with_non_string_role(): + with pytest.raises(TypeError): + Messages([{"content": "hello", "role": 123}]) + with pytest.raises(TypeError): + Messages([{"content": "hello", "role": Unserializable()}]) + with pytest.raises(TypeError): + Messages({"content": "hello", "role": {"key": "value"}}) + + +def test_messages_with_no_role_is_ok(): + """Test that a message with no role is ok and returns a message with only content.""" + messages = Messages([{"content": "hello"}, {"content": "world"}]) + assert messages.messages == [{"content": "hello"}, {"content": "world"}] diff --git a/tests/snapshots/tests.contrib.botocore.test_bedrock.test_amazon_embedding.json b/tests/snapshots/tests.contrib.botocore.test_bedrock.test_amazon_embedding.json new file mode 100644 index 00000000000..f4c09d2734b --- /dev/null +++ b/tests/snapshots/tests.contrib.botocore.test_bedrock.test_amazon_embedding.json @@ -0,0 +1,34 @@ +[[ + { + "name": "bedrock-runtime.command", + "service": "aws.bedrock-runtime", + "resource": "InvokeModel", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "662820e400000000", + "bedrock.request.model": "titan-embed-text-v1", + "bedrock.request.model_provider": "amazon", + "bedrock.request.prompt": "Hello World!", + "bedrock.response.duration": "311", + "bedrock.response.id": "1fd884e0-c9e8-44fa-b736-d31e2f607d54", + "bedrock.usage.completion_tokens": "", + "bedrock.usage.prompt_tokens": "3", + "language": "python", + "runtime-id": "a7bb6456241740dea419398d37aa13d2" + }, + "metrics": { + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "bedrock.response.embedding_length": 1536, + "process_id": 60939 + }, + "duration": 6739000, + "start": 1713905892539987000 + }]] diff --git a/tests/snapshots/tests.contrib.botocore.test_bedrock.test_cohere_embedding.json b/tests/snapshots/tests.contrib.botocore.test_bedrock.test_cohere_embedding.json new file mode 100644 index 00000000000..d1522b46ff5 --- /dev/null +++ b/tests/snapshots/tests.contrib.botocore.test_bedrock.test_cohere_embedding.json @@ -0,0 +1,36 @@ +[[ + { + "name": "bedrock-runtime.command", + "service": "aws.bedrock-runtime", + "resource": "InvokeModel", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "6628215a00000000", + "bedrock.request.input_type": "search_document", + "bedrock.request.model": "embed-english-v3", + "bedrock.request.model_provider": "cohere", + "bedrock.request.prompt": "['Hello World!', 'Goodbye cruel world!']", + "bedrock.request.truncate": "", + "bedrock.response.duration": "271", + "bedrock.response.id": "0e9cb5ab-1fef-46eb-8e2c-773f0f60f39d", + "bedrock.usage.completion_tokens": "", + "bedrock.usage.prompt_tokens": "7", + "language": "python", + "runtime-id": "c02c555fdac14227bee7b37a0c304534" + }, + "metrics": { + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "bedrock.response.embedding_length": 1024, + "process_id": 61336 + }, + "duration": 630192000, + "start": 1713906010873383000 + }]] diff --git a/tests/telemetry/test_telemetry_metrics.py b/tests/telemetry/test_telemetry_metrics.py index a3948b7af7d..2d7de578279 100644 --- a/tests/telemetry/test_telemetry_metrics.py +++ b/tests/telemetry/test_telemetry_metrics.py @@ -53,12 +53,12 @@ def _assert_logs( test_agent.telemetry_writer.periodic() events = test_agent.get_events() - expected_body = _get_request_body(expected_payload, TELEMETRY_TYPE_LOGS, seq_id) - expected_body["payload"].sort(key=lambda x: x["message"], reverse=False) - expected_body_sorted = expected_body["payload"] + expected_body = _get_request_body({"logs": expected_payload}, TELEMETRY_TYPE_LOGS, seq_id) + expected_body["payload"]["logs"].sort(key=lambda x: x["message"], reverse=False) + expected_body_sorted = expected_body["payload"]["logs"] - events[0]["payload"].sort(key=lambda x: x["message"], reverse=False) - result_event = events[0]["payload"] + events[0]["payload"]["logs"].sort(key=lambda x: x["message"], reverse=False) + result_event = events[0]["payload"]["logs"] assert result_event == expected_body_sorted diff --git a/tests/telemetry/test_writer.py b/tests/telemetry/test_writer.py index fbc56869cb6..18699170152 100644 --- a/tests/telemetry/test_writer.py +++ b/tests/telemetry/test_writer.py @@ -128,7 +128,6 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): {"name": "DD_TRACE_PROPAGATION_STYLE_INJECT", "origin": "unknown", "value": "datadog,tracecontext"}, {"name": "DD_TRACE_RATE_LIMIT", "origin": "unknown", "value": 100}, {"name": "DD_TRACE_REMOVE_INTEGRATION_SERVICE_NAMES_ENABLED", "origin": "unknown", "value": False}, - {"name": "DD_TRACE_SAMPLING_RULES", "origin": "unknown", "value": None}, {"name": "DD_TRACE_SPAN_ATTRIBUTE_SCHEMA", "origin": "unknown", "value": "v0"}, {"name": "DD_TRACE_STARTUP_LOGS", "origin": "unknown", "value": False}, {"name": "DD_TRACE_WRITER_BUFFER_SIZE_BYTES", "origin": "unknown", "value": 20 << 20}, @@ -142,6 +141,7 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): {"name": "data_streams_enabled", "origin": "default", "value": "false"}, {"name": "appsec_enabled", "origin": "default", "value": "false"}, {"name": "trace_sample_rate", "origin": "default", "value": "1.0"}, + {"name": "trace_sampling_rules", "origin": "default", "value": ""}, {"name": "trace_header_tags", "origin": "default", "value": ""}, {"name": "logs_injection_enabled", "origin": "default", "value": "false"}, {"name": "trace_tags", "origin": "default", "value": ""}, @@ -292,11 +292,6 @@ def test_app_started_event_configuration_override( {"name": "DD_TRACE_PROPAGATION_STYLE_INJECT", "origin": "unknown", "value": "tracecontext"}, {"name": "DD_TRACE_RATE_LIMIT", "origin": "unknown", "value": 50}, {"name": "DD_TRACE_REMOVE_INTEGRATION_SERVICE_NAMES_ENABLED", "origin": "unknown", "value": True}, - { - "name": "DD_TRACE_SAMPLING_RULES", - "origin": "unknown", - "value": '[{"sample_rate":1.0,"service":"xyz","name":"abc"}]', - }, {"name": "DD_TRACE_SPAN_ATTRIBUTE_SCHEMA", "origin": "unknown", "value": "v1"}, {"name": "DD_TRACE_STARTUP_LOGS", "origin": "unknown", "value": True}, {"name": "DD_TRACE_WRITER_BUFFER_SIZE_BYTES", "origin": "unknown", "value": 1000}, @@ -310,6 +305,11 @@ def test_app_started_event_configuration_override( {"name": "data_streams_enabled", "origin": "env_var", "value": "true"}, {"name": "appsec_enabled", "origin": "env_var", "value": "true"}, {"name": "trace_sample_rate", "origin": "env_var", "value": "0.5"}, + { + "name": "trace_sampling_rules", + "origin": "env_var", + "value": '[{"sample_rate":1.0,"service":"xyz","name":"abc"}]', + }, {"name": "logs_injection_enabled", "origin": "env_var", "value": "true"}, {"name": "trace_header_tags", "origin": "default", "value": ""}, {"name": "trace_tags", "origin": "env_var", "value": "team:apm,component:web"},