Skip to content

Commit

Permalink
Improved definition of modern VAEs
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Aug 15, 2024
1 parent a8991ca commit 74575a6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
1 change: 1 addition & 0 deletions ebtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
del ArgMaxLayer
del BasicAE
del BasicVAE
del SingleNeckVAE
del beta_gaussian_kldiv
del var_of_lap
del BinarizeLayer
Expand Down
1 change: 1 addition & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .architectures import SGRUHCell
from .architectures import SharedDuplexLinearNeck
from .architectures import SilhouetteScore
from .architectures import SingleNeckVAE
from .architectures import SirenSine
from .architectures import StatefulTupleSelect
from .architectures import SwiGLU
Expand Down
45 changes: 45 additions & 0 deletions ebtorch/nn/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"SirenSine",
"BasicAE",
"BasicVAE",
"SingleNeckVAE",
"Clamp",
"SwiGLU",
"TupleDecouple",
Expand Down Expand Up @@ -865,6 +866,7 @@ def forward(self, x: Tensor) -> Union[
shared: Tensor = self.encoder(x)
mean: Tensor = self.mean_neck(shared)
logvar: Tensor = self.logvar_neck(shared)
# noinspection DuplicatedCode
z: Tensor = self.grps(mean, logvar)
y: Tensor = self.decoder(z)

Expand All @@ -878,6 +880,49 @@ def forward(self, x: Tensor) -> Union[
return y


################################################################################
class SingleNeckVAE(nn.Module):
def __init__(
self,
encoder: nn.Module,
neck: nn.Module,
decoder: nn.Module,
extract_z: bool = False,
extract_mv: bool = False,
) -> None:
super().__init__()
self.encoder: nn.Module = encoder
self.neck: nn.Module = neck
self.grps: GaussianReparameterizerSampler = GaussianReparameterizerSampler()
self.decoder: nn.Module = decoder
self.extract_z: bool = extract_z
self.extract_mv: bool = extract_mv

def forward(self, x: Tensor) -> Union[
Tensor,
Tuple[Tensor, Tensor],
Tuple[Tensor, Tensor, Tensor],
Tuple[Tensor, Tensor, Tensor, Tensor],
]:
shared: Tensor = self.encoder(x)
mean, logvar = self.neck(shared)
# noinspection DuplicatedCode
z: Tensor = self.grps(mean, logvar)
y: Tensor = self.decoder(z)

if self.extract_z and self.extract_mv:
return y, z, mean, logvar
elif self.extract_z:
return y, z
elif self.extract_mv:
return y, mean, logvar
else:
return y


################################################################################


class SwiGLU(nn.Module):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.25.14",
version="0.25.15",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 74575a6

Please sign in to comment.