diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 0398f076b..6d7e6baf5 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -44,7 +44,7 @@ patched_finfo, ) from ..utils.misc import args_and_kwargs_to_kwargs_only -from ..utils.require_utils import requires_neuronx_distributed +from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla from .optimizer import NeuronAcceleratedOptimizer from .scheduler import NeuronAcceleratedScheduler from .state import NeuronAcceleratorState @@ -460,6 +460,8 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) + @requires_torch_xla + @requires_neuronx_distributed def prepare_model( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False ): @@ -477,6 +479,8 @@ def prepare_model( return self._prepare_model_for_mp( model, device_placement=device_placement, evaluation_mode=evaluation_mode ) + move_model_to_device(model, xm.xla_device()) + device_placement = False return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) def backward_for_xla_fsdp(self, loss, **kwargs):