Skip to content

Commit

Permalink
Change integrations --> integrations_enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed May 27, 2024
1 parent be9936b commit dee98a9
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 46 deletions.
44 changes: 9 additions & 35 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -50,12 +49,7 @@
log = get_logger(__name__)


SUPPORTED_INTEGRATIONS = {
"bedrock": lambda: patch(botocore=True),
"langchain": lambda: patch(langchain=True),
"openai": lambda: patch(openai=True),
}

SUPPORTED_LLMOBS_INTEGRATIONS = {"bedrock": "botocore", "openai": "openai", "langchain": "langchain"}

class LLMObs(Service):
_instance = None # type: LLMObs
Expand Down Expand Up @@ -105,7 +99,7 @@ def _stop_service(self) -> None:
def enable(
cls,
ml_app: Optional[str] = None,
integrations: Optional[List[str]] = None,
integrations_enabled: bool = True,
agentless_enabled: bool = False,
site: Optional[str] = None,
api_key: Optional[str] = None,
Expand All @@ -117,8 +111,7 @@ def enable(
Enable LLM Observability tracing.
:param str ml_app: The name of your ml application.
:param List[str] integrations: A list of integrations to enable auto-tracing for.
Must be subset of ("openai", "langchain", "bedrock")
:param bool integrations_enabled: Set to `true` to enable LLM integrations.
:param bool agentless_enabled: Set to `true` to disable sending data that requires a Datadog Agent.
:param str site: Your datadog site.
:param str api_key: Your datadog api key.
Expand Down Expand Up @@ -170,7 +163,8 @@ def enable(
log.debug("Remote configuration disabled because DD_LLMOBS_AGENTLESS_ENABLED is set to true.")
remoteconfig_poller.disable()

cls._patch_integrations(integrations)
if integrations_enabled:
cls._patch_integrations()
# override the default _instance with a new tracer
cls._instance = cls(tracer=_tracer)
cls.enabled = True
Expand Down Expand Up @@ -207,30 +201,10 @@ def flush(cls):
log.warning("Failed to flush LLMObs spans and evaluation metrics.", exc_info=True)

@staticmethod
def _patch_integrations(integrations: Optional[List[str]] = None):
"""
Patch LLM integrations based on a list of integrations passed in. Patch all supported integrations by default.
"""
integrations_to_patch = {}
if integrations is None:
integrations_to_patch.update(SUPPORTED_INTEGRATIONS)
else:
for integration in integrations:
integration = integration.lower()
if integration in SUPPORTED_INTEGRATIONS:
integrations_to_patch.update({integration: SUPPORTED_INTEGRATIONS[integration]})
else:
log.warning(
"%s is unsupported - LLMObs currently supports %s",
integration,
str(SUPPORTED_INTEGRATIONS.keys()),
)
for integration in integrations_to_patch:
try:
SUPPORTED_INTEGRATIONS[integration]()
except Exception:
log.warning("couldn't patch %s", integration, exc_info=True)
return
def _patch_integrations() -> None:
"""Patch LLM integrations."""
patch(**{integration: True for integration in SUPPORTED_LLMOBS_INTEGRATIONS.values()}) # type: ignore[arg-type]
log.debug("Patched LLM integrations: %s", list(SUPPORTED_LLMOBS_INTEGRATIONS.values()))

@classmethod
def export_span(cls, span: Optional[Span] = None) -> Optional[ExportedLLMObsSpan]:
Expand Down
6 changes: 3 additions & 3 deletions tests/contrib/botocore/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _test_llmobs_invoke(cls, provider, bedrock_client, mock_llmobs_span_writer,
pin.override(bedrock_client, tracer=mock_tracer)
# Need to disable and re-enable LLMObs service to use the mock tracer
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["bedrock"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want botocore patched

if cassette_name is None:
cassette_name = "%s_invoke.yaml" % provider
Expand Down Expand Up @@ -524,7 +524,7 @@ def _test_llmobs_invoke_stream(
pin.override(bedrock_client, tracer=mock_tracer)
# Need to disable and re-enable LLMObs service to use the mock tracer
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["bedrock"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want botocore patched

if cassette_name is None:
cassette_name = "%s_invoke_stream.yaml" % provider
Expand Down Expand Up @@ -624,7 +624,7 @@ def test_llmobs_error(self, ddtrace_global_config, bedrock_client, mock_llmobs_s
pin.override(bedrock_client, tracer=mock_tracer)
# Need to disable and re-enable LLMObs service to use the mock tracer
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["bedrock"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want botocore patched
with pytest.raises(botocore.exceptions.ClientError):
with request_vcr.use_cassette("meta_invoke_error.yaml"):
body, model = json.dumps(_REQUEST_BODIES["meta"]), _MODELS["meta"]
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def _test_llmobs_llm_invoke(
different_py39_cassette=False,
):
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

if sys.version_info < (3, 10, 0) and different_py39_cassette:
cassette_name = cassette_name.replace(".yaml", "_39.yaml")
Expand Down Expand Up @@ -1388,7 +1388,7 @@ def _test_llmobs_chain_invoke(
):
# disable the service before re-enabling it, as it was enabled in another test
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

if sys.version_info < (3, 10, 0) and different_py39_cassette:
cassette_name = cassette_name.replace(".yaml", "_39.yaml")
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/langchain/test_langchain_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@ def _test_llmobs_llm_invoke(
output_role=None,
):
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

with request_vcr.use_cassette(cassette_name):
generate_trace("Can you explain what an LLM chain is?")
Expand Down Expand Up @@ -1372,7 +1372,7 @@ def _test_llmobs_chain_invoke(
):
# disable the service before re-enabling it, as it was enabled in another test
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["langchain"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) # only want langchain patched

with request_vcr.use_cassette(cassette_name):
generate_trace("Can you explain what an LLM chain is?")
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/openai/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def mock_tracer(ddtrace_global_config, openai, patch_openai, mock_logs, mock_met
if ddtrace_global_config.get("_llmobs_enabled", False):
# Have to disable and re-enable LLMObs to use to mock tracer.
LLMObs.disable()
LLMObs.enable(_tracer=mock_tracer, integrations=["openai"])
LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False)

yield mock_tracer

Expand Down
3 changes: 0 additions & 3 deletions tests/contrib/openai/test_openai_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,6 @@ def test_integration_service_name(openai_api_key, ddtrace_run_python_code_in_sub
)
def test_llmobs_completion(openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""Ensure llmobs records are emitted for completion endpoints when configured.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
with openai_vcr.use_cassette("completion.yaml"):
Expand Down Expand Up @@ -1990,7 +1989,6 @@ def test_llmobs_completion_stream(openai_vcr, openai, ddtrace_global_config, moc
)
def test_llmobs_chat_completion(openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
if not hasattr(openai, "ChatCompletion"):
Expand Down Expand Up @@ -2033,7 +2031,6 @@ async def test_llmobs_chat_completion_stream(
openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer
):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
if not hasattr(openai, "ChatCompletion"):
Expand Down

0 comments on commit dee98a9

Please sign in to comment.