Skip to content

Commit

Permalink
Add sliding window attention to Mistral and Phi 3 (#1741)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Sep 24, 2024
1 parent af11fb5 commit be8b28d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
6 changes: 6 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,8 @@ def norm_class(self) -> Type:
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
sliding_window_size=2048,
sliding_window_layer_placing="all",
),
# https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json
dict(
Expand Down Expand Up @@ -1654,6 +1656,8 @@ def norm_class(self) -> Type:
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
sliding_window_size=4096,
sliding_window_layer_placing="all",
)
)

Expand All @@ -1673,6 +1677,8 @@ def norm_class(self) -> Type:
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
sliding_window_size=4096,
sliding_window_layer_placing="all",
),
# https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
dict(
Expand Down
78 changes: 72 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,78 @@ def test_against_hf_phi_3(model_name, device, dtype):
),
],
)
@pytest.mark.parametrize("model_name", ["Mistral-7B-Instruct-v0.1", "Mathstral-7B-v0.1"])
@pytest.mark.parametrize("model_name", ["Mistral-7B-Instruct-v0.1", "Mistral-7B-v0.1"])
def test_against_mistral_hf_models(device, dtype, model_name):
torch.set_default_dtype(dtype)

T = 20
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
)

T = 5
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attn_implementation="eager",
sliding_window=ours_config.sliding_window_size,
)

assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_mathstral_hf_models(device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
"Mathstral-7B-v0.1",
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
Expand Down Expand Up @@ -750,13 +815,14 @@ def assert_sdpa_backend(original_fn, q, k, v, mask):
args.append(False)
params = SDPAParams(q, k, v, mask, 0.0, True, *args)
if expected is SDPBackend.FLASH_ATTENTION:
assert flash_sdp_enabled()
assert can_use_flash_attention(params, True)
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"
if config.sliding_window_size is None:
assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False"
elif expected is SDPBackend.EFFICIENT_ATTENTION:
assert mem_efficient_sdp_enabled()
assert can_use_efficient_attention(params, True)
assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False"
assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False"
elif expected is SDPBackend.MATH:
assert math_sdp_enabled()
assert math_sdp_enabled(), "math_sdp_enabled() is False"
else:
raise NotImplementedError
return original_fn(q, k, v, mask)
Expand Down

0 comments on commit be8b28d

Please sign in to comment.