Skip to content

Commit

Permalink
Support FP8, new XLA collective group descriptions, and targeted prof…
Browse files Browse the repository at this point in the history
…iles
  • Loading branch information
olupton committed Aug 9, 2024
1 parent 15e65e9 commit 2822ce8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
20 changes: 14 additions & 6 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,11 @@ def element_type_width(element_type: int) -> int:
# https://github.com/openxla/xla/blob/664a36a2b5e5be9179c5841830da56799b6dfe60/xla/service/gpu/runtime/nccl_api.cc#L116-L118
return 8

# There are several 8-bit floating point types of the form F8E{n}M{m}...
if enum_name.startswith("F8E"):
return 8

# S32 is a 32-bit type and so on.
# FIXME: does not handle FP8 yet
for prefix in ["BF", "C", "F", "S", "U"]:
if enum_name.startswith(prefix):
return int(enum_name[len(prefix) :])
Expand Down Expand Up @@ -202,11 +205,16 @@ def _get_message_size(
replica_groups = comm_inst.collective_device_list.replica_groups
except AttributeError:
replica_groups = comm_inst.replica_groups
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
assert (
len(collective_sizes) == 1
), f"Heterogeneous collective {comm_inst} could not be interpreted"
collective_size = next(iter(collective_sizes))
if len(replica_groups) == 0:
# perhaps we have the newer format
iota_group_list = comm_inst.collective_device_list.iota_replica_group_list
collective_size = iota_group_list.num_devices_per_group
else:
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
assert (
len(collective_sizes) == 1
), f"Heterogeneous collective {comm_inst} could not be interpreted"
collective_size = next(iter(collective_sizes))
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
Expand Down
8 changes: 4 additions & 4 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ def remove_autotuning_detail(
- compile frame loses granular detail within autotuner compilation and measurement
"""
# Ignore autotuning executions with ProgramId < 0
if module_frame and data.module is not None:
if module_frame and data.module is not None and len(data.module):
assert data.module.index.names[0] == "ProgramId"
data.module = data.module.loc[0:]
if thunk_frame and data.thunk is not None:
if thunk_frame and data.thunk is not None and len(data.thunk):
assert data.thunk.index.names[0] == "ProgramId"
data.thunk = data.thunk.loc[0:]
if measurement and data.compile is not None:
if measurement and data.compile is not None and len(data.compile):
# Removing child ranges of XlaAutotunerMeasurement ranges. The GEMM fusion
# autotuner creates small modules/thunks when measuring, which emit XlaModule
# and XlaThunk ranges
mask = data.compile["Name"].str.startswith("XlaAutotunerMeasurement")
# Erase the name of the op being autotuned
data.compile.loc[mask, "Name"] = "XlaAutotunerMeasurement"
data.compile = remove_child_ranges(data.compile, mask)
if compilation and data.compile is not None:
if compilation and data.compile is not None and len(data.compile):
# Remove the detail of the constituent parts (EmitLlvmIr etc.) of autotuner
# compilation
data.compile = remove_child_ranges(
Expand Down

0 comments on commit 2822ce8

Please sign in to comment.