Skip to content

Commit

Permalink
Deploy the H100 flash_attention operator
Browse files Browse the repository at this point in the history
Summary: Fix the servicelab test benchmark_tritonbench_flash_attention_fwd

Reviewed By: jialiangqu

Differential Revision: D58902477

fbshipit-source-id: 14ed713a3279b038ccaab6b98c2f42cd0cfe6e3f
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jul 25, 2024
1 parent 06cf234 commit 5cc1d01
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 4 additions & 2 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class BenchmarkOperatorMetrics:
class BenchmarkOperatorResult:
# Print the result in a table format
op_name: str
op_mode: str
metrics: List[str]
result: List[Tuple[Any, Dict[str, BenchmarkOperatorMetrics]]]
_result_dict: Optional[Dict[Number, Dict[str, BenchmarkOperatorMetrics]]] = None
Expand Down Expand Up @@ -297,15 +298,15 @@ def x_vals(self):
@property
def userbenchmark_dict(self) -> Dict[str, Any]:
# Userbenchmark Metric key format:
# tritonbench_{op_name}[{x_val}-{provider}-{metric}]
# tritonbench_{op_name}_{op_mode}[{x_val}-{provider}-{metric}]
userbenchmark_metrics_dict = {}
headers, table = self._table()
for row in table:
x_val = row[0]
for ind, value in enumerate(row[1:]):
header = headers[ind+1]
provider, _dash, metrics = header.partition("-")
metric_name = f"tritonbench_{self.op_name}[x_{x_val}-{provider}]_{metrics}"
metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[x_{x_val}-{provider}]_{metrics}"
userbenchmark_metrics_dict[metric_name] = value
return userbenchmark_metrics_dict

Expand Down Expand Up @@ -561,6 +562,7 @@ def _reduce_benchmarks(acc, bm_name: str):
finally:
self.output = BenchmarkOperatorResult(
op_name=self.name,
op_mode=self.mode.value,
metrics=self.required_metrics,
result=metrics,
)
Expand Down
4 changes: 3 additions & 1 deletion userbenchmark/triton/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchbenchmark.operators import load_opbench_by_name

from torchbenchmark.util.triton_op import (
BenchmarkOperatorResult,
DEFAULT_RUN_ITERS,
DEFAULT_WARMUP,
)
Expand Down Expand Up @@ -135,7 +136,7 @@ def get_parser():
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
return parser

def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
Opbench = load_opbench_by_name(args.op)
if args.fwd_bwd:
args.mode = "fwd_bwd"
Expand Down Expand Up @@ -167,6 +168,7 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
os.makedirs(TRITON_BENCH_CSV_DUMP_PATH, exist_ok=True)
path = metrics.write_csv(TRITON_BENCH_CSV_DUMP_PATH)
print(f"[TritonBench] Dumped csv to {path}")
return metrics

def run(args: List[str] = []):
if args == []:
Expand Down

0 comments on commit 5cc1d01

Please sign in to comment.