From be8b28dad46caee5dd62fedcd6051485e27c404b Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 24 Sep 2024 13:47:23 -0700 Subject: [PATCH] Add sliding window attention to Mistral and Phi 3 (#1741) --- litgpt/config.py | 6 ++++ tests/test_model.py | 78 +++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 974cd21432..5052ab117b 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -1613,6 +1613,8 @@ def norm_class(self) -> Type: intermediate_size=8192, mlp_class_name="LLaMAMLP", parallel_residual=False, + sliding_window_size=2048, + sliding_window_layer_placing="all", ), # https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json dict( @@ -1654,6 +1656,8 @@ def norm_class(self) -> Type: norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, + sliding_window_size=4096, + sliding_window_layer_placing="all", ) ) @@ -1673,6 +1677,8 @@ def norm_class(self) -> Type: norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, + sliding_window_size=4096, + sliding_window_layer_placing="all", ), # https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json dict( diff --git a/tests/test_model.py b/tests/test_model.py index d4a866ff36..50e09de909 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -397,13 +397,78 @@ def test_against_hf_phi_3(model_name, device, dtype): ), ], ) -@pytest.mark.parametrize("model_name", ["Mistral-7B-Instruct-v0.1", "Mathstral-7B-v0.1"]) +@pytest.mark.parametrize("model_name", ["Mistral-7B-Instruct-v0.1", "Mistral-7B-v0.1"]) def test_against_mistral_hf_models(device, dtype, model_name): torch.set_default_dtype(dtype) + T = 20 ours_config = Config.from_name( model_name, padded_vocab_size=10000, + block_size=T, + sliding_window_size=T // 2, + n_layer=2, + n_embd=32, + n_head=8, + n_query_groups=2, + intermediate_size=86, + ) + + T = 5 + theirs_config = MistralConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attn_implementation="eager", + sliding_window=ours_config.sliding_window_size, + ) + + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = MistralForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + +@torch.inference_mode() +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_mathstral_hf_models(device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + "Mathstral-7B-v0.1", + padded_vocab_size=10000, n_layer=2, n_embd=32, n_head=8, @@ -750,13 +815,14 @@ def assert_sdpa_backend(original_fn, q, k, v, mask): args.append(False) params = SDPAParams(q, k, v, mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: - assert flash_sdp_enabled() - assert can_use_flash_attention(params, True) + assert flash_sdp_enabled(), "flash_sdp_enabled() is False" + if config.sliding_window_size is None: + assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False" elif expected is SDPBackend.EFFICIENT_ATTENTION: - assert mem_efficient_sdp_enabled() - assert can_use_efficient_attention(params, True) + assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False" + assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False" elif expected is SDPBackend.MATH: - assert math_sdp_enabled() + assert math_sdp_enabled(), "math_sdp_enabled() is False" else: raise NotImplementedError return original_fn(q, k, v, mask)