Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

espnet-style attn_output_weight scaling and extra after-norm layer #204

Merged
merged 3 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def main():
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
vgg_frontend=True,
is_espnet_structure=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have this in training script too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. and it's better if you change the directory name, when changing the model structure.
you can remove a couple of older components of the filename, to stop it getting too long.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have this in training script too

added.

.. and it's better if you change the directory name, when changing the model structure.
you can remove a couple of older components of the filename, to stop it getting too long.

-    -noam-mmi-att-musan-sa-vgg
+    -mmi-att-sa-vgg-normlayer

elif model_type == "contextnet":
model = ContextNet(
num_features=80,
Expand Down
30 changes: 23 additions & 7 deletions snowfall/models/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int
d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_decoder_layers: int = 6,
dropout: float = 0.1, cnn_module_kernel: int = 31,
normalize_before: bool = True, vgg_frontend: bool = False) -> None:
normalize_before: bool = True, vgg_frontend: bool = False,
is_espnet_structure: bool = False) -> None:
super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor,
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend)

self.encoder_pos = RelPositionalEncoding(d_model, dropout)

encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before)
encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before, is_espnet_structure)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)

def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]:
"""
Expand All @@ -66,6 +71,9 @@ def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor
mask = mask.to(x.device) if mask != None else None
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)

if self.normalize_before and self.is_espnet_structure:
x = self.after_norm(x)

return x, mask


Expand All @@ -90,9 +98,10 @@ class ConformerEncoderLayer(nn.Module):
"""

def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
cnn_module_kernel: int = 31, normalize_before: bool = True) -> None:
cnn_module_kernel: int = 31, normalize_before: bool = True,
is_espnet_structure: bool = False) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure)

self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Expand Down Expand Up @@ -319,7 +328,8 @@ class RelPositionMultiheadAttention(nn.Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""

def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.,
is_espnet_structure: bool = False) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
Expand All @@ -339,6 +349,8 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:

self._reset_parameters()

self.is_espnet_structure = is_espnet_structure

def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.)
Expand Down Expand Up @@ -538,7 +550,8 @@ def multi_head_attention_forward(self, query: Tensor,
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)

q = q * scaling
if not self.is_espnet_structure:
q = q * scaling

if attn_mask is not None:
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
Expand Down Expand Up @@ -596,7 +609,10 @@ def multi_head_attention_forward(self, query: Tensor,
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)

attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2)
if not self.is_espnet_structure:
attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2)
else:
attn_output_weights = (matrix_ac + matrix_bd) * scaling # (batch, head, time1, time2)

attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)

Expand Down