Skip to content

Commit

Permalink
if condition in tie weights added (#989)
Browse files Browse the repository at this point in the history
* if condition in tie weights added

* unit test for tie weights
  • Loading branch information
megha95 committed Feb 23, 2024
1 parent 2478f0a commit e5fffac
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
3 changes: 2 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,8 @@ def set_output_embeddings(
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
self.lm_head = None
if getattr(self.config, 'tie_word_embeddings', True):
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down
25 changes: 25 additions & 0 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,31 @@ def test_remote_code_false_mpt(
tokenizer)


@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_tie_weights(tie_word_embeddings: bool):
# Test that the tie_weights function sets lm_head correctly
hf_config = MPTConfig(init_device='cpu',
d_model=128,
n_heads=4,
n_layers=2,
expansion_ratio=2,
max_seq_len=2048,
attn_config={
'attn_impl': 'torch',
},
no_bias=True,
tie_word_embeddings=tie_word_embeddings)

mpt = MPTForCausalLM(hf_config)

assert mpt.config.tie_word_embeddings == tie_word_embeddings
mpt.tie_weights()
if tie_word_embeddings:
assert mpt.lm_head is None
else:
assert mpt.lm_head is not None


@pytest.mark.parametrize('model_cfg_overrides', [
{
'max_seq_len': 1024
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
'transformer.blocks.0': 0,
'transformer.blocks.1': 1 if world_size == 2 else 0,
'transformer.norm_f': 1 if world_size == 2 else 0,
'lm_head': 1 if world_size == 2 else 0,
}

pipe = pipeline(
Expand Down

0 comments on commit e5fffac

Please sign in to comment.