Skip to content

Commit

Permalink
[WIP] peft + tp support
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jun 3, 2024
1 parent ff6c228 commit 1a90b5d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
cpu_ids[name]: xla_params[name] for name, _ in model.local_named_parameters()
}
else:
for n, p in model.named_parameters():
print(f"{n} => {p.device}")
move_model_to_device(model, self.device)
tie_parameters(model, tied_parameters_dict)
xla_params = dict(model.named_parameters())
Expand Down
2 changes: 0 additions & 2 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,6 @@ def should_parallelize_layer_predicate_func(layer):
model._gqa_qkv_metadata = gqa_qkv_metadata

print("Parallelized PEFT model", model)
assert 3 == 2

return model

@classmethod
Expand Down
27 changes: 22 additions & 5 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def peft_tuner_linear_to_parallel_linear(
skip_weight_load: bool = False,
device: Optional["torch.device"] = None,
) -> "BaseTunerLayer":
from peft.tuners import LoraLayer
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer

# This is necessary for the case that the tuner layer wraps another tuner layer.
Expand Down Expand Up @@ -935,10 +935,27 @@ def peft_tuner_linear_to_parallel_linear(
# 2. The base linear layer is a ColumnParallelLinear, then:
# - The lora A matrix does not need to be parallelized,
# - The lora B matrix needs to be a ColumnParallelLinear as well.
if axis == "row":
pass
# parent.lora_a

print(parent)
for adapter_name in parent.active_adapters:
if axis == "row":
layer_to_parallelize = parent.lora_A[adapter_name]
else:
layer_to_parallelize = parent.lora_B[adapter_name]
# TODO: handle the case were weights already exist for this adapter.
parallel_layer = linear_to_parallel_linear(
layer_to_parallelize,
axis,
input_is_parallel=input_is_parallel,
gather_output=gather_output,
stride=stride,
sequence_parallel_enabled=sequence_parallel_enabled,
skip_weight_load=skip_weight_load,
device=device,
)
if axis == "row":
parent.lora_A[adapter_name] = parallel_layer
else:
parent.lora_B[adapter_name] = parallel_layer
else:
raise NotImplementedError(f"{parent.__class__.__name__} is not supported yet for model parallelism.")

Expand Down

0 comments on commit 1a90b5d

Please sign in to comment.