You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py:185, in Mamba2.forward(self, u, seqlen, seq_idx, cu_seqlens, inference_params)
183 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
184 if self.use_mem_eff_path and inference_params is None:
--> 185 out = mamba_split_conv1d_scan_combined(
186 zxbcdt,
187 rearrange(self.conv1d.weight, "d 1 w -> d w"),
188 self.conv1d.bias,
189 self.dt_bias,
190 A,
191 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
192 chunk_size=self.chunk_size,
193 seq_idx=seq_idx,
194 activation=self.activation,
195 rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
196 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
197 outproj_weight=self.out_proj.weight,
198 outproj_bias=self.out_proj.bias,
199 headdim=None if self.D_has_hdim else self.headdim,
200 ngroups=self.ngroups,
201 norm_before_gate=self.norm_before_gate,
202 **dt_limit_kwargs,
203 )
204 if seqlen_og is not None:
205 out = rearrange(out, "b l d -> (b l) d")
File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )
TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given
The text was updated successfully, but these errors were encountered:
I am unable to use the sample Mamba2 code. Even with following simple code, it fails to do forward pass.
The text was updated successfully, but these errors were encountered: