From e9f72ec56b5713d4138aec29ade58d0be0547901 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 4 Dec 2022 21:22:34 -0800 Subject: [PATCH] Fix build, tweak types --- jax_dataclasses/_jit.py | 29 ++++++++++--------- tests/conftest.py | 5 ++++ .../{test_jit.py => test_jit_ignore_py37.py} | 0 3 files changed, 21 insertions(+), 13 deletions(-) create mode 100644 tests/conftest.py rename tests/{test_jit.py => test_jit_ignore_py37.py} (100%) diff --git a/jax_dataclasses/_jit.py b/jax_dataclasses/_jit.py index 93fcb97..afc3602 100644 --- a/jax_dataclasses/_jit.py +++ b/jax_dataclasses/_jit.py @@ -61,7 +61,7 @@ def jit( in. This is convenient for avoiding `@functools.partial()`. """ - def wrap(fun): + def wrap(fun: CallableType) -> CallableType: signature = inspect.signature(fun) # Mark any inputs annotated with jax_dataclasses.Static[] as static. @@ -77,19 +77,22 @@ def wrap(fun): else: static_argnames.append(name) - return jax.jit( - fun, - static_argnums=static_argnums if len(static_argnums) > 0 else None, - static_argnames=static_argnames if len(static_argnames) > 0 else None, - device=device, - backend=backend, - donate_argnums=donate_argnums, - inline=inline, - keep_unused=keep_unused, - abstracted_axes=abstracted_axes, + return cast( + CallableType, + jax.jit( + fun, + static_argnums=static_argnums if len(static_argnums) > 0 else None, + static_argnames=static_argnames if len(static_argnames) > 0 else None, + device=device, + backend=backend, + donate_argnums=donate_argnums, + inline=inline, + keep_unused=keep_unused, + abstracted_axes=abstracted_axes, + ), ) if fun is None: - return cast(Callable[[CallableType], CallableType], wrap) + return wrap else: - return cast(CallableType, wrap(fun)) + return wrap(fun) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7e6803b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import sys + +collect_ignore_glob = [] +if sys.version_info.major == 3 and sys.version_info.minor == 7: + collect_ignore_glob.append("*_ignore_py37.py") diff --git a/tests/test_jit.py b/tests/test_jit_ignore_py37.py similarity index 100% rename from tests/test_jit.py rename to tests/test_jit_ignore_py37.py