Skip to content

Commit

Permalink
[models] update frame-level feature extraction interface (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
czy97 authored Sep 25, 2024
1 parent 3ccc791 commit ecd36be
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 24 deletions.
6 changes: 3 additions & 3 deletions wespeaker/models/ecapa_tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(self,
else:
self.bn2 = nn.Identity()

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)

Expand All @@ -221,11 +221,11 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x).permute(0, 2, 1)
out = self._get_frame_level_feat(x).permute(0, 2, 1)
return out # (B, T, D)

def forward(self, x):
out = F.relu(self.__get_frame_level_feat(x))
out = F.relu(self._get_frame_level_feat(x))
out = self.bn(self.pool(out))
out = self.linear(out)
if self.emb_bn:
Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/eres2net.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _make_layer(self,
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
Expand All @@ -371,14 +371,14 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
out = out.transpose(1, 3)
out = torch.flatten(out, 2, -1)

return out # (B, T, D)

def forward(self, x):
fuse_out1234 = self.__get_frame_level_feat(x)
fuse_out1234 = self._get_frame_level_feat(x)
stats = self.pool(fuse_out1234)

embed_a = self.seg_1(stats)
Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/gemini_dfresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self,
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
Expand All @@ -120,15 +120,15 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
out = out.transpose(1, 3)
out = torch.flatten(out, 2, -1)

return out # (B, T, D)

def forward(self, x):

out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
stats = self.pool(out)

embed_a = self.seg_1(stats)
Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/redimnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def __init__(
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,F,T) => (B,T,F)
x = x.unsqueeze_(1)
Expand All @@ -853,12 +853,12 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x).permute(0, 2, 1)
out = self._get_frame_level_feat(x).permute(0, 2, 1)

return out # (B, T, D)

def forward(self, x):
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)

stats = self.pool(out)
embed_a = self.seg_1(stats)
Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def get_downsample_multiple(self):
def get_output_planes(self):
return self.output_planes

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
x = x.unsqueeze_(1)
Expand All @@ -573,14 +573,14 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
out = out.transpose(1, 3)
out = torch.flatten(out, 2, -1)

return out # (B, T, D)

def forward(self, x):
x = self.__get_frame_level_feat(x)
x = self._get_frame_level_feat(x)
stats = self.pool(x)
embed = self.seg(stats)

Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/res2net.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _make_layer(self, block, planes, num_blocks, stride):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)

Expand All @@ -169,14 +169,14 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
out = out.transpose(1, 3)
out = torch.flatten(out, 2, -1)

return out # (B, T, D)

def forward(self, x):
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
stats = self.pool(out)

embed_a = self.seg_1(stats)
Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _make_layer(self, block, planes, num_blocks, stride):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)

Expand All @@ -183,14 +183,14 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
out = out.transpose(1, 3)
out = torch.flatten(out, 2, -1)

return out # (B, T, D)

def forward(self, x):
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)

stats = self.pool(out)

Expand Down
6 changes: 3 additions & 3 deletions wespeaker/models/tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self,
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
self.seg_2 = nn.Linear(embed_dim, embed_dim)

def __get_frame_level_feat(self, x):
def _get_frame_level_feat(self, x):
# for inner class usage
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)

Expand All @@ -100,12 +100,12 @@ def __get_frame_level_feat(self, x):

def get_frame_level_feat(self, x):
# for outer interface
out = self.__get_frame_level_feat(x).permute(0, 2, 1)
out = self._get_frame_level_feat(x).permute(0, 2, 1)

return out # (B, T, D)

def forward(self, x):
out = self.__get_frame_level_feat(x)
out = self._get_frame_level_feat(x)
stats = self.pool(out)
embed_a = self.seg_1(stats)
out = F.relu(embed_a)
Expand Down

0 comments on commit ecd36be

Please sign in to comment.