Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing errors with modified SSM equations in the bwd CUDA kernel (wrt $\Delta$ bias & softplus) #604

Open
SudhanshuBokade opened this issue Oct 23, 2024 · 0 comments

Comments

@SudhanshuBokade
Copy link

SudhanshuBokade commented Oct 23, 2024

I was trying to learn how to make changes to the fwd & bwd CUDA kernels by attempting a simple modification:

$$ \begin{align*} x &= (e^{\Delta A} + \Delta B)x + \Delta Bu \\ y &= Cx + Du \end{align*} $$

I manually calculated the backward pass derivatives:

$$ \begin{align*} dx &= dy \cdot C \\ d\Delta &= dx \cdot (x \cdot e^{\Delta A} \cdot A + Bx + Bu) \\ du &= dx \cdot B \cdot \Delta + dy \cdot u \\ dA &= dx \cdot e^{\Delta A} \cdot x \cdot \Delta \\ dB &= dx \cdot \Delta \cdot (x + u) \\ dC &= dy \cdot x \\ dD &= dy \cdot u \end{align*} $$

All the $x$'s in the bwd pass above are $x_{t-1}$. Accordingly, I modified thread_data, thread_reverse_data, smem_delta_a, and the derivative calculations in the backward and forward kernels (.cuh files), and updated selective_scan_ref() in selective_scan_interface.py. After building the modified code and testing (test_selective_scan.py::test_selective_scan()) via pytest without changing tolerances, I got these results:

Configuration Tests Passing ✅ Tests Failing ❌
Disabling both delta_softplus and delta_bias All None
Disabling only delta_softplus seq_len < 8_192 seq_len >= 8_192
Disabling only delta_bias seq_len < 8_192 seq_len >= 8_192
Enabling both delta_softplus and delta_bias seq_len < 4_096 seq_len >= 4_096

Issues:

  1. I haven't changed $\Delta$, and delta_softplus should automatically handle scaling (selective_scan_bwd_kernel.cuh#L452-L456). Where could the error be?
  2. There seems to be error accumulation as seq_len increases in this case, but not for Mamba. Am I missing some other change?
  3. Changing gridDim & blockDim for the kernel launch alters the error magnitudes (errors increase). Why?

Thank you for your help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant