Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llmobs): modify integration patch behavior in llmobs.enable() [backport 2.9] #9434

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 9 additions & 34 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -50,11 +49,7 @@
log = get_logger(__name__)


SUPPORTED_INTEGRATIONS = {
"bedrock": lambda: patch(botocore=True),
"langchain": lambda: patch(langchain=True),
"openai": lambda: patch(openai=True),
}
SUPPORTED_LLMOBS_INTEGRATIONS = {"bedrock": "botocore", "openai": "openai", "langchain": "langchain"}


class LLMObs(Service):
Expand Down Expand Up @@ -105,7 +100,7 @@ def _stop_service(self) -> None:
def enable(
cls,
ml_app: Optional[str] = None,
integrations: Optional[List[str]] = None,
integrations_enabled: bool = True,
agentless_enabled: bool = False,
site: Optional[str] = None,
api_key: Optional[str] = None,
Expand All @@ -117,8 +112,7 @@ def enable(
Enable LLM Observability tracing.

:param str ml_app: The name of your ml application.
:param List[str] integrations: A list of integrations to enable auto-tracing for.
Must be subset of ("openai", "langchain", "bedrock")
:param bool integrations_enabled: Set to `true` to enable LLM integrations.
:param bool agentless_enabled: Set to `true` to disable sending data that requires a Datadog Agent.
:param str site: Your datadog site.
:param str api_key: Your datadog api key.
Expand Down Expand Up @@ -170,7 +164,8 @@ def enable(
log.debug("Remote configuration disabled because DD_LLMOBS_AGENTLESS_ENABLED is set to true.")
remoteconfig_poller.disable()

cls._patch_integrations(integrations)
if integrations_enabled:
cls._patch_integrations()
# override the default _instance with a new tracer
cls._instance = cls(tracer=_tracer)
cls.enabled = True
Expand Down Expand Up @@ -207,30 +202,10 @@ def flush(cls):
log.warning("Failed to flush LLMObs spans and evaluation metrics.", exc_info=True)

@staticmethod
def _patch_integrations(integrations: Optional[List[str]] = None):
"""
Patch LLM integrations based on a list of integrations passed in. Patch all supported integrations by default.
"""
integrations_to_patch = {}
if integrations is None:
integrations_to_patch.update(SUPPORTED_INTEGRATIONS)
else:
for integration in integrations:
integration = integration.lower()
if integration in SUPPORTED_INTEGRATIONS:
integrations_to_patch.update({integration: SUPPORTED_INTEGRATIONS[integration]})
else:
log.warning(
"%s is unsupported - LLMObs currently supports %s",
integration,
str(SUPPORTED_INTEGRATIONS.keys()),
)
for integration in integrations_to_patch:
try:
SUPPORTED_INTEGRATIONS[integration]()
except Exception:
log.warning("couldn't patch %s", integration, exc_info=True)
return
def _patch_integrations() -> None:
"""Patch LLM integrations."""
patch(**{integration: True for integration in SUPPORTED_LLMOBS_INTEGRATIONS.values()}) # type: ignore[arg-type]
log.debug("Patched LLM integrations: %s", list(SUPPORTED_LLMOBS_INTEGRATIONS.values()))

@classmethod
def export_span(cls, span: Optional[Span] = None) -> Optional[ExportedLLMObsSpan]:
Expand Down
6 changes: 3 additions & 3 deletions tests/contrib/botocore/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _test_llmobs_invoke(cls, provider, bedrock_client, mock_llmobs_span_writer,
pin.override(bedrock_client, tracer=mock_tracer)
# Need to disable and re-enable LLMObs service to use the mock tracer
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["bedrock"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want botocore patched

if cassette_name is None:
cassette_name = "%s_invoke.yaml" % provider
Expand Down Expand Up @@ -524,7 +524,7 @@ def _test_llmobs_invoke_stream(
pin.override(bedrock_client, tracer=mock_tracer)
# Need to disable and re-enable LLMObs service to use the mock tracer
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["bedrock"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want botocore patched

if cassette_name is None:
cassette_name = "%s_invoke_stream.yaml" % provider
Expand Down Expand Up @@ -624,7 +624,7 @@ def test_llmobs_error(self, ddtrace_global_config, bedrock_client, mock_llmobs_s
pin.override(bedrock_client, tracer=mock_tracer)
# Need to disable and re-enable LLMObs service to use the mock tracer
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["bedrock"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want botocore patched
with pytest.raises(botocore.exceptions.ClientError):
with request_vcr.use_cassette("meta_invoke_error.yaml"):
body, model = json.dumps(_REQUEST_BODIES["meta"]), _MODELS["meta"]
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def _test_llmobs_llm_invoke(
different_py39_cassette=False,
):
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

if sys.version_info < (3, 10, 0) and different_py39_cassette:
cassette_name = cassette_name.replace(".yaml", "_39.yaml")
Expand Down Expand Up @@ -1388,7 +1388,7 @@ def _test_llmobs_chain_invoke(
):
# disable the service before re-enabling it, as it was enabled in another test
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

if sys.version_info < (3, 10, 0) and different_py39_cassette:
cassette_name = cassette_name.replace(".yaml", "_39.yaml")
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/langchain/test_langchain_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@ def _test_llmobs_llm_invoke(
output_role=None,
):
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

with request_vcr.use_cassette(cassette_name):
generate_trace("Can you explain what an LLM chain is?")
Expand Down Expand Up @@ -1372,7 +1372,7 @@ def _test_llmobs_chain_invoke(
):
# disable the service before re-enabling it, as it was enabled in another test
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

with request_vcr.use_cassette(cassette_name):
generate_trace("Can you explain what an LLM chain is?")
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/openai/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def mock_tracer(ddtrace_global_config, openai, patch_openai, mock_logs, mock_met
if ddtrace_global_config.get("_llmobs_enabled", False):
# Have to disable and re-enable LLMObs to use to mock tracer.
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["openai"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False)

yield mock_tracer

Expand Down
3 changes: 0 additions & 3 deletions tests/contrib/openai/test_openai_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,6 @@ def test_integration_service_name(openai_api_key, ddtrace_run_python_code_in_sub
)
def test_llmobs_completion(openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""Ensure llmobs records are emitted for completion endpoints when configured.

Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
with openai_vcr.use_cassette("completion.yaml"):
Expand Down Expand Up @@ -1990,7 +1989,6 @@ def test_llmobs_completion_stream(openai_vcr, openai, ddtrace_global_config, moc
)
def test_llmobs_chat_completion(openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.

Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
if not hasattr(openai, "ChatCompletion"):
Expand Down Expand Up @@ -2033,7 +2031,6 @@ async def test_llmobs_chat_completion_stream(
openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer
):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.

Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
if not hasattr(openai, "ChatCompletion"):
Expand Down
Loading