From 14b3960fe1966d76b68b8a636ffe305a3281c238 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 7 Dec 2023 15:54:37 +0000 Subject: [PATCH] Drop Python 3.7 support --- .github/workflows/build.yml | 2 +- jax_dataclasses/__init__.py | 3 ++- jax_dataclasses/_dataclasses.py | 4 ++-- setup.py | 5 +++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6c8bfb1..ae5d0a8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 diff --git a/jax_dataclasses/__init__.py b/jax_dataclasses/__init__.py index 0350ea9..25225a7 100644 --- a/jax_dataclasses/__init__.py +++ b/jax_dataclasses/__init__.py @@ -21,7 +21,8 @@ # for now... from dataclasses import dataclass as pytree_dataclass else: - from ._dataclasses import pytree_dataclass + from ._dataclasses import pytree_dataclass # noqa + from ._dataclasses import deprecated_static_field as static_field # noqa from ._dataclasses import Static from ._enforced_annotations import EnforcedAnnotationsMixin diff --git a/jax_dataclasses/_dataclasses.py b/jax_dataclasses/_dataclasses.py index f41db4e..5c139bc 100644 --- a/jax_dataclasses/_dataclasses.py +++ b/jax_dataclasses/_dataclasses.py @@ -36,7 +36,7 @@ def pytree_dataclass(cls: Optional[Type] = None, **kwargs): PyTrees.""" def wrap(cls): - return dataclasses.dataclass(cls, **kwargs) + return _register_pytree_dataclass(dataclasses.dataclass(cls, **kwargs)) if "frozen" in kwargs: assert kwargs["frozen"] is True, "Pytree dataclasses can only be frozen!" @@ -63,7 +63,7 @@ class FieldInfo: static_field_names: List[str] -def register_pytree_dataclass(cls: Type[T]) -> Type[T]: +def _register_pytree_dataclass(cls: Type[T]) -> Type[T]: """Register a dataclass as a flax-serializable pytree container.""" assert dataclasses.is_dataclass(cls) diff --git a/setup.py b/setup.py index 61cf5b0..f11bef4 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ license="MIT", packages=find_packages(), package_data={"jax_dataclasses": ["py.typed"]}, - python_requires=">=3.7", + python_requires=">=3.8", install_requires=[ "jax", "jaxlib", @@ -28,9 +28,10 @@ ] }, classifiers=[ - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ],