diff --git a/freezegun/api.py b/freezegun/api.py index d235292..4103baf 100644 --- a/freezegun/api.py +++ b/freezegun/api.py @@ -335,6 +335,27 @@ def __sub__(self, other: Any) -> "FakeDate": # type: ignore else: return result # type: ignore + def __le__(self, other) -> bool: + if isinstance(other, real_date) and not isinstance(other, real_datetime): + return real_date.__le__(self, other) + return NotImplemented + + def __lt__(self, other) -> bool: + if isinstance(other, real_date) and not isinstance(other, real_datetime): + return real_date.__lt__(self, other) + return NotImplemented + + def __ge__(self, other) -> bool: + if isinstance(other, real_date) and not isinstance(other, real_datetime): + return real_date.__ge__(self, other) + return NotImplemented + + def __gt__(self, other) -> bool: + if isinstance(other, real_date) and not isinstance(other, real_datetime): + return real_date.__gt__(self, other) + return NotImplemented + + @classmethod def today(cls: Type["FakeDate"]) -> "FakeDate": result = cls._date_to_freeze() + cls._tz_offset() diff --git a/tests/test_datetimes.py b/tests/test_datetimes.py index b75ad3b..ead1329 100644 --- a/tests/test_datetimes.py +++ b/tests/test_datetimes.py @@ -2,6 +2,7 @@ import calendar import datetime import fractions +import itertools import unittest import locale import sys @@ -826,3 +827,36 @@ def test_datetime_in_timezone(monkeypatch: pytest.MonkeyPatch) -> None: assert datetime.datetime.now() == datetime.datetime(1970, 1, 1, 1, 0, 0) finally: time.tzset() # set the timezone back to what is was before + + +def test_cannot_compare_date_and_datetime() -> None: + std_date = datetime.date(2012, 1, 14) + std_datetime = datetime.datetime(2012, 1, 14, 12, 0, 0) + fake_date = FakeDate(2012, 1, 15) + fake_datetime = FakeDatetime(2012, 1, 14, 12, 1, 0) + + assert std_date < fake_date + assert std_datetime < fake_datetime + + d_objs = (std_date, fake_date) + dt_objs = (std_datetime, fake_datetime) + + for base, diff in itertools.chain( + itertools.product(d_objs, dt_objs), itertools.product(dt_objs, d_objs) + ): + with pytest.raises(TypeError): + _ = base < diff + with pytest.raises(TypeError): + _ = base <= diff + with pytest.raises(TypeError): + _ = base > diff + with pytest.raises(TypeError): + _ = base >= diff + + for base, same in itertools.chain( + itertools.product(d_objs, d_objs), itertools.product(dt_objs, dt_objs) + ): + _ = base < same + _ = base <= same + _ = base > same + _ = base >= same