Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Jax tests for the M1 mac. #508

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open

Conversation

alxmrs
Copy link
Contributor

@alxmrs alxmrs commented Jul 20, 2024

I'm beginning to explore #304 in greater depth. Since the only local GPU I have access to is an M1 chip (I have an M1 Macbook Air), I thought I would replicate this environment in CI.

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 20, 2024

Locally, I'm hitting a large number of errors. It looks like jax-metal is still highly experimental.

104 failed tests!
FAILED cubed/tests/test_array_api.py::test_ones[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_ones_like[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_add_top_level_namespace[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_add_different_chunks[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_add_different_chunks_fail[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_matmul[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/alxrsngrtn/git/cubed/cubed/array_api/linear_algebra_functions.py:63:12: error: 'mps.matmul' op operand...
FAILED cubed/tests/test_array_api.py::test_broadcast_arrays[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_all_zero_dimension[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_ones[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_ones_like[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_add_top_level_namespace[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_add_different_chunks[processes] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-6
FAILED cubed/tests/test_array_api.py::test_add_different_chunks_fail[processes] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-7
FAILED cubed/tests/test_array_api.py::test_matmul[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/alxrsngrtn/git/cubed/cubed/array_api/linear_algebra_functions.py:63:12: error: 'mps.matmul' op operand...
FAILED cubed/tests/test_array_api.py::test_outer[processes] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-10
FAILED cubed/tests/test_array_api.py::test_broadcast_arrays[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_all_zero_dimension[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_eye[-1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_eye[0] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_eye[1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_linspace[True] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_linspace[False] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_tril_triu[-1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_tril_triu[0] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_tril_triu[1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_index_2d_step[shape0-chunks0-ind0-new_chunks_expected0] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_array_api.py::test_tensordot[1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/alxrsngrtn/git/cubed/cubed/array_api/linear_algebra_functions.py:156:8: error: 'mps.matmul' op operand...
FAILED cubed/tests/test_array_api.py::test_tensordot[axes1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/alxrsngrtn/git/cubed/cubed/array_api/linear_algebra_functions.py:156:8: error: 'mps.matmul' op operand...
FAILED cubed/tests/test_core.py::test_compute_is_idempotent[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_default_spec[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_array_pickle[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/alxrsngrtn/git/cubed/cubed/array_api/linear_algebra_functions.py:63:12: error: 'mps.matmul' op operand...
FAILED cubed/tests/test_core.py::test_compute_is_idempotent[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_default_spec[processes] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-42
FAILED cubed/tests/test_core.py::test_array_pickle[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/alxrsngrtn/git/cubed/cubed/array_api/linear_algebra_functions.py:63:12: error: 'mps.matmul' op operand...
FAILED cubed/tests/test_core.py::test_map_blocks_with_different_block_shapes - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-45
FAILED cubed/tests/test_core.py::test_rechunk_intermediate - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_default_spec_config_override - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_partial_reduce - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_visualize - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_core.py::test_plan_scaling[10] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[20] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[50] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[100] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[200] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[500] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[1000] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[2000] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_core.py::test_plan_scaling[5000] - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_executor_features.py::test_compute_arrays_in_parallel[single-threaded-True] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_executor_features.py::test_compute_arrays_in_parallel[single-threaded-False] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_executor_features.py::test_compute_arrays_in_parallel[threads-True] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_executor_features.py::test_compute_arrays_in_parallel[threads-False] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_executor_features.py::test_compute_arrays_in_parallel[processes-True] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_executor_features.py::test_compute_arrays_in_parallel[processes-False] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-53
FAILED cubed/tests/test_gufunc.py::test_apply_gufunc_axes_two_kept_coredims - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_nan_functions.py::test_nansum - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_nan_functions.py::test_nansum_allnan - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a0-chunks_b0-expected_chunksize0] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a1-chunks_b1-expected_chunksize1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a2-chunks_b2-expected_chunksize2] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a3-chunks_b3-expected_chunksize3] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a4-chunks_b4-expected_chunksize4] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a5-chunks_b5-expected_chunksize5] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise[chunks_a6-chunks_b6-expected_chunksize6] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise_2d[chunks_a0-chunks_b0-expected_chunksize0] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise_2d[chunks_a1-chunks_b1-expected_chunksize1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise_2d[chunks_a2-chunks_b2-expected_chunksize2] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise_2d[chunks_a3-chunks_b3-expected_chunksize3] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_elemwise_2d[chunks_a4-chunks_b4-expected_chunksize4] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_blockwise_2d[chunks_a0-chunks_b0-expected_chunksize0] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_blockwise_2d[chunks_a1-chunks_b1-expected_chunksize1] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_blockwise_2d[chunks_a2-chunks_b2-expected_chunksize2] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_blockwise_2d[chunks_a3-chunks_b3-expected_chunksize3] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_blockwise_2d[chunks_a4-chunks_b4-expected_chunksize4] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_broadcast_scalar - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_ops.py::test_unify_chunks_broadcast_2d - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_optimization.py::test_no_fusion - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_optimization.py::test_no_fusion_multiple_edges - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_optimization.py::test_fuse_unary_op - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_binary_op - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_unary_and_binary_op - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_mixed_levels - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_diamond - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_mixed_levels_and_diamond - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_mixed_levels_and_diamond_complex - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_repeated_argument - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_other_dependents - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_unary_large_fan_in - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_large_fan_in_default - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_large_fan_in_override - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_with_merge_chunks_unary - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_with_merge_chunks_binary - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_merge_chunks_unary - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_merge_chunks_binary - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_partial_reduce_unary - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_partial_reduce_binary - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_fuse_only_optimize_dag - NotADirectoryError: [Errno 20] Not a directory: 'dot'
FAILED cubed/tests/test_optimization.py::test_optimize_stack - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_random.py::test_random[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_random.py::test_random_add[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_random.py::test_random_seed[single-threaded] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_random.py::test_random[processes] - jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
FAILED cubed/tests/test_random.py::test_random_add[processes] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-55
FAILED cubed/tests/test_random.py::test_random_seed[processes] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-56
======================================= 104 failed, 258 passed, 18 skipped, 32 deselected, 1 warning in 208.40s (0:03:28)

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 21, 2024

I've pushed some changes to cut the failed tests down in half locally. I'll definitely need design opinions on my review. The next thing I plan to tackle is randomness, which is a special case for Jax and all GPU acceleration.

Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @alxmrs. I'm honoured that you are working on Cubed during your round the world trip!

I have a Mac mini M1 so I should be able to try it when you've got it working. Can you post your environment versions and any pointers to getting it set up please?

.github/workflows/jax-tests.yml Show resolved Hide resolved
cubed/array_api/creation_functions.py Outdated Show resolved Hide resolved
cubed/array_api/creation_functions.py Outdated Show resolved Hide resolved
@tomwhite tomwhite mentioned this pull request Jul 22, 2024
9 tasks
@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 22, 2024

Can you post your environment versions and any pointers to getting it set up please?

I’m on mobile at the moment, but in the meantime: I’m using Python 3.11 (compiled for ARM). I’ve followed these instructions to set up jax for the M1 (which specifically means installing jax-metal):

https://developer.apple.com/metal/jax/

Right now, I’m trying out rye instead of using the usual conda. It’s going good so far (it takes getting used to their opinions). It fits well with Apple’s instructions, which say to create a standard python virtualenv.

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 23, 2024

I'm hitting the same errors faced in #494, namely that "dot" is not found during visualize.

NotADirectoryError: [Errno 20] Not a directory: 'dot'

@alxmrs alxmrs marked this pull request as ready for review July 23, 2024 13:39
alxmrs added a commit to alxmrs/cubed that referenced this pull request Jul 23, 2024
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators.

cubed-dev#508 needs to be merged first.
alxmrs added a commit to alxmrs/cubed that referenced this pull request Jul 23, 2024
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators.

cubed-dev#508 needs to be merged first.
alxmrs added a commit to alxmrs/cubed that referenced this pull request Jul 23, 2024
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators.

cubed-dev#508 needs to be merged first.
alxmrs added a commit to alxmrs/cubed that referenced this pull request Jul 23, 2024
Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators.

cubed-dev#508 needs to be merged first.
Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. I added a few comments about dtype handling.

The CI workflow is crashing with a seg fault - do you see the same on your machine?

dtype = nxp.arange(start, stop, step * num if num else step).dtype
for k, dtype_ in default_dtypes(device=device).items():
if nxp.isdtype(dtype, k):
dtype = dtype_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me what this is doing or why it is needed. If nxp is the jax namespace doesn't the call to arange already return the correct dtype (int32) - or does jax metal just return int64 or fail?

It might help to factor out a function to do this (given it is duplicated below too) with a name describing what it does, and perhaps a comment too.

Copy link
Contributor Author

@alxmrs alxmrs Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or does jax metal just return int64 or fail?

Yes, it looks like jax metal returns 64 bit precisions and fails; this ensures the correct precision.

I've factored this out to a function, good shout.

cubed/array_api/creation_functions.py Show resolved Hide resolved
cubed/tests/test_array_api.py Show resolved Hide resolved
x = np.arange(400, dtype=np.float32).reshape((20, 20))
a = xp.asarray(x, chunks=(5, 4), dtype=xp.float32)
y = np.arange(200, dtype=np.float32).reshape((20, 10))
b = xp.asarray(y, chunks=(4, 5), dtype=xp.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto - we need to test that leaving out dtype works OK in these cases.

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 24, 2024

Thanks for the review! I’ll address these when I can.

The CI workflow is crashing with a seg fault - do you see the same on your machine?

No! Tests are running without any faults locally for me. I’ll do some digging to better understand the CI environment.

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 27, 2024

@tomwhite It looks like to use the GPU on the M1 actions, we need to enable a premium action runner: https://github.blog/news-insights/product-news/introducing-the-new-apple-silicon-powered-m1-macos-larger-runner-for-github-actions/

I think this is the cause of the segfault (from what I can tell from related discussions): pytorch/pytorch#111449 (comment).

How do you think we should proceed here? Should we attempt to have CI target GPUs, or should I configure Jax to run on the CPU?

@tomwhite
Copy link
Member

How do you think we should proceed here? Should we attempt to have CI target GPUs, or should I configure Jax to run on the CPU?

We should certainly have JAX running against the CPU in CI as it tests that Cubed works with the JAX array API. For testing on Mac M1 GPUs, I think we can add that later - particularly since it's a paid for option? It would be good to get the work that you have done here merged, so if you change the CI back then I'm happy to merge it.

BTW I tried installing JAX metal on my Mac Mini M1 to run the tests, but I got an error (SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!). It looks like I need to update the OS to Sonoma (I'm on Ventura), so I'll try to do that at some point - but this shouldn't block merging since you've got it working.

@TomNicholas
Copy link
Member

particularly since it's a paid for option?

There was some stuff at SciPy about quansight being able to give out free access to NVIDIA GPUs for scientific python projects to use in CI. But this seems like a much later concern only for once it works on CPU and seems useful.

@alxmrs alxmrs force-pushed the m1-jax branch 2 times, most recently from e78cd21 to 2f7c324 Compare July 31, 2024 19:52
@tomwhite
Copy link
Member

tomwhite commented Aug 2, 2024

Looks like this is getting close @alxmrs. There are still a few places where you've changed the tests to have a lower precision dtype, where we should also test that it works when it's left at the default. Can we merge once they are resolved?

@alxmrs
Copy link
Contributor Author

alxmrs commented Aug 3, 2024

That sounds good to me Tom. Since my development time is sporadic and limited, I'll try to make the Jax features I work on independent from each other from here on out. Today, I extracted #536; since this is a bigger / flakier PR, I'll probably need a few more sessions to get it to fully land.

alxmrs added a commit to alxmrs/cubed that referenced this pull request Aug 3, 2024
I'm extracting cubed-dev#508 into smaller bites.
tomwhite pushed a commit that referenced this pull request Aug 4, 2024
I'm extracting #508 into smaller bites.
@alxmrs
Copy link
Contributor Author

alxmrs commented Aug 12, 2024

Hey Tom! I should have mentioned this earlier. Can you run the workflows again? I think this PR is ready.

@tomwhite
Copy link
Member

The Dask test failure is a flaky test (#549), but the Array API test failures look like they are real.

@tomwhite
Copy link
Member

@alxmrs thanks for pushing this forward! Do you think it might make things easier to reduce the scope by e.g. targeting JAX on CPU and focusing on the device/dtype inspection stuff.

Or maybe there's another way of splitting things up? I'd be happy to merge smaller PRs!

@alxmrs
Copy link
Contributor Author

alxmrs commented Sep 24, 2024 via email

@tomwhite
Copy link
Member

Great - thanks @alxmrs!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants