diff --git a/benchmark/main.py b/benchmark/main.py index 4ed18328..820adf3f 100644 --- a/benchmark/main.py +++ b/benchmark/main.py @@ -62,6 +62,12 @@ def time_func(func, x): try: if torch.cuda.is_available(): torch.cuda.synchronize() + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + try: + import torch.mps + torch.mps.synchronize() + except ImportError: + pass t = time.perf_counter() if not args.with_backward: @@ -77,6 +83,12 @@ def time_func(func, x): if torch.cuda.is_available(): torch.cuda.synchronize() + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + try: + import torch.mps + torch.mps.synchronize() + except ImportError: + pass return time.perf_counter() - t except RuntimeError as e: if 'out of memory' not in str(e): diff --git a/torch_sparse/testing.py b/torch_sparse/testing.py index 9383ee07..ac863ca3 100644 --- a/torch_sparse/testing.py +++ b/torch_sparse/testing.py @@ -16,6 +16,8 @@ devices = [torch.device('cpu')] if torch.cuda.is_available(): devices += [torch.device('cuda:0')] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices += [torch.device('mps')] def tensor(x: Any, dtype: torch.dtype, device: torch.device):