Skip to content

Commit

Permalink
Add StatefulTupleSelect layer
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Jul 21, 2024
1 parent 80d78d6 commit f350f58
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions ebtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
del SwiGLU
del TupleDecouple
del TupleSelect
del StatefulTupleSelect
del SilhouetteScore
del Concatenate
del DuplexLinearNeck
Expand Down
1 change: 1 addition & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .architectures import SharedDuplexLinearNeck
from .architectures import SilhouetteScore
from .architectures import SirenSine
from .architectures import StatefulTupleSelect
from .architectures import SwiGLU
from .architectures import TupleDecouple
from .architectures import TupleSelect
Expand Down
11 changes: 11 additions & 0 deletions ebtorch/nn/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"SwiGLU",
"TupleDecouple",
"TupleSelect",
"StatefulTupleSelect",
"SilhouetteScore",
"Concatenate",
"DuplexLinearNeck",
Expand Down Expand Up @@ -917,6 +918,16 @@ def forward(self, xtuple: Tuple[Tensor, ...]) -> Tensor:
return xtuple[self.idx]


class StatefulTupleSelect(nn.Module):
def __init__(self, module, idx: int = 0) -> None:
super().__init__()
self.module: nn.Module = module
self.idx: int = idx

def forward(self, x: Tensor) -> Tensor:
return self.module(x)[self.idx]


class SilhouetteScore(nn.Module):
"""
Layerized computation of the Silhouette Score.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.25.12",
version="0.25.13",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit f350f58

Please sign in to comment.