Skip to content

Commit

Permalink
Add PYD003/4/5 (#4)
Browse files Browse the repository at this point in the history
- Drop support for Python 3.8, this was causing compatibility issues with slices
  • Loading branch information
Viicos authored Feb 24, 2024
1 parent cf04e05 commit 66eca2e
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
62 changes: 62 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,68 @@ class Model(BaseModel):
foo = 1 # Will error at runtime
```

### `PYD003` - *Unecessary Field call to specify a default value*

Raise an error if the [`Field`](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) function
is used only to specify a default value.

```python
class Model(BaseModel):
foo: int = Field(default=1)
```

Instead, consider specifying the default value directly:

```python
class Model(BaseModel):
foo: int = 1
```

### `PYD004` - *Default argument specified in annotated*

Raise an error if the `default` argument of the [`Field`](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) function is used together with [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated).

```python
class Model(BaseModel):
foo: Annotated[int, Field(default=1, description="desc")]
```

To make type checkers aware of the default value, consider specifying the default value directly:

```python
class Model(BaseModel):
foo: Annotated[int, Field(description="desc")] = 1
```

### `PYD005` - *Field name overrides annotation*

Raise an error if the field name clashes with the annotation.

```python
from datetime import date

class Model(BaseModel):
date: date | None = None
```

Because of how Python [evaluates](https://docs.python.org/3/reference/simple_stmts.html#annassign)
annotated assignments, unexpected issues can happen when using an annotation name that clashes with a field
name. Pydantic will try its best to warn you about such issues, but can fail in complex scenarios (and the
issue may even be silently ignored).

Instead, consider, using an [alias](https://docs.pydantic.dev/latest/concepts/fields/#field-aliases) or referencing your type under a different name:

```python
from datetime import date

date_ = date

class Model(BaseModel):
date_aliased: date | None = Field(default=None, alias="date")
# or
date: date_ | None = None
```

### `PYD010` - *Usage of `__pydantic_config__`*

Raise an error if a Pydantic configuration is set with [`__pydantic_config__`](https://docs.pydantic.dev/dev/concepts/config/#configuration-with-dataclass-from-the-standard-library-or-typeddict).
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ readme = "README.md"
authors = [
{name = "Victorien", email = "[email protected]"}
]
requires-python = ">=3.8"
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -54,7 +53,7 @@ where = ["src"]
[tool.ruff]
line-length = 120
src = ["src"]
target-version = "py38"
target-version = "py39"

[tool.ruff.lint]
preview = true
Expand Down
49 changes: 42 additions & 7 deletions src/flake8_pydantic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _has_pydantic_decorator(node: ast.ClassDef) -> bool:
for stmt in node.body:
if isinstance(stmt, ast.FunctionDef):
decorator_names = get_decorator_names(stmt.decorator_list)
if PYDANTIC_DECORATORS.intersection(decorator_names):
if PYDANTIC_DECORATORS & decorator_names:
return True
return False

Expand All @@ -91,12 +91,6 @@ def _has_pydantic_method(node: ast.ClassDef) -> bool:
return False


def is_dataclass(node: ast.ClassDef) -> bool:
"""Determine if a class is a dataclass."""

return bool({"dataclass", "pydantic_dataclass"}.intersection(get_decorator_names(node.decorator_list)))


def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bool:
"""Determine if a class definition is a Pydantic model.
Expand All @@ -119,3 +113,44 @@ def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bo
or _has_pydantic_decorator(node)
or _has_pydantic_method(node)
)


def is_dataclass(node: ast.ClassDef) -> bool:
"""Determine if a class is a dataclass."""

return bool({"dataclass", "pydantic_dataclass"} & get_decorator_names(node.decorator_list))


def is_function(node: ast.Call, function_name: str) -> bool:
return (
isinstance(node.func, ast.Name)
and node.func.id == function_name
or isinstance(node.func, ast.Attribute)
and node.func.attr == function_name
)


def is_name(node: ast.expr, name: str) -> bool:
return isinstance(node, ast.Name) and node.id == name or isinstance(node, ast.Attribute) and node.attr == name


def extract_annotations(node: ast.expr) -> set[str]:
annotations: set[str] = set()

if isinstance(node, ast.Name):
# foo: date = ...
annotations.add(node.id)
if isinstance(node, ast.BinOp):
# foo: date | None = ...
annotations |= extract_annotations(node.left)
annotations |= extract_annotations(node.right)
if isinstance(node, ast.Subscript):
# foo: dict[str, date]
# foo: Annotated[list[date], ...]
if isinstance(node.slice, ast.Tuple):
for elt in node.slice.elts:
annotations |= extract_annotations(elt)
if isinstance(node.slice, ast.Name):
annotations.add(node.slice.id)

return annotations
15 changes: 15 additions & 0 deletions src/flake8_pydantic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ class PYD002(Error):
message = "Non-annotated attribute inside Pydantic model"


class PYD003(Error):
error_code = "PYD003"
message = "Unecessary Field call to specify a default value"


class PYD004(Error):
error_code = "PYD004"
message = "Default argument specified in annotated"


class PYD005(Error):
error_code = "PYD005"
message = "Field name overrides annotation"


class PYD010(Error):
error_code = "PYD010"
message = "Usage of __pydantic_config__"
3 changes: 2 additions & 1 deletion src/flake8_pydantic/plugin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import ast
from collections.abc import Iterator
from importlib.metadata import version
from typing import Any, Iterator
from typing import Any

from .visitor import Visitor

Expand Down
55 changes: 49 additions & 6 deletions src/flake8_pydantic/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Literal

from ._compat import TypeAlias
from ._utils import is_dataclass, is_pydantic_model
from .errors import PYD001, PYD002, PYD010, Error
from ._utils import extract_annotations, is_dataclass, is_function, is_name, is_pydantic_model
from .errors import PYD001, PYD002, PYD003, PYD004, PYD005, PYD010, Error

ClassType: TypeAlias = Literal["pydantic_model", "dataclass", "other_class"]

Expand Down Expand Up @@ -35,10 +35,7 @@ def _check_pyd_001(self, node: ast.AnnAssign) -> None:
if (
self.current_class in {"pydantic_model", "dataclass"}
and isinstance(node.value, ast.Call)
and (
(isinstance(node.value.func, ast.Name) and node.value.func.id == "Field")
or (isinstance(node.value.func, ast.Attribute) and node.value.func.attr == "Field")
)
and is_function(node.value, "Field")
and len(node.value.args) >= 1
):
self.errors.append(PYD001.from_node(node))
Expand All @@ -55,6 +52,49 @@ def _check_pyd_002(self, node: ast.ClassDef) -> None:
for assignment in invalid_assignments:
self.errors.append(PYD002.from_node(assignment))

def _check_pyd_003(self, node: ast.AnnAssign) -> None:
if (
self.current_class in {"pydantic_model", "dataclass"}
and isinstance(node.value, ast.Call)
and is_function(node.value, "Field")
and len(node.value.keywords) == 1
and node.value.keywords[0].arg == "default"
):
self.errors.append(PYD003.from_node(node))

def _check_pyd_004(self, node: ast.AnnAssign) -> None:
if (
self.current_class in {"pydantic_model", "dataclass"}
and isinstance(node.annotation, ast.Subscript)
and is_name(node.annotation.value, "Annotated")
and isinstance(node.annotation.slice, ast.Tuple)
):
field_call = next(
(
elt
for elt in node.annotation.slice.elts
if isinstance(elt, ast.Call)
and is_function(elt, "Field")
and any(k.arg == "default" for k in elt.keywords)
),
None,
)
if field_call is not None:
self.errors.append(PYD004.from_node(node))

def _check_pyd_005(self, node: ast.ClassDef) -> None:
if self.current_class in {"pydantic_model", "dataclass"}:
previous_targets: set[str] = set()

for stmt in node.body:
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
# TODO only add before if AnnAssign?
# the following seems to work:
# date: date
previous_targets.add(stmt.target.id)
if previous_targets & extract_annotations(stmt.annotation):
self.errors.append(PYD005.from_node(stmt))

def _check_pyd_010(self, node: ast.ClassDef) -> None:
if self.current_class == "other_class":
for stmt in node.body:
Expand All @@ -74,10 +114,13 @@ def _check_pyd_010(self, node: ast.ClassDef) -> None:
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.enter_class(node)
self._check_pyd_002(node)
self._check_pyd_005(node)
self._check_pyd_010(node)
self.generic_visit(node)
self.leave_class()

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self._check_pyd_001(node)
self._check_pyd_003(node)
self._check_pyd_004(node)
self.generic_visit(node)
2 changes: 1 addition & 1 deletion tests/rules/test_pyd001.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Model:

PYD001_OK = """
class Model(BaseModel):
a: int = Field(default=1)
a: int = Field(default=1, description="")
"""


Expand Down
33 changes: 33 additions & 0 deletions tests/rules/test_pyd003.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import ast

import pytest

from flake8_pydantic.errors import PYD003, Error
from flake8_pydantic.visitor import Visitor

PYD003_NOT_OK = """
class Model(BaseModel):
a: int = Field(default=1)
"""

PYD003_OK = """
class Model(BaseModel):
a: int = Field(default=1, description="")
"""


@pytest.mark.parametrize(
["source", "expected"],
[
(PYD003_NOT_OK, [PYD003(3, 4)]),
(PYD003_OK, []),
],
)
def test_pyd003(source: str, expected: list[Error]) -> None:
module = ast.parse(source)
visitor = Visitor()
visitor.visit(module)

assert visitor.errors == expected
33 changes: 33 additions & 0 deletions tests/rules/test_pyd004.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import ast

import pytest

from flake8_pydantic.errors import PYD004, Error
from flake8_pydantic.visitor import Visitor

PYD004_1 = """
class Model(BaseModel):
a: Annotated[int, Field(default=1, description="")]
"""

PYD004_2 = """
class Model(BaseModel):
a: Annotated[int, Unrelated(), Field(default=1)]
"""


@pytest.mark.parametrize(
["source", "expected"],
[
(PYD004_1, [PYD004(3, 4)]),
(PYD004_2, [PYD004(3, 4)]),
],
)
def test_pyd004(source: str, expected: list[Error]) -> None:
module = ast.parse(source)
visitor = Visitor()
visitor.visit(module)

assert visitor.errors == expected
Loading

0 comments on commit 66eca2e

Please sign in to comment.