Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add ncu report analyzer #2497

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions torchbenchmark/_components/ncu/analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import shutil
import sys
from collections import defaultdict
from typing import List

"""
A dictionary mapping short metric names to their corresponding NVIDIA Nsight Compute
(NCU) metric names. Don't directly use the NCU metric names in the code, use these short
names instead.
"""
short_ncu_metric_name = {
"inst_executed_ffma_peak": "sm__sass_thread_inst_executed_op_ffma_pred_on.sum.peak_sustained",
"inst_executed_dfma_peak": "sm__sass_thread_inst_executed_op_dfma_pred_on.sum.peak_sustained",
"inst_executed_fadd": "smsp__sass_thread_inst_executed_op_fadd_pred_on.sum.per_cycle_elapsed",
"inst_executed_fmul": "smsp__sass_thread_inst_executed_op_fmul_pred_on.sum.per_cycle_elapsed",
"inst_executed_ffma": "smsp__sass_thread_inst_executed_op_ffma_pred_on.sum.per_cycle_elapsed",
"inst_executed_dadd": "smsp__sass_thread_inst_executed_op_dadd_pred_on.sum.per_cycle_elapsed",
"inst_executed_dmul": "smsp__sass_thread_inst_executed_op_dmul_pred_on.sum.per_cycle_elapsed",
"inst_executed_dfma": "smsp__sass_thread_inst_executed_op_dfma_pred_on.sum.per_cycle_elapsed",
"dram_bytes_write": "dram__bytes_write.sum",
"dram_bytes_read": "dram__bytes_read.sum",
"dram_bytes_per_second": "dram__bytes.sum.per_second",
"sm_freq": "smsp__cycles_elapsed.avg.per_second",
"dram_bandwidth": "dram__bytes.sum.per_second",
"duration": "gpu__time_duration.sum",
}
# A dictionary mapping benchmark metric names to their corresponding short NCU metric names.
bench_metric_to_short_ncu_metric = {
"memory_traffic": ["dram_bytes_write", "dram_bytes_read"],
"arithmetic_intensity": [
"inst_executed_ffma_peak",
"inst_executed_dfma_peak",
"inst_executed_fadd",
"inst_executed_fmul",
"inst_executed_ffma",
"inst_executed_dadd",
"inst_executed_dmul",
"inst_executed_dfma",
"dram_bytes_write",
"dram_bytes_read",
"sm_freq",
"dram_bandwidth",
"duration",
],
}


def import_ncu_python_path():
"""
This function modifies the Python path to include the NVIDIA Nsight Compute (NCU) Python modules.
It searches for the 'ncu' command in the system PATH, determines its location, and appends the
'extras/python' directory to the Python path.

Raises:
FileNotFoundError: If the 'ncu' command is not found in the system PATH.
FileNotFoundError: If the 'extras/python' directory does not exist in the determined NCU path.
"""
ncu_path = shutil.which("ncu")
if not ncu_path:
raise FileNotFoundError("Could not find 'ncu' command in PATH.")
ncu_path = os.path.dirname(ncu_path)
if not os.path.exists(os.path.join(ncu_path, "extras/python")):
raise FileNotFoundError(
f"'extras/python' does not exist in the provided ncu_path: {ncu_path}"
)
sys.path.append(os.path.join(ncu_path, "extras/python"))


def get_mem_traffic(kernel):
return (
kernel.metric_by_name(short_ncu_metric_name["dram_bytes_read"]).value(),
kernel.metric_by_name(short_ncu_metric_name["dram_bytes_write"]).value(),
)


# Reference: ncu_install_path/sections/SpeedOfLight_Roofline.py
# and ncu_install_path/sections/SpeedOfLight_RooflineChart.section
def get_arithmetic_intensity(kernel):
fp32_add_achieved = kernel.metric_by_name(
short_ncu_metric_name["inst_executed_fadd"]
).value()
fp32_mul_achieved = kernel.metric_by_name(
short_ncu_metric_name["inst_executed_fmul"]
).value()
fp32_fma_achieved = kernel.metric_by_name(
short_ncu_metric_name["inst_executed_ffma"]
).value()
fp32_achieved = fp32_add_achieved + fp32_mul_achieved + 2 * fp32_fma_achieved
fp64_add_achieved = kernel.metric_by_name(
short_ncu_metric_name["inst_executed_dadd"]
).value()
fp64_mul_achieved = kernel.metric_by_name(
short_ncu_metric_name["inst_executed_dmul"]
).value()
fp64_fma_achieved = kernel.metric_by_name(
short_ncu_metric_name["inst_executed_dfma"]
).value()
fp64_achieved = fp64_add_achieved + fp64_mul_achieved + 2 * fp64_fma_achieved
sm_freq = kernel.metric_by_name(short_ncu_metric_name["sm_freq"]).value()
fp32_flops = fp32_achieved * sm_freq
fp64_flops = fp64_achieved * sm_freq
dram_bandwidth = kernel.metric_by_name(
short_ncu_metric_name["dram_bandwidth"]
).value()
fp32_arithmetic_intensity = fp32_flops / dram_bandwidth
fp64_arithmetic_intensity = fp64_flops / dram_bandwidth
return fp32_arithmetic_intensity, fp64_arithmetic_intensity


def read_ncu_report(report_path: str, required_metrics: List[str]):
assert os.path.exists(
report_path
), f"The NCU report at {report_path} does not exist. Ensure you add --metrics ncu_rep to your benchmark run."
import_ncu_python_path()
import ncu_report

# save all kernels' metrics. {metric_name: [kernel1_metric_value, kernel2_metric_value, ...]}
results = defaultdict(list)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuzhao9
Any suggestions on how we should save this data? We need to keep the metric results for each kernel, but we also need aggregated results, right? For example, the memory traffic (both read and write) for the whole operator should be the sum of all kernels' read and write traffic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuzhao9 @eellison
Do you think the arithmetic intensity of the whole operator can be represented as a weighted average based on execution time?

test_report = ncu_report.load_report(report_path)
assert (
test_report.num_ranges() > 0
), f"No profile data found in the NCU report at {report_path}"
default_range = test_report.range_by_idx(0)
assert (
default_range.num_actions() > 0
), f"No profile data found in the default range of the NCU report at {report_path}"
for i in range(default_range.num_actions()):
kernel = default_range.action_by_idx(i)
if "memory_traffic" in required_metrics:
mem_traffic = get_mem_traffic(kernel)
results["memory_traffic"].append(mem_traffic)
if "arithmetic_intensity" in required_metrics:
results["arithmetic_intensity"].append(get_arithmetic_intensity(kernel))
return results
39 changes: 35 additions & 4 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import triton

from torchbenchmark._components.ncu import analyzer as ncu_analyzer
from torchbenchmark.util.env_check import fresh_triton_cache, set_random_seed
from torchbenchmark.util.experiment.metrics import get_peak_memory
from torchbenchmark.util.extra_args import apply_decoration_args, parse_decoration_args
Expand Down Expand Up @@ -865,8 +866,30 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.compile_time = self.compile_time(input_id, fn_name, metrics)
if "ncu_trace" in self.required_metrics:
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
# Collect NCU metrics if any required metrics match the ncu analyzer
# metrics. Only profile with the necessary metrics to avoid excessive
# overhead.
ncu_metrics = [
ncu_analyzer.short_ncu_metric_name[short_ncu_metric]
for bench_metric, short_ncu_metrics in ncu_analyzer.bench_metric_to_short_ncu_metric.items()
if bench_metric in self.required_metrics
for short_ncu_metric in short_ncu_metrics
]
if "ncu_rep" in self.required_metrics:
metrics.ncu_rep = self.ncu_trace(input_id, fn_name, replay=True)
if ncu_metrics:
extend_ncu_args = ["--metrics", ",".join(ncu_metrics)]
else:
extend_ncu_args = None
metrics.ncu_rep = self.ncu_trace(
input_id, fn_name, replay=True, extend_ncu_args=extend_ncu_args
)
# Read and update NCU metrics if any required metrics match the NCU metrics
if ncu_metrics:
ncu_analyzer_results = ncu_analyzer.read_ncu_report(
metrics.ncu_rep, self.required_metrics
)
for metric_name, metric_value in ncu_analyzer_results.items():
metrics.extra_metrics[metric_name] = metric_value
if "ncu_rep_ir" in self.required_metrics:
metrics.ncu_rep_ir = self.ncu_trace(
input_id, fn_name, replay=True, profile_ir=True
Expand Down Expand Up @@ -1007,14 +1030,23 @@ def nsys_rep(self, input_id: int, fn_name: str) -> str:
return str(nsys_output_file.resolve())

def ncu_trace(
self, input_id: int, fn_name: str, replay: bool = False, profile_ir=False
self,
input_id: int,
fn_name: str,
replay: bool = False,
profile_ir=False,
extend_ncu_args: List[str] = None,
) -> str:
import shutil
import subprocess

# collect the ncu trace
import sys

extend_ncu_args = extend_ncu_args or [
"--set",
"full",
]
op_task_args = [] if IS_FBCODE else [sys.executable]
op_task_args.extend(copy.deepcopy(sys.argv))
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
Expand Down Expand Up @@ -1076,8 +1108,6 @@ def service_exists(service_name):
).resolve()
ncu_args = [
"ncu",
"--set",
"full",
"--nvtx",
"--nvtx-include",
f"{_RANGE_NAME}/",
Expand All @@ -1086,6 +1116,7 @@ def service_exists(service_name):
"--import-source",
"yes",
]
ncu_args.extend(extend_ncu_args)
if replay:
ncu_args.extend(
[
Expand Down
Loading