Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'jax.experimental.maps' import error #962

Open
MikeMpapa opened this issue Jul 25, 2024 · 2 comments
Open

'jax.experimental.maps' import error #962

MikeMpapa opened this issue Jul 25, 2024 · 2 comments

Comments

@MikeMpapa
Copy link

MikeMpapa commented Jul 25, 2024

Hi I am trying to use the levanter image but I get the following error: ModuleNotFoundError: No module named 'jax.experimental.maps'.

Was the model renamed? It worked fine yesterday

Thanks!

The complete error log:

  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'```
@chaserileyroberts
Copy link
Contributor

I recently hit this error too on a separate problem.

I think that jax just removed maps from experimental recently. That has been deprecated for a while https://github.com/google/jax/blob/5e418f5ab2692d4791816e85ed82eb0834a579cb/CHANGELOG.md?plain=1#L284

@HMUNACHI
Copy link

HMUNACHI commented Aug 10, 2024

Problem:
This problem is from the Haliax package: see here, thread_resource has moved into jax.experimental.mesh_utils.py with recent refactoring. They need to change that.

Solution:
You can fork Haliax repo yourself, fix the importation problem and replace the Haliax link in the said docker image here with link to your yours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants