diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..7a401629f3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7951,7 +7951,7 @@ def forward( assert ( key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups_per_partition} heads!" assert qkv_format in [ "sbhd", "bshd",