Skip to content

Commit

Permalink
Extend context length for sliding window tests (#1742)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Sep 25, 2024
1 parent be8b28d commit f9aed64
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,21 +407,21 @@ def test_against_mistral_hf_models(device, dtype, model_name):
padded_vocab_size=10000,
block_size=T,
sliding_window_size=T // 2,
sliding_window_layer_placing="all",
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
)

T = 5
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
Expand All @@ -439,7 +439,7 @@ def test_against_mistral_hf_models(device, dtype, model_name):
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
Expand Down

0 comments on commit f9aed64

Please sign in to comment.