Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llmobs): submit span events for chains from langchain #8920

Merged
merged 24 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions ddtrace/contrib/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,13 @@ def traced_embedding(langchain, pin, func, instance, args, kwargs):
@with_traced_module
def traced_chain_call(langchain, pin, func, instance, args, kwargs):
integration = langchain._datadog_integration
span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain")
span = integration.trace(
pin,
"{}.{}".format(instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="chain",
)
inputs = None
final_outputs = {}
try:
if SHOULD_PATCH_LANGCHAIN_COMMUNITY:
Expand All @@ -620,6 +626,8 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error))
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand All @@ -645,7 +653,13 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs):
@with_traced_module
async def traced_chain_acall(langchain, pin, func, instance, args, kwargs):
integration = langchain._datadog_integration
span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain")
span = integration.trace(
pin,
"{}.{}".format(instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="chain",
)
inputs = None
final_outputs = {}
try:
if SHOULD_PATCH_LANGCHAIN_COMMUNITY:
Expand All @@ -669,6 +683,8 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error))
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -706,7 +722,14 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
This method captures the initial inputs to the chain, as well as the final outputs, and tags them appropriately.
"""
integration = langchain._datadog_integration
span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain")
span = integration.trace(
pin,
"{}.{}".format(instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="chain",
)
inputs = None
final_output = None
try:
inputs = get_argument_value(args, kwargs, 0, "input")
if integration.is_pc_sampled_span(span):
Expand All @@ -730,6 +753,8 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error))
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
return final_output
Expand All @@ -741,7 +766,14 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar
Similar to `traced_lcel_runnable_sequence`, but for async chaining calls.
"""
integration = langchain._datadog_integration
span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain")
span = integration.trace(
pin,
"{}.{}".format(instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="chain",
)
inputs = None
final_output = None
try:
inputs = get_argument_value(args, kwargs, 0, "input")
if integration.is_pc_sampled_span(span):
Expand All @@ -765,6 +797,8 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error))
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
return final_output
Expand Down
33 changes: 28 additions & 5 deletions ddtrace/llmobs/_integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from ddtrace import config
from ddtrace._trace.span import Span
from ddtrace.constants import ERROR_TYPE
from ddtrace.llmobs._constants import INPUT_MESSAGES
from ddtrace.llmobs._constants import INPUT_PARAMETERS
from ddtrace.llmobs._constants import INPUT_VALUE
from ddtrace.llmobs._constants import METRICS
from ddtrace.llmobs._constants import MODEL_NAME
from ddtrace.llmobs._constants import MODEL_PROVIDER
from ddtrace.llmobs._constants import OUTPUT_MESSAGES
from ddtrace.llmobs._constants import OUTPUT_VALUE
from ddtrace.llmobs._constants import SPAN_KIND

from .base import BaseLLMIntegration
Expand All @@ -39,14 +42,13 @@ def llmobs_set_tags(
operation: str, # oneof "llm","chat","chain"
span: Span,
inputs: Any,
response: Any,
response: Any = None,
error: bool = False,
) -> None:
"""Sets meta tags and metrics for span events to be sent to LLMObs."""
if not self.llmobs_enabled:
return
model_provider = span.get_tag(PROVIDER)
span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "")
span.set_tag_str(MODEL_PROVIDER, model_provider or "")

Expand All @@ -57,7 +59,7 @@ def llmobs_set_tags(
elif operation == "chat":
self._llmobs_set_meta_tags_from_chat_model(span, inputs, response, error)
elif operation == "chain":
pass
self._llmobs_set_meta_tags_from_chain(span, inputs, response, error)

span.set_tag_str(METRICS, json.dumps({}))

Expand All @@ -79,9 +81,9 @@ def _llmobs_set_input_parameters(
or span.get_tag(f"langchain.request.{model_provider}.parameters.model_kwargs.max_tokens") # huggingface
)

if temperature:
if temperature is not None:
input_parameters["temperature"] = float(temperature)
if max_tokens:
if max_tokens is not None:
input_parameters["max_tokens"] = int(max_tokens)
if input_parameters:
span.set_tag_str(INPUT_PARAMETERS, json.dumps(input_parameters))
Expand All @@ -93,6 +95,8 @@ def _llmobs_set_meta_tags_from_llm(
completions: Any,
err: bool = False,
) -> None:
span.set_tag_str(SPAN_KIND, "llm")

if isinstance(prompts, str):
prompts = [prompts]
span.set_tag_str(INPUT_MESSAGES, json.dumps([{"content": str(prompt)} for prompt in prompts]))
Expand All @@ -109,6 +113,8 @@ def _llmobs_set_meta_tags_from_chat_model(
chat_completions: Any,
err: bool = False,
) -> None:
span.set_tag_str(SPAN_KIND, "llm")

input_messages = []
for message_set in chat_messages:
for message in message_set:
Expand Down Expand Up @@ -136,6 +142,23 @@ def _llmobs_set_meta_tags_from_chat_model(
)
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages))

def _llmobs_set_meta_tags_from_chain(
self,
span: Span,
inputs: Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]],
outputs: Any,
error: bool = False,
) -> None:
span.set_tag_str(SPAN_KIND, "workflow")

if inputs is not None:
span.set_tag_str(INPUT_VALUE, str(inputs))

if error:
span.set_tag_str(OUTPUT_VALUE, "")
elif outputs is not None:
span.set_tag_str(OUTPUT_VALUE, str(outputs))

def _set_base_span_tags( # type: ignore[override]
self,
span: Span,
Expand Down
Loading
Loading