Skip to content

Commit

Permalink
feat: default call_name for broker.subscriber (#1589)
Browse files Browse the repository at this point in the history
* default call_name for broker.subscriber

* default call_name test fix

* asyncapi schema generation with default call_name

* remove unnecessary comment

* test_subscriber_naming_default error description fix

* default multi subscribers naming test

* multiple subscriber default payload title fix

* refactor: remove side effects

---------

Co-authored-by: Nikita Pastukhov <[email protected]>
  • Loading branch information
KrySeyt and Lancetnik authored Jul 16, 2024
1 parent ee41299 commit 5e37604
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 4 deletions.
8 changes: 6 additions & 2 deletions faststream/asyncapi/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def get_app_schema(app: Union["FastStream", "StreamRouter[Any]"]) -> Schema:
payloads,
messages,
)

schema = Schema(
info=Info(
title=app.title,
Expand Down Expand Up @@ -146,9 +145,13 @@ def _resolve_msg_payloads(
payloads: Dict[str, Any],
messages: Dict[str, Any],
) -> Reference:
one_of_list: List[Reference] = []
"""Replace message payload by reference and normalize payloads.
Payloads and messages are editable dicts to store schemas for reference in AsyncAPI.
"""
one_of_list: List[Reference] = []
m.payload = _move_pydantic_refs(m.payload, DEF_KEY)

if DEF_KEY in m.payload:
payloads.update(m.payload.pop(DEF_KEY))

Expand Down Expand Up @@ -186,6 +189,7 @@ def _move_pydantic_refs(
original: Any,
key: str,
) -> Any:
"""Remove pydantic references and replacem them by real schemas."""
if not isinstance(original, Dict):
return original

Expand Down
14 changes: 13 additions & 1 deletion faststream/broker/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ def get_log_context(
@property
def call_name(self) -> str:
"""Returns the name of the handler call."""
# TODO: default call_name
if not self.calls:
return "Subscriber"

return to_camelcase(self.calls[0].call_name)

def get_description(self) -> Optional[str]:
Expand All @@ -433,4 +435,14 @@ def get_payloads(self) -> List[Tuple["AnyDict", str]]:

payloads.append((body, to_camelcase(h.call_name)))

if not self.calls:
payloads.append(
(
{
"title": f"{self.title_ or self.call_name}:Message:Payload",
},
to_camelcase(self.call_name),
)
)

return payloads
1 change: 0 additions & 1 deletion faststream/confluent/subscriber/asyncapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def get_schema(self) -> Dict[str, Channel]:
channels = {}

payloads = self.get_payloads()

for t in self.topics:
handler_name = self.title_ or f"{t}:{self.call_name}"

Expand Down
70 changes: 70 additions & 0 deletions tests/asyncapi/base/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,76 @@ async def handle_user_created(msg: str): ...
"custom:Message:Payload"
]

def test_subscriber_naming_default(self):
broker = self.broker_class()

broker.subscriber("test")

schema = get_app_schema(FastStream(broker)).to_jsonable()

assert list(schema["channels"].keys()) == [
IsStr(regex=r"test[\w:]*:Subscriber")
]

assert list(schema["components"]["messages"].keys()) == [
IsStr(regex=r"test[\w:]*:Subscriber:Message")
]

for key, v in schema["components"]["schemas"].items():
assert key == "Subscriber:Message:Payload"
assert v == {"title": key}

def test_subscriber_naming_default_with_title(self):
broker = self.broker_class()

broker.subscriber("test", title="custom")

schema = get_app_schema(FastStream(broker)).to_jsonable()

assert list(schema["channels"].keys()) == ["custom"]

assert list(schema["components"]["messages"].keys()) == ["custom:Message"]

assert list(schema["components"]["schemas"].keys()) == [
"custom:Message:Payload"
]

assert schema["components"]["schemas"]["custom:Message:Payload"] == {
"title": "custom:Message:Payload"
}

def test_multi_subscribers_naming_default(self):
broker = self.broker_class()

@broker.subscriber("test")
async def handle_user_created(msg: str): ...

broker.subscriber("test2")
broker.subscriber("test3")

schema = get_app_schema(FastStream(broker)).to_jsonable()

assert list(schema["channels"].keys()) == [
IsStr(regex=r"test[\w:]*:HandleUserCreated"),
IsStr(regex=r"test2[\w:]*:Subscriber"),
IsStr(regex=r"test3[\w:]*:Subscriber"),
]

assert list(schema["components"]["messages"].keys()) == [
IsStr(regex=r"test[\w:]*:HandleUserCreated:Message"),
IsStr(regex=r"test2[\w:]*:Subscriber:Message"),
IsStr(regex=r"test3[\w:]*:Subscriber:Message"),
]

assert list(schema["components"]["schemas"].keys()) == [
"HandleUserCreated:Message:Payload",
"Subscriber:Message:Payload",
]

assert schema["components"]["schemas"]["Subscriber:Message:Payload"] == {
"title": "Subscriber:Message:Payload"
}


class FilterNaming(BaseNaming):
def test_subscriber_filter_base(self):
Expand Down

0 comments on commit 5e37604

Please sign in to comment.