You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to learn how to make changes to the fwd & bwd CUDA kernels by attempting a simple modification:
$$
\begin{align*}
x &= (e^{\Delta A} + \Delta B)x + \Delta Bu \\
y &= Cx + Du
\end{align*}
$$
I manually calculated the backward pass derivatives:
$$
\begin{align*}
dx &= dy \cdot C \\
d\Delta &= dx \cdot (x \cdot e^{\Delta A} \cdot A + Bx + Bu) \\
du &= dx \cdot B \cdot \Delta + dy \cdot u \\
dA &= dx \cdot e^{\Delta A} \cdot x \cdot \Delta \\
dB &= dx \cdot \Delta \cdot (x + u) \\
dC &= dy \cdot x \\
dD &= dy \cdot u
\end{align*}
$$
All the $x$'s in the bwd pass above are $x_{t-1}$. Accordingly, I modified thread_data, thread_reverse_data, smem_delta_a, and the derivative calculations in the backward and forward kernels (.cuh files), and updated selective_scan_ref() in selective_scan_interface.py. After building the modified code and testing (test_selective_scan.py::test_selective_scan()) via pytest without changing tolerances, I got these results:
I was trying to learn how to make changes to the fwd & bwd CUDA kernels by attempting a simple modification:
I manually calculated the backward pass derivatives:
All the$x$ 's in the bwd pass above are $x_{t-1}$ . Accordingly, I modified
thread_data
,thread_reverse_data
,smem_delta_a
, and the derivative calculations in the backward and forward kernels (.cuh
files), and updatedselective_scan_ref()
inselective_scan_interface.py
. After building the modified code and testing (test_selective_scan.py::test_selective_scan()
) viapytest
without changing tolerances, I got these results:delta_softplus
anddelta_bias
delta_softplus
seq_len < 8_192
seq_len >= 8_192
delta_bias
seq_len < 8_192
seq_len >= 8_192
delta_softplus
anddelta_bias
seq_len < 4_096
seq_len >= 4_096
Issues:
delta_softplus
should automatically handle scaling (selective_scan_bwd_kernel.cuh#L452-L456
). Where could the error be?seq_len
increases in this case, but not for Mamba. Am I missing some other change?Thank you for your help.
The text was updated successfully, but these errors were encountered: