Skip to content

Commit

Permalink
Add apple silicon GPU Acceleration Support
Browse files Browse the repository at this point in the history
  • Loading branch information
NripeshN committed Jul 14, 2023
1 parent 40693ab commit 15970f3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
12 changes: 12 additions & 0 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def time_func(func, x):
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try:
import torch.mps
torch.mps.synchronize()
except ImportError:
pass
t = time.perf_counter()

if not args.with_backward:
Expand All @@ -77,6 +83,12 @@ def time_func(func, x):

if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try:
import torch.mps
torch.mps.synchronize()
except ImportError:
pass
return time.perf_counter() - t
except RuntimeError as e:
if 'out of memory' not in str(e):
Expand Down
2 changes: 2 additions & 0 deletions torch_sparse/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device('cuda:0')]
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
devices += [torch.device('mps')]


def tensor(x: Any, dtype: torch.dtype, device: torch.device):
Expand Down

0 comments on commit 15970f3

Please sign in to comment.