Skip to content

Commit

Permalink
Finalize multi_node
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Nov 15, 2023
1 parent fa39961 commit 20e8958
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions slides/2023sciware_jax/multi_node_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
import time

import jax
import jax.numpy as jnp
from jax import sharding
from jax.sharding import Mesh
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental.pjit import pjit
from jax.experimental import mesh_utils
from jax.experimental import multihost_utils
import math
from jax._src.distributed import initialize
import time
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--size", type=int, default=8000)
args = parser.parse_args()

initialize()
global_mesh = Mesh(np.array(jax.devices()), ('b'))
Expand All @@ -35,7 +37,7 @@
print("Number of device on this process: ",jax.local_device_count())


local_shape = (8, 2)
local_shape = (args.size // jax.process_count(), args.size)
global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:]
local_array = np.arange(math.prod(local_shape)).reshape(local_shape) + jax.process_index()*1.0
arrays = jax.device_put(
Expand All @@ -45,17 +47,22 @@
arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)

print(arrays[0].devices())
print(arrays[0].shape)
print(arr.shape)



print(arr.devices())

@jax.jit
def f(x, y):
return x@y + x

# print(jnp.sum(multihost_utils.process_allgather(arr)))
# print(local_array.shape, arr.shape)
# print(arrays)

# print(multihost_utils.process_allgather(jax.value_and_grad(f)(arr)))
f(arr, arr).block_until_ready() # Precompile

current_time = time.time()

f(arr, arr).block_until_ready()

multi_gpu_time = time.time() - current_time

if jax.process_index() == 0:
print("Time to compute on 4 devices:", multi_gpu_time)

0 comments on commit 20e8958

Please sign in to comment.