-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aa8d896
commit 675007c
Showing
2 changed files
with
75 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
import cProfile | ||
from pstats import Stats | ||
|
||
|
||
def profile_torch(func, args, row_limit=10, | ||
save_output=False, func_name=None, | ||
file_name="trace"): | ||
"""Use PyTorch's profiler to profile torch ops. | ||
To see the graph: upload trace.json to chrome://tracing | ||
More details about PyTorch profiler: | ||
https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html | ||
""" | ||
func_name = func_name if func_name else func.__name__ | ||
with torch.profiler.profile() as prof: | ||
with torch.profiler.record_function(func_name): | ||
func(*args) | ||
print(prof.key_averages().table(sort_by="cpu_time_total", | ||
row_limit=row_limit)) | ||
if save_output: | ||
prof.export_chrome_trace(f"{file_name}.json") | ||
|
||
|
||
def profile_python(func, args, row_limit=10, save_output=False, | ||
file_name="stats"): | ||
"""Use cProfile to profile python function calls. | ||
To see the graph, run the following commands: | ||
1. python -m pip install snakeviz | ||
2. snakeviz stats.prof | ||
""" | ||
pr = cProfile.Profile() | ||
pr.enable() | ||
func(*args) | ||
pr.disable() | ||
stats = Stats(pr) | ||
stats.sort_stats('tottime').print_stats(row_limit) | ||
if save_output: | ||
pr.dump_stats(f"{file_name}.prof") | ||
|
||
|
||
if __name__ == "__main__": | ||
# Example usage of the profiler. | ||
from mpact.models.kernels import MMNet | ||
from mpact_benchmark.utils.tensor_generator import generate_tensor | ||
from mpact.mpactbackend import mpact_jit | ||
|
||
|
||
# Generate input tensors. | ||
dense_tensor1 = generate_tensor(seed=0, shape=(32, 32), | ||
sparsity=0.8) | ||
dense_tensor2 = generate_tensor(seed=1, shape=(32, 32), | ||
sparsity=0.8) | ||
sparse_tensor1 = dense_tensor1.to_sparse_csr() | ||
sparse_tensor2 = dense_tensor2.to_sparse_csr() | ||
|
||
# Profile with PyTorch profiler for torch operators. | ||
# MPACT sparse. | ||
profile_torch(mpact_jit, (MMNet(), sparse_tensor1, sparse_tensor2)) | ||
# Torch sparse. | ||
profile_torch(MMNet(), (sparse_tensor1, sparse_tensor2), | ||
func_name="sparsexsparse matmul") | ||
|
||
# Profile with cProfile for Python function calls. | ||
# MPACT sparse. | ||
profile_python(mpact_jit, (MMNet(), sparse_tensor1, sparse_tensor2)) | ||
# Torch sparse. | ||
profile_python(MMNet(), (sparse_tensor1, sparse_tensor2)) |