Skip to content

Commit

Permalink
nsys-jax-combine: add report-merging script
Browse files Browse the repository at this point in the history
Also reorganise the data loading code to yield a more structured data
frame format that's better suited to multi-profile analysis.
  • Loading branch information
olupton committed Jun 26, 2024
1 parent e087481 commit 9ddb3a9
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 216 deletions.
219 changes: 133 additions & 86 deletions .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,99 @@
"module_df = all_data[\"module\"]\n",
"compile_df = all_data[\"compile\"]\n",
"# module_df may contain some entries with ProgramId == -1, which are typically\n",
"# autotuner executions. Throw these away for now.\n",
"module_df = module_df[module_df[\"ProgramId\"] >= 0]\n",
"thunk_df = thunk_df[thunk_df[\"ProgramId\"] >= 0]"
"# autotuner executions. Throw these away for now; ProgramId is the first\n",
"assert module_df.index.names[0] == thunk_df.index.names[0] == \"ProgramId\"\n",
"module_df = module_df.loc[0:]\n",
"thunk_df = thunk_df.loc[0:]"
]
},
{
"cell_type": "markdown",
"id": "313c81ca-87b7-4930-9e4d-f878f36ac61a",
"metadata": {},
"source": [
"## Data format\n",
"\n",
"First, look at the high-level format of the profile data frames.\n",
"`module_df` has a single row for each XLA module execution, which typically corresponds to a single JITed JAX function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba11b419-4b68-4b74-8449-687a070ac90a",
"metadata": {},
"outputs": [],
"source": [
"module_df"
]
},
{
"cell_type": "markdown",
"id": "0ddc41d5-e3ad-4cf8-adf0-33cf2f9de538",
"metadata": {},
"source": [
"This data frame has a three-level index:\n",
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- `Rank` is used in the MPI sense; it is a global index of the GPU on which the module execution took place, across a (potentially distributed) SPMD run\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n",
"- `ProjStartNs`: the timestamp of the start of the module execution on the GPU, in nanoseconds\n",
"- `ProjDurNs`: the duration of the module execution on the GPU, in nanoseconds\n",
"- `OrigStartNs`: the timestamp of the start of the module launch **on the host**, in nanoseconds. *i.e.* `ProjStartNs-OrigStartNs` is something like the launch latency of the first kernel\n",
"- `OrigDurNs`: the duration of the module launch **on the host**, in nanoseconds\n",
"\n",
"The other profile data frame for GPU execution is `thunk_df`, which has a single row for each XLA thunk.\n",
"Loosely, each XLA module contains a series of thunks, and each thunk launches a GPU kernel.\n",
"In reality, thunks can be nested and may launch multiple kernels, but this data frame still provides the most granular distribution available of GPU execution time across the XLA module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef6f0dc5-a327-463e-be0c-a6a9b408e374",
"metadata": {},
"outputs": [],
"source": [
"thunk_df"
]
},
{
"cell_type": "markdown",
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
"metadata": {},
"source": [
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Rank` have the same meanings as in `module_df`.\n",
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
"- `ProjStartNs`, `OrigStartNs`, `OrigDurNs`: see above, same meaning as in `module_df`.\n",
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurNs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenNs` shows the duration that **was** overlapped.\n",
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Rank)`\n",
"\n",
"The third data frame does not show any GPU execution, but is rather a host-side trace:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60913dd1-0d75-4c7b-a311-4ee5b5b02cf4",
"metadata": {},
"outputs": [],
"source": [
"compile_df"
]
},
{
"cell_type": "markdown",
"id": "7aa55141-c75f-458a-b1b4-80326bde58e5",
"metadata": {},
"source": [
"Here the index has two levels; `ProfileName` is important when multiple reports are being analysed together (*i.e.* using `nsys-jax-combine` having run multiple `nsys-jax` processes), as the `RangeId` values referred to in `ParentId` and `RangeStack` are not unique across different `ProfileName` values."
]
},
{
Expand All @@ -80,7 +170,10 @@
" .cumsum()\n",
")\n",
"top_module_mask = top_module_sum / top_module_sum.max() > threshold\n",
"top_module_ids = top_module_mask[top_module_mask].index"
"top_module_ids = top_module_mask[top_module_mask].index[::-1]\n",
"print(\n",
" f\"{1-threshold:.1%}+ of execution time accounted for by module ID(s): {' '.join(map(str, top_module_ids))}\"\n",
")"
]
},
{
Expand Down Expand Up @@ -111,24 +204,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Summarise all the observed compilation time\n",
"# The first compilation triggers a bunch of library loading, things like cuBLAS\n",
"# and cuDNN. Label that explicitly to pull it out of the generic non-leaf time.\n",
"first_xlacompile_index = compile_df[\"Name\"].eq(\"XlaCompile\").idxmax()\n",
"assert compile_df.loc[first_xlacompile_index, \"Name\"] == \"XlaCompile\"\n",
"if compile_df.loc[first_xlacompile_index, \"DurNonChildNs\"] > 0.0:\n",
" new_index = compile_df.index.max() + 1\n",
" new_row = compile_df.loc[first_xlacompile_index, :].copy()\n",
" new_row[\"DurChildNs\"] = 0.0\n",
" new_row[\"Name\"] = \"[non-leaf time in 0th XlaCompile range]\"\n",
" new_row[\"NumChild\"] = 0\n",
" new_row[\"RangeStack\"] += f\":{new_index}\"\n",
" compile_df.loc[first_xlacompile_index, \"DurNonChildNs\"] = 0.0\n",
" compile_df.loc[first_xlacompile_index, \"NumChild\"] += 1\n",
" compile_df = pd.concat([compile_df, pd.DataFrame([new_row], index=[new_index])])\n",
"\n",
"\n",
"# This averages over all profiled compilations and handles parallel compilation\n",
"# Summarise all the observed compilation time; this averages over all profiled compilations and handles parallel compilation\n",
"compile_time_ns = generate_compilation_statistics(compile_df)\n",
"\n",
"\n",
Expand All @@ -142,9 +218,6 @@
" name = \"XlaAutotunerMeasurement\"\n",
" # Parallel backend compilation leads to these split_module names for XlaEmitGpuAsm and XlaOptimizeLlvmIr\n",
" name = name.removesuffix(\":#module=split_module#\")\n",
" # Lump all XlaPass[Pipeline] stuff in together\n",
" if name.startswith(\"XlaPass:#\") or name.startswith(\"XlaPassPipeline:#\"):\n",
" name = \"XlaPass\"\n",
" return name\n",
"\n",
"\n",
Expand All @@ -156,7 +229,7 @@
")\n",
"total_compile_time = compile_summary[\"DurNonChildNs\"].sum()\n",
"# Print out the largest entries adding up to at least this fraction of the total\n",
"threshold = 0.99\n",
"threshold = 0.97\n",
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildNs\"] / total_compile_time\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n",
"for row in compile_summary[\n",
Expand All @@ -168,66 +241,42 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9578b04b-09ca-4065-a5f8-a96eebaf9f4c",
"id": "82c16981-cf5e-490f-b565-25331905a9d4",
"metadata": {},
"outputs": [],
"source": [
"# Summarise all the XLA modules that have been seen in this profile. Note that\n",
"# this does *not* respect the `top_module_ids` list derived above.\n",
"module_stats = defaultdict(list)\n",
"for module_row in module_df.itertuples():\n",
" thunk_mask = thunk_df[\"ModuleId\"] == module_row.Index\n",
" num_thunks = thunk_mask.sum()\n",
" module_stats[module_row.Name].append(\n",
" {\"GPU time [ms]\": 1e-6 * module_row.ProjDurNs, \"#Thunks\": num_thunks}\n",
"# Count the number of thunk ranges corresponding to each program/module execution\n",
"module_df[\"NumThunks\"] = module_df.index.to_frame().apply(\n",
" lambda row: len(\n",
" thunk_df.loc[row[\"ProgramId\"], row[\"ProgramExecution\"], :, row[\"Rank\"]]\n",
" ),\n",
" axis=1,\n",
")\n",
"module_stats = (\n",
" module_df.groupby(\"ProgramId\")\n",
" .agg(\n",
" {\n",
" \"Name\": (\"count\", \"first\"),\n",
" \"ProjDurNs\": (\"sum\", \"std\"),\n",
" \"NumThunks\": (\"mean\", \"std\"),\n",
" }\n",
" )\n",
"\n",
"\n",
"class Summary(NamedTuple):\n",
" mean: float\n",
" std: float\n",
" total: float\n",
"\n",
"\n",
"def reduce_module_stats(module_stats) -> dict[str, Summary]:\n",
" # [{\"a\": 0.3}, {\"a\": 0.4}] -> {\"a\": (0.35, stddev), \"#Instances\": 2}\n",
" num_instances = len(module_stats)\n",
" r = {\"#Instances\": Summary(mean=num_instances, std=0.0, total=num_instances)}\n",
" keys = module_stats[0].keys()\n",
" for stats in module_stats[1:]:\n",
" assert stats.keys() == keys\n",
" for k in keys:\n",
" values = [stats[k] for stats in module_stats]\n",
" r[k] = Summary(mean=np.mean(values), std=np.std(values), total=np.sum(values))\n",
" return r\n",
"\n",
"\n",
"# Aggregate HLO module statistics over repeated executions of them\n",
"agg_module_stats = [(k, reduce_module_stats(v)) for k, v in module_stats.items()]\n",
"\n",
"\n",
"def sort_key(x):\n",
" return x[1][\"GPU time [ms]\"].total\n",
"\n",
"\n",
"agg_module_stats.sort(key=sort_key, reverse=True)\n",
"total = sum(sort_key(x) for x in agg_module_stats)\n",
" .sort_values((\"ProjDurNs\", \"sum\"), ascending=False)\n",
")\n",
"module_total_time = module_stats[(\"ProjDurNs\", \"sum\")].sum()\n",
"print(\" Active GPU time #Exec. #Thunks Module name\")\n",
"accounted_time, top_n = 0.0, None\n",
"for n, tup in enumerate(agg_module_stats):\n",
" module_name, stats = tup\n",
" module_time = sort_key(tup)\n",
"for program_id, row in module_stats.iterrows():\n",
" print(\n",
" \" {:7.2f}% {:9.2f}ms {:5} {:5.0f}±{:<3.0f} {}\".format(\n",
" 100.0 * module_time / total,\n",
" module_time,\n",
" stats[\"#Instances\"].mean,\n",
" stats[\"#Thunks\"].mean,\n",
" stats[\"#Thunks\"].std,\n",
" module_name,\n",
" \" {:7.2f}% {:9.2f}ms {:5} {:5.0f}±{:<3.0f} {} ({})\".format(\n",
" 100.0 * row[(\"ProjDurNs\", \"sum\")] / module_total_time,\n",
" 1e-6 * row[(\"ProjDurNs\", \"sum\")],\n",
" row[(\"Name\", \"count\")],\n",
" row[(\"NumThunks\", \"mean\")],\n",
" row[(\"NumThunks\", \"std\")],\n",
" row[(\"Name\", \"first\")],\n",
" program_id,\n",
" )\n",
" )\n",
" accounted_time += module_time"
" )"
]
},
{
Expand All @@ -241,8 +290,8 @@
"# `top_module_ids` list derived above, as in particular the definition (3) of\n",
"# the total runtime is sensitive to outliers. This is probably a reasonable\n",
"# default, but it is still a heuristic.\n",
"top_module_thunk_df = thunk_df[thunk_df[\"ProgramId\"].isin(top_module_ids)]\n",
"top_module_df = module_df[module_df[\"ProgramId\"].isin(top_module_ids)].copy()\n",
"top_module_thunk_df = thunk_df.loc[top_module_ids]\n",
"top_module_df = module_df.loc[top_module_ids]\n",
"top_module_df[\"ProjEndNs\"] = top_module_df[\"ProjStartNs\"] + top_module_df[\"ProjDurNs\"]\n",
"thunk_summary = (\n",
" top_module_thunk_df.groupby([\"ProgramId\", \"Name\"])\n",
Expand All @@ -266,7 +315,7 @@
"# on a per-GPU basis and then summed over GPUs\n",
"all_thunks_active_ns = thunk_summary[\"ProjDurNs\"].sum() # (1)\n",
"all_modules_active_ns = top_module_df[\"ProjDurNs\"].sum() # (2)\n",
"top_module_duration_df = top_module_df.groupby(\"TID\").agg(\n",
"top_module_duration_df = top_module_df.groupby(\"Rank\").agg(\n",
" {\"ProjStartNs\": \"min\", \"ProjEndNs\": \"max\"}\n",
")\n",
"all_modules_wall_ns = (\n",
Expand Down Expand Up @@ -416,7 +465,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "036e3015-b950-42e5-aa3c-8684d51f41ed",
"id": "4bfbb1ba-cadd-4679-a607-52d10f3aef15",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -430,9 +479,7 @@
" # program, there may be different sub-groupings that are participating in smaller\n",
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
" # sub-groupings and group them, but we currently lack the relevant information.\n",
" collective_df = df.groupby(\n",
" [\"ProgramId\", \"Name\", \"ModuleExecution\", \"ThunkExecution\"]\n",
" )\n",
" collective_df = df.groupby([\"ProgramId\", \"ProgramExecution\", \"ThunkIndex\"])\n",
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
" # collective.\n",
" bandwidth_df = collective_df.agg(\n",
Expand Down
Loading

0 comments on commit 9ddb3a9

Please sign in to comment.