From 65547bcd7c01b7e671719a2879ceaebfc6e52d80 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 4 Aug 2021 15:45:27 +0300 Subject: [PATCH] Multiple fixes and improvements (#244) --- sanic_openapi/__init__.py | 5 +- sanic_openapi/openapi2/blueprint.py | 6 +- sanic_openapi/openapi3/__init__.py | 12 +- sanic_openapi/openapi3/blueprint.py | 10 +- sanic_openapi/openapi3/builders.py | 102 +++++++++++++- sanic_openapi/openapi3/definitions.py | 83 ++++++++--- sanic_openapi/openapi3/openapi.py | 165 +++++++++++++++++++++- sanic_openapi/openapi3/types.py | 55 +++++--- setup.cfg | 2 +- tests/conftest.py | 9 +- tests/samples/petstore.yaml | 113 +++++++++++++++ tests/test_oas3_blueprint.py | 67 +++++++++ tests/test_oas3_definition.py | 190 ++++++++++++++++++++++++++ tests/test_oas3_specification.py | 48 +++++++ 14 files changed, 813 insertions(+), 54 deletions(-) create mode 100644 tests/samples/petstore.yaml create mode 100644 tests/test_oas3_blueprint.py create mode 100644 tests/test_oas3_definition.py create mode 100644 tests/test_oas3_specification.py diff --git a/sanic_openapi/__init__.py b/sanic_openapi/__init__.py index 81066331..2dab60d5 100644 --- a/sanic_openapi/__init__.py +++ b/sanic_openapi/__init__.py @@ -1,13 +1,14 @@ from .openapi2 import doc, openapi2_blueprint -from .openapi3 import openapi, openapi3_blueprint +from .openapi3 import openapi, openapi3_blueprint, specification swagger_blueprint = openapi2_blueprint -__version__ = "21.3.3" +__version__ = "21.6.0" __all__ = [ "openapi2_blueprint", "swagger_blueprint", "openapi3_blueprint", "openapi", + "specification", "doc", ] diff --git a/sanic_openapi/openapi2/blueprint.py b/sanic_openapi/openapi2/blueprint.py index a914068a..377a6daa 100644 --- a/sanic_openapi/openapi2/blueprint.py +++ b/sanic_openapi/openapi2/blueprint.py @@ -54,6 +54,8 @@ def build_spec(app, loop): for blueprint_name, handler in get_blueprinted_routes(app): route_spec = route_specs[handler] route_spec.blueprint = blueprint_name + if route_spec.exclude: + continue if not route_spec.tags: route_spec.tags.append(blueprint_name) @@ -195,7 +197,9 @@ def build_spec(app, loop): methods[_method.lower()] = endpoint if methods: - paths[uri] = methods + if uri not in paths: + paths[uri] = {} + paths[uri].update(methods) # --------------------------------------------------------------- # # Definitions diff --git a/sanic_openapi/openapi3/__init__.py b/sanic_openapi/openapi3/__init__.py index 8eeea925..90fca6d3 100644 --- a/sanic_openapi/openapi3/__init__.py +++ b/sanic_openapi/openapi3/__init__.py @@ -3,15 +3,23 @@ """ from collections import defaultdict - +from typing import Dict, TypeVar from .builders import OperationBuilder, SpecificationBuilder +try: + from sanic.models.handler_types import RouteHandler +except ImportError: + RouteHandler = TypeVar("RouteHandler") # type: ignore + # Static datastores, which get added to via the oas3.openapi decorators, # and then read from in the blueprint generation -operations = defaultdict(OperationBuilder) +operations: Dict[RouteHandler, OperationBuilder] = defaultdict( + OperationBuilder +) specification = SpecificationBuilder() + from .blueprint import blueprint_factory # noqa diff --git a/sanic_openapi/openapi3/blueprint.py b/sanic_openapi/openapi3/blueprint.py index b1198419..6757366a 100644 --- a/sanic_openapi/openapi3/blueprint.py +++ b/sanic_openapi/openapi3/blueprint.py @@ -76,6 +76,10 @@ def build_spec(app, loop): if hasattr(_handler, "view_class"): _handler = getattr(_handler.view_class, method.lower()) operation = operations[_handler] + + if operation._exclude: + continue + docstring = inspect.getdoc(_handler) if docstring: @@ -107,19 +111,19 @@ def add_static_info_to_spec_from_config(app, specification): Modifies specification in-place and returns None """ - specification.describe( + specification._do_describe( getattr(app.config, "API_TITLE", "API"), getattr(app.config, "API_VERSION", "1.0.0"), getattr(app.config, "API_DESCRIPTION", None), getattr(app.config, "API_TERMS_OF_SERVICE", None), ) - specification.license( + specification._do_license( getattr(app.config, "API_LICENSE_NAME", None), getattr(app.config, "API_LICENSE_URL", None), ) - specification.contact( + specification._do_contact( getattr(app.config, "API_CONTACT_NAME", None), getattr(app.config, "API_CONTACT_URL", None), getattr(app.config, "API_CONTACT_EMAIL", None), diff --git a/sanic_openapi/openapi3/builders.py b/sanic_openapi/openapi3/builders.py index b6701db7..46d3b950 100644 --- a/sanic_openapi/openapi3/builders.py +++ b/sanic_openapi/openapi3/builders.py @@ -11,6 +11,7 @@ from ..utils import remove_nulls, remove_nulls_from_kwargs from .definitions import ( Any, + Components, Contact, Dict, ExternalDocumentation, @@ -47,6 +48,7 @@ def __init__(self): self.parameters = [] self.responses = {} self._autodoc = None + self._exclude = False def name(self, value: str): self.operationId = value @@ -108,6 +110,9 @@ def autodoc(self, docstring: str): y = YamlStyleParametersParser(docstring) self._autodoc = y.to_openAPI_3() + def exclude(self, flag: bool = True): + self._exclude = flag + class SpecificationBuilder: _urls: List[str] @@ -119,14 +124,24 @@ class SpecificationBuilder: _license: License _paths: Dict[str, Dict[str, OperationBuilder]] _tags: Dict[str, Tag] + _components: Dict[str, Any] + _servers: List[Server] # _components: ComponentsBuilder # deliberately not included def __init__(self): + self._components = defaultdict(dict) + self._contact = None + self._description = None + self._external = None + self._license = None self._paths = defaultdict(dict) + self._servers = [] self._tags = {} - self._license = None + self._terms = None + self._title = None self._urls = [] + self._version = None def url(self, value: str): self._urls.append(value) @@ -143,17 +158,45 @@ def describe( self._description = description self._terms = terms - def tag(self, name: str, **kwargs): - self._tags[name] = Tag(name, **kwargs) + def _do_describe( + self, + title: str, + version: str, + description: Optional[str] = None, + terms: Optional[str] = None, + ): + if any([self._title, self._version, self._description, self._terms]): + return + self.describe(title, version, description, terms) + + def tag(self, name: str, description: Optional[str] = None, **kwargs): + self._tags[name] = Tag(name, description=description, **kwargs) + + def external(self, url: str, description: Optional[str] = None, **kwargs): + self._external = ExternalDocumentation(url, description=description) def contact(self, name: str = None, url: str = None, email: str = None): kwargs = remove_nulls_from_kwargs(name=name, url=url, email=email) self._contact = Contact(**kwargs) + def _do_contact( + self, name: str = None, url: str = None, email: str = None + ): + if self._contact: + return + + self.contact(name, url, email) + def license(self, name: str = None, url: str = None): if name is not None: self._license = License(name, url=url) + def _do_license(self, name: str = None, url: str = None): + if self._license: + return + + self.license(name, url) + def operation(self, path: str, method: str, operation: OperationBuilder): for _tag in operation.tags: if _tag in self._tags.keys(): @@ -163,18 +206,62 @@ def operation(self, path: str, method: str, operation: OperationBuilder): self._paths[path][method.lower()] = operation + def add_component(self, location: str, name: str, obj: Any): + self._components[location].update({name: obj}) + + def raw(self, data): + if "info" in data: + self.describe( + data["info"].get("title"), + data["info"].get("version"), + data["info"].get("description"), + data["info"].get("terms"), + ) + + if "servers" in data: + for server in data["servers"]: + self._servers.append(Server(**server)) + + if "paths" in data: + self._paths.update(data["paths"]) + + if "components" in data: + for location, component in data["components"].items(): + self._components[location].update(component) + + if "security" in data: + ... + + if "tags" in data: + for tag in data["tags"]: + self.tag(**tag) + + if "externalDocs" in data: + self.external(**data["externalDocs"]) + def build(self) -> OpenAPI: info = self._build_info() paths = self._build_paths() tags = self._build_tags() url_servers = getattr(self, "_urls", None) - servers = [] + servers = self._servers if url_servers is not None: for url_server in url_servers: servers.append(Server(url=url_server)) - return OpenAPI(info, paths, tags=tags, servers=servers) + components = ( + Components(**self._components) if self._components else None + ) + + return OpenAPI( + info, + paths, + tags=tags, + servers=servers, + components=components, + externalDocs=self._external, + ) def _build_info(self) -> Info: kwargs = remove_nulls( @@ -197,7 +284,10 @@ def _build_paths(self) -> Dict: for path, operations in self._paths.items(): paths[path] = PathItem( - **{k: v.build() for k, v in operations.items()} + **{ + k: v if isinstance(v, dict) else v.build() + for k, v in operations.items() + } ) return paths diff --git a/sanic_openapi/openapi3/definitions.py b/sanic_openapi/openapi3/definitions.py index 7e544e04..7311657b 100644 --- a/sanic_openapi/openapi3/definitions.py +++ b/sanic_openapi/openapi3/definitions.py @@ -4,7 +4,7 @@ I.e., the objects described https://swagger.io/docs/specification """ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Type, Union from .types import Definition, Schema @@ -82,11 +82,20 @@ def all(content: Any): class Response(Definition): - content: Dict[str, MediaType] - description: str + content: Union[Any, Dict[str, Union[Any, MediaType]]] + description: Optional[str] + status: str - def __init__(self, content=None, **kwargs): - super().__init__(content=content, **kwargs) + def __init__( + self, + content: Optional[Union[Any, Dict[str, Union[Any, MediaType]]]] = None, + status: int = 200, + description: Optional[str] = None, + **kwargs, + ): + super().__init__( + content=content, status=status, description=description, **kwargs + ) @staticmethod def make(content, description: str = None, **kwargs): @@ -99,12 +108,29 @@ def make(content, description: str = None, **kwargs): class RequestBody(Definition): - description: str - required: bool - content: Dict[str, MediaType] + description: Optional[str] + required: Optional[bool] + content: Union[Any, Dict[str, Union[Any, MediaType]]] + + def __init__( + self, + content: Union[Any, Dict[str, Union[Any, MediaType]]], + required: Optional[bool] = None, + description: Optional[str] = None, + **kwargs, + ): + """Can be initialized with content in one of a few ways: - def __init__(self, content: Dict[str, MediaType], **kwargs): - super().__init__(content=content, **kwargs) + RequestBody(SomeModel) + RequestBody({"application/json": SomeModel}) + RequestBody({"application/json": {"name": str}}) + """ + super().__init__( + content=content, + required=required, + description=description, + **kwargs, + ) @staticmethod def make(content: Any, **kwargs): @@ -138,15 +164,34 @@ def make(url: str, description: str = None): class Parameter(Definition): name: str + schema: Union[Type, Schema] location: str - description: str - required: bool - deprecated: bool - allowEmptyValue: bool - schema: Schema + description: Optional[str] + required: Optional[bool] + deprecated: Optional[bool] + allowEmptyValue: Optional[bool] - def __init__(self, name, schema: Schema, location="query", **kwargs): - super().__init__(name=name, schema=schema, location=location, **kwargs) + def __init__( + self, + name: str, + schema: Union[Type, Schema], + location: str = "query", + description: Optional[str] = None, + required: Optional[bool] = None, + deprecated: Optional[bool] = None, + allowEmptyValue: Optional[bool] = None, + **kwargs, + ): + super().__init__( + name=name, + schema=schema, + location=location, + description=description, + required=required, + deprecated=deprecated, + allowEmptyValue=allowEmptyValue, + **kwargs, + ) @property def fields(self): @@ -214,13 +259,13 @@ def fields(self): return values @staticmethod - def make(_type: str, cls: type, **kwargs): + def make(_type: str, cls: Type, **kwargs): params = cls.__dict__ if hasattr(cls, "__dict__") else {} return SecurityScheme(_type, **params, **kwargs) -class ServerVariable: +class ServerVariable(Definition): default: str description: str enum: List[str] diff --git a/sanic_openapi/openapi3/openapi.py b/sanic_openapi/openapi3/openapi.py index dc75ec77..fcfc9dc1 100644 --- a/sanic_openapi/openapi3/openapi.py +++ b/sanic_openapi/openapi3/openapi.py @@ -3,7 +3,18 @@ documentation to operations and components created in the blueprints. """ -from typing import Any +from typing import Any, Dict, List, Optional, Sequence, Union + +from sanic.blueprints import Blueprint +from sanic.exceptions import SanicException + +from sanic_openapi.openapi3.definitions import ( + ExternalDocumentation, + Parameter, + RequestBody, + Response, + Tag, +) from . import operations from .types import Array # noqa @@ -23,6 +34,19 @@ from .types import Time # noqa +def exclude(flag: bool = True, *, bp: Optional[Blueprint] = None): + if bp: + for route in bp.routes: + exclude(flag)(route.handler) + return + + def inner(func): + operations[func].exclude(flag) + return func + + return inner + + def operation(name: str): def inner(func): operations[func].name(name) @@ -106,3 +130,142 @@ def inner(func): return func return inner + + +def definition( + *, + exclude: Optional[bool] = None, + operation: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + document: Optional[Union[str, ExternalDocumentation]] = None, + tag: Optional[Union[Union[str, Tag], Sequence[Union[str, Tag]]]] = None, + deprecated: bool = False, + body: Optional[Union[Dict[str, Any], RequestBody, Any]] = None, + parameter: Optional[ + Union[ + Union[Dict[str, Any], Parameter, str], + List[Union[Dict[str, Any], Parameter, str]], + ] + ] = None, + response: Optional[ + Union[ + Union[Dict[str, Any], Response, Any], + List[Union[Dict[str, Any], Response, Any]], + ] + ] = None, +): + def inner(func): + glbl = globals() + + if exclude is not None: + glbl["exclude"](exclude)(func) + + if operation: + glbl["operation"](operation)(func) + + if summary: + glbl["summary"](summary)(func) + + if description: + glbl["description"](description)(func) + + if document: + kwargs = {} + if isinstance(document, str): + kwargs["url"] = document + else: + kwargs["url"] = document.fields["url"] + kwargs["description"] = document.fields["description"] + + glbl["document"](**kwargs)(func) + + if tag: + taglist = [] + op = ( + "extend" + if isinstance(tag, (list, tuple, set, frozenset)) + else "append" + ) + + getattr(taglist, op)(tag) + glbl["tag"]( + *[ + tag.fields["name"] if isinstance(tag, Tag) else tag + for tag in taglist + ] + )(func) + + if deprecated: + glbl["deprecated"]()(func) + + if body: + kwargs = {} + if isinstance(body, RequestBody): + kwargs = body.fields + elif isinstance(body, dict): + if "content" in body: + kwargs = body + else: + kwargs["content"] = body + else: + kwargs["content"] = body + glbl["body"](**kwargs)(func) + + if parameter: + paramlist = [] + op = ( + "extend" + if isinstance(parameter, (list, tuple, set, frozenset)) + else "append" + ) + getattr(paramlist, op)(parameter) + + for param in paramlist: + kwargs = {} + if isinstance(param, Parameter): + kwargs = param.fields + elif isinstance(param, dict) and "name" in param: + kwargs = param + elif isinstance(param, str): + kwargs["name"] = param + else: + raise SanicException( + "parameter must be a Parameter instance, a string, or " + "a dictionary containing at least 'name'." + ) + + if "schema" not in kwargs: + kwargs["schema"] = str + + glbl["parameter"](**kwargs)(func) + + if response: + resplist = [] + op = ( + "extend" + if isinstance(response, (list, tuple, set, frozenset)) + else "append" + ) + getattr(resplist, op)(response) + + for resp in resplist: + kwargs = {} + if isinstance(resp, Response): + kwargs = resp.fields + elif isinstance(resp, dict): + if "content" in resp: + kwargs = resp + else: + kwargs["content"] = resp + else: + kwargs["content"] = resp + + if "status" not in kwargs: + kwargs["status"] = 200 + + glbl["response"](**kwargs)(func) + + return func + + return inner diff --git a/sanic_openapi/openapi3/types.py b/sanic_openapi/openapi3/types.py index 9f332c5d..171bf7e5 100644 --- a/sanic_openapi/openapi3/types.py +++ b/sanic_openapi/openapi3/types.py @@ -1,6 +1,8 @@ import json +import typing as t from datetime import date, datetime, time from enum import Enum +from inspect import isclass from typing import Any, Dict, List, Union, get_type_hints @@ -27,14 +29,24 @@ def serialize(self): def __str__(self): return json.dumps(self.serialize()) + def apply(self, func, operations, *args, **kwargs): + op = operations[func] + method_name = getattr( + self.__class__, "__method__", self.__class__.__name__.lower() + ) + method = getattr(op, method_name) + if not args and not kwargs: + kwargs = self.__dict__ + method(*args, **kwargs) + class Schema(Definition): title: str description: str type: str format: str - nullable: False - required: False + nullable: bool + required: bool default: None example: None oneOf: List[Definition] @@ -43,9 +55,9 @@ class Schema(Definition): multipleOf: int maximum: int - exclusiveMaximum: False + exclusiveMaximum: bool minimum: int - exclusiveMinimum: False + exclusiveMinimum: bool maxLength: int minLength: int pattern: str @@ -104,14 +116,11 @@ def make(value, **kwargs): return Array(schema, **kwargs) elif _type == dict: - return Object( - {k: Schema.make(v) for k, v in value.items()}, **kwargs - ) + return Object.make(value, **kwargs) + elif _type == t._GenericAlias and value.__origin__ == list: + return Array(Schema.make(value.__args__[0]), **kwargs) else: - return Object( - {k: Schema.make(v) for k, v in _properties(value).items()}, - **kwargs, - ) + return Object.make(value, **kwargs) class Boolean(Schema): @@ -187,12 +196,19 @@ class Object(Schema): def __init__(self, properties: Dict[str, Schema] = None, **kwargs): super().__init__(type="object", properties=properties or {}, **kwargs) + @classmethod + def make(cls, value: Any, **kwargs): + return cls( + {k: Schema.make(v) for k, v in _properties(value).items()}, + **kwargs, + ) + class Array(Schema): items: Any maxItems: int minItems: int - uniqueItems: False + uniqueItems: bool def __init__(self, items: Any, **kwargs): super().__init__(type="array", items=Schema.make(items), **kwargs) @@ -216,10 +232,13 @@ def _serialize(value) -> Any: def _properties(value: object) -> Dict: try: - fields = { - x: v for x, v in value.__dict__.items() if not x.startswith("_") - } + fields = {x: v for x, v in value.__dict__.items()} except AttributeError: - return {} - - return {**get_type_hints(value.__class__), **fields} + fields = {} + + cls = value if isclass(value) else value.__class__ + return { + k: v + for k, v in {**get_type_hints(cls), **fields}.items() + if not k.startswith("_") + } diff --git a/setup.cfg b/setup.cfg index 737cff3b..ebb57dc2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [tool:pytest] -addopts=--durations=0 --html=test_reports/report.html --self-contained-html --cov=sanic_openapi --cov-config .coveragerc --cov-report html +# addopts=--durations=0 --cov=sanic_openapi --cov-config .coveragerc --cov-report html [aliases] test=pytest diff --git a/tests/conftest.py b/tests/conftest.py index 2c118d20..22bc175c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import pytest from sanic import Sanic -from sanic_openapi import openapi2_blueprint +from sanic_openapi import openapi2_blueprint, openapi3_blueprint app_ID = itertools.count() @@ -16,3 +16,10 @@ def app(): # Clean up openapi2_blueprint.definitions = {} + + +@pytest.fixture() +def app3(): + app = Sanic("test_{}".format(next(app_ID))) + app.blueprint(openapi3_blueprint) + yield app diff --git a/tests/samples/petstore.yaml b/tests/samples/petstore.yaml new file mode 100644 index 00000000..f1374677 --- /dev/null +++ b/tests/samples/petstore.yaml @@ -0,0 +1,113 @@ +openapi: "3.0.0" +info: + version: 1.0.0 + title: Swagger Petstore + license: + name: MIT +servers: + - url: http://petstore.swagger.io/v1 +tags: + - name: pets +paths: + /pets: + get: + summary: List all pets + operationId: listPets + tags: + - pets + parameters: + - name: limit + in: query + description: How many items to return at one time (max 100) + required: false + schema: + type: integer + format: int32 + responses: + "200": + description: A paged array of pets + headers: + x-next: + description: A link to the next page of responses + schema: + type: string + content: + application/json: + schema: + $ref: "#/components/schemas/Pets" + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + post: + summary: Create a pet + operationId: createPets + tags: + - pets + responses: + "201": + description: Null response + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /pets/{petId}: + get: + summary: Info for a specific pet + operationId: showPetById + tags: + - pets + parameters: + - name: petId + in: path + required: true + description: The id of the pet to retrieve + schema: + type: string + responses: + "200": + description: Expected response to a valid request + content: + application/json: + schema: + $ref: "#/components/schemas/Pet" + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" +components: + schemas: + Pet: + type: object + required: + - id + - name + properties: + id: + type: integer + format: int64 + name: + type: string + tag: + type: string + Pets: + type: array + items: + $ref: "#/components/schemas/Pet" + Error: + type: object + required: + - code + - message + properties: + code: + type: integer + format: int32 + message: + type: string diff --git a/tests/test_oas3_blueprint.py b/tests/test_oas3_blueprint.py new file mode 100644 index 00000000..f1117cc1 --- /dev/null +++ b/tests/test_oas3_blueprint.py @@ -0,0 +1,67 @@ +from collections import defaultdict + +from sanic.blueprints import Blueprint + +from sanic_openapi import openapi, openapi3 +from sanic_openapi.openapi3.blueprint import blueprint_factory +from sanic_openapi.openapi3.builders import ( + OperationBuilder, + SpecificationBuilder, +) + + +def test_exclude_entire_blueprint(app3): + _, response = app3.test_client.get("/swagger/swagger.json") + path_count = len(response.json["paths"]) + tag_count = len(response.json["tags"]) + + bp = Blueprint("noshow") + + @bp.get("/") + def noshow(_): + ... + + # For 21.3+ + try: + app3.router.reset() + except AttributeError: + ... + + app3.blueprint(bp) + openapi.exclude(bp=bp) + + _, response = app3.test_client.get("/swagger/swagger.json") + + assert len(response.json["paths"]) == path_count + assert len(response.json["tags"]) == tag_count + + +def test_exclude_single_blueprint_route(app3): + _, response = app3.test_client.get("/swagger/swagger.json") + path_count = len(response.json["paths"]) + tag_count = len(response.json["tags"]) + + bp = Blueprint("somebp") + + @bp.get("/") + @openapi.exclude() + def noshow(_): + ... + + @bp.get("/ok") + def ok(_): + ... + + # For 21.3+ + try: + app3.router.reset() + except AttributeError: + ... + + app3.blueprint(bp) + + _, response = app3.test_client.get("/swagger/swagger.json") + + assert "/ok" in response.json["paths"] + assert len(response.json["paths"]) == path_count + 1 + assert len(response.json["tags"]) == tag_count + 1 diff --git a/tests/test_oas3_definition.py b/tests/test_oas3_definition.py new file mode 100644 index 00000000..3e1c428f --- /dev/null +++ b/tests/test_oas3_definition.py @@ -0,0 +1,190 @@ +import pytest + +from sanic_openapi import openapi +from sanic_openapi.openapi3.definitions import ( + ExternalDocumentation, + Parameter, + RequestBody, + Response, + Tag, +) +from sanic_openapi.openapi3.types import Schema, _serialize + + +class Name: + first_name: str + last_name: str + + +class User: + name: Name + + +def test_def_operation(app3): + @app3.post("/path") + @openapi.definition( + operation="operID", + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + op = response.json["paths"]["/path"]["post"] + + assert op["operationId"] == "operID" + + +def test_def_summary(app3): + @app3.post("/path") + @openapi.definition( + summary="Hello, world.", + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + op = response.json["paths"]["/path"]["post"] + + assert op["summary"] == "Hello, world." + + +def test_def_description(app3): + @app3.post("/path") + @openapi.definition( + description="Hello, world.", + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + op = response.json["paths"]["/path"]["post"] + + assert op["description"] == "Hello, world." + + +@pytest.mark.parametrize( + "value", + ( + "foo", + Tag("foo"), + ["foo"], + [Tag("foo")], + [Tag("foo"), "bar"], + ), +) +def test_def_tag(app3, value): + @app3.post("/path") + @openapi.definition( + tag=value, + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + op = response.json["paths"]["/path"]["post"] + length = len(value) if isinstance(value, list) else 1 + + assert "foo" in op["tags"] + assert len(op["tags"]) == length + + +@pytest.mark.parametrize( + "value", + ("http://somewhere", ExternalDocumentation("http://somewhere")), +) +def test_def_document(app3, value): + @app3.post("/path") + @openapi.definition( + document=value, + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + op = response.json["paths"]["/path"]["post"] + + assert op["externalDocs"]["url"] == "http://somewhere" + + +@pytest.mark.parametrize( + "media,value", + ( + ("*/*", User), + ("*/*", {"*/*": User}), + ("*/*", {"content": {"*/*": User}}), + ("*/*", RequestBody(User)), + ("*/*", RequestBody(content=User)), + ("application/json", RequestBody({"application/json": User})), + ), +) +def test_def_body(app3, media, value): + @app3.post("/path") + @openapi.definition( + body=value, + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + body = response.json["paths"]["/path"]["post"]["requestBody"] + + assert body["content"][media]["schema"] == _serialize(Schema.make(User)) + + +@pytest.mark.parametrize( + "type_,value", + ( + ("string", "something"), + ("string", {"name": "something"}), + ("string", ["something", "else"]), + ("integer", {"name": "something", "schema": int}), + ("string", Parameter("something", str)), + ("string", [Parameter("something", str)]), + ), +) +def test_def_parameter(app3, value, type_): + @app3.post("/path") + @openapi.definition( + parameter=value, + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + params = response.json["paths"]["/path"]["post"]["parameters"] + + length = len(value) if isinstance(value, list) else 1 + + assert len(params) == length + assert params[0]["name"] == "something" + assert params[0]["schema"]["type"] == type_ + + +@pytest.mark.parametrize( + "status,media,value", + ( + (200, "*/*", User), + (200, "application/json", {"application/json": User}), + (200, "*/*", Response(User)), + (201, "*/*", Response(User, 201)), + (201, "*/*", [Response(User, 201), User]), + (200, "application/json", Response({"application/json": User})), + ), +) +def test_def_response(app3, status, media, value): + @app3.post("/path") + @openapi.definition( + response=value, + ) + async def handler(request): + ... + + _, response = app3.test_client.get("/swagger/swagger.json") + responses = response.json["paths"]["/path"]["post"]["responses"] + + length = len(value) if isinstance(value, list) else 1 + + assert len(responses) == length + assert responses[f"{status}"]["content"][media] == { + "schema": _serialize(Schema.make(User)) + } diff --git a/tests/test_oas3_specification.py b/tests/test_oas3_specification.py new file mode 100644 index 00000000..1991821e --- /dev/null +++ b/tests/test_oas3_specification.py @@ -0,0 +1,48 @@ +from pathlib import Path + +import yaml + +from sanic_openapi import specification + + +def test_apply_describe(app3): + title = "This is a test" + version = "1.1.1" + specification.describe(title, version=version) + + _, response = app3.test_client.get("/swagger/swagger.json") + + assert response.json["info"]["title"] == title + assert response.json["info"]["version"] == version + + +def test_raw(app3): + with open(Path(__file__).parent / "samples" / "petstore.yaml", "r") as f: + data = yaml.safe_load(f) + + _, response = app3.test_client.get("/swagger/swagger.json") + + path_count = len(response.json["paths"]) + schema_count = len( + (response.json.get("components", {}) or {}).get("schemas", {}) + ) + servers_count = len(response.json["servers"]) + tags_count = len(response.json["tags"]) + + specification.tag("one") + specification.raw(data) + specification.url("http://foobar") + specification.tag("two") + + _, response = app3.test_client.get("/swagger/swagger.json") + + assert len(response.json["paths"]) == path_count + 2 + assert len(response.json["components"]["schemas"]) == schema_count + 3 + assert len(response.json["servers"]) == servers_count + 2 + assert len(response.json["tags"]) == tags_count + 3 + + assert response.json["servers"][1]["url"] == "http://foobar" + assert all( + x in {x["name"] for x in response.json["tags"]} + for x in ["one", "two", "pets"] + )