diff --git a/slides/2023sciware_jax/multi_node_benchmark.py b/slides/2023sciware_jax/multi_node_benchmark.py index 54a4d3433c2..dfa0b343b0e 100644 --- a/slides/2023sciware_jax/multi_node_benchmark.py +++ b/slides/2023sciware_jax/multi_node_benchmark.py @@ -6,16 +6,18 @@ import time import jax -import jax.numpy as jnp from jax import sharding from jax.sharding import Mesh import numpy as np from jax.sharding import PartitionSpec as P -from jax.experimental.pjit import pjit -from jax.experimental import mesh_utils -from jax.experimental import multihost_utils import math from jax._src.distributed import initialize +import time +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--size", type=int, default=8000) +args = parser.parse_args() initialize() global_mesh = Mesh(np.array(jax.devices()), ('b')) @@ -35,7 +37,7 @@ print("Number of device on this process: ",jax.local_device_count()) -local_shape = (8, 2) +local_shape = (args.size // jax.process_count(), args.size) global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:] local_array = np.arange(math.prod(local_shape)).reshape(local_shape) + jax.process_index()*1.0 arrays = jax.device_put( @@ -45,17 +47,22 @@ arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays) print(arrays[0].devices()) +print(arrays[0].shape) print(arr.shape) - - - +print(arr.devices()) @jax.jit def f(x, y): return x@y + x -# print(jnp.sum(multihost_utils.process_allgather(arr))) -# print(local_array.shape, arr.shape) -# print(arrays) -# print(multihost_utils.process_allgather(jax.value_and_grad(f)(arr))) +f(arr, arr).block_until_ready() # Precompile + +current_time = time.time() + +f(arr, arr).block_until_ready() + +multi_gpu_time = time.time() - current_time + +if jax.process_index() == 0: + print("Time to compute on 4 devices:", multi_gpu_time)