Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given #609

Open
saurabh-kataria opened this issue Oct 27, 2024 · 2 comments

Comments

@saurabh-kataria
Copy link

I am unable to use the sample Mamba2 code. Even with following simple code, it fails to do forward pass.

import torch
from mamba_ssm import Mamba2
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(d_model=dim, d_state=64, d_conv=4, expand=2).to('cuda')
y = model(x)

--------------------------------------------------------------------------- 09:19:57 [2/4999]
TypeError Traceback (most recent call last)
Cell In[5], line 5
3 x = torch.randn(batch, length, dim).to("cuda")
4 model = Mamba2(d_model=dim, d_state=64, d_conv=4, expand=2).to('cuda')
----> 5 y = model(x)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py:185, in Mamba2.forward(self, u, seqlen, seq_idx, cu_seqlens, inference_params)
183 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
184 if self.use_mem_eff_path and inference_params is None:
--> 185 out = mamba_split_conv1d_scan_combined(
186 zxbcdt,
187 rearrange(self.conv1d.weight, "d 1 w -> d w"),
188 self.conv1d.bias,
189 self.dt_bias,
190 A,
191 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
192 chunk_size=self.chunk_size,
193 seq_idx=seq_idx,
194 activation=self.activation,
195 rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
196 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
197 outproj_weight=self.out_proj.weight,
198 outproj_bias=self.out_proj.bias,
199 headdim=None if self.D_has_hdim else self.headdim,
200 ngroups=self.ngroups,
201 norm_before_gate=self.norm_before_gate,
202 **dt_limit_kwargs,
203 )
204 if seqlen_og is not None:
205 out = rearrange(out, "b l d -> (b l) d")

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:930, in mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_w
eight, outproj_bias, headdim, ngroups, norm_before_gate)
911 def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroup
s=1, norm_before_gate=True):
912 """
913 Argument:
914 zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
(...)
928 out: (batch, seqlen, dim)
929 """
--> 930 return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )

TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given

@saurabh-kataria
Copy link
Author

#608 works BTW

@epicfilemcnulty
Copy link
Contributor

Got the same error with the latest master, can confirm that applying #608 solves the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants