diff --git a/setup.py b/setup.py index c015158..81f2332 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name='video-diffusion-pytorch', packages=find_packages(exclude=[]), - version='0.6.3', + version='0.7.0', license='MIT', description='Video Diffusion - Pytorch', long_description_content_type='text/markdown', diff --git a/video_diffusion_pytorch/video_diffusion_pytorch.py b/video_diffusion_pytorch/video_diffusion_pytorch.py index 236601d..cefab06 100644 --- a/video_diffusion_pytorch/video_diffusion_pytorch.py +++ b/video_diffusion_pytorch/video_diffusion_pytorch.py @@ -160,6 +160,15 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.gamma +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1)) + + def forward(self, x): + return F.normalize(x, dim = 1) * self.scale * self.gamma + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -174,10 +183,10 @@ def forward(self, x, **kwargs): class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out): super().__init__() self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding = (0, 1, 1)) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift = None): @@ -191,15 +200,15 @@ def forward(self, x, scale_shift = None): return self.act(x) class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim = None): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -355,8 +364,7 @@ def __init__( init_dim = None, init_kernel_size = 7, use_sparse_linear_attn = True, - block_type = 'resnet', - resnet_groups = 8 + block_type = 'resnet' ): super().__init__() self.channels = channels @@ -412,7 +420,7 @@ def __init__( # block type - block_klass = partial(ResnetBlock, groups = resnet_groups) + block_klass = ResnetBlock block_klass_cond = partial(block_klass, time_emb_dim = cond_dim) # modules for all layers