Skip to content

Commit

Permalink
Drop Python 3.7 support
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Dec 7, 2023
1 parent fa7e67c commit 14b3960
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 2 additions & 1 deletion jax_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
# for now...
from dataclasses import dataclass as pytree_dataclass
else:
from ._dataclasses import pytree_dataclass
from ._dataclasses import pytree_dataclass # noqa
from ._dataclasses import deprecated_static_field as static_field # noqa

from ._dataclasses import Static
from ._enforced_annotations import EnforcedAnnotationsMixin
Expand Down
4 changes: 2 additions & 2 deletions jax_dataclasses/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def pytree_dataclass(cls: Optional[Type] = None, **kwargs):
PyTrees."""

def wrap(cls):
return dataclasses.dataclass(cls, **kwargs)
return _register_pytree_dataclass(dataclasses.dataclass(cls, **kwargs))

if "frozen" in kwargs:
assert kwargs["frozen"] is True, "Pytree dataclasses can only be frozen!"
Expand All @@ -63,7 +63,7 @@ class FieldInfo:
static_field_names: List[str]


def register_pytree_dataclass(cls: Type[T]) -> Type[T]:
def _register_pytree_dataclass(cls: Type[T]) -> Type[T]:
"""Register a dataclass as a flax-serializable pytree container."""

assert dataclasses.is_dataclass(cls)
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
license="MIT",
packages=find_packages(),
package_data={"jax_dataclasses": ["py.typed"]},
python_requires=">=3.7",
python_requires=">=3.8",
install_requires=[
"jax",
"jaxlib",
Expand All @@ -28,9 +28,10 @@
]
},
classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
Expand Down

0 comments on commit 14b3960

Please sign in to comment.