Skip to content

Commit

Permalink
[feature] add support for gemini-dfresnet
Browse files Browse the repository at this point in the history
  • Loading branch information
wsstriving committed Mar 27, 2024
1 parent 14e6d2a commit 8a0c923
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 0 deletions.
81 changes: 81 additions & 0 deletions examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
### train configuraton

exp_dir: exp/Gemini_DF_ResNet60-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 2
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 165
save_epoch_interval: 5 # save model every 5 epochs
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 128
num_workers: 8
pin_memory: False
prefetch_factor: 8
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 100
max_num_frames: 800
resample_rate: 16000
speed_perturb: True
num_frms: 200
aug_prob: 0.6 # prob to add reverb & noise aug per sample
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: Gemini_DF_ResNet60 # Gemini_DF_ResNet60 Gemini_DF_ResNet114 GemGemini_DF_ResNet183 Gemini_DF_ResNet237
model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
projection_args:
project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.2
final_margin: 0.2
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: False
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: AdamW
optimizer_args:
weight_decay: 0.05

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.000125
final_lr: 0.000001
warm_up_epoch: 0
warm_from_zero: False
174 changes: 174 additions & 0 deletions wespeaker/models/gemini_dfresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2024 Shuai Wang ([email protected])
# 2024 Tianchi Liu ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''The implementation of Gemini-DF-ResNet.
Reference:
[1] Liu, Tianchi, et al. "Golden Gemini is All You Need: Finding the
Sweet Spots for Speaker Verification." arXiv:2312.03620 (2023).
[2] Liu, Bei, et al. "DF-ResNet: Boosting Speaker Verification Performance
with Depth-First Design." INTERSPEECH. 2022.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import wespeaker.models.pooling_layers as pooling_layers


class Inverted_Bottleneck(nn.Module):
def __init__(self, dim):
super(Inverted_Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(dim, 4 * dim, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(4 * dim)
self.conv2 = nn.Conv2d(4 * dim, 4 * dim,
kernel_size=3, padding=1, groups=4 * dim,
bias=False)
self.bn2 = nn.BatchNorm2d(4 * dim)
self.conv3 = nn.Conv2d(4 * dim, dim, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(dim)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += x
out = F.relu(out)
return out


class Gemini_DF_ResNet(nn.Module):
# DF_ResNet with T14c stride strategy of Golden Gemini
def __init__(self,
depths=[3, 3, 9, 3],
dims=[32, 64, 128, 256],
feat_dim=40,
embed_dim=128,
pooling_func='TSTP',
two_emb_layer=False):
super(Gemini_DF_ResNet, self).__init__()
self.feat_dim = feat_dim
self.embed_dim = embed_dim
self.stats_dim = int(feat_dim / 8 / 2) * dims[-1]
self.two_emb_layer = two_emb_layer

self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(
nn.Conv2d(1, dims[0], kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(dims[0]),
nn.ReLU()
)
self.downsample_layers.append(stem)

stride_f = [2, 2, 2, 2]
stride_t = [1, 2, 1, 1]

for i in range(4):
downsample_layer = nn.Sequential(
nn.Conv2d(
dims[i], dims[i + 1], kernel_size=3,
stride=(stride_f[i], stride_t[i]),
padding=1, bias=False),
nn.BatchNorm2d(dims[i + 1])
)
self.downsample_layers.append(downsample_layer)

self.stages = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(
*[Inverted_Bottleneck(dim=dims[i + 1]) for _ in range(depths[i])]
)
self.stages.append(stage)

self.pool = getattr(pooling_layers,
pooling_func)(in_dim=self.stats_dim)
self.pool_out_dim = self.pool.get_out_dim()
self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
self.seg_2 = nn.Linear(embed_dim, embed_dim)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()

def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = self.downsample_layers[0](x)
out = self.downsample_layers[1](out)
out = self.stages[0](out)
out = self.downsample_layers[2](out)
out = self.stages[1](out)
out = self.downsample_layers[3](out)
out = self.stages[2](out)
out = self.downsample_layers[4](out)
out = self.stages[3](out)

stats = self.pool(out)

embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_a, embed_b
else:
return torch.tensor(0.0), embed_a


# following models do include separate downsmapling layers into layer counting
def Gemini_DF_ResNet60(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False):
return Gemini_DF_ResNet(depths=[3, 3, 9, 3],
dims=[32, 32, 64, 128, 256],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer)


def Gemini_DF_ResNet114(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False):
return Gemini_DF_ResNet(depths=[3, 3, 27, 3],
dims=[32, 32, 64, 128, 256],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer)


def Gemini_DF_ResNet183(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False):
return Gemini_DF_ResNet(depths=[3, 8, 45, 3],
dims=[32, 32, 64, 128, 256],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer)


def Gemini_DF_ResNet237(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): # not used
return Gemini_DF_ResNet(depths=[3, 8, 63, 3],
dims=[32, 32, 64, 128, 256],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer)


if __name__ == '__main__':
x = torch.zeros(1, 200, 80)
model = Gemini_DF_ResNet183(80, 256, 'TSTP')
model.eval()
out = model(x)
print(out[-1].size())

num_params = sum(p.numel() for p in model.parameters())
print("{} M".format(num_params / 1e6))
3 changes: 3 additions & 0 deletions wespeaker/models/speaker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import wespeaker.models.repvgg as repvgg
import wespeaker.models.campplus as campplus
import wespeaker.models.eres2net as eres2net
import wespeaker.models.gemini_dfresnet as gemini
import wespeaker.models.res2net as res2net


Expand All @@ -36,6 +37,8 @@ def get_speaker_model(model_name: str):
return getattr(eres2net, model_name)
elif model_name.startswith("Res2Net"):
return getattr(res2net, model_name)
elif model_name.startswith("Gemini"):
return getattr(gemini, model_name)
else: # model_name error !!!
print(model_name + " not found !!!")
exit(1)

0 comments on commit 8a0c923

Please sign in to comment.