diff --git a/.github/container/jax_nsys/Analysis.ipynb b/.github/container/jax_nsys/Analysis.ipynb index db8cf3988..fd21b93d5 100644 --- a/.github/container/jax_nsys/Analysis.ipynb +++ b/.github/container/jax_nsys/Analysis.ipynb @@ -16,14 +16,10 @@ " generate_compilation_statistics,\n", " load_profiler_data,\n", " remove_autotuning_detail,\n", - " remove_child_ranges,\n", " xla_module_metadata,\n", ")\n", "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "import pandas as pd # type: ignore\n", - "import sys" + "import numpy as np" ] }, { @@ -35,7 +31,7 @@ "source": [ "# Make sure that the .proto files under protos/ have been compiled to .py, and\n", "# that those generated .py files are importable.]\n", - "ensure_compiled_protos_are_importable();" + "compiled_dir = ensure_compiled_protos_are_importable()" ] }, { diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py index a6de1cd78..03b1b4816 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py @@ -63,7 +63,7 @@ def ensure_compiled_protos_are_importable(*, prefix: pathlib.Path = pathlib.Path def do_import(): # Use this as a proxy for everything being importable - from xla.service import hlo_pb2 + from xla.service import hlo_pb2 # noqa: F401 try: do_import() diff --git a/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py b/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py index f4a6ade7f..9fd6ae52e 100755 --- a/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py +++ b/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py @@ -7,7 +7,6 @@ load_profiler_data, remove_autotuning_detail, ) -import pandas as pd import pathlib parser = argparse.ArgumentParser( @@ -26,6 +25,7 @@ # Partition the profile data into initialisation and steady-state running init, steady_state = apply_warmup_heuristics(all_data) # Get high-level statistics about the modules that were profiled +assert steady_state.module is not None module_stats = ( steady_state.module.groupby("ProgramId") .agg(