From 8a0c923f58b5cb7f3e44206f24740eaf710d4275 Mon Sep 17 00:00:00 2001 From: Shuai Wang Date: Thu, 28 Mar 2024 00:15:49 +0800 Subject: [PATCH] [feature] add support for gemini-dfresnet --- .../v2/conf/gemini_dfresnet_adam.yaml | 81 ++++++++ wespeaker/models/gemini_dfresnet.py | 174 ++++++++++++++++++ wespeaker/models/speaker_model.py | 3 + 3 files changed, 258 insertions(+) create mode 100644 examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml create mode 100644 wespeaker/models/gemini_dfresnet.py diff --git a/examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml b/examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml new file mode 100644 index 00000000..a9e4180b --- /dev/null +++ b/examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml @@ -0,0 +1,81 @@ +### train configuraton + +exp_dir: exp/Gemini_DF_ResNet60-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 +gpus: "[0,1]" +num_avg: 2 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 165 +save_epoch_interval: 5 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 128 + num_workers: 8 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 100 + max_num_frames: 800 + resample_rate: 16000 + speed_perturb: True + num_frms: 200 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + fbank_args: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: Gemini_DF_ResNet60 # Gemini_DF_ResNet60 Gemini_DF_ResNet114 GemGemini_DF_ResNet183 Gemini_DF_ResNet237 +model_init: null +model_args: + feat_dim: 80 + embed_dim: 256 + pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False +projection_args: + project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.2 + final_margin: 0.2 + increase_start_epoch: 20 + fix_start_epoch: 40 + update_margin: False + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: AdamW +optimizer_args: + weight_decay: 0.05 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.000125 + final_lr: 0.000001 + warm_up_epoch: 0 + warm_from_zero: False diff --git a/wespeaker/models/gemini_dfresnet.py b/wespeaker/models/gemini_dfresnet.py new file mode 100644 index 00000000..02af6473 --- /dev/null +++ b/wespeaker/models/gemini_dfresnet.py @@ -0,0 +1,174 @@ +# Copyright (c) 2024 Shuai Wang (wsstriving@gmail.com) +# 2024 Tianchi Liu (tianchi_liu@u.nus.edu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +'''The implementation of Gemini-DF-ResNet. + +Reference: +[1] Liu, Tianchi, et al. "Golden Gemini is All You Need: Finding the + Sweet Spots for Speaker Verification." arXiv:2312.03620 (2023). +[2] Liu, Bei, et al. "DF-ResNet: Boosting Speaker Verification Performance + with Depth-First Design." INTERSPEECH. 2022. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +import wespeaker.models.pooling_layers as pooling_layers + + +class Inverted_Bottleneck(nn.Module): + def __init__(self, dim): + super(Inverted_Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(dim, 4 * dim, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(4 * dim) + self.conv2 = nn.Conv2d(4 * dim, 4 * dim, + kernel_size=3, padding=1, groups=4 * dim, + bias=False) + self.bn2 = nn.BatchNorm2d(4 * dim) + self.conv3 = nn.Conv2d(4 * dim, dim, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(dim) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += x + out = F.relu(out) + return out + + +class Gemini_DF_ResNet(nn.Module): + # DF_ResNet with T14c stride strategy of Golden Gemini + def __init__(self, + depths=[3, 3, 9, 3], + dims=[32, 64, 128, 256], + feat_dim=40, + embed_dim=128, + pooling_func='TSTP', + two_emb_layer=False): + super(Gemini_DF_ResNet, self).__init__() + self.feat_dim = feat_dim + self.embed_dim = embed_dim + self.stats_dim = int(feat_dim / 8 / 2) * dims[-1] + self.two_emb_layer = two_emb_layer + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv2d(1, dims[0], kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(dims[0]), + nn.ReLU() + ) + self.downsample_layers.append(stem) + + stride_f = [2, 2, 2, 2] + stride_t = [1, 2, 1, 1] + + for i in range(4): + downsample_layer = nn.Sequential( + nn.Conv2d( + dims[i], dims[i + 1], kernel_size=3, + stride=(stride_f[i], stride_t[i]), + padding=1, bias=False), + nn.BatchNorm2d(dims[i + 1]) + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() + for i in range(4): + stage = nn.Sequential( + *[Inverted_Bottleneck(dim=dims[i + 1]) for _ in range(depths[i])] + ) + self.stages.append(stage) + + self.pool = getattr(pooling_layers, + pooling_func)(in_dim=self.stats_dim) + self.pool_out_dim = self.pool.get_out_dim() + self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) + self.seg_2 = nn.Linear(embed_dim, embed_dim) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = self.downsample_layers[0](x) + out = self.downsample_layers[1](out) + out = self.stages[0](out) + out = self.downsample_layers[2](out) + out = self.stages[1](out) + out = self.downsample_layers[3](out) + out = self.stages[2](out) + out = self.downsample_layers[4](out) + out = self.stages[3](out) + + stats = self.pool(out) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_a, embed_b + else: + return torch.tensor(0.0), embed_a + + +# following models do include separate downsmapling layers into layer counting +def Gemini_DF_ResNet60(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 3, 9, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +def Gemini_DF_ResNet114(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 3, 27, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +def Gemini_DF_ResNet183(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 8, 45, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +def Gemini_DF_ResNet237(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): # not used + return Gemini_DF_ResNet(depths=[3, 8, 63, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +if __name__ == '__main__': + x = torch.zeros(1, 200, 80) + model = Gemini_DF_ResNet183(80, 256, 'TSTP') + model.eval() + out = model(x) + print(out[-1].size()) + + num_params = sum(p.numel() for p in model.parameters()) + print("{} M".format(num_params / 1e6)) diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index 70a6bc1d..8475f1ae 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -18,6 +18,7 @@ import wespeaker.models.repvgg as repvgg import wespeaker.models.campplus as campplus import wespeaker.models.eres2net as eres2net +import wespeaker.models.gemini_dfresnet as gemini import wespeaker.models.res2net as res2net @@ -36,6 +37,8 @@ def get_speaker_model(model_name: str): return getattr(eres2net, model_name) elif model_name.startswith("Res2Net"): return getattr(res2net, model_name) + elif model_name.startswith("Gemini"): + return getattr(gemini, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1)