Skip to content

Commit

Permalink
fix(TestClient): call task_done after the topic has consumed the cr (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh authored Nov 27, 2023
1 parent 43f27e4 commit 6986e2a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
9 changes: 1 addition & 8 deletions kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kstreams.types import Headers

from .structs import RecordMetadata
from .topics import Topic, TopicManager
from .topics import TopicManager


class Base:
Expand Down Expand Up @@ -69,7 +69,6 @@ def __init__(self, group_id: Optional[str] = None, **kwargs) -> None:
self.topics: Optional[Tuple[str]] = None
self._group_id: Optional[str] = group_id
self._assignment: List[TopicPartition] = []
self._previous_topic: Optional[Topic] = None
self.partitions_committed: Dict[TopicPartition, int] = {}

# Called to make sure that has all the kafka attributes like _coordinator
Expand Down Expand Up @@ -177,11 +176,6 @@ def partitions_for_topic(self, topic: str) -> Set:
async def getone(
self,
) -> Optional[ConsumerRecord]: # The return type must be fixed later on
if self._previous_topic:
# Assumes previous record retrieved through getone was completed
self._previous_topic.task_done()
self._previous_topic = None

topic = None
for topic_partition in self._assignment:
topic = TopicManager.get(topic_partition.topic)
Expand All @@ -192,7 +186,6 @@ async def getone(
if topic is not None:
consumer_record = await topic.get()
self._check_partition_assignments(consumer_record)
self._previous_topic = topic
return consumer_record

return None
Expand Down
4 changes: 3 additions & 1 deletion kstreams/test_utils/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ async def put(self, event: ConsumerRecord) -> None:
self._inc_amount(event)

async def get(self) -> ConsumerRecord:
return await self.queue.get()
cr = await self.queue.get()
self.task_done()
return cr

def get_nowait(self) -> ConsumerRecord:
return self.queue.get_nowait()
Expand Down

0 comments on commit 6986e2a

Please sign in to comment.