Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 15, 2024
1 parent a77894e commit dc6769b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/source/training_tutorials/sft_lora_finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def training_function(script_args, training_args):
trainer = NeuronSFTTrainer(
args=sft_config,
model=model,
# peft_config=config,
peft_config=config,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=format_dolly,
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ def create_query_or_output_projection_local_weight_from_regular_weight(
indices = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
print(num_attention_heads, head_dim, hidden_size, weight_data.shape, tp_rank, query_or_output_proj)
reshaped_weight = weight_data.view(num_attention_heads, head_dim, hidden_size)
shuffled_weight = reshaped_weight[indices]
shuffled_weight = shuffled_weight.reshape(-1, hidden_size)
Expand Down
45 changes: 34 additions & 11 deletions tests/peft/test_peft_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@
from neuronx_distributed.utils.model_utils import move_model_to_device


def get_peft_config(lora_on_embeddings: bool = False, lora_on_lm_head: bool = False, lora_droupout: float = 0.1):
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
def get_peft_config(lora_on_embeddings: bool = False, lora_on_lm_head: bool = False, lora_droupout: float = 0.1, include_qkv: bool = True, include_out_proj: bool = True):

target_modules = ["gate_proj", "up_proj", "down_proj"]
if include_qkv:
target_modules += ["q_proj", "k_proj", "v_proj"]
if include_out_proj:
target_modules += ["o_proj"]

if lora_on_embeddings:
target_modules.append("embed_tokens")
if lora_on_lm_head:
target_modules.append("lm_head")
return LoraConfig(
r=4,
r=16,
lora_alpha=16,
target_modules=target_modules,
lora_dropout=lora_droupout,
Expand Down Expand Up @@ -102,12 +108,20 @@ def test_peft_model_is_converted_to_neuron_peft_model(self):
model = accelerator.prepare(model)
assert isinstance(model, NeuronPeftModel)

def test_save_pretrained(self, parallel_sizes, tmpdir):
_, tp_size, pp_size = parallel_sizes
@pytest.mark.parametrize(
"world_size,tp_size,pp_size,include_out_proj",
[
[2, 1, 1, True],
[2, 2, 1, True],
[8, 8, 1, True],
[8, 8, 1, False],
])
def test_save_pretrained(self, world_size, tp_size, pp_size, include_out_proj, tmpdir):
# _, tp_size, pp_size = parallel_sizes

output_dir = Path(tmpdir)

peft_config = get_peft_config(lora_on_embeddings=True, lora_on_lm_head=True)
peft_config = get_peft_config(lora_on_embeddings=True, lora_on_lm_head=True, include_out_proj=include_out_proj)

# PEFT model saved using `PeftModel`.
seed_patcher = StaticSeedPatcher(42)
Expand All @@ -124,18 +138,26 @@ def test_save_pretrained(self, parallel_sizes, tmpdir):
model_path = output_dir / "peft"
peft_model = get_peft_model(model, peft_config)

import torch_xla.core.xla_model as xm
xm.master_print(peft_model)

with seed_patcher:
accelerator = create_accelerator(tp_size, pp_size)
peft_model = accelerator.prepare_model(peft_model)
xm.master_print(peft_model)
peft_model.save_pretrained(model_path.as_posix(), async_save=False)

with open(orig_model_path / "adapter_config.json") as fp:
orig_adapter_config_content = json.dumps(json.load(fp), sort_keys=True)
# with open(orig_model_path / "adapter_config.json") as fp:
# y = fp.read()
# print("y", y)
# orig_adapter_config_content = json.dumps(json.load(fp), sort_keys=True)

with open(model_path / "adapter_config.json") as fp:
adapter_config_content = json.dumps(json.load(fp), sort_keys=True)
# with open(model_path / "adapter_config.json") as fp:
# x = fp.read()
# print("x", x)
# adapter_config_content = json.dumps(json.load(fp), sort_keys=True)

assert orig_adapter_config_content == adapter_config_content, "adapter_config.json files do not match"
# assert orig_adapter_config_content == adapter_config_content, "adapter_config.json files do not match"

orig_state_dict = load_file(orig_model_path / "adapter_model.safetensors")

Expand All @@ -152,6 +174,7 @@ def test_save_pretrained(self, parallel_sizes, tmpdir):
else:
state_dict = load_file(model_path / "adapter_model.safetensors")

print(set(orig_state_dict.keys()) - set(state_dict.keys()))
assert orig_state_dict.keys() == state_dict.keys()
if xm.is_master_ordinal():
for name, tensor in orig_state_dict.items():
Expand Down

0 comments on commit dc6769b

Please sign in to comment.