Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KV Cache 的实现为什么xq要拼接 zerors 矩阵 #18

Open
Zbaoli opened this issue Sep 13, 2024 · 1 comment
Open

KV Cache 的实现为什么xq要拼接 zerors 矩阵 #18

Zbaoli opened this issue Sep 13, 2024 · 1 comment
Labels
question Further information is requested

Comments

@Zbaoli
Copy link

Zbaoli commented Sep 13, 2024

在 Attention 方法里涉及到 KV cache 的实现部分

past_key, past_value = past_kv
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
xv = torch.cat((past_value, self.wv(current_token)), dim=1)

为什么 xq 需要拼接 zeros 矩阵?
是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行;
例如 llama3 的实现就没有拼接 zeros 矩阵:

@jingyaogong
Copy link
Owner

jingyaogong commented Sep 13, 2024

在 Attention 方法里涉及到 KV cache 的实现部分

past_key, past_value = past_kv
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
xv = torch.cat((past_value, self.wv(current_token)), dim=1)

为什么 xq 需要拼接 zeros 矩阵? 是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行; 例如 llama3 的实现就没有拼接 zeros 矩阵:

llama3在推理时候的seqlen是1,它的generate函数每次只把current_token输入attention层计算。

# llama3 attention
def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
):
    # 推理时seqlen==1
    bsz, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    # 到这里为止,xq, xk, xv的shape都是[bsz, 1, *, self.head_dim]
    # freqs_cis的输入也是cis[-1:, :] = [1, head_dim]
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
    # 在计算RoPE嵌入之前,q,k,v维度需要一致

    self.cache_k = self.cache_k.to(xq)
    self.cache_v = self.cache_v.to(xq)

    self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
    self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv

    keys = self.cache_k[:bsz, : start_pos + seqlen]
    values = self.cache_v[:bsz, : start_pos + seqlen]

    # repeat k/v heads if n_kv_heads < n_heads
    keys = repeat_kv(
        keys, self.n_rep
    )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
    values = repeat_kv(
        values, self.n_rep
    )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

    xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
    keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
    values = values.transpose(
        1, 2
    )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
  • 假设q_head=16, seqlen=20
    q = (1, 16, 1, 32)
    k = (1, 16, 20, 32)
    $q \times k^T + \text{mask} = (1, 16, 1, 20) + (1, 16, 20, 20) = (1, 16, 20, 20)$
    v = (1, 16, 20, 32)

在计算时,仅使用当前 token 作为 $q$ 是没有问题的,无需拼接前 $n-1$ 个 token 的 $q$,因为在推理时,输入序列长度 $seqlen=1$

与 LLaMA3 不同的是,modelgenerate 函数和训练类似,每次把长度为seqlen 的完整 token 输入到 attention 层进行计算。

当只计算当前 token 的 $xq$ 时,为了保证 $xk$, $xv$$xq$ 在RoPE编码及之前的计算中具有相同的 $seqlen$ 维度,需要将前 $n-1$ 个 token 的 $q$ 进行拼接。前 $n-1$ 个 token 的 $q$ 并不实际影响当前token的qk计算,所以用 zeros_like(xk[:, :-1, :]) 来代替。

的确,前者注意力的计算复杂度为 $o(N \times {dim_{head}})$ << 后者的 $o(N^2 \times {dim_{head}})$,所以generate函数推理时每次只使用当前 token 更高效。此外,KV_cache 没有参考 LLaMA3,最初只是随意写了一个潦草的cache变量。

刚刚看着改了一下,这是在现有方案上强行实现current_token低复杂度的修改版本:

def forward(
        self,
        x: torch.Tensor,
        pos_cis: torch.Tensor,
        use_kv_cache: bool = False,
        past_kv: Tuple[torch.Tensor] = None
):
    bsz, seqlen, _ = x.shape

    keys, values = None, None
    flag = 1
    # QKV
    # inference
    if use_kv_cache:
        current_token = x[:, -1:, :]

        if not past_kv:
            xq = self.wq(x)
            xk, xv = self.wk(x), self.wv(x)
            flag = 1
            past_kv = (xk, xv)
        else:
            past_key, past_value = past_kv
            xq = self.wq(current_token)
            xk = self.wk(current_token)
            xv = self.wv(current_token)
            keys = torch.cat((past_key, xk), dim=1)
            values = torch.cat((past_value, xv), dim=1)
            past_kv = (keys, values)
            flag = 2
    else:
        xq = self.wq(x)
        xk, xv = self.wk(x), self.wv(x)

    if flag == 2:
        xq = xq.view(bsz, 1, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
    else:
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

    if flag == 1:
        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, pos_cis)
    else:
        xq, xk = apply_rotary_emb(xq, xk, pos_cis[-1:, :])

    if flag == 2:
        past_key, past_value = past_kv
        keys = torch.cat((past_key[:, :-1, :], xk.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
        values = torch.cat((past_value[:, :-1, :], xv.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
        past_kv = (keys, values)
        keys = keys.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        values = values.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xk = keys
        xv = values

    # grouped multiquery attention: expand out keys and values
    xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
    xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

    # make heads into a batch dimension
    xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
    xk = xk.transpose(1, 2)
    xv = xv.transpose(1, 2)

    # manual implementation
    scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
    assert hasattr(self, 'mask')
    scores = scores + self.mask[:, :, :seqlen, :seqlen]  # (bs, n_local_heads, seqlen, cache_len + seqlen)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    scores = self.attn_dropout(scores)
    output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

    if flag == 2:
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
    else:
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

    # final projection into the residual stream
    output = self.wo(output)
    output = self.resid_dropout(output)
    return output, past_kv

暂时推理函数就先不做大改了。

欢迎继续交流指正。

@jingyaogong jingyaogong added the question Further information is requested label Sep 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants