Skip to content

Commit

Permalink
allocate GPU memory directly
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed Jun 9, 2024
1 parent 00772ed commit 1fa3fb1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def _unflatten_params_as_views(self) -> None:

if self.optimize_backward_concat and self.fp32_grads is None:
total_numels = sum([torch.numel(p) for p in param_views])
self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32).cuda()
self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device())


# Save param views for easy access if anyone still wants to access
Expand Down

0 comments on commit 1fa3fb1

Please sign in to comment.