Skip to content

Commit

Permalink
Revert langchain version gating change
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed Jul 27, 2024
1 parent e5e5a9d commit af016f4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 21 deletions.
6 changes: 2 additions & 4 deletions tests/contrib/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
import re
import sys

import langchain as _langchain
import mock
import pytest

from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0
from ddtrace.internal.utils.version import parse_version
from tests.contrib.langchain.utils import get_request_vcr
from tests.contrib.langchain.utils import long_input_text
from tests.utils import override_global_config


pytestmark = pytest.mark.skipif(
parse_version(_langchain.__version__) >= (0, 1, 0), reason="This module only tests langchain < 0.1"
)
pytestmark = pytest.mark.skipif(not PATCH_LANGCHAIN_V0, reason="This module only tests langchain < 0.1")


@pytest.fixture(scope="session")
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/langchain/test_langchain_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import mock
import pytest

from ddtrace.internal.utils.version import parse_version
from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0
from tests.contrib.langchain.utils import get_request_vcr
from tests.utils import flaky
from tests.utils import override_global_config


pytestmark = pytest.mark.skipif(
parse_version(langchain.__version__) < (0, 1, 0) or sys.version_info < (3, 10),
PATCH_LANGCHAIN_V0 or sys.version_info < (3, 10),
reason="This module only tests langchain >= 0.1 and Python 3.10+",
)

Expand Down
27 changes: 12 additions & 15 deletions tests/contrib/langchain/test_langchain_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import os
import sys

import langchain
import mock
import pytest

from ddtrace import patch
from ddtrace.internal.utils.version import parse_version
from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0
from ddtrace.llmobs import LLMObs
from tests.contrib.langchain.utils import get_request_vcr
from tests.contrib.langchain.utils import long_input_text
Expand All @@ -19,9 +18,7 @@
from tests.utils import flaky


LANGCHAIN_VERSION = parse_version(langchain.__version__)

if LANGCHAIN_VERSION < (0, 1, 0):
if PATCH_LANGCHAIN_V0:
from langchain.schema import AIMessage
from langchain.schema import ChatMessage
from langchain.schema import HumanMessage
Expand Down Expand Up @@ -91,7 +88,7 @@ class BaseTestLLMObsLangchain:
def _invoke_llm(cls, llm, prompt, mock_tracer, cassette_name):
LLMObs.enable(ml_app=cls.ml_app, integrations_enabled=False, _tracer=mock_tracer)
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
if LANGCHAIN_VERSION < (0, 1, 0):
if PATCH_LANGCHAIN_V0:
llm(prompt)
else:
llm.invoke(prompt)
Expand All @@ -106,7 +103,7 @@ def _invoke_chat(cls, chat_model, prompt, mock_tracer, cassette_name, role="user
messages = [HumanMessage(content=prompt)]
else:
messages = [ChatMessage(content=prompt, role="custom")]
if LANGCHAIN_VERSION < (0, 1, 0):
if PATCH_LANGCHAIN_V0:
chat_model(messages)
else:
chat_model.invoke(messages)
Expand All @@ -119,7 +116,7 @@ def _invoke_chain(cls, chain, prompt, mock_tracer, cassette_name, batch=False):
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
if batch:
chain.batch(inputs=prompt)
elif LANGCHAIN_VERSION < (0, 1, 0):
elif PATCH_LANGCHAIN_V0:
chain.run(prompt)
else:
chain.invoke(prompt)
Expand Down Expand Up @@ -147,20 +144,20 @@ def _embed_documents(cls, embedding_model, documents, mock_tracer, cassette_name
return mock_tracer.pop_traces()[0]


@pytest.mark.skipif(LANGCHAIN_VERSION >= (0, 1, 0), reason="These tests are for langchain < 0.1.0")
@pytest.mark.skipif(not PATCH_LANGCHAIN_V0, reason="These tests are for langchain < 0.1.0")
class TestLLMObsLangchain(BaseTestLLMObsLangchain):
cassette_subdirectory_name = "langchain"

# @pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires unnecessary cassette file for Python 3.9")
def test_llmobs_openai_llm(self, langchain, mock_llmobs_span_writer, mock_tracer):
if sys.version_info < (3, 10):
assert 0, "sys.version_info < (3, 10) results in True even though sys.version_info == " % sys.version_info
span = self._invoke_llm(
llm=langchain.llms.OpenAI(model="gpt-3.5-turbo-instruct"),
prompt="Can you explain what Descartes meant by 'I think, therefore I am'?",
mock_tracer=mock_tracer,
cassette_name="openai_completion_sync.yaml",
)
if sys.version_info < (3, 10):
assert 0, "sys.version_info < (3, 10) results in True even though sys.version_info == " % sys.version_info
assert mock_llmobs_span_writer.enqueue.call_count == 1
_assert_expected_llmobs_llm_span(span, mock_llmobs_span_writer)

Expand Down Expand Up @@ -403,19 +400,19 @@ def test_llmobs_embedding_documents(self, langchain, mock_llmobs_span_writer, mo


@flaky(1735812000, reason="Community cassette tests are flaky")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0")
@pytest.mark.skipif(PATCH_LANGCHAIN_V0, reason="These tests are for langchain >= 0.1.0")
class TestLLMObsLangchainCommunity(BaseTestLLMObsLangchain):
cassette_subdirectory_name = "langchain_community"

def test_llmobs_openai_llm(self, langchain_openai, mock_llmobs_span_writer, mock_tracer):
if sys.version_info < (3, 10):
assert 0, "sys.version_info < (3, 10) results in True even though sys.version_info == " % sys.version_info
span = self._invoke_llm(
llm=langchain_openai.OpenAI(),
prompt="Can you explain what Descartes meant by 'I think, therefore I am'?",
mock_tracer=mock_tracer,
cassette_name="openai_completion_sync.yaml",
)
if sys.version_info < (3, 10):
assert 0, "sys.version_info < (3, 10) results in True even though sys.version_info == " % sys.version_info
assert mock_llmobs_span_writer.enqueue.call_count == 1
_assert_expected_llmobs_llm_span(span, mock_llmobs_span_writer)

Expand Down Expand Up @@ -655,7 +652,7 @@ def test_llmobs_embedding_documents(


@flaky(1735812000, reason="Community cassette tests are flaky")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0")
@pytest.mark.skipif(PATCH_LANGCHAIN_V0, reason="These tests are for langchain >= 0.1.0")
class TestLangchainTraceStructureWithLlmIntegrations(SubprocessTestCase):
bedrock_env_config = dict(
AWS_ACCESS_KEY_ID="testing",
Expand Down

0 comments on commit af016f4

Please sign in to comment.