diff --git a/ddtrace/contrib/aioredis/__init__.py b/ddtrace/contrib/aioredis/__init__.py index 46d5e71e851..ca5238c3ad4 100644 --- a/ddtrace/contrib/aioredis/__init__.py +++ b/ddtrace/contrib/aioredis/__init__.py @@ -76,8 +76,11 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from .patch import get_version - from .patch import patch - from .patch import unpatch + # Required to allow users to import from `ddtrace.contrib.aioredis.patch` directly + from . import patch as _ # noqa: F401, I001 + + from ..internal.aioredis.patch import get_version + from ..internal.aioredis.patch import patch + from ..internal.aioredis.patch import unpatch __all__ = ["patch", "unpatch", "get_version"] diff --git a/ddtrace/contrib/aioredis/patch.py b/ddtrace/contrib/aioredis/patch.py index 276f73f6a1f..9f9a2f29e8c 100644 --- a/ddtrace/contrib/aioredis/patch.py +++ b/ddtrace/contrib/aioredis/patch.py @@ -1,235 +1,4 @@ -import asyncio -import os -import sys +from ..internal.aioredis.patch import * # noqa: F401,F403 -import aioredis -from ddtrace import config -from ddtrace._trace.utils_redis import _instrument_redis_cmd -from ddtrace._trace.utils_redis import _instrument_redis_execute_pipeline -from ddtrace.contrib.redis_utils import ROW_RETURNING_COMMANDS -from ddtrace.contrib.redis_utils import _run_redis_command_async -from ddtrace.contrib.redis_utils import determine_row_count -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.utils.wrappers import unwrap as _u -from ddtrace.pin import Pin -from ddtrace.vendor.packaging.version import parse as parse_version -from ddtrace.vendor.wrapt import wrap_function_wrapper as _w - -from ...constants import ANALYTICS_SAMPLE_RATE_KEY -from ...constants import SPAN_KIND -from ...constants import SPAN_MEASURED_KEY -from ...ext import SpanKind -from ...ext import SpanTypes -from ...ext import db -from ...ext import net -from ...ext import redis as redisx -from ...internal.schema import schematize_cache_operation -from ...internal.schema import schematize_service_name -from ...internal.utils.formats import CMD_MAX_LEN -from ...internal.utils.formats import asbool -from ...internal.utils.formats import stringify_cache_args -from .. import trace_utils - - -try: - from aioredis.commands.transaction import _RedisBuffer -except ImportError: - _RedisBuffer = None - -config._add( - "aioredis", - dict( - _default_service=schematize_service_name("redis"), - cmd_max_length=int(os.getenv("DD_AIOREDIS_CMD_MAX_LENGTH", CMD_MAX_LEN)), - resource_only_command=asbool(os.getenv("DD_REDIS_RESOURCE_ONLY_COMMAND", True)), - ), -) - -aioredis_version_str = getattr(aioredis, "__version__", "") -aioredis_version = parse_version(aioredis_version_str) -V2 = parse_version("2.0") - - -def get_version(): - # type: () -> str - return aioredis_version_str - - -def patch(): - if getattr(aioredis, "_datadog_patch", False): - return - aioredis._datadog_patch = True - pin = Pin() - if aioredis_version >= V2: - _w("aioredis.client", "Redis.execute_command", traced_execute_command) - _w("aioredis.client", "Redis.pipeline", traced_pipeline) - _w("aioredis.client", "Pipeline.execute", traced_execute_pipeline) - pin.onto(aioredis.client.Redis) - else: - _w("aioredis", "Redis.execute", traced_13_execute_command) - _w("aioredis", "Redis.pipeline", traced_13_pipeline) - _w("aioredis.commands.transaction", "Pipeline.execute", traced_13_execute_pipeline) - pin.onto(aioredis.Redis) - - -def unpatch(): - if not getattr(aioredis, "_datadog_patch", False): - return - - aioredis._datadog_patch = False - if aioredis_version >= V2: - _u(aioredis.client.Redis, "execute_command") - _u(aioredis.client.Redis, "pipeline") - _u(aioredis.client.Pipeline, "execute") - else: - _u(aioredis.Redis, "execute") - _u(aioredis.Redis, "pipeline") - _u(aioredis.commands.transaction.Pipeline, "execute") - - -async def traced_execute_command(func, instance, args, kwargs): - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return await func(*args, **kwargs) - - with _instrument_redis_cmd(pin, config.aioredis, instance, args) as ctx: - return await _run_redis_command_async(ctx=ctx, func=func, args=args, kwargs=kwargs) - - -def traced_pipeline(func, instance, args, kwargs): - pipeline = func(*args, **kwargs) - pin = Pin.get_from(instance) - if pin: - pin.onto(pipeline) - return pipeline - - -async def traced_execute_pipeline(func, instance, args, kwargs): - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return await func(*args, **kwargs) - - cmds = [stringify_cache_args(c, cmd_max_len=config.aioredis.cmd_max_length) for c, _ in instance.command_stack] - with _instrument_redis_execute_pipeline(pin, config.aioredis, cmds, instance): - return await func(*args, **kwargs) - - -def traced_13_pipeline(func, instance, args, kwargs): - pipeline = func(*args, **kwargs) - pin = Pin.get_from(instance) - if pin: - pin.onto(pipeline) - return pipeline - - -def traced_13_execute_command(func, instance, args, kwargs): - # If we have a _RedisBuffer then we are in a pipeline - if isinstance(instance.connection, _RedisBuffer): - return func(*args, **kwargs) - - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return func(*args, **kwargs) - - # Don't activate the span since this operation is performed as a future which concludes sometime later on in - # execution so subsequent operations in the stack are not necessarily semantically related - # (we don't want this span to be the parent of all other spans created before the future is resolved) - parent = pin.tracer.current_span() - query = stringify_cache_args(args, cmd_max_len=config.aioredis.cmd_max_length) - span = pin.tracer.start_span( - schematize_cache_operation(redisx.CMD, cache_provider="redis"), - service=trace_utils.ext_service(pin, config.aioredis), - resource=query.split(" ")[0] if config.aioredis.resource_only_command else query, - span_type=SpanTypes.REDIS, - activate=False, - child_of=parent, - ) - # set span.kind to the type of request being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - span.set_tag_str(COMPONENT, config.aioredis.integration_name) - span.set_tag_str(db.SYSTEM, redisx.APP) - span.set_tag(SPAN_MEASURED_KEY) - span.set_tag_str(redisx.RAWCMD, query) - if pin.tags: - span.set_tags(pin.tags) - - span.set_tags( - { - net.TARGET_HOST: instance.address[0], - net.TARGET_PORT: instance.address[1], - redisx.DB: instance.db or 0, - } - ) - span.set_metric(redisx.ARGS_LEN, len(args)) - # set analytics sample rate if enabled - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.aioredis.get_analytics_sample_rate()) - - def _finish_span(future): - try: - # Accessing the result will raise an exception if: - # - The future was cancelled (CancelledError) - # - There was an error executing the future (`future.exception()`) - # - The future is in an invalid state - redis_command = span.resource.split(" ")[0] - future.result() - if redis_command in ROW_RETURNING_COMMANDS: - span.set_metric(db.ROWCOUNT, determine_row_count(redis_command=redis_command, result=future.result())) - # CancelledError exceptions extend from BaseException as of Python 3.8, instead of usual Exception - except (Exception, aioredis.CancelledError): - span.set_exc_info(*sys.exc_info()) - if redis_command in ROW_RETURNING_COMMANDS: - span.set_metric(db.ROWCOUNT, 0) - finally: - span.finish() - - task = func(*args, **kwargs) - # Execute command returns a coroutine when no free connections are available - # https://github.com/aio-libs/aioredis-py/blob/v1.3.1/aioredis/pool.py#L191 - task = asyncio.ensure_future(task) - task.add_done_callback(_finish_span) - return task - - -async def traced_13_execute_pipeline(func, instance, args, kwargs): - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return await func(*args, **kwargs) - - cmds = [] - for _, cmd, cmd_args, _ in instance._pipeline: - parts = [cmd] - parts.extend(cmd_args) - cmds.append(stringify_cache_args(parts, cmd_max_len=config.aioredis.cmd_max_length)) - - resource = cmds_string = "\n".join(cmds) - if config.aioredis.resource_only_command: - resource = "\n".join([cmd.split(" ")[0] for cmd in cmds]) - - with pin.tracer.trace( - schematize_cache_operation(redisx.CMD, cache_provider="redis"), - resource=resource, - service=trace_utils.ext_service(pin, config.aioredis), - span_type=SpanTypes.REDIS, - ) as span: - # set span.kind to the type of request being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - span.set_tag_str(COMPONENT, config.aioredis.integration_name) - span.set_tag_str(db.SYSTEM, redisx.APP) - span.set_tags( - { - net.TARGET_HOST: instance._pool_or_conn.address[0], - net.TARGET_PORT: instance._pool_or_conn.address[1], - redisx.DB: instance._pool_or_conn.db or 0, - } - ) - - span.set_tag(SPAN_MEASURED_KEY) - span.set_tag_str(redisx.RAWCMD, cmds_string) - span.set_metric(redisx.PIPELINE_LEN, len(instance._pipeline)) - # set analytics sample rate if enabled - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.aioredis.get_analytics_sample_rate()) - - return await func(*args, **kwargs) +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/algoliasearch/__init__.py b/ddtrace/contrib/algoliasearch/__init__.py index 8b4c96dc82b..3d5abac0248 100644 --- a/ddtrace/contrib/algoliasearch/__init__.py +++ b/ddtrace/contrib/algoliasearch/__init__.py @@ -29,8 +29,11 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from .patch import get_version - from .patch import patch - from .patch import unpatch + # Required to allow users to import from `ddtrace.contrib.algoliasearch.patch` directly + from . import patch as _ # noqa: F401, I001 + + from ..internal.algoliasearch.patch import get_version + from ..internal.algoliasearch.patch import patch + from ..internal.algoliasearch.patch import unpatch __all__ = ["patch", "unpatch", "get_version"] diff --git a/ddtrace/contrib/algoliasearch/patch.py b/ddtrace/contrib/algoliasearch/patch.py index 830fb61ee3a..f6ddddcffdf 100644 --- a/ddtrace/contrib/algoliasearch/patch.py +++ b/ddtrace/contrib/algoliasearch/patch.py @@ -1,172 +1,4 @@ -from ddtrace import config -from ddtrace.ext import SpanKind -from ddtrace.ext import SpanTypes -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.schema import schematize_cloud_api_operation -from ddtrace.internal.schema import schematize_service_name -from ddtrace.internal.utils.wrappers import unwrap as _u -from ddtrace.pin import Pin -from ddtrace.vendor.packaging.version import parse as parse_version -from ddtrace.vendor.wrapt import wrap_function_wrapper as _w +from ..internal.algoliasearch.patch import * # noqa: F401,F403 -from ...constants import SPAN_KIND -from ...constants import SPAN_MEASURED_KEY -from .. import trace_utils - -DD_PATCH_ATTR = "_datadog_patch" - -SERVICE_NAME = schematize_service_name("algoliasearch") -APP_NAME = "algoliasearch" -V0 = parse_version("0.0") -V1 = parse_version("1.0") -V2 = parse_version("2.0") -V3 = parse_version("3.0") - -try: - VERSION = "0.0.0" - import algoliasearch - from algoliasearch.version import VERSION - - algoliasearch_version = parse_version(VERSION) - - # Default configuration - config._add("algoliasearch", dict(_default_service=SERVICE_NAME, collect_query_text=False)) -except ImportError: - algoliasearch_version = V0 - - -def get_version(): - # type: () -> str - return VERSION - - -def patch(): - if algoliasearch_version == V0: - return - - if getattr(algoliasearch, DD_PATCH_ATTR, False): - return - - algoliasearch._datadog_patch = True - - pin = Pin() - - if algoliasearch_version < V2 and algoliasearch_version >= V1: - _w(algoliasearch.index, "Index.search", _patched_search) - pin.onto(algoliasearch.index.Index) - elif algoliasearch_version >= V2 and algoliasearch_version < V3: - from algoliasearch import search_index - - _w(algoliasearch, "search_index.SearchIndex.search", _patched_search) - pin.onto(search_index.SearchIndex) - else: - return - - -def unpatch(): - if algoliasearch_version == V0: - return - - if getattr(algoliasearch, DD_PATCH_ATTR, False): - setattr(algoliasearch, DD_PATCH_ATTR, False) - - if algoliasearch_version < V2 and algoliasearch_version >= V1: - _u(algoliasearch.index.Index, "search") - elif algoliasearch_version >= V2 and algoliasearch_version < V3: - from algoliasearch import search_index - - _u(search_index.SearchIndex, "search") - else: - return - - -# DEV: this maps serves the dual purpose of enumerating the algoliasearch.search() query_args that -# will be sent along as tags, as well as converting arguments names into tag names compliant with -# tag naming recommendations set out here: https://docs.datadoghq.com/tagging/ -QUERY_ARGS_DD_TAG_MAP = { - "page": "page", - "hitsPerPage": "hits_per_page", - "attributesToRetrieve": "attributes_to_retrieve", - "attributesToHighlight": "attributes_to_highlight", - "attributesToSnippet": "attributes_to_snippet", - "minWordSizefor1Typo": "min_word_size_for_1_typo", - "minWordSizefor2Typos": "min_word_size_for_2_typos", - "getRankingInfo": "get_ranking_info", - "aroundLatLng": "around_lat_lng", - "numericFilters": "numeric_filters", - "tagFilters": "tag_filters", - "queryType": "query_type", - "optionalWords": "optional_words", - "distinct": "distinct", -} - - -def _patched_search(func, instance, wrapt_args, wrapt_kwargs): - """ - wrapt_args is called the way it is to distinguish it from the 'args' - argument to the algoliasearch.index.Index.search() method. - """ - - if algoliasearch_version < V2 and algoliasearch_version >= V1: - function_query_arg_name = "args" - elif algoliasearch_version >= V2 and algoliasearch_version < V3: - function_query_arg_name = "request_options" - else: - return func(*wrapt_args, **wrapt_kwargs) - - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return func(*wrapt_args, **wrapt_kwargs) - - with pin.tracer.trace( - schematize_cloud_api_operation("algoliasearch.search", cloud_provider="algoliasearch", cloud_service="search"), - service=trace_utils.ext_service(pin, config.algoliasearch), - span_type=SpanTypes.HTTP, - ) as span: - span.set_tag_str(COMPONENT, config.algoliasearch.integration_name) - - # set span.kind to the type of request being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - span.set_tag(SPAN_MEASURED_KEY) - if span.context.sampling_priority is not None and span.context.sampling_priority <= 0: - return func(*wrapt_args, **wrapt_kwargs) - - if config.algoliasearch.collect_query_text: - span.set_tag_str("query.text", wrapt_kwargs.get("query", wrapt_args[0])) - - query_args = wrapt_kwargs.get(function_query_arg_name, wrapt_args[1] if len(wrapt_args) > 1 else None) - - if query_args and isinstance(query_args, dict): - for query_arg, tag_name in QUERY_ARGS_DD_TAG_MAP.items(): - value = query_args.get(query_arg) - if value is not None: - span.set_tag("query.args.{}".format(tag_name), value) - - # Result would look like this - # { - # 'hits': [ - # { - # .... your search results ... - # } - # ], - # 'processingTimeMS': 1, - # 'nbHits': 1, - # 'hitsPerPage': 20, - # 'exhaustiveNbHits': true, - # 'params': 'query=xxx', - # 'nbPages': 1, - # 'query': 'xxx', - # 'page': 0 - # } - result = func(*wrapt_args, **wrapt_kwargs) - - if isinstance(result, dict): - if result.get("processingTimeMS", None) is not None: - span.set_metric("processing_time_ms", int(result["processingTimeMS"])) - - if result.get("nbHits", None) is not None: - span.set_metric("number_of_hits", int(result["nbHits"])) - - return result +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/anthropic/__init__.py b/ddtrace/contrib/anthropic/__init__.py index 2bc27fb127a..9303f74a5ba 100644 --- a/ddtrace/contrib/anthropic/__init__.py +++ b/ddtrace/contrib/anthropic/__init__.py @@ -87,10 +87,11 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from . import patch as _patch + # Required to allow users to import from `ddtrace.contrib.anthropic.patch` directly + from . import patch as _ # noqa: F401, I001 - patch = _patch.patch - unpatch = _patch.unpatch - get_version = _patch.get_version + from ..internal.anthropic.patch import patch + from ..internal.anthropic.patch import unpatch + from ..internal.anthropic.patch import get_version __all__ = ["patch", "unpatch", "get_version"] diff --git a/ddtrace/contrib/anthropic/_streaming.py b/ddtrace/contrib/anthropic/_streaming.py index ad4b1f13e39..c5b6a6a3f40 100644 --- a/ddtrace/contrib/anthropic/_streaming.py +++ b/ddtrace/contrib/anthropic/_streaming.py @@ -1,324 +1,15 @@ -import json -import sys -from typing import Any -from typing import Dict -from typing import Tuple +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -import anthropic +from ..internal.anthropic._streaming import * # noqa: F401,F403 -from ddtrace.contrib.anthropic.utils import tag_tool_use_output_on_span -from ddtrace.internal.logger import get_logger -from ddtrace.llmobs._integrations.anthropic import _get_attr -from ddtrace.vendor import wrapt +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) -log = get_logger(__name__) - - -def handle_streamed_response(integration, resp, args, kwargs, span): - if _is_stream(resp): - return TracedAnthropicStream(resp, integration, span, args, kwargs) - elif _is_async_stream(resp): - return TracedAnthropicAsyncStream(resp, integration, span, args, kwargs) - elif _is_stream_manager(resp): - return TracedAnthropicStreamManager(resp, integration, span, args, kwargs) - elif _is_async_stream_manager(resp): - return TracedAnthropicAsyncStreamManager(resp, integration, span, args, kwargs) - - -class BaseTracedAnthropicStream(wrapt.ObjectProxy): - def __init__(self, wrapped, integration, span, args, kwargs): - super().__init__(wrapped) - self._dd_span = span - self._streamed_chunks = [] - self._dd_integration = integration - self._kwargs = kwargs - self._args = args - - -class TracedAnthropicStream(BaseTracedAnthropicStream): - def __init__(self, wrapped, integration, span, args, kwargs): - super().__init__(wrapped, integration, span, args, kwargs) - # we need to set a text_stream attribute so we can trace the yielded chunks - self.text_stream = self.__stream_text__() - - def __enter__(self): - self.__wrapped__.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) - - def __iter__(self): - return self - - def __next__(self): - try: - chunk = self.__wrapped__.__next__() - self._streamed_chunks.append(chunk) - return chunk - except StopIteration: - _process_finished_stream( - self._dd_integration, self._dd_span, self._args, self._kwargs, self._streamed_chunks - ) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise - - def __stream_text__(self): - # this is overridden because it is a helper function that collects all stream content chunks - for chunk in self: - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": - yield chunk.delta.text - - -class TracedAnthropicAsyncStream(BaseTracedAnthropicStream): - def __init__(self, wrapped, integration, span, args, kwargs): - super().__init__(wrapped, integration, span, args, kwargs) - # we need to set a text_stream attribute so we can trace the yielded chunks - self.text_stream = self.__stream_text__() - - async def __aenter__(self): - await self.__wrapped__.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - chunk = await self.__wrapped__.__anext__() - self._streamed_chunks.append(chunk) - return chunk - except StopAsyncIteration: - _process_finished_stream( - self._dd_integration, - self._dd_span, - self._args, - self._kwargs, - self._streamed_chunks, - ) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise - - async def __stream_text__(self): - # this is overridden because it is a helper function that collects all stream content chunks - async for chunk in self: - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": - yield chunk.delta.text - - -class TracedAnthropicStreamManager(BaseTracedAnthropicStream): - def __enter__(self): - stream = self.__wrapped__.__enter__() - traced_stream = TracedAnthropicStream( - stream, - self._dd_integration, - self._dd_span, - self._args, - self._kwargs, - ) - return traced_stream - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) - - -class TracedAnthropicAsyncStreamManager(BaseTracedAnthropicStream): - async def __aenter__(self): - stream = await self.__wrapped__.__aenter__() - traced_stream = TracedAnthropicAsyncStream( - stream, - self._dd_integration, - self._dd_span, - self._args, - self._kwargs, - ) - return traced_stream - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - - -def _process_finished_stream(integration, span, args, kwargs, streamed_chunks): - # builds the response message given streamed chunks and sets according span tags - try: - resp_message = _construct_message(streamed_chunks) - - if integration.is_pc_sampled_span(span): - _tag_streamed_chat_completion_response(integration, span, resp_message) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - span=span, - resp=resp_message, - args=args, - kwargs=kwargs, - ) - except Exception: - log.warning("Error processing streamed completion/chat response.", exc_info=True) - - -def _construct_message(streamed_chunks): - """Iteratively build up a response message from streamed chunks. - - The resulting message dictionary is of form: - {"content": [{"type": [TYPE], "text": "[TEXT]"}], "role": "...", "finish_reason": "...", "usage": ...} - """ - message = {"content": []} - for chunk in streamed_chunks: - message = _extract_from_chunk(chunk, message) - return message - - -def _extract_from_chunk(chunk, message) -> Tuple[Dict[str, str], bool]: - """Constructs a chat message dictionary from streamed chunks given chunk type""" - TRANSFORMATIONS_BY_BLOCK_TYPE = { - "message_start": _on_message_start_chunk, - "content_block_start": _on_content_block_start_chunk, - "content_block_delta": _on_content_block_delta_chunk, - "content_block_stop": _on_content_block_stop_chunk, - "message_delta": _on_message_delta_chunk, - "error": _on_error_chunk, - } - chunk_type = _get_attr(chunk, "type", "") - transformation = TRANSFORMATIONS_BY_BLOCK_TYPE.get(chunk_type) - if transformation is not None: - message = transformation(chunk, message) - - return message - - -def _on_message_start_chunk(chunk, message): - # this is the starting chunk of the message - chunk_message = _get_attr(chunk, "message", "") - if chunk_message: - chunk_role = _get_attr(chunk_message, "role", "") - chunk_usage = _get_attr(chunk_message, "usage", "") - if chunk_role: - message["role"] = chunk_role - if chunk_usage: - message["usage"] = {"input_tokens": _get_attr(chunk_usage, "input_tokens", 0)} - return message - - -def _on_content_block_start_chunk(chunk, message): - # this is the start to a message.content block (possibly 1 of several content blocks) - chunk_content_block = _get_attr(chunk, "content_block", "") - if chunk_content_block: - chunk_content_block_type = _get_attr(chunk_content_block, "type", "") - if chunk_content_block_type == "text": - chunk_content_block_text = _get_attr(chunk_content_block, "text", "") - message["content"].append({"type": "text", "text": chunk_content_block_text}) - elif chunk_content_block_type == "tool_use": - chunk_content_block_name = _get_attr(chunk_content_block, "name", "") - message["content"].append({"type": "tool_use", "name": chunk_content_block_name, "input": ""}) - return message - - -def _on_content_block_delta_chunk(chunk, message): - # delta events contain new content for the current message.content block - delta_block = _get_attr(chunk, "delta", "") - if delta_block: - chunk_content_text = _get_attr(delta_block, "text", "") - if chunk_content_text: - message["content"][-1]["text"] += chunk_content_text - - chunk_content_json = _get_attr(delta_block, "partial_json", "") - if chunk_content_json and _get_attr(delta_block, "type", "") == "input_json_delta": - # we have a json content block, most likely a tool input dict - message["content"][-1]["input"] += chunk_content_json - return message - - -def _on_content_block_stop_chunk(chunk, message): - # this is the start to a message.content block (possibly 1 of several content blocks) - content_type = _get_attr(message["content"][-1], "type", "") - if content_type == "tool_use": - input_json = _get_attr(message["content"][-1], "input", "{}") - message["content"][-1]["input"] = json.loads(input_json) - return message - - -def _on_message_delta_chunk(chunk, message): - # message delta events signal the end of the message - delta_block = _get_attr(chunk, "delta", "") - chunk_finish_reason = _get_attr(delta_block, "stop_reason", "") - if chunk_finish_reason: - message["finish_reason"] = chunk_finish_reason - - chunk_usage = _get_attr(chunk, "usage", {}) - if chunk_usage: - message_usage = message.get("usage", {"output_tokens": 0, "input_tokens": 0}) - message_usage["output_tokens"] = _get_attr(chunk_usage, "output_tokens", 0) - message["usage"] = message_usage - - return message - - -def _on_error_chunk(chunk, message): - if _get_attr(chunk, "error"): - message["error"] = {} - if _get_attr(chunk.error, "type"): - message["error"]["type"] = chunk.error.type - if _get_attr(chunk.error, "message"): - message["error"]["message"] = chunk.error.message - return message - - -def _tag_streamed_chat_completion_response(integration, span, message): - """Tagging logic for streamed chat completions.""" - if message is None: - return - for idx, block in enumerate(message["content"]): - span.set_tag_str(f"anthropic.response.completions.content.{idx}.type", str(block["type"])) - span.set_tag_str("anthropic.response.completions.role", str(message["role"])) - if "text" in block: - span.set_tag_str( - f"anthropic.response.completions.content.{idx}.text", integration.trunc(str(block["text"])) - ) - if block["type"] == "tool_use": - tag_tool_use_output_on_span(integration, span, block, idx) - - if message.get("finish_reason") is not None: - span.set_tag_str("anthropic.response.completions.finish_reason", str(message["finish_reason"])) - - usage = _get_attr(message, "usage", {}) - integration.record_usage(span, usage) - - -def _is_stream(resp: Any) -> bool: - if hasattr(anthropic, "Stream") and isinstance(resp, anthropic.Stream): - return True - return False - - -def _is_async_stream(resp: Any) -> bool: - if hasattr(anthropic, "AsyncStream") and isinstance(resp, anthropic.AsyncStream): - return True - return False - - -def _is_stream_manager(resp: Any) -> bool: - if hasattr(anthropic, "MessageStreamManager") and isinstance(resp, anthropic.MessageStreamManager): - return True - return False - - -def _is_async_stream_manager(resp: Any) -> bool: - if hasattr(anthropic, "AsyncMessageStreamManager") and isinstance(resp, anthropic.AsyncMessageStreamManager): - return True - return False - - -def is_streaming_operation(resp: Any) -> bool: - return _is_stream(resp) or _is_async_stream(resp) or _is_stream_manager(resp) or _is_async_stream_manager(resp) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/anthropic/patch.py b/ddtrace/contrib/anthropic/patch.py index 65adf89b49a..6944820f065 100644 --- a/ddtrace/contrib/anthropic/patch.py +++ b/ddtrace/contrib/anthropic/patch.py @@ -1,228 +1,4 @@ -import os -import sys +from ..internal.anthropic.patch import * # noqa: F401,F403 -import anthropic -from ddtrace import config -from ddtrace.contrib.trace_utils import unwrap -from ddtrace.contrib.trace_utils import with_traced_module -from ddtrace.contrib.trace_utils import wrap -from ddtrace.internal.logger import get_logger -from ddtrace.internal.utils import get_argument_value -from ddtrace.llmobs._integrations import AnthropicIntegration -from ddtrace.llmobs._integrations.anthropic import _get_attr -from ddtrace.pin import Pin - -from ._streaming import handle_streamed_response -from ._streaming import is_streaming_operation -from .utils import _extract_api_key -from .utils import handle_non_streamed_response -from .utils import tag_params_on_span -from .utils import tag_tool_result_input_on_span -from .utils import tag_tool_use_input_on_span - - -log = get_logger(__name__) - - -def get_version(): - # type: () -> str - return getattr(anthropic, "__version__", "") - - -config._add( - "anthropic", - { - "span_prompt_completion_sample_rate": float(os.getenv("DD_ANTHROPIC_SPAN_PROMPT_COMPLETION_SAMPLE_RATE", 1.0)), - "span_char_limit": int(os.getenv("DD_ANTHROPIC_SPAN_CHAR_LIMIT", 128)), - }, -) - - -@with_traced_module -def traced_chat_model_generate(anthropic, pin, func, instance, args, kwargs): - chat_messages = get_argument_value(args, kwargs, 0, "messages") - integration = anthropic._datadog_integration - stream = False - - span = integration.trace( - pin, - "%s.%s" % (instance.__class__.__name__, func.__name__), - submit_to_llmobs=True, - interface_type="chat_model", - provider="anthropic", - model=kwargs.get("model", ""), - api_key=_extract_api_key(instance), - ) - - chat_completions = None - try: - for message_idx, message in enumerate(chat_messages): - if not isinstance(message, dict): - continue - if isinstance(message.get("content", None), str): - if integration.is_pc_sampled_span(span): - span.set_tag_str( - "anthropic.request.messages.%d.content.0.text" % (message_idx), - integration.trunc(message.get("content", "")), - ) - span.set_tag_str( - "anthropic.request.messages.%d.content.0.type" % (message_idx), - "text", - ) - elif isinstance(message.get("content", None), list): - for block_idx, block in enumerate(message.get("content", [])): - if integration.is_pc_sampled_span(span): - if _get_attr(block, "type", None) == "text": - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), - integration.trunc(str(_get_attr(block, "text", ""))), - ) - elif _get_attr(block, "type", None) == "image": - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), - "([IMAGE DETECTED])", - ) - elif _get_attr(block, "type", None) == "tool_use": - tag_tool_use_input_on_span(integration, span, block, message_idx, block_idx) - - elif _get_attr(block, "type", None) == "tool_result": - tag_tool_result_input_on_span(integration, span, block, message_idx, block_idx) - - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.type" % (message_idx, block_idx), - _get_attr(block, "type", "text"), - ) - span.set_tag_str( - "anthropic.request.messages.%d.role" % (message_idx), - message.get("role", ""), - ) - tag_params_on_span(span, kwargs, integration) - - chat_completions = func(*args, **kwargs) - - if is_streaming_operation(chat_completions): - stream = True - return handle_streamed_response(integration, chat_completions, args, kwargs, span) - else: - handle_non_streamed_response(integration, chat_completions, args, kwargs, span) - except Exception: - span.set_exc_info(*sys.exc_info()) - raise - finally: - # we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted - if span.error or not stream: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs) - span.finish() - return chat_completions - - -@with_traced_module -async def traced_async_chat_model_generate(anthropic, pin, func, instance, args, kwargs): - chat_messages = get_argument_value(args, kwargs, 0, "messages") - integration = anthropic._datadog_integration - stream = False - - span = integration.trace( - pin, - "%s.%s" % (instance.__class__.__name__, func.__name__), - submit_to_llmobs=True, - interface_type="chat_model", - provider="anthropic", - model=kwargs.get("model", ""), - api_key=_extract_api_key(instance), - ) - - chat_completions = None - try: - for message_idx, message in enumerate(chat_messages): - if not isinstance(message, dict): - continue - if isinstance(message.get("content", None), str): - if integration.is_pc_sampled_span(span): - span.set_tag_str( - "anthropic.request.messages.%d.content.0.text" % (message_idx), - integration.trunc(message.get("content", "")), - ) - span.set_tag_str( - "anthropic.request.messages.%d.content.0.type" % (message_idx), - "text", - ) - elif isinstance(message.get("content", None), list): - for block_idx, block in enumerate(message.get("content", [])): - if integration.is_pc_sampled_span(span): - if _get_attr(block, "type", None) == "text": - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), - integration.trunc(str(_get_attr(block, "text", ""))), - ) - elif _get_attr(block, "type", None) == "image": - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), - "([IMAGE DETECTED])", - ) - elif _get_attr(block, "type", None) == "tool_use": - tag_tool_use_input_on_span(integration, span, block, message_idx, block_idx) - - elif _get_attr(block, "type", None) == "tool_result": - tag_tool_result_input_on_span(integration, span, block, message_idx, block_idx) - - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.type" % (message_idx, block_idx), - _get_attr(block, "type", "text"), - ) - span.set_tag_str( - "anthropic.request.messages.%d.role" % (message_idx), - message.get("role", ""), - ) - tag_params_on_span(span, kwargs, integration) - - chat_completions = await func(*args, **kwargs) - - if is_streaming_operation(chat_completions): - stream = True - return handle_streamed_response(integration, chat_completions, args, kwargs, span) - else: - handle_non_streamed_response(integration, chat_completions, args, kwargs, span) - except Exception: - span.set_exc_info(*sys.exc_info()) - raise - finally: - # we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted - if span.error or not stream: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs) - span.finish() - return chat_completions - - -def patch(): - if getattr(anthropic, "_datadog_patch", False): - return - - anthropic._datadog_patch = True - - Pin().onto(anthropic) - integration = AnthropicIntegration(integration_config=config.anthropic) - anthropic._datadog_integration = integration - - wrap("anthropic", "resources.messages.Messages.create", traced_chat_model_generate(anthropic)) - wrap("anthropic", "resources.messages.Messages.stream", traced_chat_model_generate(anthropic)) - wrap("anthropic", "resources.messages.AsyncMessages.create", traced_async_chat_model_generate(anthropic)) - # AsyncMessages.stream is a sync function - wrap("anthropic", "resources.messages.AsyncMessages.stream", traced_chat_model_generate(anthropic)) - - -def unpatch(): - if not getattr(anthropic, "_datadog_patch", False): - return - - anthropic._datadog_patch = False - - unwrap(anthropic.resources.messages.Messages, "create") - unwrap(anthropic.resources.messages.Messages, "stream") - unwrap(anthropic.resources.messages.AsyncMessages, "create") - unwrap(anthropic.resources.messages.AsyncMessages, "stream") - - delattr(anthropic, "_datadog_integration") +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/anthropic/utils.py b/ddtrace/contrib/anthropic/utils.py index d55364e818d..1813ceb51a4 100644 --- a/ddtrace/contrib/anthropic/utils.py +++ b/ddtrace/contrib/anthropic/utils.py @@ -1,104 +1,15 @@ -import json -from typing import Any -from typing import Optional +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -from ddtrace.internal.logger import get_logger -from ddtrace.llmobs._integrations.anthropic import _get_attr +from ..internal.anthropic.utils import * # noqa: F401,F403 -log = get_logger(__name__) - - -def handle_non_streamed_response(integration, chat_completions, args, kwargs, span): - for idx, block in enumerate(chat_completions.content): - if integration.is_pc_sampled_span(span): - if getattr(block, "text", "") != "": - span.set_tag_str( - "anthropic.response.completions.content.%d.text" % (idx), - integration.trunc(str(getattr(block, "text", ""))), - ) - elif block.type == "tool_use": - tag_tool_use_output_on_span(integration, span, block, idx) - - span.set_tag_str("anthropic.response.completions.content.%d.type" % (idx), block.type) - - # set message level tags - if getattr(chat_completions, "stop_reason", None) is not None: - span.set_tag_str("anthropic.response.completions.finish_reason", chat_completions.stop_reason) - span.set_tag_str("anthropic.response.completions.role", chat_completions.role) - - usage = _get_attr(chat_completions, "usage", {}) - integration.record_usage(span, usage) - - -def tag_tool_use_input_on_span(integration, span, chat_input, message_idx, block_idx): - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.tool_call.name" % (message_idx, block_idx), - _get_attr(chat_input, "name", ""), - ) - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.tool_call.input" % (message_idx, block_idx), - integration.trunc(json.dumps(_get_attr(chat_input, "input", {}))), +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, ) - -def tag_tool_result_input_on_span(integration, span, chat_input, message_idx, block_idx): - content = _get_attr(chat_input, "content", None) - if isinstance(content, str): - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.tool_result.content.0.text" % (message_idx, block_idx), - integration.trunc(str(content)), - ) - elif isinstance(content, list): - for tool_block_idx, tool_block in enumerate(content): - tool_block_type = _get_attr(tool_block, "type", "") - if tool_block_type == "text": - tool_block_text = _get_attr(tool_block, "text", "") - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.tool_result.content.%d.text" - % (message_idx, block_idx, tool_block_idx), - integration.trunc(str(tool_block_text)), - ) - elif tool_block_type == "image": - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.tool_result.content.%d.text" - % (message_idx, block_idx, tool_block_idx), - "([IMAGE DETECTED])", - ) - span.set_tag_str( - "anthropic.request.messages.%d.content.%d.tool_result.content.%d.type" - % (message_idx, block_idx, tool_block_idx), - tool_block_type, - ) - - -def tag_tool_use_output_on_span(integration, span, chat_completion, idx): - tool_name = _get_attr(chat_completion, "name", None) - tool_inputs = _get_attr(chat_completion, "input", None) - if tool_name: - span.set_tag_str("anthropic.response.completions.content.%d.tool_call.name" % (idx), tool_name) - if tool_inputs: - span.set_tag_str( - "anthropic.response.completions.content.%d.tool_call.input" % (idx), - integration.trunc(json.dumps(tool_inputs)), - ) - - -def tag_params_on_span(span, kwargs, integration): - tagged_params = {} - for k, v in kwargs.items(): - if k == "system" and integration.is_pc_sampled_span(span): - span.set_tag_str("anthropic.request.system", integration.trunc(v)) - elif k not in ("messages", "model"): - tagged_params[k] = v - span.set_tag_str("anthropic.request.parameters", json.dumps(tagged_params)) - - -def _extract_api_key(instance: Any) -> Optional[str]: - """ - Extract and format LLM-provider API key from instance. - """ - client = getattr(instance, "_client", "") - if client: - return getattr(client, "api_key", None) - return None + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/aredis/__init__.py b/ddtrace/contrib/aredis/__init__.py index b00475cbb74..e9b3765c5df 100644 --- a/ddtrace/contrib/aredis/__init__.py +++ b/ddtrace/contrib/aredis/__init__.py @@ -73,7 +73,10 @@ async def example(): with require_modules(required_modules) as missing_modules: if not missing_modules: - from .patch import get_version - from .patch import patch + # Required to allow users to import from `ddtrace.contrib.aredis.patch` directly + from . import patch as _ # noqa: F401, I001 + + from ..internal.aredis.patch import get_version + from ..internal.aredis.patch import patch __all__ = ["patch", "get_version"] diff --git a/ddtrace/contrib/aredis/patch.py b/ddtrace/contrib/aredis/patch.py index 523adffaacd..98fe97ef393 100644 --- a/ddtrace/contrib/aredis/patch.py +++ b/ddtrace/contrib/aredis/patch.py @@ -1,86 +1,4 @@ -import os +from ..internal.aredis.patch import * # noqa: F401,F403 -import aredis -from ddtrace import config -from ddtrace._trace.utils_redis import _instrument_redis_cmd -from ddtrace._trace.utils_redis import _instrument_redis_execute_pipeline -from ddtrace.contrib.redis_utils import _run_redis_command_async -from ddtrace.vendor import wrapt - -from ...internal.schema import schematize_service_name -from ...internal.utils.formats import CMD_MAX_LEN -from ...internal.utils.formats import asbool -from ...internal.utils.formats import stringify_cache_args -from ...internal.utils.wrappers import unwrap -from ...pin import Pin - - -config._add( - "aredis", - dict( - _default_service=schematize_service_name("redis"), - cmd_max_length=int(os.getenv("DD_AREDIS_CMD_MAX_LENGTH", CMD_MAX_LEN)), - resource_only_command=asbool(os.getenv("DD_REDIS_RESOURCE_ONLY_COMMAND", True)), - ), -) - - -def get_version(): - # type: () -> str - return getattr(aredis, "__version__", "") - - -def patch(): - """Patch the instrumented methods""" - if getattr(aredis, "_datadog_patch", False): - return - aredis._datadog_patch = True - - _w = wrapt.wrap_function_wrapper - - _w("aredis.client", "StrictRedis.execute_command", traced_execute_command) - _w("aredis.client", "StrictRedis.pipeline", traced_pipeline) - _w("aredis.pipeline", "StrictPipeline.execute", traced_execute_pipeline) - _w("aredis.pipeline", "StrictPipeline.immediate_execute_command", traced_execute_command) - Pin(service=None).onto(aredis.StrictRedis) - - -def unpatch(): - if getattr(aredis, "_datadog_patch", False): - aredis._datadog_patch = False - - unwrap(aredis.client.StrictRedis, "execute_command") - unwrap(aredis.client.StrictRedis, "pipeline") - unwrap(aredis.pipeline.StrictPipeline, "execute") - unwrap(aredis.pipeline.StrictPipeline, "immediate_execute_command") - - -# -# tracing functions -# -async def traced_execute_command(func, instance, args, kwargs): - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return await func(*args, **kwargs) - - with _instrument_redis_cmd(pin, config.aredis, instance, args) as ctx: - return await _run_redis_command_async(ctx=ctx, func=func, args=args, kwargs=kwargs) - - -async def traced_pipeline(func, instance, args, kwargs): - pipeline = await func(*args, **kwargs) - pin = Pin.get_from(instance) - if pin: - pin.onto(pipeline) - return pipeline - - -async def traced_execute_pipeline(func, instance, args, kwargs): - pin = Pin.get_from(instance) - if not pin or not pin.enabled(): - return await func(*args, **kwargs) - - cmds = [stringify_cache_args(c, cmd_max_len=config.aredis.cmd_max_length) for c, _ in instance.command_stack] - with _instrument_redis_execute_pipeline(pin, config.aredis, cmds, instance): - return await func(*args, **kwargs) +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/asgi/__init__.py b/ddtrace/contrib/asgi/__init__.py index da09bc52603..1e8abe6a538 100644 --- a/ddtrace/contrib/asgi/__init__.py +++ b/ddtrace/contrib/asgi/__init__.py @@ -62,8 +62,11 @@ def handle_request(scope, send): with require_modules(required_modules) as missing_modules: if not missing_modules: - from .middleware import TraceMiddleware - from .middleware import get_version - from .middleware import span_from_scope + # Required to allow users to import from `ddtrace.contrib.asgi.patch` directly + from . import middleware as _ # noqa: F401, I001 + + from ..internal.asgi.middleware import TraceMiddleware + from ..internal.asgi.middleware import get_version + from ..internal.asgi.middleware import span_from_scope __all__ = ["TraceMiddleware", "span_from_scope", "get_version"] diff --git a/ddtrace/contrib/asgi/middleware.py b/ddtrace/contrib/asgi/middleware.py index 21061cf63fe..56869093458 100644 --- a/ddtrace/contrib/asgi/middleware.py +++ b/ddtrace/contrib/asgi/middleware.py @@ -1,304 +1,15 @@ -import os -import sys -from typing import Any -from typing import Mapping -from typing import Optional -from urllib import parse +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -import ddtrace -from ddtrace import config -from ddtrace._trace.span import Span -from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY -from ddtrace.constants import SPAN_KIND -from ddtrace.ext import SpanKind -from ddtrace.ext import SpanTypes -from ddtrace.ext import http -from ddtrace.internal._exceptions import BlockingException -from ddtrace.internal.compat import is_valid_ip -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED -from ddtrace.internal.schema import schematize_url_operation -from ddtrace.internal.schema.span_attribute_schema import SpanDirection +from ..internal.asgi.middleware import * # noqa: F401,F403 -from ...internal import core -from ...internal.logger import get_logger -from .. import trace_utils -from .utils import guarantee_single_callable +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) -log = get_logger(__name__) - -config._add( - "asgi", - dict( - service_name=config._get_service(default="asgi"), - request_span_name="asgi.request", - distributed_tracing=True, - _trace_asgi_websocket=os.getenv("DD_ASGI_TRACE_WEBSOCKET", default=False), - ), -) - -ASGI_VERSION = "asgi.version" -ASGI_SPEC_VERSION = "asgi.spec_version" - - -def get_version() -> str: - return "" - - -def bytes_to_str(str_or_bytes): - return str_or_bytes.decode(errors="ignore") if isinstance(str_or_bytes, bytes) else str_or_bytes - - -def _extract_versions_from_scope(scope, integration_config): - tags = {} - - http_version = scope.get("http_version") - if http_version: - tags[http.VERSION] = http_version - - scope_asgi = scope.get("asgi") - - if scope_asgi and "version" in scope_asgi: - tags[ASGI_VERSION] = scope_asgi["version"] - - if scope_asgi and "spec_version" in scope_asgi: - tags[ASGI_SPEC_VERSION] = scope_asgi["spec_version"] - - return tags - - -def _extract_headers(scope): - headers = scope.get("headers") - if headers: - # headers: (Iterable[[byte string, byte string]]) - return dict((bytes_to_str(k), bytes_to_str(v)) for (k, v) in headers) - return {} - - -def _default_handle_exception_span(exc, span): - """Default handler for exception for span""" - span.set_tag(http.STATUS_CODE, 500) - - -def span_from_scope(scope: Mapping[str, Any]) -> Optional[Span]: - """Retrieve the top-level ASGI span from the scope.""" - return scope.get("datadog", {}).get("request_spans", [None])[0] - - -async def _blocked_asgi_app(scope, receive, send): - await send({"type": "http.response.start", "status": 403, "headers": []}) - await send({"type": "http.response.body", "body": b""}) - - -class TraceMiddleware: - """ - ASGI application middleware that traces the requests. - Args: - app: The ASGI application. - tracer: Custom tracer. Defaults to the global tracer. - """ - - default_ports = {"http": 80, "https": 443, "ws": 80, "wss": 443} - - def __init__( - self, - app, - tracer=None, - integration_config=config.asgi, - handle_exception_span=_default_handle_exception_span, - span_modifier=None, - ): - self.app = guarantee_single_callable(app) - self.tracer = tracer or ddtrace.tracer - self.integration_config = integration_config - self.handle_exception_span = handle_exception_span - self.span_modifier = span_modifier - - async def __call__(self, scope, receive, send): - if scope["type"] == "http": - method = scope["method"] - elif scope["type"] == "websocket" and self.integration_config._trace_asgi_websocket: - method = "WEBSOCKET" - else: - return await self.app(scope, receive, send) - try: - headers = _extract_headers(scope) - except Exception: - log.warning("failed to decode headers for distributed tracing", exc_info=True) - headers = {} - else: - trace_utils.activate_distributed_headers( - self.tracer, int_config=self.integration_config, request_headers=headers - ) - resource = " ".join([method, scope["path"]]) - - # in the case of websockets we don't currently schematize the operation names - operation_name = self.integration_config.get("request_span_name", "asgi.request") - if scope["type"] == "http": - operation_name = schematize_url_operation(operation_name, direction=SpanDirection.INBOUND, protocol="http") - - pin = ddtrace.pin.Pin(service="asgi", tracer=self.tracer) - with pin.tracer.trace( - name=operation_name, - service=trace_utils.int_service(None, self.integration_config), - resource=resource, - span_type=SpanTypes.WEB, - ) as span, core.context_with_data( - "asgi.__call__", - remote_addr=scope.get("REMOTE_ADDR"), - headers=headers, - headers_case_sensitive=True, - environ=scope, - middleware=self, - span=span, - ) as ctx: - span.set_tag_str(COMPONENT, self.integration_config.integration_name) - ctx.set_item("req_span", span) - - # set span.kind to the type of request being performed - span.set_tag_str(SPAN_KIND, SpanKind.SERVER) - - if "datadog" not in scope: - scope["datadog"] = {"request_spans": [span]} - else: - scope["datadog"]["request_spans"].append(span) - - if self.span_modifier: - self.span_modifier(span, scope) - - sample_rate = self.integration_config.get_analytics_sample_rate(use_global_config=True) - if sample_rate is not None: - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, sample_rate) - - host_header = None - for key, value in _extract_headers(scope).items(): - if key.encode() == b"host": - try: - host_header = value - except UnicodeDecodeError: - log.warning( - "failed to decode host header, host from http headers will not be considered", exc_info=True - ) - break - method = scope.get("method") - server = scope.get("server") - scheme = scope.get("scheme", "http") - parsed_query = parse.parse_qs(bytes_to_str(scope.get("query_string", b""))) - full_path = scope.get("path", "") - if host_header: - url = "{}://{}{}".format(scheme, host_header, full_path) - elif server and len(server) == 2: - port = server[1] - default_port = self.default_ports.get(scheme, None) - server_host = server[0] + (":" + str(port) if port is not None and port != default_port else "") - url = "{}://{}{}".format(scheme, server_host, full_path) - else: - url = None - query_string = scope.get("query_string") - if query_string: - query_string = bytes_to_str(query_string) - if url: - url = f"{url}?{query_string}" - if not self.integration_config.trace_query_string: - query_string = None - body = None - result = core.dispatch_with_results("asgi.request.parse.body", (receive, headers)).await_receive_and_body - if result: - receive, body = await result.value - - client = scope.get("client") - if isinstance(client, list) and len(client) and is_valid_ip(client[0]): - peer_ip = client[0] - else: - peer_ip = None - - trace_utils.set_http_meta( - span, - self.integration_config, - method=method, - url=url, - query=query_string, - request_headers=headers, - raw_uri=url, - parsed_query=parsed_query, - request_body=body, - peer_ip=peer_ip, - headers_are_case_sensitive=True, - ) - tags = _extract_versions_from_scope(scope, self.integration_config) - span.set_tags(tags) - - async def wrapped_send(message): - try: - response_headers = _extract_headers(message) - except Exception: - log.warning("failed to extract response headers", exc_info=True) - response_headers = None - - if span and message.get("type") == "http.response.start" and "status" in message: - status_code = message["status"] - trace_utils.set_http_meta( - span, self.integration_config, status_code=status_code, response_headers=response_headers - ) - core.dispatch("asgi.start_response", ("asgi",)) - core.dispatch("asgi.finalize_response", (message.get("body"), response_headers)) - - if core.get_item(HTTP_REQUEST_BLOCKED): - raise trace_utils.InterruptException("wrapped_send") - try: - return await send(message) - finally: - # Per asgi spec, "more_body" is used if there is still data to send - # Close the span if "http.response.body" has no more data left to send in the - # response. - if ( - message.get("type") == "http.response.body" - and not message.get("more_body", False) - # If the span has an error status code delay finishing the span until the - # traceback and exception message is available - and span.error == 0 - ): - span.finish() - - async def wrapped_blocked_send(message): - result = core.dispatch_with_results("asgi.block.started", (ctx, url)).status_headers_content - if result: - status, headers, content = result.value - else: - status, headers, content = 403, [], b"" - if span and message.get("type") == "http.response.start": - message["headers"] = headers - message["status"] = int(status) - core.dispatch("asgi.finalize_response", (None, headers)) - elif message.get("type") == "http.response.body": - message["body"] = ( - content if isinstance(content, bytes) else content.encode("utf-8", errors="ignore") - ) - message["more_body"] = False - core.dispatch("asgi.finalize_response", (content, None)) - try: - return await send(message) - finally: - trace_utils.set_http_meta( - span, self.integration_config, status_code=status, response_headers=headers - ) - if message.get("type") == "http.response.body" and span.error == 0: - span.finish() - - try: - core.dispatch("asgi.start_request", ("asgi",)) - return await self.app(scope, receive, wrapped_send) - except BlockingException as e: - core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) - return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) - except trace_utils.InterruptException: - return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) - except Exception as exc: - (exc_type, exc_val, exc_tb) = sys.exc_info() - span.set_exc_info(exc_type, exc_val, exc_tb) - self.handle_exception_span(exc, span) - raise - finally: - if span in scope["datadog"]["request_spans"]: - scope["datadog"]["request_spans"].remove(span) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/asgi/utils.py b/ddtrace/contrib/asgi/utils.py index 73f5d176ccc..ca2fd2a985f 100644 --- a/ddtrace/contrib/asgi/utils.py +++ b/ddtrace/contrib/asgi/utils.py @@ -1,82 +1,15 @@ -""" -Compatibility functions vendored from asgiref +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -Source: https://github.com/django/asgiref -Version: 3.2.10 -License: +from ..internal.asgi.utils import * # noqa: F401,F403 -Copyright (c) Django Software Foundation and individual contributors. -All rights reserved. -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - 1. Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Django nor the names of its contributors may be used - to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" -import asyncio -import inspect - - -def is_double_callable(application): - """ - Tests to see if an application is a legacy-style (double-callable) application. - """ - # Look for a hint on the object first - if getattr(application, "_asgi_single_callable", False): - return False - if getattr(application, "_asgi_double_callable", False): - return True - # Uninstanted classes are double-callable - if inspect.isclass(application): - return True - # Instanted classes depend on their __call__ - if hasattr(application, "__call__"): # noqa: B004 - # We only check to see if its __call__ is a coroutine function - - # if it's not, it still might be a coroutine function itself. - if asyncio.iscoroutinefunction(application.__call__): - return False - # Non-classes we just check directly - return not asyncio.iscoroutinefunction(application) - - -def double_to_single_callable(application): - """ - Transforms a double-callable ASGI application into a single-callable one. - """ - - async def new_application(scope, receive, send): - instance = application(scope) - return await instance(receive, send) - - return new_application - - -def guarantee_single_callable(application): - """ - Takes either a single- or double-callable application and always returns it - in single-callable style. Use this to add backwards compatibility for ASGI - 2.0 applications to your server/test harness/etc. - """ - if is_double_callable(application): - application = double_to_single_callable(application) - return application + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/internal/aioredis/patch.py b/ddtrace/contrib/internal/aioredis/patch.py new file mode 100644 index 00000000000..e38095a1d92 --- /dev/null +++ b/ddtrace/contrib/internal/aioredis/patch.py @@ -0,0 +1,234 @@ +import asyncio +import os +import sys + +import aioredis + +from ddtrace import config +from ddtrace._trace.utils_redis import _instrument_redis_cmd +from ddtrace._trace.utils_redis import _instrument_redis_execute_pipeline +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib import trace_utils +from ddtrace.contrib.redis_utils import ROW_RETURNING_COMMANDS +from ddtrace.contrib.redis_utils import _run_redis_command_async +from ddtrace.contrib.redis_utils import determine_row_count +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import net +from ddtrace.ext import redis as redisx +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.schema import schematize_cache_operation +from ddtrace.internal.schema import schematize_service_name +from ddtrace.internal.utils.formats import CMD_MAX_LEN +from ddtrace.internal.utils.formats import asbool +from ddtrace.internal.utils.formats import stringify_cache_args +from ddtrace.internal.utils.wrappers import unwrap as _u +from ddtrace.pin import Pin +from ddtrace.vendor.packaging.version import parse as parse_version +from ddtrace.vendor.wrapt import wrap_function_wrapper as _w + + +try: + from aioredis.commands.transaction import _RedisBuffer +except ImportError: + _RedisBuffer = None + +config._add( + "aioredis", + dict( + _default_service=schematize_service_name("redis"), + cmd_max_length=int(os.getenv("DD_AIOREDIS_CMD_MAX_LENGTH", CMD_MAX_LEN)), + resource_only_command=asbool(os.getenv("DD_REDIS_RESOURCE_ONLY_COMMAND", True)), + ), +) + +aioredis_version_str = getattr(aioredis, "__version__", "") +aioredis_version = parse_version(aioredis_version_str) +V2 = parse_version("2.0") + + +def get_version(): + # type: () -> str + return aioredis_version_str + + +def patch(): + if getattr(aioredis, "_datadog_patch", False): + return + aioredis._datadog_patch = True + pin = Pin() + if aioredis_version >= V2: + _w("aioredis.client", "Redis.execute_command", traced_execute_command) + _w("aioredis.client", "Redis.pipeline", traced_pipeline) + _w("aioredis.client", "Pipeline.execute", traced_execute_pipeline) + pin.onto(aioredis.client.Redis) + else: + _w("aioredis", "Redis.execute", traced_13_execute_command) + _w("aioredis", "Redis.pipeline", traced_13_pipeline) + _w("aioredis.commands.transaction", "Pipeline.execute", traced_13_execute_pipeline) + pin.onto(aioredis.Redis) + + +def unpatch(): + if not getattr(aioredis, "_datadog_patch", False): + return + + aioredis._datadog_patch = False + if aioredis_version >= V2: + _u(aioredis.client.Redis, "execute_command") + _u(aioredis.client.Redis, "pipeline") + _u(aioredis.client.Pipeline, "execute") + else: + _u(aioredis.Redis, "execute") + _u(aioredis.Redis, "pipeline") + _u(aioredis.commands.transaction.Pipeline, "execute") + + +async def traced_execute_command(func, instance, args, kwargs): + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return await func(*args, **kwargs) + + with _instrument_redis_cmd(pin, config.aioredis, instance, args) as ctx: + return await _run_redis_command_async(ctx=ctx, func=func, args=args, kwargs=kwargs) + + +def traced_pipeline(func, instance, args, kwargs): + pipeline = func(*args, **kwargs) + pin = Pin.get_from(instance) + if pin: + pin.onto(pipeline) + return pipeline + + +async def traced_execute_pipeline(func, instance, args, kwargs): + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return await func(*args, **kwargs) + + cmds = [stringify_cache_args(c, cmd_max_len=config.aioredis.cmd_max_length) for c, _ in instance.command_stack] + with _instrument_redis_execute_pipeline(pin, config.aioredis, cmds, instance): + return await func(*args, **kwargs) + + +def traced_13_pipeline(func, instance, args, kwargs): + pipeline = func(*args, **kwargs) + pin = Pin.get_from(instance) + if pin: + pin.onto(pipeline) + return pipeline + + +def traced_13_execute_command(func, instance, args, kwargs): + # If we have a _RedisBuffer then we are in a pipeline + if isinstance(instance.connection, _RedisBuffer): + return func(*args, **kwargs) + + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return func(*args, **kwargs) + + # Don't activate the span since this operation is performed as a future which concludes sometime later on in + # execution so subsequent operations in the stack are not necessarily semantically related + # (we don't want this span to be the parent of all other spans created before the future is resolved) + parent = pin.tracer.current_span() + query = stringify_cache_args(args, cmd_max_len=config.aioredis.cmd_max_length) + span = pin.tracer.start_span( + schematize_cache_operation(redisx.CMD, cache_provider="redis"), + service=trace_utils.ext_service(pin, config.aioredis), + resource=query.split(" ")[0] if config.aioredis.resource_only_command else query, + span_type=SpanTypes.REDIS, + activate=False, + child_of=parent, + ) + # set span.kind to the type of request being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + span.set_tag_str(COMPONENT, config.aioredis.integration_name) + span.set_tag_str(db.SYSTEM, redisx.APP) + span.set_tag(SPAN_MEASURED_KEY) + span.set_tag_str(redisx.RAWCMD, query) + if pin.tags: + span.set_tags(pin.tags) + + span.set_tags( + { + net.TARGET_HOST: instance.address[0], + net.TARGET_PORT: instance.address[1], + redisx.DB: instance.db or 0, + } + ) + span.set_metric(redisx.ARGS_LEN, len(args)) + # set analytics sample rate if enabled + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.aioredis.get_analytics_sample_rate()) + + def _finish_span(future): + try: + # Accessing the result will raise an exception if: + # - The future was cancelled (CancelledError) + # - There was an error executing the future (`future.exception()`) + # - The future is in an invalid state + redis_command = span.resource.split(" ")[0] + future.result() + if redis_command in ROW_RETURNING_COMMANDS: + span.set_metric(db.ROWCOUNT, determine_row_count(redis_command=redis_command, result=future.result())) + # CancelledError exceptions extend from BaseException as of Python 3.8, instead of usual Exception + except (Exception, aioredis.CancelledError): + span.set_exc_info(*sys.exc_info()) + if redis_command in ROW_RETURNING_COMMANDS: + span.set_metric(db.ROWCOUNT, 0) + finally: + span.finish() + + task = func(*args, **kwargs) + # Execute command returns a coroutine when no free connections are available + # https://github.com/aio-libs/aioredis-py/blob/v1.3.1/aioredis/pool.py#L191 + task = asyncio.ensure_future(task) + task.add_done_callback(_finish_span) + return task + + +async def traced_13_execute_pipeline(func, instance, args, kwargs): + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return await func(*args, **kwargs) + + cmds = [] + for _, cmd, cmd_args, _ in instance._pipeline: + parts = [cmd] + parts.extend(cmd_args) + cmds.append(stringify_cache_args(parts, cmd_max_len=config.aioredis.cmd_max_length)) + + resource = cmds_string = "\n".join(cmds) + if config.aioredis.resource_only_command: + resource = "\n".join([cmd.split(" ")[0] for cmd in cmds]) + + with pin.tracer.trace( + schematize_cache_operation(redisx.CMD, cache_provider="redis"), + resource=resource, + service=trace_utils.ext_service(pin, config.aioredis), + span_type=SpanTypes.REDIS, + ) as span: + # set span.kind to the type of request being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + span.set_tag_str(COMPONENT, config.aioredis.integration_name) + span.set_tag_str(db.SYSTEM, redisx.APP) + span.set_tags( + { + net.TARGET_HOST: instance._pool_or_conn.address[0], + net.TARGET_PORT: instance._pool_or_conn.address[1], + redisx.DB: instance._pool_or_conn.db or 0, + } + ) + + span.set_tag(SPAN_MEASURED_KEY) + span.set_tag_str(redisx.RAWCMD, cmds_string) + span.set_metric(redisx.PIPELINE_LEN, len(instance._pipeline)) + # set analytics sample rate if enabled + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.aioredis.get_analytics_sample_rate()) + + return await func(*args, **kwargs) diff --git a/ddtrace/contrib/internal/algoliasearch/patch.py b/ddtrace/contrib/internal/algoliasearch/patch.py new file mode 100644 index 00000000000..6a61bd245c3 --- /dev/null +++ b/ddtrace/contrib/internal/algoliasearch/patch.py @@ -0,0 +1,171 @@ +from ddtrace import config +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib import trace_utils +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.schema import schematize_cloud_api_operation +from ddtrace.internal.schema import schematize_service_name +from ddtrace.internal.utils.wrappers import unwrap as _u +from ddtrace.pin import Pin +from ddtrace.vendor.packaging.version import parse as parse_version +from ddtrace.vendor.wrapt import wrap_function_wrapper as _w + + +DD_PATCH_ATTR = "_datadog_patch" + +SERVICE_NAME = schematize_service_name("algoliasearch") +APP_NAME = "algoliasearch" +V0 = parse_version("0.0") +V1 = parse_version("1.0") +V2 = parse_version("2.0") +V3 = parse_version("3.0") + +try: + VERSION = "0.0.0" + import algoliasearch + from algoliasearch.version import VERSION + + algoliasearch_version = parse_version(VERSION) + + # Default configuration + config._add("algoliasearch", dict(_default_service=SERVICE_NAME, collect_query_text=False)) +except ImportError: + algoliasearch_version = V0 + + +def get_version(): + # type: () -> str + return VERSION + + +def patch(): + if algoliasearch_version == V0: + return + + if getattr(algoliasearch, DD_PATCH_ATTR, False): + return + + algoliasearch._datadog_patch = True + + pin = Pin() + + if algoliasearch_version < V2 and algoliasearch_version >= V1: + _w(algoliasearch.index, "Index.search", _patched_search) + pin.onto(algoliasearch.index.Index) + elif algoliasearch_version >= V2 and algoliasearch_version < V3: + from algoliasearch import search_index + + _w(algoliasearch, "search_index.SearchIndex.search", _patched_search) + pin.onto(search_index.SearchIndex) + else: + return + + +def unpatch(): + if algoliasearch_version == V0: + return + + if getattr(algoliasearch, DD_PATCH_ATTR, False): + setattr(algoliasearch, DD_PATCH_ATTR, False) + + if algoliasearch_version < V2 and algoliasearch_version >= V1: + _u(algoliasearch.index.Index, "search") + elif algoliasearch_version >= V2 and algoliasearch_version < V3: + from algoliasearch import search_index + + _u(search_index.SearchIndex, "search") + else: + return + + +# DEV: this maps serves the dual purpose of enumerating the algoliasearch.search() query_args that +# will be sent along as tags, as well as converting arguments names into tag names compliant with +# tag naming recommendations set out here: https://docs.datadoghq.com/tagging/ +QUERY_ARGS_DD_TAG_MAP = { + "page": "page", + "hitsPerPage": "hits_per_page", + "attributesToRetrieve": "attributes_to_retrieve", + "attributesToHighlight": "attributes_to_highlight", + "attributesToSnippet": "attributes_to_snippet", + "minWordSizefor1Typo": "min_word_size_for_1_typo", + "minWordSizefor2Typos": "min_word_size_for_2_typos", + "getRankingInfo": "get_ranking_info", + "aroundLatLng": "around_lat_lng", + "numericFilters": "numeric_filters", + "tagFilters": "tag_filters", + "queryType": "query_type", + "optionalWords": "optional_words", + "distinct": "distinct", +} + + +def _patched_search(func, instance, wrapt_args, wrapt_kwargs): + """ + wrapt_args is called the way it is to distinguish it from the 'args' + argument to the algoliasearch.index.Index.search() method. + """ + + if algoliasearch_version < V2 and algoliasearch_version >= V1: + function_query_arg_name = "args" + elif algoliasearch_version >= V2 and algoliasearch_version < V3: + function_query_arg_name = "request_options" + else: + return func(*wrapt_args, **wrapt_kwargs) + + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return func(*wrapt_args, **wrapt_kwargs) + + with pin.tracer.trace( + schematize_cloud_api_operation("algoliasearch.search", cloud_provider="algoliasearch", cloud_service="search"), + service=trace_utils.ext_service(pin, config.algoliasearch), + span_type=SpanTypes.HTTP, + ) as span: + span.set_tag_str(COMPONENT, config.algoliasearch.integration_name) + + # set span.kind to the type of request being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + span.set_tag(SPAN_MEASURED_KEY) + if span.context.sampling_priority is not None and span.context.sampling_priority <= 0: + return func(*wrapt_args, **wrapt_kwargs) + + if config.algoliasearch.collect_query_text: + span.set_tag_str("query.text", wrapt_kwargs.get("query", wrapt_args[0])) + + query_args = wrapt_kwargs.get(function_query_arg_name, wrapt_args[1] if len(wrapt_args) > 1 else None) + + if query_args and isinstance(query_args, dict): + for query_arg, tag_name in QUERY_ARGS_DD_TAG_MAP.items(): + value = query_args.get(query_arg) + if value is not None: + span.set_tag("query.args.{}".format(tag_name), value) + + # Result would look like this + # { + # 'hits': [ + # { + # .... your search results ... + # } + # ], + # 'processingTimeMS': 1, + # 'nbHits': 1, + # 'hitsPerPage': 20, + # 'exhaustiveNbHits': true, + # 'params': 'query=xxx', + # 'nbPages': 1, + # 'query': 'xxx', + # 'page': 0 + # } + result = func(*wrapt_args, **wrapt_kwargs) + + if isinstance(result, dict): + if result.get("processingTimeMS", None) is not None: + span.set_metric("processing_time_ms", int(result["processingTimeMS"])) + + if result.get("nbHits", None) is not None: + span.set_metric("number_of_hits", int(result["nbHits"])) + + return result diff --git a/ddtrace/contrib/internal/anthropic/_streaming.py b/ddtrace/contrib/internal/anthropic/_streaming.py new file mode 100644 index 00000000000..ad4b1f13e39 --- /dev/null +++ b/ddtrace/contrib/internal/anthropic/_streaming.py @@ -0,0 +1,324 @@ +import json +import sys +from typing import Any +from typing import Dict +from typing import Tuple + +import anthropic + +from ddtrace.contrib.anthropic.utils import tag_tool_use_output_on_span +from ddtrace.internal.logger import get_logger +from ddtrace.llmobs._integrations.anthropic import _get_attr +from ddtrace.vendor import wrapt + + +log = get_logger(__name__) + + +def handle_streamed_response(integration, resp, args, kwargs, span): + if _is_stream(resp): + return TracedAnthropicStream(resp, integration, span, args, kwargs) + elif _is_async_stream(resp): + return TracedAnthropicAsyncStream(resp, integration, span, args, kwargs) + elif _is_stream_manager(resp): + return TracedAnthropicStreamManager(resp, integration, span, args, kwargs) + elif _is_async_stream_manager(resp): + return TracedAnthropicAsyncStreamManager(resp, integration, span, args, kwargs) + + +class BaseTracedAnthropicStream(wrapt.ObjectProxy): + def __init__(self, wrapped, integration, span, args, kwargs): + super().__init__(wrapped) + self._dd_span = span + self._streamed_chunks = [] + self._dd_integration = integration + self._kwargs = kwargs + self._args = args + + +class TracedAnthropicStream(BaseTracedAnthropicStream): + def __init__(self, wrapped, integration, span, args, kwargs): + super().__init__(wrapped, integration, span, args, kwargs) + # we need to set a text_stream attribute so we can trace the yielded chunks + self.text_stream = self.__stream_text__() + + def __enter__(self): + self.__wrapped__.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) + + def __iter__(self): + return self + + def __next__(self): + try: + chunk = self.__wrapped__.__next__() + self._streamed_chunks.append(chunk) + return chunk + except StopIteration: + _process_finished_stream( + self._dd_integration, self._dd_span, self._args, self._kwargs, self._streamed_chunks + ) + self._dd_span.finish() + raise + except Exception: + self._dd_span.set_exc_info(*sys.exc_info()) + self._dd_span.finish() + raise + + def __stream_text__(self): + # this is overridden because it is a helper function that collects all stream content chunks + for chunk in self: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class TracedAnthropicAsyncStream(BaseTracedAnthropicStream): + def __init__(self, wrapped, integration, span, args, kwargs): + super().__init__(wrapped, integration, span, args, kwargs) + # we need to set a text_stream attribute so we can trace the yielded chunks + self.text_stream = self.__stream_text__() + + async def __aenter__(self): + await self.__wrapped__.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + chunk = await self.__wrapped__.__anext__() + self._streamed_chunks.append(chunk) + return chunk + except StopAsyncIteration: + _process_finished_stream( + self._dd_integration, + self._dd_span, + self._args, + self._kwargs, + self._streamed_chunks, + ) + self._dd_span.finish() + raise + except Exception: + self._dd_span.set_exc_info(*sys.exc_info()) + self._dd_span.finish() + raise + + async def __stream_text__(self): + # this is overridden because it is a helper function that collects all stream content chunks + async for chunk in self: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class TracedAnthropicStreamManager(BaseTracedAnthropicStream): + def __enter__(self): + stream = self.__wrapped__.__enter__() + traced_stream = TracedAnthropicStream( + stream, + self._dd_integration, + self._dd_span, + self._args, + self._kwargs, + ) + return traced_stream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) + + +class TracedAnthropicAsyncStreamManager(BaseTracedAnthropicStream): + async def __aenter__(self): + stream = await self.__wrapped__.__aenter__() + traced_stream = TracedAnthropicAsyncStream( + stream, + self._dd_integration, + self._dd_span, + self._args, + self._kwargs, + ) + return traced_stream + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) + + +def _process_finished_stream(integration, span, args, kwargs, streamed_chunks): + # builds the response message given streamed chunks and sets according span tags + try: + resp_message = _construct_message(streamed_chunks) + + if integration.is_pc_sampled_span(span): + _tag_streamed_chat_completion_response(integration, span, resp_message) + if integration.is_pc_sampled_llmobs(span): + integration.llmobs_set_tags( + span=span, + resp=resp_message, + args=args, + kwargs=kwargs, + ) + except Exception: + log.warning("Error processing streamed completion/chat response.", exc_info=True) + + +def _construct_message(streamed_chunks): + """Iteratively build up a response message from streamed chunks. + + The resulting message dictionary is of form: + {"content": [{"type": [TYPE], "text": "[TEXT]"}], "role": "...", "finish_reason": "...", "usage": ...} + """ + message = {"content": []} + for chunk in streamed_chunks: + message = _extract_from_chunk(chunk, message) + return message + + +def _extract_from_chunk(chunk, message) -> Tuple[Dict[str, str], bool]: + """Constructs a chat message dictionary from streamed chunks given chunk type""" + TRANSFORMATIONS_BY_BLOCK_TYPE = { + "message_start": _on_message_start_chunk, + "content_block_start": _on_content_block_start_chunk, + "content_block_delta": _on_content_block_delta_chunk, + "content_block_stop": _on_content_block_stop_chunk, + "message_delta": _on_message_delta_chunk, + "error": _on_error_chunk, + } + chunk_type = _get_attr(chunk, "type", "") + transformation = TRANSFORMATIONS_BY_BLOCK_TYPE.get(chunk_type) + if transformation is not None: + message = transformation(chunk, message) + + return message + + +def _on_message_start_chunk(chunk, message): + # this is the starting chunk of the message + chunk_message = _get_attr(chunk, "message", "") + if chunk_message: + chunk_role = _get_attr(chunk_message, "role", "") + chunk_usage = _get_attr(chunk_message, "usage", "") + if chunk_role: + message["role"] = chunk_role + if chunk_usage: + message["usage"] = {"input_tokens": _get_attr(chunk_usage, "input_tokens", 0)} + return message + + +def _on_content_block_start_chunk(chunk, message): + # this is the start to a message.content block (possibly 1 of several content blocks) + chunk_content_block = _get_attr(chunk, "content_block", "") + if chunk_content_block: + chunk_content_block_type = _get_attr(chunk_content_block, "type", "") + if chunk_content_block_type == "text": + chunk_content_block_text = _get_attr(chunk_content_block, "text", "") + message["content"].append({"type": "text", "text": chunk_content_block_text}) + elif chunk_content_block_type == "tool_use": + chunk_content_block_name = _get_attr(chunk_content_block, "name", "") + message["content"].append({"type": "tool_use", "name": chunk_content_block_name, "input": ""}) + return message + + +def _on_content_block_delta_chunk(chunk, message): + # delta events contain new content for the current message.content block + delta_block = _get_attr(chunk, "delta", "") + if delta_block: + chunk_content_text = _get_attr(delta_block, "text", "") + if chunk_content_text: + message["content"][-1]["text"] += chunk_content_text + + chunk_content_json = _get_attr(delta_block, "partial_json", "") + if chunk_content_json and _get_attr(delta_block, "type", "") == "input_json_delta": + # we have a json content block, most likely a tool input dict + message["content"][-1]["input"] += chunk_content_json + return message + + +def _on_content_block_stop_chunk(chunk, message): + # this is the start to a message.content block (possibly 1 of several content blocks) + content_type = _get_attr(message["content"][-1], "type", "") + if content_type == "tool_use": + input_json = _get_attr(message["content"][-1], "input", "{}") + message["content"][-1]["input"] = json.loads(input_json) + return message + + +def _on_message_delta_chunk(chunk, message): + # message delta events signal the end of the message + delta_block = _get_attr(chunk, "delta", "") + chunk_finish_reason = _get_attr(delta_block, "stop_reason", "") + if chunk_finish_reason: + message["finish_reason"] = chunk_finish_reason + + chunk_usage = _get_attr(chunk, "usage", {}) + if chunk_usage: + message_usage = message.get("usage", {"output_tokens": 0, "input_tokens": 0}) + message_usage["output_tokens"] = _get_attr(chunk_usage, "output_tokens", 0) + message["usage"] = message_usage + + return message + + +def _on_error_chunk(chunk, message): + if _get_attr(chunk, "error"): + message["error"] = {} + if _get_attr(chunk.error, "type"): + message["error"]["type"] = chunk.error.type + if _get_attr(chunk.error, "message"): + message["error"]["message"] = chunk.error.message + return message + + +def _tag_streamed_chat_completion_response(integration, span, message): + """Tagging logic for streamed chat completions.""" + if message is None: + return + for idx, block in enumerate(message["content"]): + span.set_tag_str(f"anthropic.response.completions.content.{idx}.type", str(block["type"])) + span.set_tag_str("anthropic.response.completions.role", str(message["role"])) + if "text" in block: + span.set_tag_str( + f"anthropic.response.completions.content.{idx}.text", integration.trunc(str(block["text"])) + ) + if block["type"] == "tool_use": + tag_tool_use_output_on_span(integration, span, block, idx) + + if message.get("finish_reason") is not None: + span.set_tag_str("anthropic.response.completions.finish_reason", str(message["finish_reason"])) + + usage = _get_attr(message, "usage", {}) + integration.record_usage(span, usage) + + +def _is_stream(resp: Any) -> bool: + if hasattr(anthropic, "Stream") and isinstance(resp, anthropic.Stream): + return True + return False + + +def _is_async_stream(resp: Any) -> bool: + if hasattr(anthropic, "AsyncStream") and isinstance(resp, anthropic.AsyncStream): + return True + return False + + +def _is_stream_manager(resp: Any) -> bool: + if hasattr(anthropic, "MessageStreamManager") and isinstance(resp, anthropic.MessageStreamManager): + return True + return False + + +def _is_async_stream_manager(resp: Any) -> bool: + if hasattr(anthropic, "AsyncMessageStreamManager") and isinstance(resp, anthropic.AsyncMessageStreamManager): + return True + return False + + +def is_streaming_operation(resp: Any) -> bool: + return _is_stream(resp) or _is_async_stream(resp) or _is_stream_manager(resp) or _is_async_stream_manager(resp) diff --git a/ddtrace/contrib/internal/anthropic/patch.py b/ddtrace/contrib/internal/anthropic/patch.py new file mode 100644 index 00000000000..ff6328ea81d --- /dev/null +++ b/ddtrace/contrib/internal/anthropic/patch.py @@ -0,0 +1,227 @@ +import os +import sys + +import anthropic + +from ddtrace import config +from ddtrace.contrib.internal.anthropic._streaming import handle_streamed_response +from ddtrace.contrib.internal.anthropic._streaming import is_streaming_operation +from ddtrace.contrib.internal.anthropic.utils import _extract_api_key +from ddtrace.contrib.internal.anthropic.utils import handle_non_streamed_response +from ddtrace.contrib.internal.anthropic.utils import tag_params_on_span +from ddtrace.contrib.internal.anthropic.utils import tag_tool_result_input_on_span +from ddtrace.contrib.internal.anthropic.utils import tag_tool_use_input_on_span +from ddtrace.contrib.trace_utils import unwrap +from ddtrace.contrib.trace_utils import with_traced_module +from ddtrace.contrib.trace_utils import wrap +from ddtrace.internal.logger import get_logger +from ddtrace.internal.utils import get_argument_value +from ddtrace.llmobs._integrations import AnthropicIntegration +from ddtrace.llmobs._integrations.anthropic import _get_attr +from ddtrace.pin import Pin + + +log = get_logger(__name__) + + +def get_version(): + # type: () -> str + return getattr(anthropic, "__version__", "") + + +config._add( + "anthropic", + { + "span_prompt_completion_sample_rate": float(os.getenv("DD_ANTHROPIC_SPAN_PROMPT_COMPLETION_SAMPLE_RATE", 1.0)), + "span_char_limit": int(os.getenv("DD_ANTHROPIC_SPAN_CHAR_LIMIT", 128)), + }, +) + + +@with_traced_module +def traced_chat_model_generate(anthropic, pin, func, instance, args, kwargs): + chat_messages = get_argument_value(args, kwargs, 0, "messages") + integration = anthropic._datadog_integration + stream = False + + span = integration.trace( + pin, + "%s.%s" % (instance.__class__.__name__, func.__name__), + submit_to_llmobs=True, + interface_type="chat_model", + provider="anthropic", + model=kwargs.get("model", ""), + api_key=_extract_api_key(instance), + ) + + chat_completions = None + try: + for message_idx, message in enumerate(chat_messages): + if not isinstance(message, dict): + continue + if isinstance(message.get("content", None), str): + if integration.is_pc_sampled_span(span): + span.set_tag_str( + "anthropic.request.messages.%d.content.0.text" % (message_idx), + integration.trunc(message.get("content", "")), + ) + span.set_tag_str( + "anthropic.request.messages.%d.content.0.type" % (message_idx), + "text", + ) + elif isinstance(message.get("content", None), list): + for block_idx, block in enumerate(message.get("content", [])): + if integration.is_pc_sampled_span(span): + if _get_attr(block, "type", None) == "text": + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), + integration.trunc(str(_get_attr(block, "text", ""))), + ) + elif _get_attr(block, "type", None) == "image": + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), + "([IMAGE DETECTED])", + ) + elif _get_attr(block, "type", None) == "tool_use": + tag_tool_use_input_on_span(integration, span, block, message_idx, block_idx) + + elif _get_attr(block, "type", None) == "tool_result": + tag_tool_result_input_on_span(integration, span, block, message_idx, block_idx) + + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.type" % (message_idx, block_idx), + _get_attr(block, "type", "text"), + ) + span.set_tag_str( + "anthropic.request.messages.%d.role" % (message_idx), + message.get("role", ""), + ) + tag_params_on_span(span, kwargs, integration) + + chat_completions = func(*args, **kwargs) + + if is_streaming_operation(chat_completions): + stream = True + return handle_streamed_response(integration, chat_completions, args, kwargs, span) + else: + handle_non_streamed_response(integration, chat_completions, args, kwargs, span) + except Exception: + span.set_exc_info(*sys.exc_info()) + raise + finally: + # we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted + if span.error or not stream: + if integration.is_pc_sampled_llmobs(span): + integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs) + span.finish() + return chat_completions + + +@with_traced_module +async def traced_async_chat_model_generate(anthropic, pin, func, instance, args, kwargs): + chat_messages = get_argument_value(args, kwargs, 0, "messages") + integration = anthropic._datadog_integration + stream = False + + span = integration.trace( + pin, + "%s.%s" % (instance.__class__.__name__, func.__name__), + submit_to_llmobs=True, + interface_type="chat_model", + provider="anthropic", + model=kwargs.get("model", ""), + api_key=_extract_api_key(instance), + ) + + chat_completions = None + try: + for message_idx, message in enumerate(chat_messages): + if not isinstance(message, dict): + continue + if isinstance(message.get("content", None), str): + if integration.is_pc_sampled_span(span): + span.set_tag_str( + "anthropic.request.messages.%d.content.0.text" % (message_idx), + integration.trunc(message.get("content", "")), + ) + span.set_tag_str( + "anthropic.request.messages.%d.content.0.type" % (message_idx), + "text", + ) + elif isinstance(message.get("content", None), list): + for block_idx, block in enumerate(message.get("content", [])): + if integration.is_pc_sampled_span(span): + if _get_attr(block, "type", None) == "text": + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), + integration.trunc(str(_get_attr(block, "text", ""))), + ) + elif _get_attr(block, "type", None) == "image": + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.text" % (message_idx, block_idx), + "([IMAGE DETECTED])", + ) + elif _get_attr(block, "type", None) == "tool_use": + tag_tool_use_input_on_span(integration, span, block, message_idx, block_idx) + + elif _get_attr(block, "type", None) == "tool_result": + tag_tool_result_input_on_span(integration, span, block, message_idx, block_idx) + + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.type" % (message_idx, block_idx), + _get_attr(block, "type", "text"), + ) + span.set_tag_str( + "anthropic.request.messages.%d.role" % (message_idx), + message.get("role", ""), + ) + tag_params_on_span(span, kwargs, integration) + + chat_completions = await func(*args, **kwargs) + + if is_streaming_operation(chat_completions): + stream = True + return handle_streamed_response(integration, chat_completions, args, kwargs, span) + else: + handle_non_streamed_response(integration, chat_completions, args, kwargs, span) + except Exception: + span.set_exc_info(*sys.exc_info()) + raise + finally: + # we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted + if span.error or not stream: + if integration.is_pc_sampled_llmobs(span): + integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs) + span.finish() + return chat_completions + + +def patch(): + if getattr(anthropic, "_datadog_patch", False): + return + + anthropic._datadog_patch = True + + Pin().onto(anthropic) + integration = AnthropicIntegration(integration_config=config.anthropic) + anthropic._datadog_integration = integration + + wrap("anthropic", "resources.messages.Messages.create", traced_chat_model_generate(anthropic)) + wrap("anthropic", "resources.messages.Messages.stream", traced_chat_model_generate(anthropic)) + wrap("anthropic", "resources.messages.AsyncMessages.create", traced_async_chat_model_generate(anthropic)) + # AsyncMessages.stream is a sync function + wrap("anthropic", "resources.messages.AsyncMessages.stream", traced_chat_model_generate(anthropic)) + + +def unpatch(): + if not getattr(anthropic, "_datadog_patch", False): + return + + anthropic._datadog_patch = False + + unwrap(anthropic.resources.messages.Messages, "create") + unwrap(anthropic.resources.messages.Messages, "stream") + unwrap(anthropic.resources.messages.AsyncMessages, "create") + unwrap(anthropic.resources.messages.AsyncMessages, "stream") + + delattr(anthropic, "_datadog_integration") diff --git a/ddtrace/contrib/internal/anthropic/utils.py b/ddtrace/contrib/internal/anthropic/utils.py new file mode 100644 index 00000000000..d55364e818d --- /dev/null +++ b/ddtrace/contrib/internal/anthropic/utils.py @@ -0,0 +1,104 @@ +import json +from typing import Any +from typing import Optional + +from ddtrace.internal.logger import get_logger +from ddtrace.llmobs._integrations.anthropic import _get_attr + + +log = get_logger(__name__) + + +def handle_non_streamed_response(integration, chat_completions, args, kwargs, span): + for idx, block in enumerate(chat_completions.content): + if integration.is_pc_sampled_span(span): + if getattr(block, "text", "") != "": + span.set_tag_str( + "anthropic.response.completions.content.%d.text" % (idx), + integration.trunc(str(getattr(block, "text", ""))), + ) + elif block.type == "tool_use": + tag_tool_use_output_on_span(integration, span, block, idx) + + span.set_tag_str("anthropic.response.completions.content.%d.type" % (idx), block.type) + + # set message level tags + if getattr(chat_completions, "stop_reason", None) is not None: + span.set_tag_str("anthropic.response.completions.finish_reason", chat_completions.stop_reason) + span.set_tag_str("anthropic.response.completions.role", chat_completions.role) + + usage = _get_attr(chat_completions, "usage", {}) + integration.record_usage(span, usage) + + +def tag_tool_use_input_on_span(integration, span, chat_input, message_idx, block_idx): + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.tool_call.name" % (message_idx, block_idx), + _get_attr(chat_input, "name", ""), + ) + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.tool_call.input" % (message_idx, block_idx), + integration.trunc(json.dumps(_get_attr(chat_input, "input", {}))), + ) + + +def tag_tool_result_input_on_span(integration, span, chat_input, message_idx, block_idx): + content = _get_attr(chat_input, "content", None) + if isinstance(content, str): + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.tool_result.content.0.text" % (message_idx, block_idx), + integration.trunc(str(content)), + ) + elif isinstance(content, list): + for tool_block_idx, tool_block in enumerate(content): + tool_block_type = _get_attr(tool_block, "type", "") + if tool_block_type == "text": + tool_block_text = _get_attr(tool_block, "text", "") + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.tool_result.content.%d.text" + % (message_idx, block_idx, tool_block_idx), + integration.trunc(str(tool_block_text)), + ) + elif tool_block_type == "image": + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.tool_result.content.%d.text" + % (message_idx, block_idx, tool_block_idx), + "([IMAGE DETECTED])", + ) + span.set_tag_str( + "anthropic.request.messages.%d.content.%d.tool_result.content.%d.type" + % (message_idx, block_idx, tool_block_idx), + tool_block_type, + ) + + +def tag_tool_use_output_on_span(integration, span, chat_completion, idx): + tool_name = _get_attr(chat_completion, "name", None) + tool_inputs = _get_attr(chat_completion, "input", None) + if tool_name: + span.set_tag_str("anthropic.response.completions.content.%d.tool_call.name" % (idx), tool_name) + if tool_inputs: + span.set_tag_str( + "anthropic.response.completions.content.%d.tool_call.input" % (idx), + integration.trunc(json.dumps(tool_inputs)), + ) + + +def tag_params_on_span(span, kwargs, integration): + tagged_params = {} + for k, v in kwargs.items(): + if k == "system" and integration.is_pc_sampled_span(span): + span.set_tag_str("anthropic.request.system", integration.trunc(v)) + elif k not in ("messages", "model"): + tagged_params[k] = v + span.set_tag_str("anthropic.request.parameters", json.dumps(tagged_params)) + + +def _extract_api_key(instance: Any) -> Optional[str]: + """ + Extract and format LLM-provider API key from instance. + """ + client = getattr(instance, "_client", "") + if client: + return getattr(client, "api_key", None) + return None diff --git a/ddtrace/contrib/internal/aredis/patch.py b/ddtrace/contrib/internal/aredis/patch.py new file mode 100644 index 00000000000..838aa9999d9 --- /dev/null +++ b/ddtrace/contrib/internal/aredis/patch.py @@ -0,0 +1,85 @@ +import os + +import aredis + +from ddtrace import config +from ddtrace._trace.utils_redis import _instrument_redis_cmd +from ddtrace._trace.utils_redis import _instrument_redis_execute_pipeline +from ddtrace.contrib.redis_utils import _run_redis_command_async +from ddtrace.internal.schema import schematize_service_name +from ddtrace.internal.utils.formats import CMD_MAX_LEN +from ddtrace.internal.utils.formats import asbool +from ddtrace.internal.utils.formats import stringify_cache_args +from ddtrace.internal.utils.wrappers import unwrap +from ddtrace.pin import Pin +from ddtrace.vendor import wrapt + + +config._add( + "aredis", + dict( + _default_service=schematize_service_name("redis"), + cmd_max_length=int(os.getenv("DD_AREDIS_CMD_MAX_LENGTH", CMD_MAX_LEN)), + resource_only_command=asbool(os.getenv("DD_REDIS_RESOURCE_ONLY_COMMAND", True)), + ), +) + + +def get_version(): + # type: () -> str + return getattr(aredis, "__version__", "") + + +def patch(): + """Patch the instrumented methods""" + if getattr(aredis, "_datadog_patch", False): + return + aredis._datadog_patch = True + + _w = wrapt.wrap_function_wrapper + + _w("aredis.client", "StrictRedis.execute_command", traced_execute_command) + _w("aredis.client", "StrictRedis.pipeline", traced_pipeline) + _w("aredis.pipeline", "StrictPipeline.execute", traced_execute_pipeline) + _w("aredis.pipeline", "StrictPipeline.immediate_execute_command", traced_execute_command) + Pin(service=None).onto(aredis.StrictRedis) + + +def unpatch(): + if getattr(aredis, "_datadog_patch", False): + aredis._datadog_patch = False + + unwrap(aredis.client.StrictRedis, "execute_command") + unwrap(aredis.client.StrictRedis, "pipeline") + unwrap(aredis.pipeline.StrictPipeline, "execute") + unwrap(aredis.pipeline.StrictPipeline, "immediate_execute_command") + + +# +# tracing functions +# +async def traced_execute_command(func, instance, args, kwargs): + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return await func(*args, **kwargs) + + with _instrument_redis_cmd(pin, config.aredis, instance, args) as ctx: + return await _run_redis_command_async(ctx=ctx, func=func, args=args, kwargs=kwargs) + + +async def traced_pipeline(func, instance, args, kwargs): + pipeline = await func(*args, **kwargs) + pin = Pin.get_from(instance) + if pin: + pin.onto(pipeline) + return pipeline + + +async def traced_execute_pipeline(func, instance, args, kwargs): + pin = Pin.get_from(instance) + if not pin or not pin.enabled(): + return await func(*args, **kwargs) + + cmds = [stringify_cache_args(c, cmd_max_len=config.aredis.cmd_max_length) for c, _ in instance.command_stack] + with _instrument_redis_execute_pipeline(pin, config.aredis, cmds, instance): + return await func(*args, **kwargs) diff --git a/ddtrace/contrib/internal/asgi/middleware.py b/ddtrace/contrib/internal/asgi/middleware.py new file mode 100644 index 00000000000..360b7aef639 --- /dev/null +++ b/ddtrace/contrib/internal/asgi/middleware.py @@ -0,0 +1,303 @@ +import os +import sys +from typing import Any +from typing import Mapping +from typing import Optional +from urllib import parse + +import ddtrace +from ddtrace import config +from ddtrace._trace.span import Span +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.contrib import trace_utils +from ddtrace.contrib.internal.asgi.utils import guarantee_single_callable +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import http +from ddtrace.internal import core +from ddtrace.internal._exceptions import BlockingException +from ddtrace.internal.compat import is_valid_ip +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED +from ddtrace.internal.logger import get_logger +from ddtrace.internal.schema import schematize_url_operation +from ddtrace.internal.schema.span_attribute_schema import SpanDirection + + +log = get_logger(__name__) + +config._add( + "asgi", + dict( + service_name=config._get_service(default="asgi"), + request_span_name="asgi.request", + distributed_tracing=True, + _trace_asgi_websocket=os.getenv("DD_ASGI_TRACE_WEBSOCKET", default=False), + ), +) + +ASGI_VERSION = "asgi.version" +ASGI_SPEC_VERSION = "asgi.spec_version" + + +def get_version() -> str: + return "" + + +def bytes_to_str(str_or_bytes): + return str_or_bytes.decode(errors="ignore") if isinstance(str_or_bytes, bytes) else str_or_bytes + + +def _extract_versions_from_scope(scope, integration_config): + tags = {} + + http_version = scope.get("http_version") + if http_version: + tags[http.VERSION] = http_version + + scope_asgi = scope.get("asgi") + + if scope_asgi and "version" in scope_asgi: + tags[ASGI_VERSION] = scope_asgi["version"] + + if scope_asgi and "spec_version" in scope_asgi: + tags[ASGI_SPEC_VERSION] = scope_asgi["spec_version"] + + return tags + + +def _extract_headers(scope): + headers = scope.get("headers") + if headers: + # headers: (Iterable[[byte string, byte string]]) + return dict((bytes_to_str(k), bytes_to_str(v)) for (k, v) in headers) + return {} + + +def _default_handle_exception_span(exc, span): + """Default handler for exception for span""" + span.set_tag(http.STATUS_CODE, 500) + + +def span_from_scope(scope: Mapping[str, Any]) -> Optional[Span]: + """Retrieve the top-level ASGI span from the scope.""" + return scope.get("datadog", {}).get("request_spans", [None])[0] + + +async def _blocked_asgi_app(scope, receive, send): + await send({"type": "http.response.start", "status": 403, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + + +class TraceMiddleware: + """ + ASGI application middleware that traces the requests. + Args: + app: The ASGI application. + tracer: Custom tracer. Defaults to the global tracer. + """ + + default_ports = {"http": 80, "https": 443, "ws": 80, "wss": 443} + + def __init__( + self, + app, + tracer=None, + integration_config=config.asgi, + handle_exception_span=_default_handle_exception_span, + span_modifier=None, + ): + self.app = guarantee_single_callable(app) + self.tracer = tracer or ddtrace.tracer + self.integration_config = integration_config + self.handle_exception_span = handle_exception_span + self.span_modifier = span_modifier + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + method = scope["method"] + elif scope["type"] == "websocket" and self.integration_config._trace_asgi_websocket: + method = "WEBSOCKET" + else: + return await self.app(scope, receive, send) + try: + headers = _extract_headers(scope) + except Exception: + log.warning("failed to decode headers for distributed tracing", exc_info=True) + headers = {} + else: + trace_utils.activate_distributed_headers( + self.tracer, int_config=self.integration_config, request_headers=headers + ) + resource = " ".join([method, scope["path"]]) + + # in the case of websockets we don't currently schematize the operation names + operation_name = self.integration_config.get("request_span_name", "asgi.request") + if scope["type"] == "http": + operation_name = schematize_url_operation(operation_name, direction=SpanDirection.INBOUND, protocol="http") + + pin = ddtrace.pin.Pin(service="asgi", tracer=self.tracer) + with pin.tracer.trace( + name=operation_name, + service=trace_utils.int_service(None, self.integration_config), + resource=resource, + span_type=SpanTypes.WEB, + ) as span, core.context_with_data( + "asgi.__call__", + remote_addr=scope.get("REMOTE_ADDR"), + headers=headers, + headers_case_sensitive=True, + environ=scope, + middleware=self, + span=span, + ) as ctx: + span.set_tag_str(COMPONENT, self.integration_config.integration_name) + ctx.set_item("req_span", span) + + # set span.kind to the type of request being performed + span.set_tag_str(SPAN_KIND, SpanKind.SERVER) + + if "datadog" not in scope: + scope["datadog"] = {"request_spans": [span]} + else: + scope["datadog"]["request_spans"].append(span) + + if self.span_modifier: + self.span_modifier(span, scope) + + sample_rate = self.integration_config.get_analytics_sample_rate(use_global_config=True) + if sample_rate is not None: + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, sample_rate) + + host_header = None + for key, value in _extract_headers(scope).items(): + if key.encode() == b"host": + try: + host_header = value + except UnicodeDecodeError: + log.warning( + "failed to decode host header, host from http headers will not be considered", exc_info=True + ) + break + method = scope.get("method") + server = scope.get("server") + scheme = scope.get("scheme", "http") + parsed_query = parse.parse_qs(bytes_to_str(scope.get("query_string", b""))) + full_path = scope.get("path", "") + if host_header: + url = "{}://{}{}".format(scheme, host_header, full_path) + elif server and len(server) == 2: + port = server[1] + default_port = self.default_ports.get(scheme, None) + server_host = server[0] + (":" + str(port) if port is not None and port != default_port else "") + url = "{}://{}{}".format(scheme, server_host, full_path) + else: + url = None + query_string = scope.get("query_string") + if query_string: + query_string = bytes_to_str(query_string) + if url: + url = f"{url}?{query_string}" + if not self.integration_config.trace_query_string: + query_string = None + body = None + result = core.dispatch_with_results("asgi.request.parse.body", (receive, headers)).await_receive_and_body + if result: + receive, body = await result.value + + client = scope.get("client") + if isinstance(client, list) and len(client) and is_valid_ip(client[0]): + peer_ip = client[0] + else: + peer_ip = None + + trace_utils.set_http_meta( + span, + self.integration_config, + method=method, + url=url, + query=query_string, + request_headers=headers, + raw_uri=url, + parsed_query=parsed_query, + request_body=body, + peer_ip=peer_ip, + headers_are_case_sensitive=True, + ) + tags = _extract_versions_from_scope(scope, self.integration_config) + span.set_tags(tags) + + async def wrapped_send(message): + try: + response_headers = _extract_headers(message) + except Exception: + log.warning("failed to extract response headers", exc_info=True) + response_headers = None + + if span and message.get("type") == "http.response.start" and "status" in message: + status_code = message["status"] + trace_utils.set_http_meta( + span, self.integration_config, status_code=status_code, response_headers=response_headers + ) + core.dispatch("asgi.start_response", ("asgi",)) + core.dispatch("asgi.finalize_response", (message.get("body"), response_headers)) + + if core.get_item(HTTP_REQUEST_BLOCKED): + raise trace_utils.InterruptException("wrapped_send") + try: + return await send(message) + finally: + # Per asgi spec, "more_body" is used if there is still data to send + # Close the span if "http.response.body" has no more data left to send in the + # response. + if ( + message.get("type") == "http.response.body" + and not message.get("more_body", False) + # If the span has an error status code delay finishing the span until the + # traceback and exception message is available + and span.error == 0 + ): + span.finish() + + async def wrapped_blocked_send(message): + result = core.dispatch_with_results("asgi.block.started", (ctx, url)).status_headers_content + if result: + status, headers, content = result.value + else: + status, headers, content = 403, [], b"" + if span and message.get("type") == "http.response.start": + message["headers"] = headers + message["status"] = int(status) + core.dispatch("asgi.finalize_response", (None, headers)) + elif message.get("type") == "http.response.body": + message["body"] = ( + content if isinstance(content, bytes) else content.encode("utf-8", errors="ignore") + ) + message["more_body"] = False + core.dispatch("asgi.finalize_response", (content, None)) + try: + return await send(message) + finally: + trace_utils.set_http_meta( + span, self.integration_config, status_code=status, response_headers=headers + ) + if message.get("type") == "http.response.body" and span.error == 0: + span.finish() + + try: + core.dispatch("asgi.start_request", ("asgi",)) + return await self.app(scope, receive, wrapped_send) + except BlockingException as e: + core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) + except trace_utils.InterruptException: + return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) + except Exception as exc: + (exc_type, exc_val, exc_tb) = sys.exc_info() + span.set_exc_info(exc_type, exc_val, exc_tb) + self.handle_exception_span(exc, span) + raise + finally: + if span in scope["datadog"]["request_spans"]: + scope["datadog"]["request_spans"].remove(span) diff --git a/ddtrace/contrib/internal/asgi/utils.py b/ddtrace/contrib/internal/asgi/utils.py new file mode 100644 index 00000000000..73f5d176ccc --- /dev/null +++ b/ddtrace/contrib/internal/asgi/utils.py @@ -0,0 +1,82 @@ +""" +Compatibility functions vendored from asgiref + +Source: https://github.com/django/asgiref +Version: 3.2.10 +License: + +Copyright (c) Django Software Foundation and individual contributors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of Django nor the names of its contributors may be used + to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +import asyncio +import inspect + + +def is_double_callable(application): + """ + Tests to see if an application is a legacy-style (double-callable) application. + """ + # Look for a hint on the object first + if getattr(application, "_asgi_single_callable", False): + return False + if getattr(application, "_asgi_double_callable", False): + return True + # Uninstanted classes are double-callable + if inspect.isclass(application): + return True + # Instanted classes depend on their __call__ + if hasattr(application, "__call__"): # noqa: B004 + # We only check to see if its __call__ is a coroutine function - + # if it's not, it still might be a coroutine function itself. + if asyncio.iscoroutinefunction(application.__call__): + return False + # Non-classes we just check directly + return not asyncio.iscoroutinefunction(application) + + +def double_to_single_callable(application): + """ + Transforms a double-callable ASGI application into a single-callable one. + """ + + async def new_application(scope, receive, send): + instance = application(scope) + return await instance(receive, send) + + return new_application + + +def guarantee_single_callable(application): + """ + Takes either a single- or double-callable application and always returns it + in single-callable style. Use this to add backwards compatibility for ASGI + 2.0 applications to your server/test harness/etc. + """ + if is_double_callable(application): + application = double_to_single_callable(application) + return application diff --git a/releasenotes/notes/move-integrations-to-internal-aioredis-ff7da7e2cee9c57b.yaml b/releasenotes/notes/move-integrations-to-internal-aioredis-ff7da7e2cee9c57b.yaml new file mode 100644 index 00000000000..dc3a10a5a19 --- /dev/null +++ b/releasenotes/notes/move-integrations-to-internal-aioredis-ff7da7e2cee9c57b.yaml @@ -0,0 +1,12 @@ +--- +deprecations: + - | + aioredis: Deprecates all modules in the ``ddtrace.contrib.aioredis`` package. Use attributes exposed in ``ddtrace.contrib.aioredis.__all__`` instead. + - | + algoliasearch: Deprecates all modules in the ``ddtrace.contrib.algoliasearch`` package. Use attributes exposed in ``ddtrace.contrib.algoliasearch.__all__`` instead. + - | + anthropic: Deprecates all modules in the ``ddtrace.contrib.anthropic`` package. Use attributes exposed in ``ddtrace.contrib.anthropic.__all__`` instead. + - | + aredis: Deprecates all modules in the ``ddtrace.contrib.aredis`` package. Use attributes exposed in ``ddtrace.contrib.aredis.__all__`` instead. + - | + asgi: Deprecates all modules in the ``ddtrace.contrib.asgi`` package. Use attributes exposed in ``ddtrace.contrib.asgi.__all__`` instead.