Skip to content

Commit

Permalink
Fix build, tweak types
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Dec 5, 2022
1 parent 87d5c44 commit e9f72ec
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
29 changes: 16 additions & 13 deletions jax_dataclasses/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def jit(
in. This is convenient for avoiding `@functools.partial()`.
"""

def wrap(fun):
def wrap(fun: CallableType) -> CallableType:
signature = inspect.signature(fun)

# Mark any inputs annotated with jax_dataclasses.Static[] as static.
Expand All @@ -77,19 +77,22 @@ def wrap(fun):
else:
static_argnames.append(name)

return jax.jit(
fun,
static_argnums=static_argnums if len(static_argnums) > 0 else None,
static_argnames=static_argnames if len(static_argnames) > 0 else None,
device=device,
backend=backend,
donate_argnums=donate_argnums,
inline=inline,
keep_unused=keep_unused,
abstracted_axes=abstracted_axes,
return cast(
CallableType,
jax.jit(
fun,
static_argnums=static_argnums if len(static_argnums) > 0 else None,
static_argnames=static_argnames if len(static_argnames) > 0 else None,
device=device,
backend=backend,
donate_argnums=donate_argnums,
inline=inline,
keep_unused=keep_unused,
abstracted_axes=abstracted_axes,
),
)

if fun is None:
return cast(Callable[[CallableType], CallableType], wrap)
return wrap
else:
return cast(CallableType, wrap(fun))
return wrap(fun)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys

collect_ignore_glob = []
if sys.version_info.major == 3 and sys.version_info.minor == 7:
collect_ignore_glob.append("*_ignore_py37.py")
File renamed without changes.

0 comments on commit e9f72ec

Please sign in to comment.