Skip to content

Commit

Permalink
Support python 3.11
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647581034
  • Loading branch information
shaobohou authored and TF2JAXDev committed Jun 28, 2024
1 parent 776d2f2 commit 0a5da95
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.10", "3.11"]
os: [ubuntu-latest]

steps:
Expand Down
4 changes: 2 additions & 2 deletions requirements_tests.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
chex==0.1.85
dm-haiku==0.0.11
tensorflow==2.11.1; sys_platform != 'darwin' or platform_machine != 'arm64'
tensorflow-macos==2.11.1; sys_platform == 'darwin' and platform_machine == 'arm64'
tensorflow==2.16.1; sys_platform != 'darwin' or platform_machine != 'arm64'
tensorflow-macos==2.16.1; sys_platform == 'darwin' and platform_machine == 'arm64'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _parse_requirements(path):
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Software Development :: Libraries :: Python Modules',
Expand Down
2 changes: 1 addition & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python --version
# Install dependencies.
pip install --upgrade pip setuptools wheel
# See https://github.com/google/pytype/issues/1316
pip install flake8 pytest-xdist pytype==2022.9.19 importlab==0.7 pylint pylint-exit
pip install flake8 pytest-xdist pytype importlab pylint pylint-exit
pip install -r requirements.txt
pip install -r requirements_tests.txt

Expand Down
11 changes: 11 additions & 0 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ def uses_native_serialization():

class Jax2TfTest(test_util.TestCase):

def setUp(self):
super().setUp()
if not uses_native_serialization():
self._xla_op = tf2jax.ops._jax_ops.pop("XlaCallModule", None)

def tearDown(self):
super().tearDown()
if not uses_native_serialization():
if self._xla_op is not None:
tf2jax.ops._jax_ops["XlaCallModule"] = self._xla_op

def _test_convert(
self,
jax_func,
Expand Down

0 comments on commit 0a5da95

Please sign in to comment.