From cb3c008d9b936bdcaffefc440deeac72d46b378b Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 26 Jun 2024 14:19:14 +0200 Subject: [PATCH] fix rank assignment for single-process/many-device profiles --- .../python/jax_nsys/jax_nsys/data_loaders.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index 548dc2cb6..e625a268c 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -226,14 +226,18 @@ def _load_nvtx_gpu_proj_trace_single( # Assume that the thunks are launched from one thread per device, this is probably # safe. Also, until https://github.com/openxla/xla/pull/14092 is plumbed through, # assume that thread ID order is local rank order (FIXME!) + tid_to_ordinal = {} + for _, module_df in df[all_thunks].groupby("ProgramId"): + # A given module should have N threads submitting work to N devices, but the + # thread ID submitting work to device 0 is different for N=1 (main thread) and + # N>1 (a worker thread) + for ordinal, tid in enumerate(sorted(module_df["TID"].unique())): + assert tid_to_ordinal.get(tid, ordinal) == ordinal + tid_to_ordinal[tid] = ordinal # This profile contains ranks [process_index*num_devices, (process_index+1)*num_devices] - unique_tids = df.loc[all_thunks, "TID"].unique() - num_devices = len(unique_tids) + num_devices = len(set(tid_to_ordinal.values())) df["Rank"] = df["TID"].map( - { - tid: process_index * num_devices + n_tid - for n_tid, tid in enumerate(unique_tids) - } + {k: process_index * num_devices + v for k, v in tid_to_ordinal.items()} ) if warmup_removal_heuristics: