diff --git a/pyproject.toml b/pyproject.toml index 46f50ec0a8..28de271755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ test = [ "transformers>=4.38.0", # numerical comparisons "einops>=0.7.0", "protobuf>=4.23.4", - "lightning-thunder==0.2.0.dev20240623; python_version >= '3.10'", + "lightning-thunder @ git+https://github.com/Lightning-AI/lightning-thunder/ ; python_version >= '3.10' and sys_platform == 'linux'", ] all = [ "bitsandbytes==0.42.0", # quantization diff --git a/tests/test_model.py b/tests/test_model.py index 918c34ac97..5e9dc943f4 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -742,7 +742,11 @@ def test_sdpa_choice(config): torch.set_default_dtype(torch.float16) def assert_sdpa_backend(original_fn, q, k, v, mask): - params = SDPAParams(q, k, v, mask, 0.0, True) + # SDPAParams gained an additional argument in PyTorch 2.5 + args = [] + if hasattr(SDPAParams, "enable_gqa"): + args.append(False) + params = SDPAParams(q, k, v, mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: assert flash_sdp_enabled() assert can_use_flash_attention(params, True) @@ -786,7 +790,11 @@ def test_sdpa_choice_kv_cache(config): torch.set_default_dtype(torch.float16) def assert_sdpa_backend(original_fn, q, k, v, mask): - params = SDPAParams(q, k, v, mask, 0.0, True) + # SDPAParams gained an additional argument in PyTorch 2.5 + args = [] + if hasattr(SDPAParams, "enable_gqa"): + args.append(False) + params = SDPAParams(q, k, v, mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: assert flash_sdp_enabled() assert can_use_flash_attention(params, True)