diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py index a2c5bf7f98..bd07f61636 100644 --- a/tests/test_hf_mpt_gen.py +++ b/tests/test_hf_mpt_gen.py @@ -134,7 +134,7 @@ def test_mpt_generate_callback(callback_generate: Any, tmpdir: Path): # build mpt model model_config = DictConfig({ - 'name': 'hf_causal_lm', + 'name': 'mpt_causal_lm', 'pretrained_model_name_or_path': 'mosaicml/mpt-7b', 'pretrained': False, 'config_overrides': { @@ -143,8 +143,7 @@ def test_mpt_generate_callback(callback_generate: Any, tmpdir: Path): 'n_layers': 2, 'expansion_ratio': 2, 'attn_config': { - 'attn_impl': 'torch', - 'attn_uses_sequence_id': True, + 'attn_impl': 'triton', }, }, })