From 56853fe65421d11c904fb6266562cad4a17a4022 Mon Sep 17 00:00:00 2001 From: Tensor Templar Date: Sat, 5 Oct 2024 17:29:42 +0300 Subject: [PATCH] Give up and rewrite build_rope_cache with vectorization instead. Add clarification comment in test --- litgpt/model.py | 49 ++++++++++++++------------------------------- tests/test_utils.py | 2 +- 2 files changed, 16 insertions(+), 35 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 8d9c245580..7306e54b5d 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -444,11 +444,8 @@ def build_rope_cache( condense_ratio: int = 1, extra_config: Optional[dict] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + Enhanced Transformer with Rotary Position Embedding. Args: seq_len (int): Sequence length. @@ -462,8 +459,9 @@ def build_rope_cache( Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. """ if device is None: - print("warning: build_rope_cache called without device, meta device custom ops may fail") - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Compute the inverse frequencies theta theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) if extra_config is not None: @@ -480,39 +478,22 @@ def build_rope_cache( # Compute wavelengths corresponding to the inverse frequencies wavelen = 2 * torch.pi / theta - # Initialize adjusted inverse frequencies - adjusted_theta = theta.clone() + # Compute ratio across all elements + ratio = orig_context_len / wavelen - # Low Frequency Region: wavelen > low_freq_wavelen - mask_low_freq = wavelen > low_freq_wavelen - # avoid NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors - if device is not None: - adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor - else: - adjusted_theta = torch.where( - mask_low_freq, - theta / factor, - adjusted_theta - ) - print(f"theta device: {theta.device}") - print(f"mask_low_freq device: {mask_low_freq.device}") - - # Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen - mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen) - # Compute smooth factor for medium frequencies - ratio = orig_context_len / wavelen[mask_medium_freq] + # Compute smooth_factor and clamp between 0 and 1 smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) - # Interpolate inverse frequencies - adjusted_theta[mask_medium_freq] = ( - (1 - smooth_factor) * (theta[mask_medium_freq] / factor) - + smooth_factor * theta[mask_medium_freq] - ) + smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) + + # Compute adjusted_theta without masked indexing + adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta + theta = adjusted_theta - # Create position indexes `[0, 1, ..., seq_len - 1]` + # Create position indices `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=device) / condense_ratio - # Calculate the product of position index and $\theta_i$ + # Calculate the product of position index and θ_i idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) return torch.cos(idx_theta), torch.sin(idx_theta) diff --git a/tests/test_utils.py b/tests/test_utils.py index da98e3e81a..12929df3c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -701,7 +701,7 @@ def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvli def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi(monkeypatch): # Mock the GPU device properties to simulate AMD GPUs mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"]) - mock_device_properties.name = "Advanced Micro Devices [AMD/ATI] MI250X" + mock_device_properties.name = "amd instinct mi250x" # ROCM 6.0.3 monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties) monkeypatch.setattr(torch.cuda, "is_available", lambda: True)