From 99b78fbe5e291e803b6de8b51c5aec7b0e980854 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 9 Nov 2023 14:46:46 -0500 Subject: [PATCH] update benchmark --- .../2023sciware_jax/multi_node_benchmark.py | 2 +- .../2023sciware_jax/single_node_benchmark.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 slides/2023sciware_jax/single_node_benchmark.py diff --git a/slides/2023sciware_jax/multi_node_benchmark.py b/slides/2023sciware_jax/multi_node_benchmark.py index 0ac675584c4..3d96f8fcbbe 100644 --- a/slides/2023sciware_jax/multi_node_benchmark.py +++ b/slides/2023sciware_jax/multi_node_benchmark.py @@ -50,4 +50,4 @@ def f(x): return jnp.sum(x*x) -# print(multihost_utils.process_allgather(jax.value_and_grad(f)(arr))) \ No newline at end of file +print(multihost_utils.process_allgather(jax.value_and_grad(f)(arr))) \ No newline at end of file diff --git a/slides/2023sciware_jax/single_node_benchmark.py b/slides/2023sciware_jax/single_node_benchmark.py new file mode 100644 index 00000000000..1e37dc7b4b3 --- /dev/null +++ b/slides/2023sciware_jax/single_node_benchmark.py @@ -0,0 +1,23 @@ +import jax +import jax.numpy as jnp +from IPython import get_ipython +from jax.experimental import mesh_utils +from jax.sharding import PositionalSharding +ipython = get_ipython() + +print(jax.process_count()) +print(jax.devices()) +print(jax.local_device_count()) + +x = jnp.zeros((4000,4000)) + +print(x.device()) + +ipython.run_line_magic("timeit", "jnp.matmul(x,x)") + +sharding = PositionalSharding(mesh_utils.create_device_mesh((4,))) +y = jax.device_put(x, sharding.reshape(4,1)) + +print(y.devices()) + +ipython.run_line_magic("timeit", "jnp.matmul(y,y)")