From cf6f007dcde78958ea44514f0ceca3701635b068 Mon Sep 17 00:00:00 2001 From: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> Date: Mon, 23 Sep 2024 18:27:37 -0400 Subject: [PATCH] fix(llmobs): avoid raising errors during llmobs integration span processing (#10713) This PR does 2 things: ### User facing changes - captures any integration-specific `_llmobs_set_tags()` method errors and logs the error instead of potentially crashing the user application. ### Non-user facing changes Refactors the `BaseLLMIntegration` class and child classes to follow a cleaner and shared `llmobs_set_tags()` method, which internally try/catches an abstract method `_llmobs_set_tags()` instead (which is implemented by each integration). We also no longer need to check `integration.is_pc_sampled_llmobs(span)` since we don't currently do any sampling yet and we can handle it in the `llmobs_set_tags()` method if needed. tldr: `_llmobs_set_tags()` is now an abstract method that needs to be implemented by all LLM integrations, and its function signature now takes in the following arguments/keyword arguments (same as `llmobs_set_tags()`): - span: span to annotate - args: list of args passed to the traced method - kwargs: dict of keyword args passed to the traced method. If any integration requires additional data not contained by either args/kwargs (such as the model instance in Gemini or tool_input dictionary in langchain), we can pass it into the method using the kwarg dict. - response: returned response from llm provider (streamed or non-streamed) - operation: string denoting which LLM operation it is (eg. "completion", "chat", "embedding", "chain", "retrieval") I did some refactoring to each integration to follow this new signature, which included merging logic for how we handle streamed responses, and additional required args (i.e. model instance, tool inputs). Previously each integration did its own thing for `llmobs_set_tags()` with arbitrary args/kwargs, and it was difficult to maintain. Now that we have a strict function signature, future integrations should be simpler to create, and existing integrations should be easier to maintain. ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/_trace/trace_handlers.py | 8 +- .../contrib/internal/anthropic/_streaming.py | 9 +- ddtrace/contrib/internal/anthropic/patch.py | 6 +- .../internal/google_generativeai/_utils.py | 22 ++- .../internal/google_generativeai/patch.py | 8 +- ddtrace/contrib/internal/langchain/patch.py | 94 ++-------- .../internal/openai/_endpoint_hooks.py | 14 +- ddtrace/contrib/internal/openai/utils.py | 6 +- ddtrace/llmobs/_integrations/anthropic.py | 17 +- ddtrace/llmobs/_integrations/base.py | 42 ++++- ddtrace/llmobs/_integrations/bedrock.py | 40 ++--- ddtrace/llmobs/_integrations/gemini.py | 19 +- ddtrace/llmobs/_integrations/langchain.py | 162 +++++++++--------- ddtrace/llmobs/_integrations/openai.py | 146 ++++++++-------- ...rations-safe-tagging-5e170868e5758510.yaml | 5 + tests/llmobs/test_llmobs_integrations.py | 21 +++ 16 files changed, 286 insertions(+), 333 deletions(-) create mode 100644 releasenotes/notes/fix-llmobs-integrations-safe-tagging-5e170868e5758510.yaml diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index aaa123ccb13..c078a7dda55 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -687,11 +687,10 @@ def _on_botocore_patched_bedrock_api_call_started(ctx, request_params): def _on_botocore_patched_bedrock_api_call_exception(ctx, exc_info): span = ctx[ctx["call_key"]] span.set_exc_info(*exc_info) - prompt = ctx["prompt"] model_name = ctx["model_name"] integration = ctx["bedrock_integration"] - if integration.is_pc_sampled_llmobs(span) and "embed" not in model_name: - integration.llmobs_set_tags(span, formatted_response=None, prompt=prompt, err=True) + if "embed" not in model_name: + integration.llmobs_set_tags(span, args=[], kwargs={"prompt": ctx["prompt"]}) span.finish() @@ -733,8 +732,7 @@ def _on_botocore_bedrock_process_response( span.set_tag_str( "bedrock.response.choices.{}.finish_reason".format(i), str(formatted_response["finish_reason"][i]) ) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span, formatted_response=formatted_response, prompt=ctx["prompt"]) + integration.llmobs_set_tags(span, args=[], kwargs={"prompt": ctx["prompt"]}, response=formatted_response) span.finish() diff --git a/ddtrace/contrib/internal/anthropic/_streaming.py b/ddtrace/contrib/internal/anthropic/_streaming.py index f79d4965d12..439f61bb5a6 100644 --- a/ddtrace/contrib/internal/anthropic/_streaming.py +++ b/ddtrace/contrib/internal/anthropic/_streaming.py @@ -154,16 +154,9 @@ def _process_finished_stream(integration, span, args, kwargs, streamed_chunks): # builds the response message given streamed chunks and sets according span tags try: resp_message = _construct_message(streamed_chunks) - if integration.is_pc_sampled_span(span): _tag_streamed_chat_completion_response(integration, span, resp_message) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - span=span, - resp=resp_message, - args=args, - kwargs=kwargs, - ) + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp_message) except Exception: log.warning("Error processing streamed completion/chat response.", exc_info=True) diff --git a/ddtrace/contrib/internal/anthropic/patch.py b/ddtrace/contrib/internal/anthropic/patch.py index 0e56f3cc170..e82c4421e78 100644 --- a/ddtrace/contrib/internal/anthropic/patch.py +++ b/ddtrace/contrib/internal/anthropic/patch.py @@ -105,8 +105,7 @@ def traced_chat_model_generate(anthropic, pin, func, instance, args, kwargs): finally: # we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted if span.error or not stream: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs) + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=chat_completions) span.finish() return chat_completions @@ -178,8 +177,7 @@ async def traced_async_chat_model_generate(anthropic, pin, func, instance, args, finally: # we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted if span.error or not stream: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs) + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=chat_completions) span.finish() return chat_completions diff --git a/ddtrace/contrib/internal/google_generativeai/_utils.py b/ddtrace/contrib/internal/google_generativeai/_utils.py index a4e46383828..5982f990b18 100644 --- a/ddtrace/contrib/internal/google_generativeai/_utils.py +++ b/ddtrace/contrib/internal/google_generativeai/_utils.py @@ -30,10 +30,13 @@ def __iter__(self): else: tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance) finally: - if self._dd_integration.is_pc_sampled_llmobs(self._dd_span): - self._dd_integration.llmobs_set_tags( - self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__ - ) + self._kwargs["instance"] = self._model_instance + self._dd_integration.llmobs_set_tags( + self._dd_span, + args=self._args, + kwargs=self._kwargs, + response=self.__wrapped__, + ) self._dd_span.finish() @@ -48,10 +51,13 @@ async def __aiter__(self): else: tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance) finally: - if self._dd_integration.is_pc_sampled_llmobs(self._dd_span): - self._dd_integration.llmobs_set_tags( - self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__ - ) + self._kwargs["instance"] = self._model_instance + self._dd_integration.llmobs_set_tags( + self._dd_span, + args=self._args, + kwargs=self._kwargs, + response=self.__wrapped__, + ) self._dd_span.finish() diff --git a/ddtrace/contrib/internal/google_generativeai/patch.py b/ddtrace/contrib/internal/google_generativeai/patch.py index eb131bb0bce..2e2c27912eb 100644 --- a/ddtrace/contrib/internal/google_generativeai/patch.py +++ b/ddtrace/contrib/internal/google_generativeai/patch.py @@ -60,8 +60,8 @@ def traced_generate(genai, pin, func, instance, args, kwargs): finally: # streamed spans will be finished separately once the stream generator is exhausted if span.error or not stream: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span, args, kwargs, instance, generations) + kwargs["instance"] = instance + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations) span.finish() return generations @@ -90,8 +90,8 @@ async def traced_agenerate(genai, pin, func, instance, args, kwargs): finally: # streamed spans will be finished separately once the stream generator is exhausted if span.error or not stream: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags(span, args, kwargs, instance, generations) + kwargs["instance"] = instance + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations) span.finish() return generations diff --git a/ddtrace/contrib/internal/langchain/patch.py b/ddtrace/contrib/internal/langchain/patch.py index 3f24b3e4b02..1a356f1a93b 100644 --- a/ddtrace/contrib/internal/langchain/patch.py +++ b/ddtrace/contrib/internal/langchain/patch.py @@ -244,14 +244,7 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "llm", - span, - prompts, - completions, - error=bool(span.error), - ) + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=completions, operation="llm") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -322,14 +315,7 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "llm", - span, - prompts, - completions, - error=bool(span.error), - ) + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=completions, operation="llm") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -438,14 +424,7 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "chat", - span, - chat_messages, - chat_completions, - error=bool(span.error), - ) + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=chat_completions, operation="chat") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -570,14 +549,7 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "chat", - span, - chat_messages, - chat_completions, - error=bool(span.error), - ) + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=chat_completions, operation="chat") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -662,14 +634,7 @@ def traced_embedding(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "embedding", - span, - input_texts, - embeddings, - error=bool(span.error), - ) + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=embeddings, operation="embedding") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -717,8 +682,7 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error)) + integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_outputs, operation="chain") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -774,8 +738,7 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error)) + integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_outputs, operation="chain") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -847,8 +810,7 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error)) + integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_output, operation="chain") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) return final_output @@ -894,8 +856,7 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error)) + integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_output, operation="chain") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) return final_output @@ -953,14 +914,7 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs): integration.metric(span, "incr", "request.error", 1) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "retrieval", - span, - query, - documents, - error=bool(span.error), - ) + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=documents, operation="retrieval") span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) if integration.is_pc_sampled_log(span): @@ -1024,18 +978,8 @@ def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs): span.set_exc_info(*sys.exc_info()) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "tool", - span, - { - "input": tool_input, - "config": config if config else {}, - "info": tool_info if tool_info else {}, - }, - tool_output, - error=bool(span.error), - ) + tool_inputs = {"input": tool_input, "config": config or {}, "info": tool_info or {}} + integration.llmobs_set_tags(span, args=[], kwargs=tool_inputs, response=tool_output, operation="tool") span.finish() return tool_output @@ -1085,18 +1029,8 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs) span.set_exc_info(*sys.exc_info()) raise finally: - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "tool", - span, - { - "input": tool_input, - "config": config if config else {}, - "info": tool_info if tool_info else {}, - }, - tool_output, - error=bool(span.error), - ) + tool_inputs = {"input": tool_input, "config": config or {}, "info": tool_info or {}} + integration.llmobs_set_tags(span, args=[], kwargs=tool_inputs, response=tool_output, operation="tool") span.finish() return tool_output diff --git a/ddtrace/contrib/internal/openai/_endpoint_hooks.py b/ddtrace/contrib/internal/openai/_endpoint_hooks.py index 5ac7a6e8f6c..73a2b2511c9 100644 --- a/ddtrace/contrib/internal/openai/_endpoint_hooks.py +++ b/ddtrace/contrib/internal/openai/_endpoint_hooks.py @@ -9,7 +9,6 @@ from ddtrace.contrib.internal.openai.utils import _process_finished_stream from ddtrace.contrib.internal.openai.utils import _tag_tool_calls from ddtrace.internal.utils.version import parse_version -from ddtrace.llmobs._constants import SPAN_KIND API_VERSION = "v1" @@ -189,8 +188,6 @@ class _CompletionHook(_BaseCompletionHook): def _record_request(self, pin, integration, span, args, kwargs): super()._record_request(pin, integration, span, args, kwargs) - if integration.is_pc_sampled_llmobs(span): - span.set_tag_str(SPAN_KIND, "llm") if integration.is_pc_sampled_span(span): prompt = kwargs.get("prompt", "") if isinstance(prompt, str): @@ -212,8 +209,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error): integration.log( span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict ) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("completion", resp, span, kwargs, err=error) + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="completion") if not resp: return for choice in resp.choices: @@ -247,8 +243,6 @@ class _ChatCompletionHook(_BaseCompletionHook): def _record_request(self, pin, integration, span, args, kwargs): super()._record_request(pin, integration, span, args, kwargs) - if integration.is_pc_sampled_llmobs(span): - span.set_tag_str(SPAN_KIND, "llm") for idx, m in enumerate(kwargs.get("messages", [])): role = getattr(m, "role", "") name = getattr(m, "name", "") @@ -274,8 +268,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error): integration.log( span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict ) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chat", resp, span, kwargs, err=error) + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="chat") if not resp: return for choice in resp.choices: @@ -319,8 +312,7 @@ def _record_request(self, pin, integration, span, args, kwargs): def _record_response(self, pin, integration, span, args, kwargs, resp, error): resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("embedding", resp, span, kwargs, err=error) + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="embedding") if not resp: return span.set_metric("openai.response.embeddings_count", len(resp.data)) diff --git a/ddtrace/contrib/internal/openai/utils.py b/ddtrace/contrib/internal/openai/utils.py index b7e549a73ec..d967383e366 100644 --- a/ddtrace/contrib/internal/openai/utils.py +++ b/ddtrace/contrib/internal/openai/utils.py @@ -208,10 +208,8 @@ def _process_finished_stream(integration, span, kwargs, streamed_chunks, is_comp if integration.is_pc_sampled_span(span): _tag_streamed_response(integration, span, formatted_completions) _set_token_metrics(span, integration, formatted_completions, prompts, request_messages, kwargs) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "completion" if is_completion else "chat", None, span, kwargs, formatted_completions, None - ) + operation = "completion" if is_completion else "chat" + integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=formatted_completions, operation=operation) except Exception: log.warning("Error processing streamed completion/chat response.", exc_info=True) diff --git a/ddtrace/llmobs/_integrations/anthropic.py b/ddtrace/llmobs/_integrations/anthropic.py index c62b68fb200..0747d68e77b 100644 --- a/ddtrace/llmobs/_integrations/anthropic.py +++ b/ddtrace/llmobs/_integrations/anthropic.py @@ -48,18 +48,15 @@ def _set_base_span_tags( else: span.set_tag_str(API_KEY, api_key) - def llmobs_set_tags( + def _llmobs_set_tags( self, - resp: Any, span: Span, args: List[Any], kwargs: Dict[str, Any], - err: Optional[Any] = None, + response: Optional[Any] = None, + operation: str = "", ) -> None: """Extract prompt/response tags from a completion and set them as temporary "_ml_obs.*" tags.""" - if not self.llmobs_enabled: - return - parameters = {} if kwargs.get("temperature"): parameters["temperature"] = kwargs.get("temperature") @@ -74,11 +71,13 @@ def llmobs_set_tags( span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) span.set_tag_str(METADATA, safe_json(parameters)) span.set_tag_str(MODEL_PROVIDER, "anthropic") - if err or resp is None: - span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) + + if span.error or response is None: + span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}])) else: - output_messages = self._extract_output_message(resp) + output_messages = self._extract_output_message(response) span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) + usage = self._get_llmobs_metrics_tags(span) if usage: span.set_tag_str(METRICS, safe_json(usage)) diff --git a/ddtrace/llmobs/_integrations/base.py b/ddtrace/llmobs/_integrations/base.py index c4186c50a5a..709e72f3a26 100644 --- a/ddtrace/llmobs/_integrations/base.py +++ b/ddtrace/llmobs/_integrations/base.py @@ -15,6 +15,7 @@ from ddtrace.internal.agent import get_stats_url from ddtrace.internal.dogstatsd import get_dogstatsd_client from ddtrace.internal.hostname import get_hostname +from ddtrace.internal.logger import get_logger from ddtrace.internal.utils.formats import asbool from ddtrace.llmobs._constants import PARENT_ID_KEY from ddtrace.llmobs._constants import PROPAGATED_PARENT_ID_KEY @@ -25,6 +26,9 @@ from ddtrace.settings import IntegrationConfig +log = get_logger(__name__) + + class BaseLLMIntegration: _integration_name = "baseLLM" @@ -81,18 +85,15 @@ def llmobs_enabled(self) -> bool: return LLMObs.enabled def is_pc_sampled_span(self, span: Span) -> bool: - if span.context.sampling_priority is not None: - if span.context.sampling_priority <= 0: - return False + if span.context.sampling_priority is not None and span.context.sampling_priority <= 0: + return False return self._span_pc_sampler.sample(span) def is_pc_sampled_log(self, span: Span) -> bool: - if span.context.sampling_priority is not None: - if span.context.sampling_priority <= 0: - return False - if not self.logs_enabled: return False + if span.context.sampling_priority is not None and span.context.sampling_priority <= 0: + return False return self._log_pc_sampler.sample(span) def is_pc_sampled_llmobs(self, span: Span) -> bool: @@ -195,3 +196,30 @@ def trunc(self, text: str) -> str: if len(text) > self.integration_config.span_char_limit: text = text[: self.integration_config.span_char_limit] + "..." return text + + def llmobs_set_tags( + self, + span: Span, + args: List[Any], + kwargs: Dict[str, Any], + response: Optional[Any] = None, + operation: str = "", + ) -> None: + """Extract input/output information from the request and response to be submitted to LLMObs.""" + if not self.llmobs_enabled: + return + try: + self._llmobs_set_tags(span, args, kwargs, response, operation) + except Exception: + log.error("Error extracting LLMObs fields for span %s, likely due to malformed data", span, exc_info=True) + + @abc.abstractmethod + def _llmobs_set_tags( + self, + span: Span, + args: List[Any], + kwargs: Dict[str, Any], + response: Optional[Any] = None, + operation: str = "", + ) -> None: + raise NotImplementedError() diff --git a/ddtrace/llmobs/_integrations/bedrock.py b/ddtrace/llmobs/_integrations/bedrock.py index 82aa0ff6a08..0aaa545b47e 100644 --- a/ddtrace/llmobs/_integrations/bedrock.py +++ b/ddtrace/llmobs/_integrations/bedrock.py @@ -1,5 +1,6 @@ from typing import Any from typing import Dict +from typing import List from typing import Optional from ddtrace._trace.span import Span @@ -27,16 +28,10 @@ class BedrockIntegration(BaseLLMIntegration): _integration_name = "bedrock" - def llmobs_set_tags( - self, - span: Span, - formatted_response: Optional[Dict[str, Any]] = None, - prompt: Optional[str] = None, - err: bool = False, + def _llmobs_set_tags( + self, span: Span, args: List[Any], kwargs: Dict[str, Any], response: Optional[Any] = None, operation: str = "" ) -> None: """Extract prompt/response tags from a completion and set them as temporary "_ml_obs.*" tags.""" - if not self.llmobs_enabled: - return if span.get_tag(PROPAGATED_PARENT_ID_KEY) is None: parent_id = _get_llmobs_parent_id(span) or "undefined" span.set_tag(PARENT_ID_KEY, parent_id) @@ -45,25 +40,28 @@ def llmobs_set_tags( parameters["temperature"] = float(span.get_tag("bedrock.request.temperature") or 0.0) if span.get_tag("bedrock.request.max_tokens"): parameters["max_tokens"] = int(span.get_tag("bedrock.request.max_tokens") or 0) + + prompt = kwargs.get("prompt", "") input_messages = self._extract_input_message(prompt) span.set_tag_str(SPAN_KIND, "llm") span.set_tag_str(MODEL_NAME, span.get_tag("bedrock.request.model") or "") span.set_tag_str(MODEL_PROVIDER, span.get_tag("bedrock.request.model_provider") or "") + span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) span.set_tag_str(METADATA, safe_json(parameters)) - if err or formatted_response is None: + if span.error or response is None: span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) else: - output_messages = self._extract_output_message(formatted_response) + output_messages = self._extract_output_message(response) span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) - metrics = self._llmobs_metrics(span, formatted_response) + metrics = self._llmobs_metrics(span, response) span.set_tag_str(METRICS, safe_json(metrics)) @staticmethod - def _llmobs_metrics(span: Span, formatted_response: Optional[Dict[str, Any]]) -> Dict[str, Any]: + def _llmobs_metrics(span: Span, response: Optional[Dict[str, Any]]) -> Dict[str, Any]: metrics = {} - if formatted_response and formatted_response.get("text"): + if response and response.get("text"): prompt_tokens = int(span.get_tag("bedrock.usage.prompt_tokens") or 0) completion_tokens = int(span.get_tag("bedrock.usage.completion_tokens") or 0) metrics[INPUT_TOKENS_METRIC_KEY] = prompt_tokens @@ -96,14 +94,14 @@ def _extract_input_message(prompt): return input_messages @staticmethod - def _extract_output_message(formatted_response): + def _extract_output_message(response): """Extract output messages from the stored response. Anthropic allows for chat messages, which requires some special casing. """ - if isinstance(formatted_response["text"], str): - return [{"content": formatted_response["text"]}] - if isinstance(formatted_response["text"], list): - if isinstance(formatted_response["text"][0], str): - return [{"content": str(resp)} for resp in formatted_response["text"]] - if isinstance(formatted_response["text"][0], dict): - return [{"content": formatted_response["text"][0].get("text", "")}] + if isinstance(response["text"], str): + return [{"content": response["text"]}] + if isinstance(response["text"], list): + if isinstance(response["text"][0], str): + return [{"content": str(content)} for content in response["text"]] + if isinstance(response["text"][0], dict): + return [{"content": response["text"][0].get("text", "")}] diff --git a/ddtrace/llmobs/_integrations/gemini.py b/ddtrace/llmobs/_integrations/gemini.py index 6f75a02048a..21e74b036f0 100644 --- a/ddtrace/llmobs/_integrations/gemini.py +++ b/ddtrace/llmobs/_integrations/gemini.py @@ -32,16 +32,19 @@ def _set_base_span_tags( if model is not None: span.set_tag_str("google_generativeai.request.model", str(model)) - def llmobs_set_tags( - self, span: Span, args: List[Any], kwargs: Dict[str, Any], instance: Any, generations: Any = None + def _llmobs_set_tags( + self, + span: Span, + args: List[Any], + kwargs: Dict[str, Any], + response: Optional[Any] = None, + operation: str = "", ) -> None: - if not self.llmobs_enabled: - return - span.set_tag_str(SPAN_KIND, "llm") span.set_tag_str(MODEL_NAME, span.get_tag("google_generativeai.request.model") or "") span.set_tag_str(MODEL_PROVIDER, span.get_tag("google_generativeai.request.provider") or "") + instance = kwargs.get("instance", None) metadata = self._llmobs_set_metadata(kwargs, instance) span.set_tag_str(METADATA, safe_json(metadata)) @@ -50,10 +53,10 @@ def llmobs_set_tags( input_messages = self._extract_input_message(input_contents, system_instruction) span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - if span.error or generations is None: + if span.error or response is None: span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) else: - output_messages = self._extract_output_message(generations) + output_messages = self._extract_output_message(response) span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) usage = self._get_llmobs_metrics_tags(span) @@ -63,7 +66,7 @@ def llmobs_set_tags( @staticmethod def _llmobs_set_metadata(kwargs, instance): metadata = {} - model_config = instance._generation_config or {} + model_config = _get_attr(instance, "_generation_config", {}) request_config = kwargs.get("generation_config", {}) parameters = ("temperature", "max_output_tokens", "candidate_count", "top_p", "top_k") for param in parameters: diff --git a/ddtrace/llmobs/_integrations/langchain.py b/ddtrace/llmobs/_integrations/langchain.py index 3a852e91958..87a2ba482dc 100644 --- a/ddtrace/llmobs/_integrations/langchain.py +++ b/ddtrace/llmobs/_integrations/langchain.py @@ -9,6 +9,8 @@ from ddtrace._trace.span import Span from ddtrace.constants import ERROR_TYPE from ddtrace.internal.logger import get_logger +from ddtrace.internal.utils import ArgumentError +from ddtrace.internal.utils import get_argument_value from ddtrace.llmobs import LLMObs from ddtrace.llmobs._constants import INPUT_DOCUMENTS from ddtrace.llmobs._constants import INPUT_MESSAGES @@ -51,13 +53,13 @@ class LangChainIntegration(BaseLLMIntegration): _integration_name = "langchain" - def llmobs_set_tags( + def _llmobs_set_tags( self, - operation: str, # oneof "llm","chat","chain","embedding","retrieval","tool" span: Span, - inputs: Any, - response: Any = None, - error: bool = False, + args: List[Any], + kwargs: Dict[str, Any], + response: Optional[Any] = None, + operation: str = "", # oneof "llm","chat","chain","embedding","retrieval","tool" ) -> None: """Sets meta tags and metrics for span events to be sent to LLMObs.""" if not self.llmobs_enabled: @@ -83,17 +85,18 @@ def llmobs_set_tags( is_workflow = LLMObs._integration_is_enabled(llmobs_integration) if operation == "llm": - self._llmobs_set_meta_tags_from_llm(span, inputs, response, error, is_workflow=is_workflow) + self._llmobs_set_meta_tags_from_llm(span, args, kwargs, response, is_workflow=is_workflow) elif operation == "chat": - self._llmobs_set_meta_tags_from_chat_model(span, inputs, response, error, is_workflow=is_workflow) + self._llmobs_set_meta_tags_from_chat_model(span, args, kwargs, response, is_workflow=is_workflow) elif operation == "chain": - self._llmobs_set_meta_tags_from_chain(span, inputs, response, error) + self._llmobs_set_meta_tags_from_chain(span, inputs=kwargs, outputs=response) elif operation == "embedding": - self._llmobs_set_meta_tags_from_embedding(span, inputs, response, error, is_workflow=is_workflow) + self._llmobs_set_meta_tags_from_embedding(span, args, kwargs, response, is_workflow=is_workflow) elif operation == "retrieval": - self._llmobs_set_meta_tags_from_similarity_search(span, inputs, response, error, is_workflow=is_workflow) + self._llmobs_set_meta_tags_from_similarity_search(span, args, kwargs, response, is_workflow=is_workflow) elif operation == "tool": - self._llmobs_set_meta_tags_from_tool(span, inputs, response, error) + self._llmobs_set_meta_tags_from_tool(span, tool_inputs=kwargs, tool_output=response) + span.set_tag_str(METRICS, safe_json({})) def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) -> None: @@ -118,7 +121,7 @@ def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) span.set_tag_str(METADATA, safe_json(metadata)) def _llmobs_set_meta_tags_from_llm( - self, span: Span, prompts: List[Any], completions: Any, err: bool = False, is_workflow: bool = False + self, span: Span, args: List[Any], kwargs: Dict[str, Any], completions: Any, is_workflow: bool = False ) -> None: span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "llm") span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "") @@ -127,22 +130,23 @@ def _llmobs_set_meta_tags_from_llm( input_tag_key = INPUT_VALUE if is_workflow else INPUT_MESSAGES output_tag_key = OUTPUT_VALUE if is_workflow else OUTPUT_MESSAGES - if isinstance(prompts, str): + prompts = get_argument_value(args, kwargs, 0, "prompts") + if isinstance(prompts, str) or not isinstance(prompts, list): prompts = [prompts] span.set_tag_str(input_tag_key, safe_json([{"content": str(prompt)} for prompt in prompts])) - - message_content = [{"content": ""}] - if not err: - message_content = [{"content": completion[0].text} for completion in completions.generations] + if span.error: + span.set_tag_str(output_tag_key, safe_json([{"content": ""}])) + return + message_content = [{"content": completion[0].text} for completion in completions.generations] span.set_tag_str(output_tag_key, safe_json(message_content)) def _llmobs_set_meta_tags_from_chat_model( self, span: Span, - chat_messages: List[List[Any]], + args: List[Any], + kwargs: Dict[str, Any], chat_completions: Any, - err: bool = False, is_workflow: bool = False, ) -> None: span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "llm") @@ -153,32 +157,27 @@ def _llmobs_set_meta_tags_from_chat_model( output_tag_key = OUTPUT_VALUE if is_workflow else OUTPUT_MESSAGES input_messages = [] + chat_messages = get_argument_value(args, kwargs, 0, "messages", optional=True) or [] for message_set in chat_messages: for message in message_set: content = message.get("content", "") if isinstance(message, dict) else getattr(message, "content", "") - input_messages.append( - { - "content": str(content), - "role": getattr(message, "role", ROLE_MAPPING.get(message.type, "")), - } - ) + role = getattr(message, "role", ROLE_MAPPING.get(message.type, "")) + input_messages.append({"content": str(content), "role": str(role)}) span.set_tag_str(input_tag_key, safe_json(input_messages)) - output_messages = [{"content": ""}] - if not err: - output_messages = [] - for message_set in chat_completions.generations: - for chat_completion in message_set: - chat_completion_msg = chat_completion.message - role = getattr(chat_completion_msg, "role", ROLE_MAPPING.get(chat_completion_msg.type, "")) - output_message = { - "content": str(chat_completion.text), - "role": role, - } - tool_calls_info = self._extract_tool_calls(chat_completion_msg) - if tool_calls_info: - output_message["tool_calls"] = tool_calls_info - output_messages.append(output_message) + if span.error: + span.set_tag_str(output_tag_key, json.dumps([{"content": ""}])) + return + output_messages = [] + for message_set in chat_completions.generations: + for chat_completion in message_set: + chat_completion_msg = chat_completion.message + role = getattr(chat_completion_msg, "role", ROLE_MAPPING.get(chat_completion_msg.type, "")) + output_message = {"content": str(chat_completion.text), "role": role} + tool_calls_info = self._extract_tool_calls(chat_completion_msg) + if tool_calls_info: + output_message["tool_calls"] = tool_calls_info + output_messages.append(output_message) span.set_tag_str(output_tag_key, safe_json(output_messages)) def _extract_tool_calls(self, chat_completion_msg: Any) -> List[Dict[str, Any]]: @@ -197,30 +196,24 @@ def _extract_tool_calls(self, chat_completion_msg: Any) -> List[Dict[str, Any]]: tool_calls_info.append(tool_call_info) return tool_calls_info - def _llmobs_set_meta_tags_from_chain( - self, - span: Span, - inputs: Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]], - outputs: Any, - error: bool = False, - ) -> None: + def _llmobs_set_meta_tags_from_chain(self, span: Span, outputs: Any, inputs: Optional[Any] = None) -> None: span.set_tag_str(SPAN_KIND, "workflow") if inputs is not None: formatted_inputs = self.format_io(inputs) span.set_tag_str(INPUT_VALUE, safe_json(formatted_inputs)) - if error: + if span.error or outputs is None: span.set_tag_str(OUTPUT_VALUE, "") - elif outputs is not None: - formatted_outputs = self.format_io(outputs) - span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs)) + return + formatted_outputs = self.format_io(outputs) + span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs)) def _llmobs_set_meta_tags_from_embedding( self, span: Span, - input_texts: Union[str, List[str]], + args: List[Any], + kwargs: Dict[str, Any], output_embedding: Union[List[float], List[List[float]], None], - error: bool = False, is_workflow: bool = False, ) -> None: span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "embedding") @@ -232,6 +225,10 @@ def _llmobs_set_meta_tags_from_embedding( output_values: Any + try: + input_texts = get_argument_value(args, kwargs, 0, "texts") + except ArgumentError: + input_texts = get_argument_value(args, kwargs, 0, "text") try: if isinstance(input_texts, str) or ( isinstance(input_texts, list) and all(isinstance(text, str) for text in input_texts) @@ -246,42 +243,43 @@ def _llmobs_set_meta_tags_from_embedding( span.set_tag_str(input_tag_key, safe_json(input_documents)) except TypeError: log.warning("Failed to serialize embedding input data to JSON") - if error: + if span.error or output_embedding is None: span.set_tag_str(output_tag_key, "") - elif output_embedding is not None: - try: - if isinstance(output_embedding[0], float): - # single embedding through embed_query - output_values = [output_embedding] - embeddings_count = 1 - else: - # multiple embeddings through embed_documents - output_values = output_embedding - embeddings_count = len(output_embedding) - embedding_dim = len(output_values[0]) - span.set_tag_str( - output_tag_key, - "[{} embedding(s) returned with size {}]".format(embeddings_count, embedding_dim), - ) - except (TypeError, IndexError): - log.warning("Failed to write output vectors", output_embedding) + return + try: + if isinstance(output_embedding[0], float): + # single embedding through embed_query + output_values = [output_embedding] + embeddings_count = 1 + else: + # multiple embeddings through embed_documents + output_values = output_embedding + embeddings_count = len(output_embedding) + embedding_dim = len(output_values[0]) + span.set_tag_str( + output_tag_key, + "[{} embedding(s) returned with size {}]".format(embeddings_count, embedding_dim), + ) + except (TypeError, IndexError): + log.warning("Failed to write output vectors", output_embedding) def _llmobs_set_meta_tags_from_similarity_search( self, span: Span, - input_query: str, + args: List[Any], + kwargs: Dict[str, Any], output_documents: Union[List[Any], None], - error: bool = False, is_workflow: bool = False, ) -> None: span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "retrieval") span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "") span.set_tag_str(MODEL_PROVIDER, span.get_tag(PROVIDER) or "") + input_query = get_argument_value(args, kwargs, 0, "query") if input_query is not None: formatted_inputs = self.format_io(input_query) span.set_tag_str(INPUT_VALUE, safe_json(formatted_inputs)) - if error or output_documents is None: + if span.error or not output_documents or not isinstance(output_documents, list): span.set_tag_str(OUTPUT_VALUE, "") return if is_workflow: @@ -298,13 +296,7 @@ def _llmobs_set_meta_tags_from_similarity_search( # we set the value as well to ensure that the UI would display it in case the span was the root span.set_tag_str(OUTPUT_VALUE, "[{} document(s) retrieved]".format(len(documents))) - def _llmobs_set_meta_tags_from_tool( - self, - span: Span, - tool_inputs: Dict[str, Any], - tool_output: object, - error: bool, - ) -> None: + def _llmobs_set_meta_tags_from_tool(self, span: Span, tool_inputs: Dict[str, Any], tool_output: object) -> None: if span.get_tag(METADATA): metadata = json.loads(str(span.get_tag(METADATA))) else: @@ -321,11 +313,11 @@ def _llmobs_set_meta_tags_from_tool( span.set_tag_str(METADATA, safe_json(metadata)) formatted_input = self.format_io(tool_input) span.set_tag_str(INPUT_VALUE, safe_json(formatted_input)) - if error: + if span.error or tool_output is None: span.set_tag_str(OUTPUT_VALUE, "") - elif tool_output is not None: - formatted_outputs = self.format_io(tool_output) - span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs)) + return + formatted_outputs = self.format_io(tool_output) + span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs)) def _set_base_span_tags( # type: ignore[override] self, diff --git a/ddtrace/llmobs/_integrations/openai.py b/ddtrace/llmobs/_integrations/openai.py index d9098bc81da..5c9e73eaca7 100644 --- a/ddtrace/llmobs/_integrations/openai.py +++ b/ddtrace/llmobs/_integrations/openai.py @@ -138,18 +138,15 @@ def record_usage(self, span: Span, usage: Dict[str, Any]) -> None: span.set_metric("openai.response.usage.%s_tokens" % token_type, num_tokens) self.metric(span, "dist", "tokens.%s" % token_type, num_tokens, tags=tags) - def llmobs_set_tags( + def _llmobs_set_tags( self, - operation: str, # oneof "completion", "chat", "embedding" - resp: Any, span: Span, + args: List[Any], kwargs: Dict[str, Any], - streamed_completions: Optional[Any] = None, - err: Optional[Any] = None, + response: Optional[Any] = None, + operation: str = "", # oneof "completion", "chat", "embedding" ) -> None: """Sets meta tags and metrics for span events to be sent to LLMObs.""" - if not self.llmobs_enabled: - return span_kind = "embedding" if operation == "embedding" else "llm" span.set_tag_str(SPAN_KIND, span_kind) model_name = span.get_tag("openai.response.model") or span.get_tag("openai.request.model") @@ -157,18 +154,16 @@ def llmobs_set_tags( model_provider = "azure_openai" if self._is_azure_openai(span) else "openai" span.set_tag_str(MODEL_PROVIDER, model_provider) if operation == "completion": - self._llmobs_set_meta_tags_from_completion(resp, err, kwargs, streamed_completions, span) + self._llmobs_set_meta_tags_from_completion(span, kwargs, response) elif operation == "chat": - self._llmobs_set_meta_tags_from_chat(resp, err, kwargs, streamed_completions, span) + self._llmobs_set_meta_tags_from_chat(span, kwargs, response) elif operation == "embedding": - self._llmobs_set_meta_tags_from_embedding(resp, err, kwargs, span) - metrics = self._set_llmobs_metrics_tags(span, resp, streamed_completions is not None) + self._llmobs_set_meta_tags_from_embedding(span, kwargs, response) + metrics = self._set_llmobs_metrics_tags(span, response) span.set_tag_str(METRICS, safe_json(metrics)) @staticmethod - def _llmobs_set_meta_tags_from_completion( - resp: Any, err: Any, kwargs: Dict[str, Any], streamed_completions: Optional[Any], span: Span - ) -> None: + def _llmobs_set_meta_tags_from_completion(span: Span, kwargs: Dict[str, Any], completions: Any) -> None: """Extract prompt/response tags from a completion and set them as temporary "_ml_obs.meta.*" tags.""" prompt = kwargs.get("prompt", "") if isinstance(prompt, str): @@ -178,19 +173,18 @@ def _llmobs_set_meta_tags_from_completion( parameters = {k: v for k, v in kwargs.items() if k not in ("model", "prompt")} span.set_tag_str(METADATA, safe_json(parameters)) - if err is not None: + if span.error or not completions: span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) return - if streamed_completions: - messages = [{"content": _get_attr(choice, "text", "")} for choice in streamed_completions] - else: - messages = [{"content": _get_attr(choice, "text", "")} for choice in resp.choices] + if hasattr(completions, "choices"): # non-streaming response + choices = completions.choices + else: # streamed response + choices = completions + messages = [{"content": _get_attr(choice, "text", "")} for choice in choices] span.set_tag_str(OUTPUT_MESSAGES, safe_json(messages)) @staticmethod - def _llmobs_set_meta_tags_from_chat( - resp: Any, err: Any, kwargs: Dict[str, Any], streamed_messages: Optional[Any], span: Span - ) -> None: + def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages: Optional[Any]) -> None: """Extract prompt/response tags from a chat completion and set them as temporary "_ml_obs.meta.*" tags.""" input_messages = [] for m in kwargs.get("messages", []): @@ -200,12 +194,12 @@ def _llmobs_set_meta_tags_from_chat( parameters = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools", "functions")} span.set_tag_str(METADATA, safe_json(parameters)) - if err is not None: + if span.error or not messages: span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) return - if streamed_messages: - messages = [] - for streamed_message in streamed_messages: + output_messages = [] + if isinstance(messages, list): # streamed response + for streamed_message in messages: message = {"content": streamed_message["content"], "role": streamed_message["role"]} tool_calls = streamed_message.get("tool_calls", []) if tool_calls: @@ -218,41 +212,39 @@ def _llmobs_set_meta_tags_from_chat( } for tool_call in tool_calls ] - messages.append(message) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(messages)) + output_messages.append(message) + span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) return - output_messages = [] - for idx, choice in enumerate(resp.choices): + choices = _get_attr(messages, "choices", []) + for idx, choice in enumerate(choices): tool_calls_info = [] - content = getattr(choice.message, "content", "") - if getattr(choice.message, "function_call", None): - function_call_info = { - "name": getattr(choice.message.function_call, "name", ""), - "arguments": json.loads(getattr(choice.message.function_call, "arguments", "")), + choice_message = _get_attr(choice, "message", {}) + role = _get_attr(choice_message, "role", "") + content = _get_attr(choice_message, "content", "") or "" + function_call = _get_attr(choice_message, "function_call", None) + if function_call: + function_name = _get_attr(function_call, "name", "") + arguments = json.loads(_get_attr(function_call, "arguments", "")) + function_call_info = {"name": function_name, "arguments": arguments} + output_messages.append({"content": content, "role": role, "tool_calls": [function_call_info]}) + continue + tool_calls = _get_attr(choice_message, "tool_calls", []) or [] + for tool_call in tool_calls: + tool_call_info = { + "name": getattr(tool_call.function, "name", ""), + "arguments": json.loads(getattr(tool_call.function, "arguments", "")), + "tool_id": getattr(tool_call, "id", ""), + "type": getattr(tool_call, "type", ""), } - if content is None: - content = "" - output_messages.append( - {"content": content, "role": choice.message.role, "tool_calls": [function_call_info]} - ) - elif getattr(choice.message, "tool_calls", None): - for tool_call in choice.message.tool_calls: - tool_call_info = { - "name": getattr(tool_call.function, "name", ""), - "arguments": json.loads(getattr(tool_call.function, "arguments", "")), - "tool_id": getattr(tool_call, "id", ""), - "type": getattr(tool_call, "type", ""), - } - tool_calls_info.append(tool_call_info) - if content is None: - content = "" - output_messages.append({"content": content, "role": choice.message.role, "tool_calls": tool_calls_info}) - else: - output_messages.append({"content": content, "role": choice.message.role}) + tool_calls_info.append(tool_call_info) + if tool_calls_info: + output_messages.append({"content": content, "role": role, "tool_calls": tool_calls_info}) + continue + output_messages.append({"content": content, "role": role}) span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) @staticmethod - def _llmobs_set_meta_tags_from_embedding(resp: Any, err: Any, kwargs: Dict[str, Any], span: Span) -> None: + def _llmobs_set_meta_tags_from_embedding(span: Span, kwargs: Dict[str, Any], resp: Any) -> None: """Extract prompt tags from an embedding and set them as temporary "_ml_obs.meta.*" tags.""" encoding_format = kwargs.get("encoding_format") or "float" metadata = {"encoding_format": encoding_format} @@ -268,7 +260,7 @@ def _llmobs_set_meta_tags_from_embedding(resp: Any, err: Any, kwargs: Dict[str, input_documents.append(Document(text=str(doc))) span.set_tag_str(INPUT_DOCUMENTS, safe_json(input_documents)) - if err is not None: + if span.error: return if encoding_format == "float": embedding_dim = len(resp.data[0].embedding) @@ -279,27 +271,23 @@ def _llmobs_set_meta_tags_from_embedding(resp: Any, err: Any, kwargs: Dict[str, span.set_tag_str(OUTPUT_VALUE, "[{} embedding(s) returned]".format(len(resp.data))) @staticmethod - def _set_llmobs_metrics_tags(span: Span, resp: Any, streamed: bool = False) -> Dict[str, Any]: + def _set_llmobs_metrics_tags(span: Span, resp: Any) -> Dict[str, Any]: """Extract metrics from a chat/completion and set them as a temporary "_ml_obs.metrics" tag.""" - metrics = {} - if streamed: - prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0 - completion_tokens = span.get_metric("openai.response.usage.completion_tokens") or 0 - metrics.update( - { - INPUT_TOKENS_METRIC_KEY: prompt_tokens, - OUTPUT_TOKENS_METRIC_KEY: completion_tokens, - TOTAL_TOKENS_METRIC_KEY: prompt_tokens + completion_tokens, - } - ) - elif resp: - prompt_tokens = getattr(resp.usage, "prompt_tokens", 0) - completion_tokens = getattr(resp.usage, "completion_tokens", 0) - metrics.update( - { - INPUT_TOKENS_METRIC_KEY: prompt_tokens, - OUTPUT_TOKENS_METRIC_KEY: completion_tokens, - TOTAL_TOKENS_METRIC_KEY: prompt_tokens + completion_tokens, - } - ) - return metrics + token_usage = _get_attr(resp, "usage", None) + if token_usage is not None: + prompt_tokens = _get_attr(token_usage, "prompt_tokens", 0) + completion_tokens = _get_attr(token_usage, "completion_tokens", 0) + return { + INPUT_TOKENS_METRIC_KEY: prompt_tokens, + OUTPUT_TOKENS_METRIC_KEY: completion_tokens, + TOTAL_TOKENS_METRIC_KEY: prompt_tokens + completion_tokens, + } + prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") + completion_tokens = span.get_metric("openai.response.usage.completion_tokens") + if prompt_tokens is None or completion_tokens is None: + return {} + return { + INPUT_TOKENS_METRIC_KEY: prompt_tokens, + OUTPUT_TOKENS_METRIC_KEY: completion_tokens, + TOTAL_TOKENS_METRIC_KEY: prompt_tokens + completion_tokens, + } diff --git a/releasenotes/notes/fix-llmobs-integrations-safe-tagging-5e170868e5758510.yaml b/releasenotes/notes/fix-llmobs-integrations-safe-tagging-5e170868e5758510.yaml new file mode 100644 index 00000000000..785ee135e4e --- /dev/null +++ b/releasenotes/notes/fix-llmobs-integrations-safe-tagging-5e170868e5758510.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + LLM Observability: The OpenAI, LangChain, Anthropic, Bedrock, and Gemini integrations now will handle and log errors + during LLM Observability span processing to avoid disrupting user applications. diff --git a/tests/llmobs/test_llmobs_integrations.py b/tests/llmobs/test_llmobs_integrations.py index e983b2a5d12..f1af5737213 100644 --- a/tests/llmobs/test_llmobs_integrations.py +++ b/tests/llmobs/test_llmobs_integrations.py @@ -190,3 +190,24 @@ def test_integration_trace(mock_integration_config, mock_pin): assert span[0].resource == "dummy_operation_id" assert span[0].service == "dummy_service" mock_set_base_span_tags.assert_called_once() + + +@mock.patch("ddtrace.llmobs._integrations.base.log") +@mock.patch("ddtrace.llmobs._integrations.base.LLMObs") +def test_llmobs_set_tags(mock_llmobs, mock_log, mock_integration_config): + span = DummyTracer().trace("Dummy span", service="dummy_service") + integration = BaseLLMIntegration(mock_integration_config) + integration._llmobs_set_tags = mock.Mock() + integration.llmobs_set_tags(span, args=[], kwargs={}, response="response", operation="operation") + integration._llmobs_set_tags.assert_called_once_with(span, [], {}, "response", "operation") + + integration._llmobs_set_tags = mock.Mock(side_effect=AttributeError("Mocked Exception during _llmobs_set_tags()")) + integration.llmobs_set_tags( + span, args=[1, 2, 3], kwargs={"a": 123}, response=[{"content": "hello"}], operation="operation" + ) + integration._llmobs_set_tags.assert_called_once_with( + span, [1, 2, 3], {"a": 123}, [{"content": "hello"}], "operation" + ) + mock_log.error.assert_called_once_with( + "Error extracting LLMObs fields for span %s, likely due to malformed data", span, exc_info=True + )