Skip to content

Commit

Permalink
Fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jun 3, 2024
1 parent 60dadf2 commit 33aeaa0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
![Mamba](assets/selection.png "Selective State Space")
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
> Albert Gu*, Tri Dao*\
> Paper: https://arxiv.org/abs/2312.00752
> **Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\
> Paper: https://arxiv.org/abs/2312.00752\
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\
> Tri Dao*, Albert Gu*\
> Paper: https://arxiv.org/abs/2405.21060
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.0.0"
__version__ = "2.0.1"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from einops import rearrange

from src.distributed.distributed_utils import (
from mamba_ssm.distributed.distributed_utils import (
all_gather_raw,
all_reduce,
all_reduce_raw,
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from einops import rearrange, repeat

from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd

TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')

Expand Down
30 changes: 15 additions & 15 deletions mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
except ImportError:
causal_conv1d_fn, causal_conv1d_cuda = None, None

from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
from src.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
from src.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
from src.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
from src.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
from src.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
from src.ops.triton.ssd_state_passing import state_passing, state_passing_ref
from src.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
from src.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
from src.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
from src.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd

TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')

Expand Down Expand Up @@ -651,7 +651,7 @@ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus
Return:
out: (batch, seqlen, nheads, headdim)
"""
from src.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn

batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
Expand Down

0 comments on commit 33aeaa0

Please sign in to comment.