Skip to content

Commit

Permalink
refactor: Updated code to Pydantic 2.0 (get rid of deprecation warnin…
Browse files Browse the repository at this point in the history
…gs) (#156)

Co-authored-by: wieri494 <[email protected]>
  • Loading branch information
wpeterw and wieri494 authored Jan 9, 2024
1 parent cb7e6f4 commit 240c5f3
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 28 deletions.
34 changes: 16 additions & 18 deletions kstreams/backends/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, root_validator
from pydantic import BaseModel, ConfigDict, model_validator


class SecurityProtocol(str, Enum):
Expand Down Expand Up @@ -60,53 +60,51 @@ class Kafka(BaseModel):
sasl_plain_username: Optional[str] = None
sasl_plain_password: Optional[str] = None
sasl_oauth_token_provider: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True, use_enum_values=True)

class Config:
arbitrary_types_allowed = True
use_enum_values = True

@root_validator(skip_on_failure=True)
@model_validator(mode="after")
@classmethod
def protocols_validation(cls, values):
security_protocol = values["security_protocol"]
security_protocol = values.security_protocol

if security_protocol == SecurityProtocol.PLAINTEXT:
return values
elif security_protocol == SecurityProtocol.SSL:
if values["ssl_context"] is None:
if values.ssl_context is None:
raise ValueError("`ssl_context` is required")
return values
elif security_protocol == SecurityProtocol.SASL_PLAINTEXT:
if values["sasl_mechanism"] is SaslMechanism.OAUTHBEARER:
if values.sasl_mechanism is SaslMechanism.OAUTHBEARER:
# We don't perform a username and password check if OAUTHBEARER
return values
if (
values["sasl_mechanism"] is SaslMechanism.PLAIN
and values["sasl_plain_username"] is None
values.sasl_mechanism is SaslMechanism.PLAIN
and values.sasl_plain_username is None
):
raise ValueError(
"`sasl_plain_username` is required when using SASL_PLAIN"
)
if (
values["sasl_mechanism"] is SaslMechanism.PLAIN
and values["sasl_plain_password"] is None
values.sasl_mechanism is SaslMechanism.PLAIN
and values.sasl_plain_password is None
):
raise ValueError(
"`sasl_plain_password` is required when using SASL_PLAIN"
)
return values
elif security_protocol == SecurityProtocol.SASL_SSL:
if values["ssl_context"] is None:
if values.ssl_context is None:
raise ValueError("`ssl_context` is required")
if (
values["sasl_mechanism"] is SaslMechanism.PLAIN
and values["sasl_plain_username"] is None
values.sasl_mechanism is SaslMechanism.PLAIN
and values.sasl_plain_username is None
):
raise ValueError(
"`sasl_plain_username` is required when using SASL_PLAIN"
)
if (
values["sasl_mechanism"] is SaslMechanism.PLAIN
and values["sasl_plain_password"] is None
values.sasl_mechanism is SaslMechanism.PLAIN
and values.sasl_plain_password is None
):
raise ValueError(
"`sasl_plain_password` is required when using SASL_PLAIN"
Expand Down
2 changes: 1 addition & 1 deletion kstreams/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def stop_producer(self):
async def start_producer(self, **kwargs) -> None:
if self.producer_class is None:
return None
config = {**self.backend.dict(), **kwargs}
config = {**self.backend.model_dump(), **kwargs}
self._producer = self.producer_class(**config)
if self._producer is None:
return None
Expand Down
4 changes: 2 additions & 2 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
def _create_consumer(self) -> ConsumerType:
if self.backend is None:
raise BackendNotSet("A backend has not been set for this stream")
config = {**self.backend.dict(), **self.config}
config = {**self.backend.model_dump(), **self.config}
return self.consumer_class(**config)

async def stop(self) -> None:
Expand Down Expand Up @@ -218,7 +218,7 @@ async def start(self) -> Optional[AsyncGenerator]:
else:
# It is not an async_generator so we need to
# create an asyncio.Task with func
logging.warn(
logging.warning(
"Streams with `async for in` loop approach might be deprecated. "
"Consider migrating to a typing approach."
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ update_changelog_on_bump = true
major_version_zero = true

[tool.pytest.ini_options]
timeout = 300
asyncio_mode = "auto"
log_level = "DEBUG"

[[tool.mypy.overrides]]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_backend_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_backend_to_dict():
sasl_plain_password="pwd",
)
assert kafka_backend.security_protocol == SecurityProtocol.SASL_PLAINTEXT
assert kafka_backend.dict() == {
assert kafka_backend.model_dump() == {
"bootstrap_servers": ["localhost:9092"],
"security_protocol": "SASL_PLAINTEXT",
"ssl_context": None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_consumer():
@pytest.mark.asyncio
async def test_consumer_with_ssl(ssl_context):
backend = Kafka(security_protocol="SSL", ssl_context=ssl_context)
consumer = Consumer(**backend.dict())
consumer = Consumer(**backend.model_dump())
assert consumer._client._ssl_context


Expand Down
8 changes: 4 additions & 4 deletions tests/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
async def test_producer():
with patch("kstreams.clients.aiokafka.AIOKafkaProducer.start") as mock_start_super:
backend = Kafka()
prod = Producer(**backend.dict())
prod = Producer(**backend.model_dump())

await prod.start()
mock_start_super.assert_called()
Expand All @@ -19,7 +19,7 @@ async def test_producer():
@pytest.mark.asyncio
async def test_producer_with_ssl(ssl_context):
backend = Kafka(ssl_context=ssl_context)
producer = Producer(**backend.dict())
producer = Producer(**backend.model_dump())
assert producer.client._ssl_context

await producer.client.close()
Expand All @@ -46,15 +46,15 @@ async def test_two_producers():
"group_id": "my-group-consumer",
}
backend_1 = Kafka(bootstrap_servers=kafka_config_1["bootstrap_servers"])
producer_1 = Producer(**backend_1.dict(), client_id="my-client")
producer_1 = Producer(**backend_1.model_dump(), client_id="my-client")

kafka_config_2 = {
"bootstrap_servers": ["otherhost:9092"],
"group_id": "my-group-consumer",
}

backend_2 = Kafka(bootstrap_servers=kafka_config_2["bootstrap_servers"])
producer_2 = Producer(**backend_2.dict(), client_id="client_id2")
producer_2 = Producer(**backend_2.model_dump(), client_id="client_id2")

assert producer_1.client._bootstrap_servers == kafka_config_1["bootstrap_servers"]
assert producer_1.client._client_id == "my-client"
Expand Down

0 comments on commit 240c5f3

Please sign in to comment.