diff --git a/ebtorch/__init__.py b/ebtorch/__init__.py index c96c52e..6b87714 100644 --- a/ebtorch/__init__.py +++ b/ebtorch/__init__.py @@ -93,6 +93,7 @@ del SwiGLU del TupleDecouple del TupleSelect +del StatefulTupleSelect del SilhouetteScore del Concatenate del DuplexLinearNeck diff --git a/ebtorch/nn/__init__.py b/ebtorch/nn/__init__.py index f09f340..46bce65 100644 --- a/ebtorch/nn/__init__.py +++ b/ebtorch/nn/__init__.py @@ -38,6 +38,7 @@ from .architectures import SharedDuplexLinearNeck from .architectures import SilhouetteScore from .architectures import SirenSine +from .architectures import StatefulTupleSelect from .architectures import SwiGLU from .architectures import TupleDecouple from .architectures import TupleSelect diff --git a/ebtorch/nn/architectures.py b/ebtorch/nn/architectures.py index 8c29df2..57ce343 100644 --- a/ebtorch/nn/architectures.py +++ b/ebtorch/nn/architectures.py @@ -61,6 +61,7 @@ "SwiGLU", "TupleDecouple", "TupleSelect", + "StatefulTupleSelect", "SilhouetteScore", "Concatenate", "DuplexLinearNeck", @@ -917,6 +918,16 @@ def forward(self, xtuple: Tuple[Tensor, ...]) -> Tensor: return xtuple[self.idx] +class StatefulTupleSelect(nn.Module): + def __init__(self, module, idx: int = 0) -> None: + super().__init__() + self.module: nn.Module = module + self.idx: int = idx + + def forward(self, x: Tensor) -> Tensor: + return self.module(x)[self.idx] + + class SilhouetteScore(nn.Module): """ Layerized computation of the Silhouette Score. diff --git a/setup.py b/setup.py index 77e2a66..cb8cd3a 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def read(fname): setup( name=PACKAGENAME, - version="0.25.12", + version="0.25.13", author="Emanuele Ballarin", author_email="emanuele@ballarin.cc", url="https://github.com/emaballarin/ebtorch",