diff --git a/tests/unit/test_req.py b/tests/unit/test_req.py index dd0af289e7b..3b78ead3fe3 100644 --- a/tests/unit/test_req.py +++ b/tests/unit/test_req.py @@ -770,11 +770,16 @@ def test_install_req_drop_extras(self, inp: str, out: str) -> None: without_extras = install_req_drop_extras(req) assert not without_extras.extras assert str(without_extras.req) == out - # should always be a copy - assert req is not without_extras - assert req.req is not without_extras.req + + # if there are no extras they should be the same object, + # otherwise they may be a copy due to cache + if req.extras: + assert req is not without_extras + assert req.req is not without_extras.req + # comes_from should point to original assert without_extras.comes_from is req + # all else should be the same assert without_extras.link == req.link assert without_extras.markers == req.markers @@ -790,9 +795,9 @@ def test_install_req_drop_extras(self, inp: str, out: str) -> None: @pytest.mark.parametrize( "inp, extras, out", [ - ("pkg", {}, "pkg"), - ("pkg==1.0", {}, "pkg==1.0"), - ("pkg[ext]", {}, "pkg[ext]"), + ("pkg", set(), "pkg"), + ("pkg==1.0", set(), "pkg==1.0"), + ("pkg[ext]", set(), "pkg[ext]"), ("pkg", {"ext"}, "pkg[ext]"), ("pkg==1.0", {"ext"}, "pkg[ext]==1.0"), ("pkg==1.0", {"ext1", "ext2"}, "pkg[ext1,ext2]==1.0"), @@ -816,9 +821,14 @@ def test_install_req_extend_extras( assert str(extended.req) == out assert extended.req is not None assert set(extended.extras) == set(extended.req.extras) - # should always be a copy - assert req is not extended - assert req.req is not extended.req + + # if extras is not a subset of req.extras then the extended + # requirement object should not be the same, otherwise they + # might be a copy due to cache + if not extras.issubset(req.extras): + assert req is not extended + assert req.req is not extended.req + # all else should be the same assert extended.link == req.link assert extended.markers == req.markers