Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llmobs): submit langchain tool.invoke tool spans #10410

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import sys
from typing import Any
Expand Down Expand Up @@ -974,6 +975,118 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
return documents


@with_traced_module
def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of any metrics to emit from this function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about datadog metrics? If so we can ignore, that's out of scope for our team and this integration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we use any metrics from those emitted from the langchain integration ?

integration = langchain._datadog_integration
tool_input = get_argument_value(args, kwargs, 0, "input")
config = get_argument_value(args, kwargs, 1, "config", optional=True)

span = integration.trace(
pin,
"%s.%s.%s.%s" % (func.__module__, func.__class__.__name__, func.__name__, func.__self__.name),
interface_type="tool",
submit_to_llmobs=True,
)

tool_output = None
try:
tool_attributes = [
"name",
"description",
]
for attribute in tool_attributes:
value = getattr(instance, attribute, None)
if value:
span.set_tag_str("langchain.request.tool.%s" % attribute, str(value))

if getattr(instance, "metadata", None):
for key, value in instance.metadata.items():
span.set_tag_str("langchain.request.tool.metadata.%s" % key, str(value))
if getattr(instance, "tags", None):
for idx, tag in enumerate(instance.tags):
span.set_tag_str("langchain.request.tool.tags.%d" % idx, str(tag))

if integration.is_pc_sampled_span(span):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pc_sampling refers to prompt/completion sampling, or inputs/outputs in this case. We can safely tag most of the below items without needing to check is_pc_sampled_span(...), only need this check for input and the return value.

if tool_input:
span.set_tag_str("langchain.request.input", integration.trunc(str(tool_input)))
if config:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this enough or should I parse the configs one by one and add them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We don't need to truncate config as that isn't a direct I/O to the tool
  2. We can probably just json dump the config as it is a TypedDict (see source)

span.set_tag_str("langchain.request.config", json.dumps(config))
tool_output = func(*args, **kwargs)
if tool_output is not None:
if integration.is_pc_sampled_span(span):
span.set_tag_str("langchain.response.output", integration.trunc(str(tool_output)))
except Exception:
span.set_exc_info(*sys.exc_info())
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"tool",
span,
tool_input,
tool_output,
error=bool(span.error),
)
span.finish()
return tool_output


@with_traced_module
async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to trace the async variant besides duplicating the code ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately nope 😢 I think there's a couple ways to depollute the patch module (extract tracing functions to separate files, reuse common code in the sync/async functions) but we can keep this for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments apply to the async function as above

integration = langchain._datadog_integration
tool_input = get_argument_value(args, kwargs, 0, "input")
tool_config = get_argument_value(args, kwargs, 1, "config", optional=True)

span = integration.trace(
pin,
"%s" % func.__self__.name,
interface_type="tool",
submit_to_llmobs=True,
)

tool_output = None
try:
tool_attributes = [
"name",
"description",
]
for attribute in tool_attributes:
value = getattr(instance, attribute, None)
if value:
span.set_tag_str("langchain.request.tool.%s" % attribute, str(value))

if getattr(instance, "metadata", None):
for key, value in instance.metadata.items():
span.set_tag_str("langchain.request.tool.metadata.%s" % key, str(value))
if getattr(instance, "tags", None):
for idx, tag in enumerate(instance.tags):
span.set_tag_str("langchain.request.tool.tags.%d" % idx, str(tag))

if integration.is_pc_sampled_span(span):
if tool_input:
span.set_tag_str("langchain.request.input", integration.trunc(str(tool_input)))
if tool_config:
span.set_tag_str("langchain.request.config", json.dumps(tool_config))
tool_output = await func(*args, **kwargs)
if tool_output is not None:
if integration.is_pc_sampled_span(span):
span.set_tag_str("langchain.response.output", integration.trunc(str(tool_output)))
except Exception:
span.set_exc_info(*sys.exc_info())
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"tool",
span,
tool_input,
tool_output,
error=bool(span.error),
)
span.finish()
return tool_output


def _patch_embeddings_and_vectorstores():
"""
Text embedding models override two abstract base methods instead of super calls,
Expand Down Expand Up @@ -1101,6 +1214,8 @@ def patch():
)
wrap("langchain_core", "runnables.base.RunnableSequence.batch", traced_lcel_runnable_sequence(langchain))
wrap("langchain_core", "runnables.base.RunnableSequence.abatch", traced_lcel_runnable_sequence_async(langchain))
wrap("langchain_core", "tools.BaseTool.invoke", traced_base_tool_invoke(langchain))
wrap("langchain_core", "tools.BaseTool.ainvoke", traced_base_tool_ainvoke(langchain))
if langchain_openai:
wrap("langchain_openai", "OpenAIEmbeddings.embed_documents", traced_embedding(langchain))
if langchain_pinecone:
Expand Down
35 changes: 33 additions & 2 deletions ddtrace/llmobs/_integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@
"system": "system",
}

SUPPORTED_OPERATIONS = ["llm", "chat", "chain", "embedding", "retrieval"]
SUPPORTED_OPERATIONS = ["llm", "chat", "chain", "embedding", "retrieval", "tool"]


class LangChainIntegration(BaseLLMIntegration):
_integration_name = "langchain"

def llmobs_set_tags(
self,
operation: str, # oneof "llm","chat","chain","embedding","retrieval"
operation: str, # oneof "llm","chat","chain","embedding","retrieval","tool"
span: Span,
inputs: Any,
response: Any = None,
Expand Down Expand Up @@ -92,6 +92,8 @@ def llmobs_set_tags(
self._llmobs_set_meta_tags_from_embedding(span, inputs, response, error, is_workflow=is_workflow)
elif operation == "retrieval":
self._llmobs_set_meta_tags_from_similarity_search(span, inputs, response, error, is_workflow=is_workflow)
elif operation == "tool":
self._llmobs_set_meta_tags_from_tool(span, inputs, response, error)
span.set_tag_str(METRICS, json.dumps({}))

def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) -> None:
Expand Down Expand Up @@ -304,6 +306,35 @@ def _llmobs_set_meta_tags_from_similarity_search(
except TypeError:
log.warning("Failed to serialize similarity output documents to JSON")

def _llmobs_set_meta_tags_from_tool(
self,
span: Span,
tool_input: Union[str, Dict[str, object], object],
tool_output: object,
error: bool,
) -> None:
span.set_tag_str(SPAN_KIND, "tool")
if tool_input is not None:
try:
formatted_inputs = self.format_io(tool_input)
if isinstance(formatted_inputs, str):
span.set_tag_str(INPUT_VALUE, formatted_inputs)
else:
span.set_tag_str(INPUT_VALUE, json.dumps(self.format_io(tool_input)))
except TypeError:
log.warning("Failed to serialize tool input data to JSON")
if error:
span.set_tag_str(OUTPUT_VALUE, "")
elif tool_output is not None:
try:
formatted_outputs = self.format_io(tool_output)
if isinstance(formatted_outputs, str):
span.set_tag_str(OUTPUT_VALUE, formatted_outputs)
else:
span.set_tag_str(OUTPUT_VALUE, json.dumps(self.format_io(tool_output)))
except TypeError:
log.warning("Failed to serialize tool output data to JSON")

def _set_base_span_tags( # type: ignore[override]
self,
span: Span,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
LLM Observability: The LangChain integration now submits tool spans to LLM Observability.
62 changes: 62 additions & 0 deletions tests/contrib/langchain/test_langchain_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,3 +1274,65 @@ def test_faiss_vectorstore_retrieval(langchain_community, langchain_openai, requ
retriever = faiss.as_retriever()
with request_vcr.use_cassette("openai_retrieval_embedding.yaml"):
retriever.invoke("What was the message of the last test query?")


@pytest.mark.snapshot(
ignores=["meta.langchain.request.tool.description"],
token="tests.contrib.langchain.test_langchain_community.test_base_tool_invoke",
)
def test_base_tool_invoke(langchain_core, request_vcr):
"""
Test that invoking a tool with langchain will
result in a 1-span trace with a tool span.
"""
if langchain_core is None:
pytest.skip("langchain-core not installed which is required for this test.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very small nit: I think it's ever-so-slightly preferable to use a skipif mark in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but the issue is that this specific dependency isn't needed on all tests and if I understand correctly, fixtures are not available on the level of skipif annotations.
So in order to have it available there i'll need to import it on the whole file instead. Which is not very optimal since it's only used on a couple tests.

Do you know of any workaround for this ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of a workaround if you want to keep using langchain_core as a fixture, no.

My personal taste is that it's more traceable to import langchain_core in this module and skipif if it's absent, than to go sort out why the langchain_core fixture is None.

Looks like there are ~10 tests that use the langchain_core fixture, so it looks like it's going to get loaded in just about all CI cases anyway?

Anyway, it's a very small nit. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i completely understand your point.
I believe I'll just keep it as it is right now since I haven't authored the file before this PR and wouldn't want to change the existing structure until I speak to the original authors.
Thanks for giving it a look Romain ! 🙇


from math import pi

from langchain_core.tools import StructuredTool

def circumference_tool(radius: float) -> float:
return float(radius) * 2.0 * pi

calculator = StructuredTool.from_function(
func=circumference_tool,
name="Circumference calculator",
description="Use this tool when you need to calculate a circumference using the radius of a circle",
return_direct=True,
response_format="content",
)

calculator.invoke("2")


@pytest.mark.asyncio
@pytest.mark.snapshot(
ignores=["meta.langchain.request.tool.description"],
token="tests.contrib.langchain.test_langchain_community.test_base_tool_ainvoke",
)
async def test_base_tool_ainvoke(langchain_core, request_vcr):
"""
Test that invoking a tool with langchain will
result in a 1-span trace with a tool span. Async mode
"""

if langchain_core is None:
pytest.skip("langchain-core not installed which is required for this test.")

from math import pi

from langchain_core.tools import StructuredTool

def circumference_tool(radius: float) -> float:
return float(radius) * 2.0 * pi

calculator = StructuredTool.from_function(
func=circumference_tool,
name="Circumference calculator",
description="Use this tool when you need to calculate a circumference using the radius of a circle",
return_direct=True,
response_format="content",
)

await calculator.ainvoke("2")
52 changes: 52 additions & 0 deletions tests/contrib/langchain/test_langchain_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ def _similarity_search(cls, pinecone, pinecone_vector_store, embedding_model, qu
LLMObs.disable()
return mock_tracer.pop_traces()[0]

@classmethod
def _invoke_tool(cls, tool, tool_input, mock_tracer, cassette_name):
LLMObs.enable(ml_app=cls.ml_app, integrations_enabled=False, _tracer=mock_tracer)
if LANGCHAIN_VERSION > (0, 1):
if cassette_name is not None:
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
tool.invoke(tool_input)
else:
tool.invoke(tool_input)
LLMObs.disable()
return mock_tracer.pop_traces()[0][0]


@pytest.mark.skipif(LANGCHAIN_VERSION >= (0, 1), reason="These tests are for langchain < 0.1.0")
class TestLLMObsLangchain(BaseTestLLMObsLangchain):
Expand Down Expand Up @@ -707,6 +719,46 @@ def test_llmobs_similarity_search(self, langchain_openai, langchain_pinecone, mo
)
mock_llmobs_span_writer.enqueue.assert_any_call(expected_span)

def test_llmobs_base_tool_invoke(self, langchain_core, mock_llmobs_span_writer, mock_tracer):
if langchain_core is None:
pytest.skip("langchain-core not installed which is required for this test.")

if langchain_core is None:
pytest.skip("langchain-core not installed which is required for this test.")

from math import pi

from langchain_core.tools import StructuredTool

def circumference_tool(radius: float) -> float:
return float(radius) * 2.0 * pi

calculator = StructuredTool.from_function(
func=circumference_tool,
name="Circumference calculator",
description="Use this tool when you need to calculate a circumference using the radius of a circle",
return_direct=True,
response_format="content",
)

span = self._invoke_tool(
tool=calculator,
tool_input="2",
mock_tracer=mock_tracer,
cassette_name=None,
)
assert mock_llmobs_span_writer.enqueue.call_count == 1
mock_llmobs_span_writer.enqueue.assert_called_with(
_expected_llmobs_non_llm_span_event(
span,
span_kind="tool",
input_value="2",
output_value="12.566370614359172",
tags={"ml_app": "langchain_test"},
integration="langchain",
)
)


@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1), reason="These tests are for langchain >= 0.1.0")
class TestTraceStructureWithLLMIntegrations(SubprocessTestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
[[
{
"name": "langchain.request",
"service": "",
"resource": "langchain_core.tools.method.invoke.Circumference calculator",
"trace_id": 0,
"span_id": 1,
"parent_id": 0,
"type": "",
"error": 0,
"meta": {
"_dd.p.dm": "-0",
"_dd.p.tid": "66cf6d8d00000000",
"langchain.request.input": "2",
"langchain.request.tool.description": "Circumference calculator(radius: float) -> float - Use this tool when you need to calculate a circumference using the radius of a circle",
"langchain.request.tool.name": "Circumference calculator",
"langchain.request.type": "tool",
"langchain.response.output": "12.566370614359172",
"language": "python",
"runtime-id": "fdfa007d5e604a6a880c73034778aa7f"
},
"metrics": {
"_dd.measured": 1,
"_dd.top_level": 1,
"_dd.tracer_kr": 1.0,
"_sampling_priority_v1": 1,
"process_id": 73484
},
"duration": 382000,
"start": 1724870029728802000
},
{
"name": "langchain.request",
"service": "",
"resource": "langchain_core.tools.method.invoke.Circumference calculator",
"trace_id": 0,
"span_id": 2,
"parent_id": 1,
"type": "",
"error": 0,
"meta": {
"langchain.request.input": "2",
"langchain.request.tool.description": "Circumference calculator(radius: float) -> float - Use this tool when you need to calculate a circumference using the radius of a circle",
"langchain.request.tool.name": "Circumference calculator",
"langchain.request.type": "tool",
"langchain.response.output": "12.566370614359172"
},
"metrics": {
"_dd.measured": 1
},
"duration": 247000,
"start": 1724870029728902000
}]]
Loading
Loading