Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Jul 8, 2024
1 parent 81d8725 commit 34f8266
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
from typing import Any

from .protobuf import xla_module_metadata
from .protobuf import HloProto, xla_module_metadata
from .utils import make_child_mask, ProfilerData

pd.options.mode.copy_on_write = True
Expand Down Expand Up @@ -158,18 +158,9 @@ def _collective_correction(kind: str, size: int) -> tuple[float, float]:
assert False, f"Unknown collective kind {kind}"


@functools.lru_cache
def get_message_size(
program_id: int, instruction: str, prefix: pathlib.Path
) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
https://openxla.org/xla/operation_semantics#allreduce and so on for more explanation
of the semantics. This implementation aims to follow the same conventions that NCCL
uses in its NVTX payloads and tests.
"""
module_proto = xla_module_metadata(program_id, prefix=prefix)
def _get_message_size(
module_proto: HloProto, instruction: str
) -> tuple[int, str, int, float, float]:
_, inst = module_proto.find_instruction(instruction)
assert (
inst.opcode
Expand Down Expand Up @@ -211,8 +202,31 @@ def get_message_size(

collective = inst.opcode.removesuffix("-start")
bw_correction, bus_correction = _collective_correction(collective, collective_size)
return (total_msg_size, collective, collective_size, bw_correction, bus_correction)


@functools.lru_cache
def get_message_size(
program_id: int, instruction: str, prefix: pathlib.Path
) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
https://openxla.org/xla/operation_semantics#allreduce and so on for more explanation
of the semantics. This implementation aims to follow the same conventions that NCCL
uses in its NVTX payloads and tests.
"""
results = {
_get_message_size(module_proto, instruction)
for module_proto in xla_module_metadata(
program_id, prefix=prefix, policy="all"
).values()
}
assert (
len(results) == 1
), f"Got inconsistent collective stats for {instruction} ({program_id}): {results}"
return pd.Series(
[total_msg_size, collective, collective_size, bw_correction, bus_correction],
list(next(iter(results))),
index=[
"MessageSize",
"Collective",
Expand Down

0 comments on commit 34f8266

Please sign in to comment.