From d45131daa8a9837a7cfc49f1d0acc79e0a502ba1 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Fri, 9 Aug 2024 12:30:02 +0200 Subject: [PATCH] fixups --- .github/container/jax_nsys/Analysis.ipynb | 4 ++-- .../container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/container/jax_nsys/Analysis.ipynb b/.github/container/jax_nsys/Analysis.ipynb index 30d87cdd6..82e32263d 100644 --- a/.github/container/jax_nsys/Analysis.ipynb +++ b/.github/container/jax_nsys/Analysis.ipynb @@ -498,7 +498,7 @@ "metadata": {}, "outputs": [], "source": [ - "if len(comm_df):\n", + "if len(steady_state.communication):\n", " fig, axs2d = plt.subplots(\n", " ncols=3, figsize=[15, 5], squeeze=False, tight_layout=True\n", " )\n", @@ -647,7 +647,7 @@ "metadata": {}, "outputs": [], "source": [ - "if len(comm_df):\n", + "if len(steady_state.communication):\n", " fig, grid = plt.subplots(\n", " nrows=len(top_module_ids), figsize=[15, 5], squeeze=False, tight_layout=True\n", " )\n", diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index ce3126a91..44507e889 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -100,7 +100,9 @@ def apply_warmup_heuristics(frames: ProfilerData) -> tuple[ProfilerData, Profile # expected to launch closer to in lockstep across processes. init = ProfilerData(compile=frames.compile) steady = ProfilerData() - steady_state_threshold = 1 if len(frames.communication) else 0 + steady_state_threshold = ( + 1 if frames.communication is not None and len(frames.communication) else 0 + ) for k in ["communication", "thunk", "module"]: df = getattr(frames, k) if df is None: