From 5eaf7a85f0b94c97b98458f5f66b9e18ab0df07b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 15 Nov 2023 14:04:40 -0500 Subject: [PATCH] Single node benchmark.py --- .../2023sciware_jax/single_node_benchmark.py | 46 +++++++++++++++---- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/slides/2023sciware_jax/single_node_benchmark.py b/slides/2023sciware_jax/single_node_benchmark.py index 3b1d075d099..53fd7eb26be 100644 --- a/slides/2023sciware_jax/single_node_benchmark.py +++ b/slides/2023sciware_jax/single_node_benchmark.py @@ -3,25 +3,51 @@ from IPython import get_ipython from jax.experimental import mesh_utils from jax.sharding import PositionalSharding +import argparse +import time ipython = get_ipython() -print(jax.process_count()) -print(jax.devices()) -print(jax.local_device_count()) +print("Number of processes:" , jax.process_count()) +print("Number of devices:", jax.local_device_count()) + +# parse command line arguments of how many numbers to process + +parser = argparse.ArgumentParser() +parser.add_argument("--size", type=int, default=8000) +args = parser.parse_args() @jax.jit def f(x,y): return x@y + x -x = jnp.zeros((8000,8000)) +x = jnp.zeros((args.size, args.size)) -print(x.device()) +sharding = PositionalSharding(mesh_utils.create_device_mesh((4,))) +y = jax.device_put(x, sharding.reshape(2,2)) -ipython.run_line_magic("timeit", "f(x,x)") +print("Number of devices on tensor y:", len(y.devices())) -sharding = PositionalSharding(mesh_utils.create_device_mesh((4,))) -y = jax.device_put(x, sharding.reshape(4,1)) +f(y,y).block_until_ready() # Precompile + +current_time = time.time() + +f(y,y).block_until_ready() + +multi_gpu_time = time.time() - current_time + +print("Time to compute on 4 devices:", multi_gpu_time) + +print("Number of devices on tensor x:", len(x.devices())) + +f(x,x).block_until_ready() # Precompile + +current_time = time.time() + +f(x,x).block_until_ready() + +single_gpu_time = time.time() - current_time + +print("Time to compute on 1 device:", single_gpu_time) -print(y.devices()) +print("Speedup:", single_gpu_time / multi_gpu_time) -ipython.run_line_magic("timeit", "f(y,y)")