Skip to content

Commit

Permalink
Give up and rewrite build_rope_cache with vectorization instead. Add …
Browse files Browse the repository at this point in the history
…clarification comment in test
  • Loading branch information
TensorTemplar committed Oct 5, 2024
1 parent 09df0a7 commit 56853fe
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 35 deletions.
49 changes: 15 additions & 34 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,8 @@ def build_rope_cache(
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
Enhanced Transformer with Rotary Position Embedding.
Args:
seq_len (int): Sequence length.
Expand All @@ -462,8 +459,9 @@ def build_rope_cache(
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""
if device is None:
print("warning: build_rope_cache called without device, meta device custom ops may fail")
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Compute the inverse frequencies theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
Expand All @@ -480,39 +478,22 @@ def build_rope_cache(
# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()
# Compute ratio across all elements
ratio = orig_context_len / wavelen

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
# avoid NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors
if device is not None:
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
else:
adjusted_theta = torch.where(
mask_low_freq,
theta / factor,
adjusted_theta
)
print(f"theta device: {theta.device}")
print(f"mask_low_freq device: {mask_low_freq.device}")

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
# Compute smooth_factor and clamp between 0 and 1
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta

theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
# Create position indices `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

# Calculate the product of position index and $\theta_i$
# Calculate the product of position index and θ_i
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)

return torch.cos(idx_theta), torch.sin(idx_theta)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvli
def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi(monkeypatch):
# Mock the GPU device properties to simulate AMD GPUs
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "Advanced Micro Devices [AMD/ATI] MI250X"
mock_device_properties.name = "amd instinct mi250x" # ROCM 6.0.3
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

Expand Down

0 comments on commit 56853fe

Please sign in to comment.