Skip to content

Commit

Permalink
Don't fail hard if the profile is too short for heuristics to cope
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Aug 9, 2024
1 parent 1d967ab commit 45c8908
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,20 @@ def apply_warmup_heuristics(frames: ProfilerData) -> tuple[ProfilerData, Profile
compile_mask = df.index.get_level_values("ProgramId").isin(compilation_ids_seen)
prog_exec_values = df.index.get_level_values("ProgramExecution")
init_mask = compile_mask & (prog_exec_values == 0)
steady_mask = ~compile_mask | (prog_exec_values > 1)
assert (
len(df) == 0 or steady_mask.any()
), "No steady-state executions identified, profile collection may have been too short"
assert (prog_exec_values[~init_mask & ~steady_mask] == 1).all()
setattr(init, k, df[init_mask])
setattr(steady, k, df[steady_mask])
steady_mask = ~compile_mask | (prog_exec_values > steady_state_threshold)
if len(df) != 0 and not steady_mask.any():
print(
f"WARNING: heuristics could not identify steady-state execution in {k} frame, assuming EVERYTHING is steady-state. You may want to increase the number of profiled executions."
)
setattr(init, k, df[steady_mask])
setattr(steady, k, df[~steady_mask])
else:
assert (
steady_state_threshold == 0
or (prog_exec_values[~init_mask & ~steady_mask] == 1).all()
)
setattr(init, k, df[init_mask])
setattr(steady, k, df[steady_mask])
return init, steady


Expand Down

0 comments on commit 45c8908

Please sign in to comment.