diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index 89d5a506..a616bca9 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -242,7 +242,7 @@ def main(): output_beam_size = args.output_beam_size - exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg') + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') logging.info(f'output_beam_size: {output_beam_size}') @@ -285,7 +285,8 @@ def main(): num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, - vgg_frontend=True) + vgg_frontend=True, + is_espnet_structure=True) elif model_type == "contextnet": model = ContextNet( num_features=80, diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py index cb1f236e..f1ab15b9 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py @@ -465,7 +465,7 @@ def run(rank, world_size, args): fix_random_seed(42) setup_dist(rank, world_size, args.master_port) - exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg') + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') @@ -526,7 +526,8 @@ def run(rank, world_size, args): num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, - vgg_frontend=True) + vgg_frontend=True, + is_espnet_structure=True) elif model_type == "contextnet": model = ContextNet( num_features=80, diff --git a/snowfall/models/conformer.py b/snowfall/models/conformer.py index 4bf921b2..faf52b45 100644 --- a/snowfall/models/conformer.py +++ b/snowfall/models/conformer.py @@ -36,7 +36,8 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, vgg_frontend: bool = False) -> None: + normalize_before: bool = True, vgg_frontend: bool = False, + is_espnet_structure: bool = False) -> None: super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, @@ -44,8 +45,12 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int self.encoder_pos = RelPositionalEncoding(d_model, dropout) - encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before) + encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before, is_espnet_structure) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + self.is_espnet_structure = is_espnet_structure + if self.normalize_before and self.is_espnet_structure: + self.after_norm = nn.LayerNorm(d_model) def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]: """ @@ -66,6 +71,9 @@ def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor mask = mask.to(x.device) if mask != None else None x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + if self.normalize_before and self.is_espnet_structure: + x = self.after_norm(x) + return x, mask @@ -90,9 +98,10 @@ class ConformerEncoderLayer(nn.Module): """ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, - cnn_module_kernel: int = 31, normalize_before: bool = True) -> None: + cnn_module_kernel: int = 31, normalize_before: bool = True, + is_espnet_structure: bool = False) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -319,7 +328,8 @@ class RelPositionMultiheadAttention(nn.Module): >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) """ - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None: + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0., + is_espnet_structure: bool = False) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -339,6 +349,8 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None: self._reset_parameters() + self.is_espnet_structure = is_espnet_structure + def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.) @@ -538,7 +550,8 @@ def multi_head_attention_forward(self, query: Tensor, _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - q = q * scaling + if not self.is_espnet_structure: + q = q * scaling if attn_mask is not None: assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ @@ -596,7 +609,10 @@ def multi_head_attention_forward(self, query: Tensor, matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2) + if not self.is_espnet_structure: + attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2) + else: + attn_output_weights = (matrix_ac + matrix_bd) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)