Skip to content

Commit

Permalink
feat: NATS test client supports filter subsciption
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Jun 11, 2024
1 parent 2bb5620 commit 56dffa9
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 13 deletions.
13 changes: 12 additions & 1 deletion faststream/nats/subscriber/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DEFAULT_SUB_PENDING_BYTES_LIMIT,
DEFAULT_SUB_PENDING_MSGS_LIMIT,
)
from nats.js.api import ConsumerConfig
from nats.js.client import (
DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
DEFAULT_JS_SUB_PENDING_MSGS_LIMIT,
Expand Down Expand Up @@ -83,6 +84,8 @@ def create_subscriber(
if not subject and not config:
raise SetupError("You must provide either `subject` or `config` option.")

config = config or ConsumerConfig(filter_subjects=[])

if stream:
# TODO: pull & queue warning
# TODO: push & durable warning
Expand All @@ -94,7 +97,6 @@ def create_subscriber(
or DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
"durable": durable,
"stream": stream.name,
"config": config,
}

if pull_sub is not None:
Expand Down Expand Up @@ -123,6 +125,7 @@ def create_subscriber(
if obj_watch is not None:
return AsyncAPIObjStoreWatchSubscriber(
subject=subject,
config=config,
obj_watch=obj_watch,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
Expand All @@ -134,6 +137,7 @@ def create_subscriber(
if kv_watch is not None:
return AsyncAPIKeyValueWatchSubscriber(
subject=subject,
config=config,
kv_watch=kv_watch,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
Expand All @@ -147,6 +151,7 @@ def create_subscriber(
return AsyncAPIConcurrentCoreSubscriber(
max_workers=max_workers,
subject=subject,
config=config,
queue=queue,
# basic args
extra_options=extra_options,
Expand All @@ -165,6 +170,7 @@ def create_subscriber(
else:
return AsyncAPICoreSubscriber(
subject=subject,
config=config,
queue=queue,
# basic args
extra_options=extra_options,
Expand All @@ -188,6 +194,7 @@ def create_subscriber(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand All @@ -207,6 +214,7 @@ def create_subscriber(
max_workers=max_workers,
stream=stream,
subject=subject,
config=config,
queue=queue,
# basic args
extra_options=extra_options,
Expand All @@ -229,6 +237,7 @@ def create_subscriber(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand All @@ -248,6 +257,7 @@ def create_subscriber(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand All @@ -267,6 +277,7 @@ def create_subscriber(
stream=stream,
subject=subject,
queue=queue,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand Down
40 changes: 36 additions & 4 deletions faststream/nats/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import anyio
from fast_depends.dependencies import Depends
from nats.errors import ConnectionClosedError, TimeoutError
from nats.js.api import ObjectInfo
from nats.js.api import ConsumerConfig, ObjectInfo
from nats.js.kv import KeyValue
from typing_extensions import Annotated, Doc, override

Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
extra_options: Optional[AnyDict],
# Subscriber args
default_parser: "AsyncCallable",
Expand All @@ -88,6 +89,7 @@ def __init__(
include_in_schema: bool,
) -> None:
self.subject = subject
self.config = config

self.extra_options = extra_options or {}

Expand Down Expand Up @@ -205,10 +207,18 @@ def build_log_context(

def add_prefix(self, prefix: str) -> None:
"""Include Subscriber in router."""
self.subject = "".join((prefix, self.subject))
if self.subject:
self.subject = "".join((prefix, self.subject))
else:
self.config.filter_subjects = [
"".join((prefix, subject))
for subject in (self.config.filter_subjects or ())
]

def __hash__(self) -> int:
return self.get_routing_hash(self.subject)
return self.get_routing_hash(
self.subject or "".join(self.config.filter_subjects or ())
)

@staticmethod
def get_routing_hash(
Expand All @@ -229,6 +239,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
# default args
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -246,6 +257,7 @@ def __init__(
) -> None:
super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=default_parser,
Expand Down Expand Up @@ -368,6 +380,7 @@ def __init__(
*,
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -387,6 +400,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=parser_.parse_message,
Expand Down Expand Up @@ -439,6 +453,7 @@ def __init__(
max_workers: int,
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -456,6 +471,7 @@ def __init__(
max_workers=max_workers,
# basic args
subject=subject,
config=config,
queue=queue,
extra_options=extra_options,
# Propagated args
Expand Down Expand Up @@ -494,6 +510,7 @@ def __init__(
stream: "JStream",
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -514,6 +531,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=parser_.parse_message,
Expand All @@ -540,7 +558,7 @@ def get_log_context(
"""Log context factory using in `self.consume` scope."""
return self.build_log_context(
message=message,
subject=self.subject,
subject=self.subject or ", ".join(self.config.filter_subjects or ()),
queue=self.queue,
stream=self.stream.name,
)
Expand All @@ -560,6 +578,7 @@ async def _create_subscription( # type: ignore[override]
subject=self.clear_subject,
queue=self.queue,
cb=self.consume,
config=self.config,
**self.extra_options,
)

Expand All @@ -574,6 +593,7 @@ def __init__(
stream: "JStream",
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -592,6 +612,7 @@ def __init__(
# basic args
stream=stream,
subject=subject,
config=config,
queue=queue,
extra_options=extra_options,
# Propagated args
Expand Down Expand Up @@ -619,6 +640,7 @@ async def _create_subscription( # type: ignore[override]
subject=self.clear_subject,
queue=self.queue,
cb=self._put_msg,
config=self.config,
**self.extra_options,
)

Expand All @@ -633,6 +655,7 @@ def __init__(
stream: "JStream",
# default args
subject: str,
config: "ConsumerConfig",
extra_options: Optional[AnyDict],
# Subscriber args
no_ack: bool,
Expand All @@ -651,6 +674,7 @@ def __init__(
# basic args
stream=stream,
subject=subject,
config=config,
extra_options=extra_options,
queue="",
# Propagated args
Expand Down Expand Up @@ -708,6 +732,7 @@ def __init__(
pull_sub: "PullSub",
stream: "JStream",
subject: str,
config: "ConsumerConfig",
extra_options: Optional[AnyDict],
# Subscriber args
no_ack: bool,
Expand All @@ -726,6 +751,7 @@ def __init__(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
extra_options=extra_options,
# Propagated args
no_ack=no_ack,
Expand Down Expand Up @@ -765,6 +791,7 @@ def __init__(
*,
# default args
subject: str,
config: "ConsumerConfig",
stream: "JStream",
pull_sub: "PullSub",
extra_options: Optional[AnyDict],
Expand All @@ -786,6 +813,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=parser.parse_batch,
Expand Down Expand Up @@ -837,6 +865,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
kv_watch: "KvWatch",
broker_dependencies: Iterable[Depends],
broker_middlewares: Iterable["BrokerMiddleware[KeyValue.Entry]"],
Expand All @@ -850,6 +879,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=None,
no_ack=True,
no_reply=True,
Expand Down Expand Up @@ -941,6 +971,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
obj_watch: "ObjWatch",
broker_dependencies: Iterable[Depends],
broker_middlewares: Iterable["BrokerMiddleware[List[Msg]]"],
Expand All @@ -955,6 +986,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=None,
no_ack=True,
no_reply=True,
Expand Down
5 changes: 4 additions & 1 deletion faststream/nats/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ async def publish( # type: ignore[override]
):
continue

if is_subject_match_wildcard(subject, handler.clear_subject):
if is_subject_match_wildcard(subject, handler.clear_subject) or any(
is_subject_match_wildcard(subject, filter_subject)
for filter_subject in (handler.config.filter_subjects or ())
):
msg: Union[List[PatchedMessage], PatchedMessage]
if (pull := getattr(handler, "pull_sub", None)) and pull.batch:
msg = [incoming]
Expand Down
34 changes: 32 additions & 2 deletions tests/brokers/nats/test_consume.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from nats.aio.msg import Msg

from faststream.exceptions import AckMessage
from faststream.nats import JStream, NatsBroker, PullSub
from faststream.nats import ConsumerConfig, JStream, NatsBroker, PullSub
from faststream.nats.annotations import NatsMessage
from tests.brokers.base.consume import BrokerRealConsumeTestcase
from tests.tools import spy_decorator
Expand Down Expand Up @@ -40,6 +40,36 @@ def subscriber(m):

assert event.is_set()

async def test_consume_with_filter(
self,
queue,
mock: Mock,
event: asyncio.Event,
):
consume_broker = self.get_broker()

@consume_broker.subscriber(
config=ConsumerConfig(filter_subjects=[f"{queue}.a"]),
stream=JStream(queue, subjects=[f"{queue}.*"]),
)
def subscriber(m):
mock(m)
event.set()

async with self.patch_broker(consume_broker) as br:
await br.start()
await asyncio.wait(
(
asyncio.create_task(br.publish(1, f"{queue}.b")),
asyncio.create_task(br.publish(2, f"{queue}.a")),
asyncio.create_task(event.wait()),
),
timeout=3,
)

assert event.is_set()
mock.assert_called_once_with(2)

async def test_consume_pull(
self,
queue: str,
Expand Down
Loading

0 comments on commit 56dffa9

Please sign in to comment.