Skip to content

Commit

Permalink
Load parallel linears directly on device
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Mar 20, 2024
1 parent 4a7df1a commit 1191f0b
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ def _prepare_model_for_mp(
cpu_ids = {name: id(param) for name, param in model.named_parameters()}
tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
# TODO: use self.device.
model = self.state.mp_plugin.parallelize_model(model, device=None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)

if model_main_input_name is not None:
setattr(model, "main_input_name", model_main_input_name)
Expand Down

0 comments on commit 1191f0b

Please sign in to comment.