Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Phi-3-mini-4k-instruct checkpoint #1341

Merged
merged 36 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
93f3024
Add phi-3 checkpoint
rasbt Apr 23, 2024
1012eaf
progress
rasbt Apr 23, 2024
581e27f
weight loading works
rasbt Apr 23, 2024
40fe01f
Convert Phi3 qkv into an interleaved one
Andrei-Aksionov Apr 25, 2024
0322ecd
Config: Phi3 doesn't use parallel residual
Andrei-Aksionov Apr 25, 2024
7f33850
Fix layer shapes in Phi3MLP
Andrei-Aksionov Apr 25, 2024
1b217ba
Config: update vocab size
Andrei-Aksionov Apr 25, 2024
6fc4a7c
Add prompt
Andrei-Aksionov Apr 25, 2024
ba1c930
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Apr 25, 2024
2ee1e0d
Add test for Phi3 model
Andrei-Aksionov Apr 25, 2024
29760ab
Update litgpt/prompts.py
rasbt May 3, 2024
6c4cd25
Merge branch 'main' into phi-3-checkpoint
rasbt May 3, 2024
fbc45b4
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Jun 26, 2024
efb8388
Fix prompt
Andrei-Aksionov Jun 26, 2024
7f092fa
The prompt has been changed. Update it
Andrei-Aksionov Jun 26, 2024
a2acd37
A workaround for a Phi-3 tokenizer
Andrei-Aksionov Jun 27, 2024
ef21d37
Convert in copy_weihght_phi without Phi3MLP
Andrei-Aksionov Jun 27, 2024
aa184e7
Config: Phi3MLP -> LlaMAMLP
Andrei-Aksionov Jun 27, 2024
4f941bb
test_model.py: add test for Phi-3
Andrei-Aksionov Jun 27, 2024
3bd0692
model.py: drop Phi3MLP
Andrei-Aksionov Jun 27, 2024
9583cd7
convert_hf: copy_weight_llama without Phi3 related code
Andrei-Aksionov Jun 27, 2024
1c661be
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Jun 27, 2024
c25e533
test_model.py: update test for Phi-3
Andrei-Aksionov Jun 27, 2024
6e484eb
test_covert_hf: add test for qkv_reassemble
Andrei-Aksionov Jun 27, 2024
c8a1e03
Update test_tokenzer to match AutoTokenizers
Andrei-Aksionov Jun 27, 2024
39614e7
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Jun 28, 2024
81e56b6
convert_lit: add Phi-3 code
Andrei-Aksionov Jun 28, 2024
b483ca2
test_convert_lit: prettify test for qkv_split
Andrei-Aksionov Jun 28, 2024
82b4124
Update test_prompts.py
Andrei-Aksionov Jun 28, 2024
d252505
Add Phi-3-mini to the list of supported models
Andrei-Aksionov Jun 28, 2024
0eb288e
Update README.md
rasbt Jun 28, 2024
a2683d7
Merge branch 'main' into phi-3-checkpoint
rasbt Jul 1, 2024
a923c5e
Update tutorials/download_model_weights.md
rasbt Jul 1, 2024
bcbddec
Update tutorials/download_model_weights.md
rasbt Jul 1, 2024
9dc0330
Apply Sebastian's suggestion: model_name.lower()...
Andrei-Aksionov Jul 1, 2024
ceae946
Merge branch 'main' into phi-3-checkpoint
rasbt Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Config:
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE", "Phi3MLP"] = "GptNeoxMLP"
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
Expand Down Expand Up @@ -836,7 +836,7 @@ def norm_class(self) -> Type:
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


###############
# Meta LLaMA 3
Expand Down Expand Up @@ -1413,6 +1413,21 @@ def norm_class(self) -> Type:
lm_head_bias=True,
gelu_approximate="tanh",
),
# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
dict(
name="Phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="microsoft/Phi-3-mini-4k-instruct"),
vocab_size=32064,
padded_vocab_size=32064,
block_size=4096,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=16384,
mlp_class_name="Phi3MLP",
),
]
configs.extend(phi)

Expand Down
14 changes: 14 additions & 0 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)


class Phi3MLP(nn.Module):
rasbt marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, config: Config) -> None:
super().__init__()
self.gate_up_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.down_proj = nn.Linear(config.intermediate_size//2, config.n_embd, bias=config.bias)
self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.gate_up_proj(x)
gate, y = y.chunk(2, dim=-1)
y = y * torch.nn.functional.silu(gate)
return self.down_proj(y)


class LLaMAMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
Expand Down
47 changes: 29 additions & 18 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def copy_weights_hf_llama(
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.qkv_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight",
Expand All @@ -146,6 +147,13 @@ def copy_weights_hf_llama(
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight",
}
)
elif config.mlp_class_name in ("Phi3MLP",):
weight_map.update(
{
"model.layers.{}.mlp.gate_up_proj.weight": "transformer.h.{l}.mlp.gate_up_proj.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError

Expand All @@ -156,7 +164,9 @@ def copy_weights_hf_llama(
if "block_sparse_moe.experts" in name:
from_name, e = layer_template(from_name, 5)
qkv = qkv_weights.setdefault(l, [None, None, None])
if "q_proj" in name:
if "qkv_proj" in name:
state_dict[f"transformer.h.{l}.attn.attn.weight"] = load_param(param, f"layer {l} qkv", dtype)
elif "q_proj" in name:
qkv[0] = param
elif "k_proj" in name:
qkv[1] = param
Expand All @@ -177,21 +187,22 @@ def copy_weights_hf_llama(
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

# convert separate q, k, v matrices into an interleaved qkv
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
continue
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]
if "qkv_proj" not in name:
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
continue
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]


def copy_weights_phi(
Expand Down Expand Up @@ -312,7 +323,7 @@ def convert_hf_checkpoint(

if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE", "Phi3MLP"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
Expand Down Expand Up @@ -354,4 +365,4 @@ def convert_hf_checkpoint(
if __name__ == "__main__":
from jsonargparse import CLI

CLI(convert_hf_checkpoint)
CLI(convert_hf_checkpoint)
1 change: 1 addition & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
microsoft/phi-1_5
microsoft/phi-2
microsoft/Phi-3-mini-4k-instruct
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistralai/Mistral-7B-v0.1
Expand Down
Loading