From 240c5f396e95838085bd10626b2ff4e0c18bdd4e Mon Sep 17 00:00:00 2001 From: Peter Wieringa Date: Tue, 9 Jan 2024 10:41:31 +0100 Subject: [PATCH] refactor: Updated code to Pydantic 2.0 (get rid of deprecation warnings) (#156) Co-authored-by: wieri494 --- kstreams/backends/kafka.py | 34 ++++++++++++++++------------------ kstreams/engine.py | 2 +- kstreams/streams.py | 4 ++-- pyproject.toml | 2 +- tests/test_backend_kafka.py | 2 +- tests/test_consumer.py | 2 +- tests/test_producer.py | 8 ++++---- 7 files changed, 26 insertions(+), 28 deletions(-) diff --git a/kstreams/backends/kafka.py b/kstreams/backends/kafka.py index 5c81cb6..077574a 100644 --- a/kstreams/backends/kafka.py +++ b/kstreams/backends/kafka.py @@ -2,7 +2,7 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, ConfigDict, model_validator class SecurityProtocol(str, Enum): @@ -60,53 +60,51 @@ class Kafka(BaseModel): sasl_plain_username: Optional[str] = None sasl_plain_password: Optional[str] = None sasl_oauth_token_provider: Optional[str] = None + model_config = ConfigDict(arbitrary_types_allowed=True, use_enum_values=True) - class Config: - arbitrary_types_allowed = True - use_enum_values = True - - @root_validator(skip_on_failure=True) + @model_validator(mode="after") + @classmethod def protocols_validation(cls, values): - security_protocol = values["security_protocol"] + security_protocol = values.security_protocol if security_protocol == SecurityProtocol.PLAINTEXT: return values elif security_protocol == SecurityProtocol.SSL: - if values["ssl_context"] is None: + if values.ssl_context is None: raise ValueError("`ssl_context` is required") return values elif security_protocol == SecurityProtocol.SASL_PLAINTEXT: - if values["sasl_mechanism"] is SaslMechanism.OAUTHBEARER: + if values.sasl_mechanism is SaslMechanism.OAUTHBEARER: # We don't perform a username and password check if OAUTHBEARER return values if ( - values["sasl_mechanism"] is SaslMechanism.PLAIN - and values["sasl_plain_username"] is None + values.sasl_mechanism is SaslMechanism.PLAIN + and values.sasl_plain_username is None ): raise ValueError( "`sasl_plain_username` is required when using SASL_PLAIN" ) if ( - values["sasl_mechanism"] is SaslMechanism.PLAIN - and values["sasl_plain_password"] is None + values.sasl_mechanism is SaslMechanism.PLAIN + and values.sasl_plain_password is None ): raise ValueError( "`sasl_plain_password` is required when using SASL_PLAIN" ) return values elif security_protocol == SecurityProtocol.SASL_SSL: - if values["ssl_context"] is None: + if values.ssl_context is None: raise ValueError("`ssl_context` is required") if ( - values["sasl_mechanism"] is SaslMechanism.PLAIN - and values["sasl_plain_username"] is None + values.sasl_mechanism is SaslMechanism.PLAIN + and values.sasl_plain_username is None ): raise ValueError( "`sasl_plain_username` is required when using SASL_PLAIN" ) if ( - values["sasl_mechanism"] is SaslMechanism.PLAIN - and values["sasl_plain_password"] is None + values.sasl_mechanism is SaslMechanism.PLAIN + and values.sasl_plain_password is None ): raise ValueError( "`sasl_plain_password` is required when using SASL_PLAIN" diff --git a/kstreams/engine.py b/kstreams/engine.py index c3b15d5..bc6b1c2 100644 --- a/kstreams/engine.py +++ b/kstreams/engine.py @@ -151,7 +151,7 @@ async def stop_producer(self): async def start_producer(self, **kwargs) -> None: if self.producer_class is None: return None - config = {**self.backend.dict(), **kwargs} + config = {**self.backend.model_dump(), **kwargs} self._producer = self.producer_class(**config) if self._producer is None: return None diff --git a/kstreams/streams.py b/kstreams/streams.py index b75c95a..4342f80 100644 --- a/kstreams/streams.py +++ b/kstreams/streams.py @@ -121,7 +121,7 @@ def __init__( def _create_consumer(self) -> ConsumerType: if self.backend is None: raise BackendNotSet("A backend has not been set for this stream") - config = {**self.backend.dict(), **self.config} + config = {**self.backend.model_dump(), **self.config} return self.consumer_class(**config) async def stop(self) -> None: @@ -218,7 +218,7 @@ async def start(self) -> Optional[AsyncGenerator]: else: # It is not an async_generator so we need to # create an asyncio.Task with func - logging.warn( + logging.warning( "Streams with `async for in` loop approach might be deprecated. " "Consider migrating to a typing approach." ) diff --git a/pyproject.toml b/pyproject.toml index 195d807..8c7e91a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ update_changelog_on_bump = true major_version_zero = true [tool.pytest.ini_options] -timeout = 300 +asyncio_mode = "auto" log_level = "DEBUG" [[tool.mypy.overrides]] diff --git a/tests/test_backend_kafka.py b/tests/test_backend_kafka.py index 8e51078..3b2f13c 100644 --- a/tests/test_backend_kafka.py +++ b/tests/test_backend_kafka.py @@ -107,7 +107,7 @@ def test_backend_to_dict(): sasl_plain_password="pwd", ) assert kafka_backend.security_protocol == SecurityProtocol.SASL_PLAINTEXT - assert kafka_backend.dict() == { + assert kafka_backend.model_dump() == { "bootstrap_servers": ["localhost:9092"], "security_protocol": "SASL_PLAINTEXT", "ssl_context": None, diff --git a/tests/test_consumer.py b/tests/test_consumer.py index f50a22a..539bb1a 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -29,7 +29,7 @@ async def test_consumer(): @pytest.mark.asyncio async def test_consumer_with_ssl(ssl_context): backend = Kafka(security_protocol="SSL", ssl_context=ssl_context) - consumer = Consumer(**backend.dict()) + consumer = Consumer(**backend.model_dump()) assert consumer._client._ssl_context diff --git a/tests/test_producer.py b/tests/test_producer.py index ba17377..b969d56 100644 --- a/tests/test_producer.py +++ b/tests/test_producer.py @@ -10,7 +10,7 @@ async def test_producer(): with patch("kstreams.clients.aiokafka.AIOKafkaProducer.start") as mock_start_super: backend = Kafka() - prod = Producer(**backend.dict()) + prod = Producer(**backend.model_dump()) await prod.start() mock_start_super.assert_called() @@ -19,7 +19,7 @@ async def test_producer(): @pytest.mark.asyncio async def test_producer_with_ssl(ssl_context): backend = Kafka(ssl_context=ssl_context) - producer = Producer(**backend.dict()) + producer = Producer(**backend.model_dump()) assert producer.client._ssl_context await producer.client.close() @@ -46,7 +46,7 @@ async def test_two_producers(): "group_id": "my-group-consumer", } backend_1 = Kafka(bootstrap_servers=kafka_config_1["bootstrap_servers"]) - producer_1 = Producer(**backend_1.dict(), client_id="my-client") + producer_1 = Producer(**backend_1.model_dump(), client_id="my-client") kafka_config_2 = { "bootstrap_servers": ["otherhost:9092"], @@ -54,7 +54,7 @@ async def test_two_producers(): } backend_2 = Kafka(bootstrap_servers=kafka_config_2["bootstrap_servers"]) - producer_2 = Producer(**backend_2.dict(), client_id="client_id2") + producer_2 = Producer(**backend_2.model_dump(), client_id="client_id2") assert producer_1.client._bootstrap_servers == kafka_config_1["bootstrap_servers"] assert producer_1.client._client_id == "my-client"