Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent dataclass parsing from triggering DuplicateBasesError #2629

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ What's New in astroid 3.3.6?
Release date: TBA

* Fix precedence of `path` arg in `modpath_from_file_with_callback` to be higher than `sys.path`
* Prevent dataclass parsing from triggering `DuplicateBasesError`.


What's New in astroid 3.3.5?
Expand Down
4 changes: 2 additions & 2 deletions astroid/brain/brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _find_arguments_from_base_classes(
# See TODO down below
# all_have_defaults = True

for base in reversed(node.mro()):
for base in reversed(node.mro(ignore_duplicates=True)):
if not base.is_dataclass:
continue
try:
Expand Down Expand Up @@ -221,7 +221,7 @@ def _parse_arguments_into_strings(

def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None:
"""Get the default value of a previously defined field."""
for base in reversed(node.mro()):
for base in reversed(node.mro(ignore_duplicates=True)):
if not base.is_dataclass:
continue
if name in base.locals:
Expand Down
23 changes: 17 additions & 6 deletions astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ def clean_duplicates_mro(
sequences: list[list[ClassDef]],
cls: ClassDef,
context: InferenceContext | None,
ignore_duplicates: bool,
) -> list[list[ClassDef]]:
for sequence in sequences:
seen = set()
for node in sequence:
lineno_and_qname = (node.lineno, node.qname())
if lineno_and_qname in seen:
if lineno_and_qname in seen and not ignore_duplicates:
raise DuplicateBasesError(
message="Duplicates found in MROs {mros} for {cls!r}.",
mros=sequences,
Expand Down Expand Up @@ -2834,7 +2835,9 @@ def _inferred_bases(self, context: InferenceContext | None = None):
else:
yield from baseobj.bases

def _compute_mro(self, context: InferenceContext | None = None):
def _compute_mro(
self, context: InferenceContext | None = None, ignore_duplicates: bool = False
):
if self.qname() == "builtins.object":
return [self]

Expand All @@ -2844,23 +2847,31 @@ def _compute_mro(self, context: InferenceContext | None = None):
if base is self:
continue

mro = base._compute_mro(context=context)
mro = base._compute_mro(
context=context, ignore_duplicates=ignore_duplicates
)
bases_mro.append(mro)

unmerged_mro: list[list[ClassDef]] = [[self], *bases_mro, inferred_bases]
unmerged_mro = clean_duplicates_mro(unmerged_mro, self, context)
unmerged_mro = clean_duplicates_mro(
unmerged_mro, self, context, ignore_duplicates=ignore_duplicates
)
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context: InferenceContext | None = None) -> list[ClassDef]:
def mro(
self, context: InferenceContext | None = None, ignore_duplicates: bool = False
) -> list[ClassDef]:
"""Get the method resolution order, using C3 linearization.

:param ignore_duplicates: Don't raise DuplicateBasesError on duplicate bases
of the same base class.
:returns: The list of ancestors, sorted by the mro.
:rtype: list(NodeNG)
:raises DuplicateBasesError: Duplicate bases in the same class base
:raises InconsistentMroError: A class' MRO is inconsistent
"""
return self._compute_mro(context=context)
return self._compute_mro(context=context, ignore_duplicates=ignore_duplicates)

def bool_value(self, context: InferenceContext | None = None) -> Literal[True]:
"""Determine the boolean value of this node.
Expand Down
30 changes: 30 additions & 0 deletions tests/brain/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,33 @@ def attr(self, value: int) -> None:
fourth_init: bases.UnboundMethod = next(fourth.infer())
assert [a.name for a in fourth_init.args.args] == ["self", "other_attr", "attr"]
assert [a.name for a in fourth_init.args.defaults] == ["Uninferable"]


@parametrize_module
def test_dataclass_inherited_from_multiple_protocol_bases(module: str):
code = astroid.extract_node(
f"""
from {module} import dataclass
from typing import TypeVar, Protocol

BaseT = TypeVar("BaseT")
T = TypeVar("T", bound=BaseT)


class A(Protocol[BaseT]):
pass


class B(A[T], Protocol[T]):
pass


@dataclass
class Dataclass(B[T]):
pass
"""
)
inferred = code.inferred()
assert len(inferred) == 1
assert isinstance(inferred[0], nodes.ClassDef)
assert inferred[0].is_dataclass
3 changes: 3 additions & 0 deletions tests/test_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,7 @@ class A(Generic[T1], Generic[T2]): ...
assert isinstance(cls, nodes.ClassDef)
with self.assertRaises(DuplicateBasesError):
cls.mro()
assert len(cls.mro(ignore_duplicates=True)) == 3

def test_mro_generic_error_2(self):
cls = builder.extract_node(
Expand All @@ -1951,6 +1952,8 @@ class B(A[T], A[T]): ...
assert isinstance(cls, nodes.ClassDef)
with self.assertRaises(DuplicateBasesError):
cls.mro()
with self.assertRaises(InconsistentMroError):
cls.mro(ignore_duplicates=True)

def test_mro_typing_extensions(self):
"""Regression test for mro() inference on typing_extensions.
Expand Down