Skip to content

Commit

Permalink
Fix MambaChunkScanCombinedFn returning incorrect number of gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 1, 2024
1 parent 95a2041 commit ddce0c1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.0"
__version__ = "2.2.1"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def backward(ctx, dout, *args):
assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward"
dfinal_states = args[0] if ctx.return_final_states else None
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None
return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None


def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
Expand Down

0 comments on commit ddce0c1

Please sign in to comment.