From 0a5da9590fb4217a0949180788b3dcf75cd6b315 Mon Sep 17 00:00:00 2001 From: Shaobo Hou Date: Fri, 28 Jun 2024 00:37:15 -0700 Subject: [PATCH] Support python 3.11 PiperOrigin-RevId: 647581034 --- .github/workflows/ci.yml | 2 +- requirements_tests.txt | 4 ++-- setup.py | 1 + test.sh | 2 +- tf2jax/_src/roundtrip_test.py | 11 +++++++++++ 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a73c41..789c423 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] os: [ubuntu-latest] steps: diff --git a/requirements_tests.txt b/requirements_tests.txt index 6f929a7..561fd12 100644 --- a/requirements_tests.txt +++ b/requirements_tests.txt @@ -1,4 +1,4 @@ chex==0.1.85 dm-haiku==0.0.11 -tensorflow==2.11.1; sys_platform != 'darwin' or platform_machine != 'arm64' -tensorflow-macos==2.11.1; sys_platform == 'darwin' and platform_machine == 'arm64' +tensorflow==2.16.1; sys_platform != 'darwin' or platform_machine != 'arm64' +tensorflow-macos==2.16.1; sys_platform == 'darwin' and platform_machine == 'arm64' diff --git a/setup.py b/setup.py index dfc1cb4..5088956 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,7 @@ def _parse_requirements(path): 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Software Development :: Libraries :: Python Modules', diff --git a/test.sh b/test.sh index 450cef5..75f11a7 100755 --- a/test.sh +++ b/test.sh @@ -26,7 +26,7 @@ python --version # Install dependencies. pip install --upgrade pip setuptools wheel # See https://github.com/google/pytype/issues/1316 -pip install flake8 pytest-xdist pytype==2022.9.19 importlab==0.7 pylint pylint-exit +pip install flake8 pytest-xdist pytype importlab pylint pylint-exit pip install -r requirements.txt pip install -r requirements_tests.txt diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index 138d48d..8738fff 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -69,6 +69,17 @@ def uses_native_serialization(): class Jax2TfTest(test_util.TestCase): + def setUp(self): + super().setUp() + if not uses_native_serialization(): + self._xla_op = tf2jax.ops._jax_ops.pop("XlaCallModule", None) + + def tearDown(self): + super().tearDown() + if not uses_native_serialization(): + if self._xla_op is not None: + tf2jax.ops._jax_ops["XlaCallModule"] = self._xla_op + def _test_convert( self, jax_func,