Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: roll back to 2024.4.1 and handle single-gpu profiles of single module executions better #988

Merged
merged 5 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles

apt-get update
apt-get install -y nsight-compute nsight-systems-cli
# TODO: revert to nsight-systems-cli instead of explicitly pinning
apt-get install -y nsight-compute nsight-systems-cli-2024.4.1
apt-get clean

rm -rf /var/lib/apt/lists/*
Expand Down
332 changes: 169 additions & 163 deletions .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,23 @@
"metadata": {},
"outputs": [],
"source": [
"seen_devices = [False] * alignment_metadata[\"collective_size\"]\n",
"data: list[list[float]] = [[]] * alignment_metadata[\"collective_size\"]\n",
"for device, delta_ms in alignment_metadata[\"collective_end_time_skews_ms\"].groupby(\n",
" \"Device\"\n",
"):\n",
" assert not seen_devices[device]\n",
" seen_devices[device] = True\n",
" data[device] = delta_ms\n",
"fig, ax = plt.subplots()\n",
"ax.violinplot(data, positions=range(len(data)))\n",
"ax.set_title(\n",
" f\"Estimated clock skew from N={alignment_metadata['collective_size']} collectives\"\n",
")\n",
"ax.set_xlabel(\"Device\")\n",
"ax.set_ylabel(\"Clock skew [ms]\");"
"# If no collectives were profiled, this metadata is not available\n",
"if len(alignment_metadata):\n",
" seen_devices = [False] * alignment_metadata[\"collective_size\"]\n",
" data: list[list[float]] = [[]] * alignment_metadata[\"collective_size\"]\n",
" for device, delta_ms in alignment_metadata[\"collective_end_time_skews_ms\"].groupby(\n",
" \"Device\"\n",
" ):\n",
" assert not seen_devices[device]\n",
" seen_devices[device] = True\n",
" data[device] = delta_ms\n",
" fig, ax = plt.subplots()\n",
" ax.violinplot(data, positions=range(len(data)))\n",
" ax.set_title(\n",
" f\"Estimated clock skew from N={alignment_metadata['collective_size']} collectives\"\n",
" )\n",
" ax.set_xlabel(\"Device\")\n",
" ax.set_ylabel(\"Clock skew [ms]\")"
]
},
{
Expand Down Expand Up @@ -434,11 +436,11 @@
"\n",
"# When summarising over source locations use total time (3) as the top level of\n",
"# the hierarchy, assuming that the visualisation will be able to handle this.\n",
"src_runtime[tuple(gpu_idle_inside_modules)] = (\n",
" all_modules_active_ms - all_thunks_active_ms\n",
"src_runtime[tuple(gpu_idle_inside_modules)] = max(\n",
" 0.0, all_modules_active_ms - all_thunks_active_ms\n",
")\n",
"src_runtime[tuple(gpu_idle_between_modules)] = (\n",
" all_modules_wall_ms - all_modules_active_ms\n",
"src_runtime[tuple(gpu_idle_between_modules)] = max(\n",
" 0.0, all_modules_wall_ms - all_modules_active_ms\n",
")\n",
"op_name_runtime[tuple(gpu_idle_inside_modules)] = src_runtime[\n",
" tuple(gpu_idle_inside_modules)\n",
Expand Down Expand Up @@ -496,60 +498,63 @@
"metadata": {},
"outputs": [],
"source": [
"fig, axs2d = plt.subplots(ncols=3, figsize=[15, 5], squeeze=False, tight_layout=True)\n",
"axs = axs2d[0]\n",
"wait_data, wait_data_labels = [], []\n",
"comm_df = steady_state.communication\n",
"comm_df[\"ProjDurFullMs\"] = comm_df[\"ProjDurMs\"] + comm_df[\"ProjDurHiddenMs\"]\n",
"comm_df[\"ProjEndMs\"] = comm_df[\"ProjStartMs\"] + comm_df[\"ProjDurFullMs\"]\n",
"for comm, df in comm_df.groupby(\"Collective\"):\n",
" # The grouped data frame will have a row for each device that is participating in\n",
" # this instance of this collective, in the loose SPMD sense. Depending on the JAX\n",
" # program, there may be different sub-groupings that are participating in smaller\n",
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
" # sub-groupings and group them, but we currently lack the relevant information.\n",
" collective_df = df.groupby([\"ProgramId\", \"ProgramExecution\", \"ThunkIndex\"])\n",
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
" # collective.\n",
" bandwidth_df = collective_df.agg(\n",
" {\n",
" \"BusBandwidthGBPerSec\": \"max\",\n",
" \"MessageSize\": \"min\",\n",
" \"ProjStartMs\": \"min\",\n",
" \"ProjDurFullMs\": \"min\",\n",
" \"ProjEndMs\": \"max\",\n",
" \"Name\": \"count\",\n",
" }\n",
" )\n",
" axs[0].plot(\n",
" bandwidth_df[\"MessageSize\"],\n",
" bandwidth_df[\"BusBandwidthGBPerSec\"],\n",
" \"o\",\n",
" label=comm,\n",
"if len(steady_state.communication):\n",
" fig, axs2d = plt.subplots(\n",
" ncols=3, figsize=[15, 5], squeeze=False, tight_layout=True\n",
" )\n",
" # Take last_end - first_start - fastest_duration as a proxy for time lost due\n",
" # to stragglers / failing to operate in neat lockstep.\n",
" wait_time_ms = (\n",
" bandwidth_df[\"ProjEndMs\"]\n",
" - bandwidth_df[\"ProjStartMs\"]\n",
" - bandwidth_df[\"ProjDurFullMs\"]\n",
" )\n",
" wait_data.append(wait_time_ms)\n",
" wait_data_labels.append(comm)\n",
" axs[2].plot(bandwidth_df[\"MessageSize\"], wait_time_ms, \"o\", label=comm)\n",
"axs[0].legend()\n",
"axs[0].set_xlabel(\"Message size (B)\")\n",
"axs[0].set_xscale(\"log\")\n",
"axs[0].set_ylabel(\"Bus bandwidth (GB/s)\")\n",
"axs[1].boxplot(wait_data, vert=True)\n",
"axs[1].set_xticks([y + 1 for y in range(len(wait_data))], labels=wait_data_labels)\n",
"axs[1].set_xlabel(\"Collective\")\n",
"axs[1].set_ylabel(\"Wait time [ms]\")\n",
"axs[1].set_yscale(\"log\")\n",
"axs[2].set_xlabel(\"Message size (B)\")\n",
"axs[2].set_ylabel(\"Wait time [ms]\")\n",
"axs[2].set_xscale(\"log\")\n",
"axs[2].set_yscale(\"log\")"
" axs = axs2d[0]\n",
" wait_data, wait_data_labels = [], []\n",
" comm_df = steady_state.communication\n",
" comm_df[\"ProjDurFullMs\"] = comm_df[\"ProjDurMs\"] + comm_df[\"ProjDurHiddenMs\"]\n",
" comm_df[\"ProjEndMs\"] = comm_df[\"ProjStartMs\"] + comm_df[\"ProjDurFullMs\"]\n",
" for comm, df in comm_df.groupby(\"Collective\"):\n",
" # The grouped data frame will have a row for each device that is participating in\n",
" # this instance of this collective, in the loose SPMD sense. Depending on the JAX\n",
" # program, there may be different sub-groupings that are participating in smaller\n",
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
" # sub-groupings and group them, but we currently lack the relevant information.\n",
" collective_df = df.groupby([\"ProgramId\", \"ProgramExecution\", \"ThunkIndex\"])\n",
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
" # collective.\n",
" bandwidth_df = collective_df.agg(\n",
" {\n",
" \"BusBandwidthGBPerSec\": \"max\",\n",
" \"MessageSize\": \"min\",\n",
" \"ProjStartMs\": \"min\",\n",
" \"ProjDurFullMs\": \"min\",\n",
" \"ProjEndMs\": \"max\",\n",
" \"Name\": \"count\",\n",
" }\n",
" )\n",
" axs[0].plot(\n",
" bandwidth_df[\"MessageSize\"],\n",
" bandwidth_df[\"BusBandwidthGBPerSec\"],\n",
" \"o\",\n",
" label=comm,\n",
" )\n",
" # Take last_end - first_start - fastest_duration as a proxy for time lost due\n",
" # to stragglers / failing to operate in neat lockstep.\n",
" wait_time_ms = (\n",
" bandwidth_df[\"ProjEndMs\"]\n",
" - bandwidth_df[\"ProjStartMs\"]\n",
" - bandwidth_df[\"ProjDurFullMs\"]\n",
" )\n",
" wait_data.append(wait_time_ms)\n",
" wait_data_labels.append(comm)\n",
" axs[2].plot(bandwidth_df[\"MessageSize\"], wait_time_ms, \"o\", label=comm)\n",
" axs[0].legend()\n",
" axs[0].set_xlabel(\"Message size (B)\")\n",
" axs[0].set_xscale(\"log\")\n",
" axs[0].set_ylabel(\"Bus bandwidth (GB/s)\")\n",
" axs[1].boxplot(wait_data, vert=True)\n",
" axs[1].set_xticks([y + 1 for y in range(len(wait_data))], labels=wait_data_labels)\n",
" axs[1].set_xlabel(\"Collective\")\n",
" axs[1].set_ylabel(\"Wait time [ms]\")\n",
" axs[1].set_yscale(\"log\")\n",
" axs[2].set_xlabel(\"Message size (B)\")\n",
" axs[2].set_ylabel(\"Wait time [ms]\")\n",
" axs[2].set_xscale(\"log\")\n",
" axs[2].set_yscale(\"log\")"
]
},
{
Expand Down Expand Up @@ -642,103 +647,104 @@
"metadata": {},
"outputs": [],
"source": [
"fig, grid = plt.subplots(\n",
" nrows=len(top_module_ids), figsize=[15, 5], squeeze=False, tight_layout=True\n",
")\n",
"time_df = steady_state.thunk.loc[\n",
" ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n",
"]\n",
"time_df[\"ProjEndMs\"] = time_df[\"ProjStartMs\"] + time_df.pop(\"ProjDurMs\")\n",
"\n",
"\n",
"def interleave(df):\n",
" s, e = df[\"ProjStartMs\"], df[\"ProjEndMs\"]\n",
" r = np.empty((s.size + e.size,), dtype=s.dtype)\n",
" r[0::2] = s\n",
" r[1::2] = e\n",
" return r\n",
"\n",
"\n",
"devices_to_show = 8\n",
"for n_row, program_id in enumerate(top_module_ids):\n",
" x_values = []\n",
" y_values = defaultdict(list)\n",
" ax = grid[n_row][0]\n",
" for module_execution, exec_df in time_df.loc[program_id].groupby(\n",
" \"ProgramExecution\"\n",
" ):\n",
" # Mean over devices to get a single [thunk0_start, thunk0_end, thunk1_start, ...]\n",
" # array for this execution of this module\n",
" mean_times = interleave(exec_df.groupby(\"ThunkIndex\").agg(\"mean\"))\n",
" # x axis of the plot will be the average over executions of the module\n",
" x_values.append(mean_times - mean_times[0])\n",
" for device, device_values in exec_df.groupby(\"Device\"):\n",
" # [thunk0_start, thunk0_end, ...] array for one device within one module exec\n",
" # with the average over devices subtracted\n",
" y_values[device].append(interleave(device_values) - mean_times)\n",
" mean_start_time_ms = np.mean(x_values, axis=0)\n",
" all_values = np.array(list(y_values.values()))\n",
" ax.plot(\n",
" mean_start_time_ms,\n",
" np.min(all_values, axis=(0, 1)),\n",
" \"k:\",\n",
" lw=1,\n",
" label=\"min/max\",\n",
"if len(steady_state.communication):\n",
" fig, grid = plt.subplots(\n",
" nrows=len(top_module_ids), figsize=[15, 5], squeeze=False, tight_layout=True\n",
" )\n",
" ax.plot(mean_start_time_ms, np.max(all_values, axis=(0, 1)), \"k:\", lw=1)\n",
" std = np.std(all_values, axis=(0, 1))\n",
" ax.fill_between(mean_start_time_ms, -std, +std, alpha=0.2, label=r\"$\\pm1\\sigma$\")\n",
" # max abs(bias) over ProgramExecution within a device, summed over ThunkIndex\n",
" outlier_devices = np.sum(np.max(np.abs(all_values), axis=1), axis=1)\n",
" for _, device in sorted(\n",
" zip(outlier_devices, range(all_values.shape[0])), reverse=True\n",
" )[:devices_to_show]:\n",
" time_df = steady_state.thunk.loc[\n",
" ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n",
" ]\n",
" time_df[\"ProjEndMs\"] = time_df[\"ProjStartMs\"] + time_df.pop(\"ProjDurMs\")\n",
"\n",
" def interleave(df):\n",
" s, e = df[\"ProjStartMs\"], df[\"ProjEndMs\"]\n",
" r = np.empty((s.size + e.size,), dtype=s.dtype)\n",
" r[0::2] = s\n",
" r[1::2] = e\n",
" return r\n",
"\n",
" devices_to_show = 8\n",
" for n_row, program_id in enumerate(top_module_ids):\n",
" x_values = []\n",
" y_values = defaultdict(list)\n",
" ax = grid[n_row][0]\n",
" for module_execution, exec_df in time_df.loc[program_id].groupby(\n",
" \"ProgramExecution\"\n",
" ):\n",
" # Mean over devices to get a single [thunk0_start, thunk0_end, thunk1_start, ...]\n",
" # array for this execution of this module\n",
" mean_times = interleave(exec_df.groupby(\"ThunkIndex\").agg(\"mean\"))\n",
" # x axis of the plot will be the average over executions of the module\n",
" x_values.append(mean_times - mean_times[0])\n",
" for device, device_values in exec_df.groupby(\"Device\"):\n",
" # [thunk0_start, thunk0_end, ...] array for one device within one module exec\n",
" # with the average over devices subtracted\n",
" y_values[device].append(interleave(device_values) - mean_times)\n",
" mean_start_time_ms = np.mean(x_values, axis=0)\n",
" all_values = np.array(list(y_values.values()))\n",
" ax.plot(\n",
" mean_start_time_ms,\n",
" np.mean(all_values[device], axis=0),\n",
" label=f\"Device {device}\",\n",
" )\n",
"\n",
" comm_x_values = defaultdict(list)\n",
" for module_execution, exec_df in comm_df.loc[program_id].groupby(\n",
" \"ProgramExecution\"\n",
" ):\n",
" exec_df[\"EndInModuleMs\"] = (\n",
" exec_df[\"ProjEndMs\"]\n",
" - steady_state.module.loc[(program_id, module_execution), \"ProjStartMs\"]\n",
" )\n",
" tmp = exec_df.groupby(\"ThunkIndex\").agg(\n",
" {\n",
" \"Name\": \"first\",\n",
" \"Collective\": \"first\",\n",
" \"CollectiveSize\": \"first\",\n",
" \"EndInModuleMs\": \"mean\",\n",
" }\n",
" np.min(all_values, axis=(0, 1)),\n",
" \"k:\",\n",
" lw=1,\n",
" label=\"min/max\",\n",
" )\n",
" for coll_size, values in tmp.groupby(\"CollectiveSize\"):\n",
" comm_x_values[coll_size].append(values[\"EndInModuleMs\"])\n",
" (_, xmax), (ymin, ymax) = ax.get_xlim(), ax.get_ylim()\n",
" ax.set_xlim(0, xmax)\n",
" ax.set_ylim(ymin, ymax)\n",
" largest_collective = max(comm_x_values.keys())\n",
" for n_color, (coll_size, values) in enumerate(comm_x_values.items()):\n",
" collective_times = np.mean(values, axis=0)\n",
" ax.vlines(\n",
" collective_times,\n",
" ymin,\n",
" # Draw taller vertical lines for collectives involving more devices\n",
" ymin * (1 - coll_size / largest_collective),\n",
" color=f\"C{n_color}\",\n",
" label=f\"{coll_size}-device collective\",\n",
" linestyle=\"--\",\n",
" ax.plot(mean_start_time_ms, np.max(all_values, axis=(0, 1)), \"k:\", lw=1)\n",
" std = np.std(all_values, axis=(0, 1))\n",
" ax.fill_between(\n",
" mean_start_time_ms, -std, +std, alpha=0.2, label=r\"$\\pm1\\sigma$\"\n",
" )\n",
" # max abs(bias) over ProgramExecution within a device, summed over ThunkIndex\n",
" outlier_devices = np.sum(np.max(np.abs(all_values), axis=1), axis=1)\n",
" for _, device in sorted(\n",
" zip(outlier_devices, range(all_values.shape[0])), reverse=True\n",
" )[:devices_to_show]:\n",
" ax.plot(\n",
" mean_start_time_ms,\n",
" np.mean(all_values[device], axis=0),\n",
" label=f\"Device {device}\",\n",
" )\n",
"\n",
" ax.set_title(\n",
" f\"{steady_state.module.loc[program_id, 'Name'].iloc[0]} ({program_id}), {min(outlier_devices.size, devices_to_show)} most extreme devices\"\n",
" )\n",
" ax.set_xlabel(\"Mean time within module [ms]\")\n",
" ax.set_ylabel(\"Mean(executions) bias from mean(executions&devices) [ms]\")\n",
" ax.legend(ncols=2)"
" comm_x_values = defaultdict(list)\n",
" for module_execution, exec_df in comm_df.loc[program_id].groupby(\n",
" \"ProgramExecution\"\n",
" ):\n",
" exec_df[\"EndInModuleMs\"] = (\n",
" exec_df[\"ProjEndMs\"]\n",
" - steady_state.module.loc[(program_id, module_execution), \"ProjStartMs\"]\n",
" )\n",
" tmp = exec_df.groupby(\"ThunkIndex\").agg(\n",
" {\n",
" \"Name\": \"first\",\n",
" \"Collective\": \"first\",\n",
" \"CollectiveSize\": \"first\",\n",
" \"EndInModuleMs\": \"mean\",\n",
" }\n",
" )\n",
" for coll_size, values in tmp.groupby(\"CollectiveSize\"):\n",
" comm_x_values[coll_size].append(values[\"EndInModuleMs\"])\n",
" (_, xmax), (ymin, ymax) = ax.get_xlim(), ax.get_ylim()\n",
" ax.set_xlim(0, xmax)\n",
" ax.set_ylim(ymin, ymax)\n",
" largest_collective = max(comm_x_values.keys())\n",
" for n_color, (coll_size, values) in enumerate(comm_x_values.items()):\n",
" collective_times = np.mean(values, axis=0)\n",
" ax.vlines(\n",
" collective_times,\n",
" ymin,\n",
" # Draw taller vertical lines for collectives involving more devices\n",
" ymin * (1 - coll_size / largest_collective),\n",
" color=f\"C{n_color}\",\n",
" label=f\"{coll_size}-device collective\",\n",
" linestyle=\"--\",\n",
" )\n",
"\n",
" ax.set_title(\n",
" f\"{steady_state.module.loc[program_id, 'Name'].iloc[0]} ({program_id}), {min(outlier_devices.size, devices_to_show)} most extreme devices\"\n",
" )\n",
" ax.set_xlabel(\"Mean time within module [ms]\")\n",
" ax.set_ylabel(\"Mean(executions) bias from mean(executions&devices) [ms]\")\n",
" ax.legend(ncols=2)"
]
},
{
Expand Down
Loading
Loading