-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add-cuda-image-arg
- Loading branch information
Showing
37 changed files
with
1,124 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
#!/usr/bin/env python | ||
import argparse | ||
from ctypes import byref, cdll, c_int, POINTER | ||
from functools import partial | ||
import jax | ||
from jax.experimental.multihost_utils import sync_global_devices | ||
from jax.experimental.shard_map import shard_map | ||
import jax.numpy as jnp | ||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P | ||
import os | ||
import time | ||
|
||
|
||
libcudart = cdll.LoadLibrary("libcudart.so") | ||
cudaGetDeviceCount = libcudart.cudaGetDeviceCount | ||
cudaGetDeviceCount.argtypes = [POINTER(c_int)] | ||
cudaGetDeviceCount.restype = c_int | ||
cudaProfilerStart = libcudart.cudaProfilerStart | ||
cudaProfilerStop = libcudart.cudaProfilerStop | ||
|
||
|
||
def visible_device_count() -> int: | ||
""" | ||
Query the number of local devices visible to this process. | ||
""" | ||
count = c_int() | ||
assert cudaGetDeviceCount(byref(count)) == 0 | ||
return count.value | ||
|
||
|
||
def int_or_env(value) -> int: | ||
try: | ||
return int(value) | ||
except ValueError: | ||
return int(os.environ[value]) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Pure-JAX implementation of a NCCL performance test" | ||
) | ||
parser.add_argument( | ||
"--coordinator-address", | ||
help="Distributed coordinator address:port; used if --distributed is passed.", | ||
) | ||
parser.add_argument( | ||
"--distributed", | ||
action="store_true", | ||
help="Run jax.distributed.initialize()", | ||
) | ||
parser.add_argument( | ||
"--gpus-per-process", | ||
help=( | ||
"Number of GPUs driven by each controller process. " | ||
"Defaults to 1 with --distributed and all of them otherwise." | ||
), | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"--process-count", | ||
help=( | ||
"When --distributed is passed this gives the total number of processes. " | ||
"This can either be an integer of the name of an environment variable." | ||
), | ||
type=int_or_env, | ||
) | ||
parser.add_argument( | ||
"--process-id", | ||
help=( | ||
"When --distributed is passed this gives the global index of this process." | ||
"This can either be an integer or the name of an environment variable." | ||
), | ||
type=int_or_env, | ||
) | ||
args = parser.parse_args() | ||
|
||
assert ( | ||
args.process_id is None or args.distributed | ||
), "--process-id is only relevant with --distributed" | ||
if args.distributed: | ||
null_args = { | ||
args.coordinator_address is None, | ||
args.gpus_per_process is None, | ||
args.process_count is None, | ||
args.process_id is None, | ||
} | ||
if all(null_args): | ||
# Use default behaviour | ||
jax.distributed.initialize() | ||
else: | ||
assert not any(null_args), ( | ||
"All of --coordinator-address, --gpus-per-process, --process-count and " | ||
"--process-id must be passed if any of them are." | ||
) | ||
visible_devices = visible_device_count() | ||
local_processes, rem = divmod(visible_devices, args.gpus_per_process) | ||
assert rem == 0, ( | ||
f"--gpus-per-process={args.gpus_per_process} does not divide the " | ||
"visible device count {visible_devices}" | ||
) | ||
# assume processes within a node are globally numbered contiguously | ||
local_process_id = args.process_id % local_processes | ||
first_local_device = local_process_id * args.gpus_per_process | ||
local_device_ids = list( | ||
range(first_local_device, first_local_device + args.gpus_per_process) | ||
) | ||
print( | ||
f"Rank {args.process_id} has local rank {local_process_id} and " | ||
f"devices {local_device_ids} from a total of {visible_devices} " | ||
f"visible on this node, {args.process_count} processes and " | ||
f"{args.process_count*args.gpus_per_process} total devices.", | ||
flush=True, | ||
) | ||
jax.distributed.initialize( | ||
coordinator_address=args.coordinator_address, | ||
local_device_ids=local_device_ids, | ||
num_processes=args.process_count, | ||
process_id=args.process_id, | ||
) | ||
elif args.gpus_per_process is not None: | ||
# Respect --gpus-per-process even without --distributed | ||
jax.config.update( | ||
"jax_cuda_visible_devices", | ||
",".join(str(x) for x in range(args.gpus_per_process)), | ||
) | ||
|
||
if jax.process_index() == 0: | ||
print(f"JAX devices: {jax.devices()}") | ||
n_devices = jax.device_count() | ||
assert ( | ||
args.gpus_per_process is None | ||
or jax.local_device_count() == args.gpus_per_process | ||
), ( | ||
f"Got {jax.local_device_count()} local devices despite " | ||
f"--gpus-per-process={args.gpus_per_process}" | ||
) | ||
mesh = Mesh(jax.devices(), axis_names=("i",)) | ||
min_size_power = 0 | ||
max_size_power = 30 | ||
max_elements = 2**32 | ||
sharding = partial( | ||
shard_map, | ||
mesh=mesh, | ||
in_specs=(P("i"), P("i", None), None), | ||
check_rep=False, | ||
out_specs=P("i"), | ||
) | ||
|
||
@partial(jax.jit, static_argnames="collective") | ||
@sharding | ||
def measure_collective(sync, big_input, collective): | ||
with jax.named_scope(collective): | ||
output = 1.0 | ||
big_input = big_input * jax.lax.psum(sync, "i") | ||
assert big_input.shape == (1, 2**max_size_power), big_input.shape | ||
for size in range(max_size_power + 1): | ||
values_per_device = 2**size | ||
input = output * jax.lax.slice( | ||
big_input, (0, 0), (1, values_per_device) | ||
) | ||
assert input.shape == (1, values_per_device), input.shape | ||
result = None | ||
# Trigger the collective we want to measure | ||
if collective == "all_gather": | ||
if input.size * n_devices < max_elements: | ||
result = jax.lax.all_gather(input, "i") | ||
assert result.shape == (n_devices, *input.shape), result.shape | ||
elif collective == "all_reduce": | ||
if input.size < max_elements: | ||
result = jax.lax.psum(input, "i") | ||
assert result.shape == (1, values_per_device), result.shape | ||
elif collective == "broadcast": | ||
if input.size < max_elements: | ||
# FIXME: need https://github.com/google/jax/pull/20705 re-land | ||
result = jax.lax.pbroadcast(input, "i", 0) | ||
assert result.shape == (1, values_per_device), result.shape | ||
elif collective == "permute": | ||
if input.size < max_elements: | ||
# TODO: make this sensitive to whether the permutation does or | ||
# does not cross NVLink domain boundaries | ||
permutation = [ | ||
(i, (i + 1) % n_devices) for i in range(n_devices) | ||
] | ||
result = jax.lax.ppermute(input, "i", permutation) | ||
assert result.shape == (1, values_per_device), result.shape | ||
else: | ||
assert collective == "reduce_scatter", collective | ||
if values_per_device >= n_devices: | ||
# Need to be able to scatter at least 1 value of the result on | ||
# each device. This results in the largest message size (NCCL | ||
# convention) for reduce-scatter being a factor `n_devices` | ||
# smaller than the other collectives | ||
result = jax.lax.psum_scatter( | ||
input, "i", scatter_dimension=1, tiled=True | ||
) | ||
assert result.shape == ( | ||
1, | ||
values_per_device // n_devices, | ||
), result.shape | ||
# Do something with the results to stop them getting combined/removed | ||
if result is not None: | ||
output *= 1.5 + jnp.tanh(jnp.mean(result)) # scale by [0.5, 1.5] | ||
return jnp.array([output]) | ||
|
||
def measure(sync, input, host_timer=False): | ||
for op in ["all_gather", "all_reduce", "permute", "reduce_scatter"]: | ||
start = time.time() | ||
result = measure_collective(sync, input, op) | ||
if host_timer: | ||
result.block_until_ready() | ||
if jax.process_index() == 0: | ||
print(f"First {op} duration {time.time()-start:.2f}s") | ||
return result | ||
|
||
def device_put_local(x: jax.Array): | ||
return [jax.device_put(x, d) for d in jax.local_devices()] | ||
|
||
# This helper is used to trigger a small barrier before the main measurement, again | ||
# to improve measurement quality. It's always the same and is sharded with one | ||
# value per device. | ||
sync = jax.make_array_from_single_device_arrays( | ||
(n_devices,), | ||
NamedSharding(mesh, P("i")), | ||
device_put_local(jnp.ones((1,))), | ||
) | ||
input = jax.make_array_from_single_device_arrays( | ||
(n_devices, 2**max_size_power), | ||
NamedSharding(mesh, P("i")), | ||
device_put_local(jax.random.normal(jax.random.key(1), (1, 2**max_size_power))), | ||
) | ||
if jax.process_index() == 0: | ||
print(f"Data for pre-measurement synchronisation {sync.shape}") | ||
jax.debug.visualize_array_sharding(sync) | ||
print(f"Data for collective measurements {input.shape}") | ||
jax.debug.visualize_array_sharding(input) | ||
|
||
start = time.time() | ||
sync_global_devices("init") | ||
sync_time = time.time() - start | ||
if jax.process_index() == 0: | ||
print(f"Barrier time (NCCL init): {sync_time:.2f}s") | ||
|
||
measure(sync, input, host_timer=True) | ||
sync_global_devices("warmup_done") | ||
cudaProfilerStart() | ||
sync_global_devices("profiling_started") | ||
for _ in range(10): | ||
measure(sync, input) | ||
sync_global_devices("measurements_completed") | ||
cudaProfilerStop() | ||
sync_global_devices("profiling_ended") | ||
if jax.process_index() == 0: | ||
print("Exiting...") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.