From 1abfa02441e4fa4e97a4e5c566e1fc6120f5a927 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 29 Aug 2024 12:40:16 +0200 Subject: [PATCH] add support for batched input_pos to model --- litgpt/model.py | 60 +++++++++++++++++++++++++++++++++++++++++---- tests/test_batch.py | 36 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index c7aa2b649e..d42886f27e 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -76,11 +76,15 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") if input_pos is not None: # use the kv cache - cos = self.cos.index_select(0, input_pos) - sin = self.sin.index_select(0, input_pos) + cos = batched_index_select(self.cos, 0, input_pos) + sin = batched_index_select(self.sin, 0, input_pos) if self.mask_cache is None: raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = self.mask_cache.index_select(2, input_pos) + mask = batched_index_select(self.mask_cache, 2, input_pos) + if mask.dim() > 4: + # the mask cache has a batch dim of 1 in addition to the one + # we get if input_pos has a batch dimension + mask = mask.squeeze(1) else: cos = self.cos[:T] sin = self.sin[:T] @@ -425,11 +429,57 @@ def build_rope_cache( return torch.cos(idx_theta), torch.sin(idx_theta) +def batched_index_select(t, dim, idx): + """index_select for batched index and unbatched t""" + if idx.dim() == 1: + return torch.index_select(t, dim, idx) + + *batch_shape, idx_size = idx.shape + res = torch.index_select(t, dim, idx.reshape(-1)) # flat index + # split out single batch idx + res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.dim())) + del dims[dim + 1] + res = res.permute(dims) + # unflatten batch dims + res = res.view(*batch_shape, *res.shape[1:]) + return res + + +def batched_index_copy_(t, dim, idx, val): + """index copy for batched t, idx, val""" + if idx.dim() == 1: + return t.index_copy_(dim, idx, val) + + assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}" + assert dim != 0, f"cannot index batch dim" + batch_size, idx_size = idx.shape + assert batch_size == t.size(0) + assert batch_size == val.size(0) + t_indexed_dim = t.size(dim) + + # if we can view the batch and indexed dimensions together, we could + # do index trickery. This is, sadly, not the case for kvcache so we + # fall back to for loop + for i in range(batch_size): + unbatched_dim = dim if dim < 0 else dim - 1 + t[i].index_copy_(unbatched_dim, idx[i], val[i]) + return t + + def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: head_size = x.size(-1) x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + if cos.dim() > 1: + # batch dimensions must align + # sin/cos are (B, T, hs) so we unsqeeze -3 for nh + # we count from back because all of apply_rope does + cos = cos.unsqueeze(-3) + sin = sin.unsqueeze(-3) + roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) @@ -452,8 +502,8 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> self.v = self.v.to(v.dtype) # update the cache n = k.size(0) - k = self.k[:n, ...].index_copy_(2, input_pos, k) - v = self.v[:n, ...].index_copy_(2, input_pos, v) + k = batched_index_copy_(self.k[:n, ...], -2, input_pos, k) + v = batched_index_copy_(self.v[:n, ...], -2, input_pos, v) return k, v def reset_parameters(self) -> None: diff --git a/tests/test_batch.py b/tests/test_batch.py index e59a7a0eb5..b845c741b1 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -63,3 +63,39 @@ def test_batched_equivalence(tmp_path): # Assert that single and batched next token generation are equivalent assert all(t == tok_1 for t in toks_1), f"{tok_1} != {toks_1}" assert all(t == tok_2 for t in toks_2), f"{tok_2} != {toks_2}" + + +@RunIf(min_cuda_gpus=1) +def test_simple_batch(): + config = Config.from_name( + "Llama-3.1-8B", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256 + ) + with torch.device("cuda"): + m = GPT(config).requires_grad_(False).eval() + x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]]) + input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]]) + x1 = torch.tensor([[1], [2]]) + input_pos1 = torch.tensor([[4], [3]]) + + with torch.device("cuda"): + m.set_kv_cache(2) + outs0 = m(x0, input_pos0) + outs1 = m(x1, input_pos1) + + with torch.device("cuda"): + m.set_kv_cache(1) + + outs0_ref0 = m(x0[:1], input_pos0[0]) + outs1_ref0 = m(x1[:1], input_pos1[0]) + + with torch.device("cuda"): + m.set_kv_cache(1) + + outs0_ref1 = m(x0[1:], input_pos0[1]) + outs1_ref1 = m(x1[1:], input_pos1[1]) + + outs_ref0 = torch.cat([outs0_ref0, outs0_ref1]) + outs_ref1 = torch.cat([outs1_ref0, outs1_ref1]) + + torch.testing.assert_close(outs0, outs_ref0) + torch.testing.assert_close(outs1, outs_ref1)