Skip to content

Commit

Permalink
update benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Nov 9, 2023
1 parent 4683f3e commit 99b78fb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion slides/2023sciware_jax/multi_node_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@
def f(x):
return jnp.sum(x*x)

# print(multihost_utils.process_allgather(jax.value_and_grad(f)(arr)))
print(multihost_utils.process_allgather(jax.value_and_grad(f)(arr)))
23 changes: 23 additions & 0 deletions slides/2023sciware_jax/single_node_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import jax
import jax.numpy as jnp
from IPython import get_ipython
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
ipython = get_ipython()

print(jax.process_count())
print(jax.devices())
print(jax.local_device_count())

x = jnp.zeros((4000,4000))

print(x.device())

ipython.run_line_magic("timeit", "jnp.matmul(x,x)")

sharding = PositionalSharding(mesh_utils.create_device_mesh((4,)))
y = jax.device_put(x, sharding.reshape(4,1))

print(y.devices())

ipython.run_line_magic("timeit", "jnp.matmul(y,y)")

0 comments on commit 99b78fb

Please sign in to comment.