Skip to content

Commit

Permalink
Move model to device by default
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jan 16, 2024
1 parent 7bdad6a commit 410a77b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
patched_finfo,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only
from ..utils.require_utils import requires_neuronx_distributed
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from .optimizer import NeuronAcceleratedOptimizer
from .scheduler import NeuronAcceleratedScheduler
from .state import NeuronAcceleratorState
Expand Down Expand Up @@ -460,6 +460,8 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):

return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)

@requires_torch_xla
@requires_neuronx_distributed
def prepare_model(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
Expand All @@ -477,6 +479,8 @@ def prepare_model(
return self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
move_model_to_device(model, xm.xla_device())
device_placement = False
return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)

def backward_for_xla_fsdp(self, loss, **kwargs):
Expand Down

0 comments on commit 410a77b

Please sign in to comment.