Skip to content

Commit

Permalink
Remove skip marker for langchain community tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed Jul 29, 2024
1 parent c061dce commit 2c5b38c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/contrib/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


pytestmark = pytest.mark.skipif(
parse_version(_langchain.__version__) >= (0, 1, 0), reason="This module only tests langchain < 0.1"
parse_version(_langchain.__version__) >= (0, 1), reason="This module only tests langchain < 0.1"
)

PY39 = sys.version_info < (3, 10)
Expand Down
9 changes: 3 additions & 6 deletions tests/contrib/langchain/test_langchain_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

LANGCHAIN_VERSION = parse_version(langchain.__version__)

pytestmark = pytest.mark.skipif(
LANGCHAIN_VERSION < (0, 1, 0) or sys.version_info < (3, 10),
reason="This module only tests langchain >= 0.1 and Python 3.10+",
)
pytestmark = pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="This module only tests langchain >= 0.1")

IGNORE_FIELDS = [
"resources",
Expand Down Expand Up @@ -128,15 +125,15 @@ def test_openai_llm_error(langchain, langchain_openai, request_vcr):
llm.generate([12345, 123456])


@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 2, 0), reason="Requires separate cassette for langchain v0.1")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 2), reason="Requires separate cassette for langchain v0.1")
@pytest.mark.snapshot
def test_cohere_llm_sync(langchain_cohere, request_vcr):
llm = langchain_cohere.llms.Cohere(cohere_api_key=os.getenv("COHERE_API_KEY", "<not-a-real-key>"))
with request_vcr.use_cassette("cohere_completion_sync.yaml"):
llm.invoke("What is the secret Krabby Patty recipe?")


@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 2, 0), reason="Requires separate cassette for langchain v0.1")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 2), reason="Requires separate cassette for langchain v0.1")
@pytest.mark.snapshot
def test_ai21_llm_sync(langchain_community, request_vcr):
if langchain_community is None:
Expand Down
14 changes: 7 additions & 7 deletions tests/contrib/langchain/test_langchain_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LANGCHAIN_VERSION = parse_version(langchain_.__version__)
PY39 = sys.version_info < (3, 10)

if LANGCHAIN_VERSION < (0, 1, 0):
if LANGCHAIN_VERSION < (0, 1):
from langchain.schema import AIMessage
from langchain.schema import ChatMessage
from langchain.schema import HumanMessage
Expand Down Expand Up @@ -92,7 +92,7 @@ class BaseTestLLMObsLangchain:
def _invoke_llm(cls, llm, prompt, mock_tracer, cassette_name):
LLMObs.enable(ml_app=cls.ml_app, integrations_enabled=False, _tracer=mock_tracer)
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
if LANGCHAIN_VERSION < (0, 1, 0):
if LANGCHAIN_VERSION < (0, 1):
llm(prompt)
else:
llm.invoke(prompt)
Expand All @@ -107,7 +107,7 @@ def _invoke_chat(cls, chat_model, prompt, mock_tracer, cassette_name, role="user
messages = [HumanMessage(content=prompt)]
else:
messages = [ChatMessage(content=prompt, role="custom")]
if LANGCHAIN_VERSION < (0, 1, 0):
if LANGCHAIN_VERSION < (0, 1):
chat_model(messages)
else:
chat_model.invoke(messages)
Expand All @@ -120,15 +120,15 @@ def _invoke_chain(cls, chain, prompt, mock_tracer, cassette_name, batch=False):
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
if batch:
chain.batch(inputs=prompt)
elif LANGCHAIN_VERSION < (0, 1, 0):
elif LANGCHAIN_VERSION < (0, 1):
chain.run(prompt)
else:
chain.invoke(prompt)
LLMObs.disable()
return mock_tracer.pop_traces()[0]


@pytest.mark.skipif(LANGCHAIN_VERSION >= (0, 1, 0), reason="These tests are for langchain < 0.1.0")
@pytest.mark.skipif(LANGCHAIN_VERSION >= (0, 1), reason="These tests are for langchain < 0.1.0")
class TestLLMObsLangchain(BaseTestLLMObsLangchain):
cassette_subdirectory_name = "langchain"

Expand Down Expand Up @@ -316,7 +316,7 @@ def test_llmobs_chain_schema_io(self, langchain, mock_llmobs_span_writer, mock_t
_assert_expected_llmobs_llm_span(trace[1], mock_llmobs_span_writer, mock_io=True)


@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1), reason="These tests are for langchain >= 0.1.0")
class TestLLMObsLangchainCommunity(BaseTestLLMObsLangchain):
cassette_subdirectory_name = "langchain_community"

Expand Down Expand Up @@ -500,7 +500,7 @@ def test_llmobs_anthropic_chat_model(self, langchain_anthropic, mock_llmobs_span
_assert_expected_llmobs_llm_span(span, mock_llmobs_span_writer, input_role="user")


@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1), reason="These tests are for langchain >= 0.1.0")
class TestLangchainTraceStructureWithLlmIntegrations(SubprocessTestCase):
bedrock_env_config = dict(
AWS_ACCESS_KEY_ID="testing",
Expand Down

0 comments on commit 2c5b38c

Please sign in to comment.