diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 67a2dbd0e..a0107b703 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -194,6 +194,7 @@ class BenchmarkOperatorMetrics: class BenchmarkOperatorResult: # Print the result in a table format op_name: str + op_mode: str metrics: List[str] result: List[Tuple[Any, Dict[str, BenchmarkOperatorMetrics]]] _result_dict: Optional[Dict[Number, Dict[str, BenchmarkOperatorMetrics]]] = None @@ -297,7 +298,7 @@ def x_vals(self): @property def userbenchmark_dict(self) -> Dict[str, Any]: # Userbenchmark Metric key format: - # tritonbench_{op_name}[{x_val}-{provider}-{metric}] + # tritonbench_{op_name}_{op_mode}[{x_val}-{provider}-{metric}] userbenchmark_metrics_dict = {} headers, table = self._table() for row in table: @@ -305,7 +306,7 @@ def userbenchmark_dict(self) -> Dict[str, Any]: for ind, value in enumerate(row[1:]): header = headers[ind+1] provider, _dash, metrics = header.partition("-") - metric_name = f"tritonbench_{self.op_name}[x_{x_val}-{provider}]_{metrics}" + metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[x_{x_val}-{provider}]_{metrics}" userbenchmark_metrics_dict[metric_name] = value return userbenchmark_metrics_dict @@ -561,6 +562,7 @@ def _reduce_benchmarks(acc, bm_name: str): finally: self.output = BenchmarkOperatorResult( op_name=self.name, + op_mode=self.mode.value, metrics=self.required_metrics, result=metrics, ) diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py index f7e518551..fb0eae74c 100644 --- a/userbenchmark/triton/run.py +++ b/userbenchmark/triton/run.py @@ -8,6 +8,7 @@ from torchbenchmark.operators import load_opbench_by_name from torchbenchmark.util.triton_op import ( + BenchmarkOperatorResult, DEFAULT_RUN_ITERS, DEFAULT_WARMUP, ) @@ -135,7 +136,7 @@ def get_parser(): parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.") return parser -def _run(args: argparse.Namespace, extra_args: List[str]) -> None: +def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult: Opbench = load_opbench_by_name(args.op) if args.fwd_bwd: args.mode = "fwd_bwd" @@ -167,6 +168,7 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None: os.makedirs(TRITON_BENCH_CSV_DUMP_PATH, exist_ok=True) path = metrics.write_csv(TRITON_BENCH_CSV_DUMP_PATH) print(f"[TritonBench] Dumped csv to {path}") + return metrics def run(args: List[str] = []): if args == []: