diff --git a/tests/contrib/langchain/test_langchain.py b/tests/contrib/langchain/test_langchain.py index 4dc8fc54622..9be17c51b5a 100644 --- a/tests/contrib/langchain/test_langchain.py +++ b/tests/contrib/langchain/test_langchain.py @@ -2,19 +2,17 @@ import re import sys -import langchain as _langchain import mock import pytest +from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0 from ddtrace.internal.utils.version import parse_version from tests.contrib.langchain.utils import get_request_vcr from tests.contrib.langchain.utils import long_input_text from tests.utils import override_global_config -pytestmark = pytest.mark.skipif( - parse_version(_langchain.__version__) >= (0, 1, 0), reason="This module only tests langchain < 0.1" -) +pytestmark = pytest.mark.skipif(not PATCH_LANGCHAIN_V0, reason="This module only tests langchain < 0.1") @pytest.fixture(scope="session") diff --git a/tests/contrib/langchain/test_langchain_community.py b/tests/contrib/langchain/test_langchain_community.py index eb7b66e0558..908bd5c67a5 100644 --- a/tests/contrib/langchain/test_langchain_community.py +++ b/tests/contrib/langchain/test_langchain_community.py @@ -8,14 +8,14 @@ import mock import pytest -from ddtrace.internal.utils.version import parse_version +from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0 from tests.contrib.langchain.utils import get_request_vcr from tests.utils import flaky from tests.utils import override_global_config pytestmark = pytest.mark.skipif( - parse_version(langchain.__version__) < (0, 1, 0) or sys.version_info < (3, 10), + PATCH_LANGCHAIN_V0 or sys.version_info < (3, 10), reason="This module only tests langchain >= 0.1 and Python 3.10+", ) diff --git a/tests/contrib/langchain/test_langchain_llmobs.py b/tests/contrib/langchain/test_langchain_llmobs.py index 65d1fc4c2b7..70ca6016bba 100644 --- a/tests/contrib/langchain/test_langchain_llmobs.py +++ b/tests/contrib/langchain/test_langchain_llmobs.py @@ -3,12 +3,11 @@ import os import sys -import langchain import mock import pytest from ddtrace import patch -from ddtrace.internal.utils.version import parse_version +from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0 from ddtrace.llmobs import LLMObs from tests.contrib.langchain.utils import get_request_vcr from tests.contrib.langchain.utils import long_input_text @@ -19,9 +18,7 @@ from tests.utils import flaky -LANGCHAIN_VERSION = parse_version(langchain.__version__) - -if LANGCHAIN_VERSION < (0, 1, 0): +if PATCH_LANGCHAIN_V0: from langchain.schema import AIMessage from langchain.schema import ChatMessage from langchain.schema import HumanMessage @@ -91,7 +88,7 @@ class BaseTestLLMObsLangchain: def _invoke_llm(cls, llm, prompt, mock_tracer, cassette_name): LLMObs.enable(ml_app=cls.ml_app, integrations_enabled=False, _tracer=mock_tracer) with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name): - if LANGCHAIN_VERSION < (0, 1, 0): + if PATCH_LANGCHAIN_V0: llm(prompt) else: llm.invoke(prompt) @@ -106,7 +103,7 @@ def _invoke_chat(cls, chat_model, prompt, mock_tracer, cassette_name, role="user messages = [HumanMessage(content=prompt)] else: messages = [ChatMessage(content=prompt, role="custom")] - if LANGCHAIN_VERSION < (0, 1, 0): + if PATCH_LANGCHAIN_V0: chat_model(messages) else: chat_model.invoke(messages) @@ -119,7 +116,7 @@ def _invoke_chain(cls, chain, prompt, mock_tracer, cassette_name, batch=False): with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name): if batch: chain.batch(inputs=prompt) - elif LANGCHAIN_VERSION < (0, 1, 0): + elif PATCH_LANGCHAIN_V0: chain.run(prompt) else: chain.invoke(prompt) @@ -147,7 +144,7 @@ def _embed_documents(cls, embedding_model, documents, mock_tracer, cassette_name return mock_tracer.pop_traces()[0] -@pytest.mark.skipif(LANGCHAIN_VERSION >= (0, 1, 0), reason="These tests are for langchain < 0.1.0") +@pytest.mark.skipif(not PATCH_LANGCHAIN_V0, reason="These tests are for langchain < 0.1.0") class TestLLMObsLangchain(BaseTestLLMObsLangchain): cassette_subdirectory_name = "langchain" @@ -403,19 +400,19 @@ def test_llmobs_embedding_documents(self, langchain, mock_llmobs_span_writer, mo @flaky(1735812000, reason="Community cassette tests are flaky") -@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0") +@pytest.mark.skipif(PATCH_LANGCHAIN_V0, reason="These tests are for langchain >= 0.1.0") class TestLLMObsLangchainCommunity(BaseTestLLMObsLangchain): cassette_subdirectory_name = "langchain_community" def test_llmobs_openai_llm(self, langchain_openai, mock_llmobs_span_writer, mock_tracer): + if sys.version_info < (3, 10): + assert 0, "sys.version_info < (3, 10) results in True even though sys.version_info == " % sys.version_info span = self._invoke_llm( llm=langchain_openai.OpenAI(), prompt="Can you explain what Descartes meant by 'I think, therefore I am'?", mock_tracer=mock_tracer, cassette_name="openai_completion_sync.yaml", ) - if sys.version_info < (3, 10): - assert 0, "sys.version_info < (3, 10) results in True even though sys.version_info == " % sys.version_info assert mock_llmobs_span_writer.enqueue.call_count == 1 _assert_expected_llmobs_llm_span(span, mock_llmobs_span_writer) @@ -655,7 +652,7 @@ def test_llmobs_embedding_documents( @flaky(1735812000, reason="Community cassette tests are flaky") -@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0") +@pytest.mark.skipif(PATCH_LANGCHAIN_V0, reason="These tests are for langchain >= 0.1.0") class TestLangchainTraceStructureWithLlmIntegrations(SubprocessTestCase): bedrock_env_config = dict( AWS_ACCESS_KEY_ID="testing",