Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Merge pull request #204 from glynpu/espnet_scaling_and_layernorm
Browse files Browse the repository at this point in the history
espnet-style attn_output_weight scaling and extra after-norm layer
  • Loading branch information
danpovey authored Jun 4, 2021
2 parents 2a9fe04 + 4ce43f8 commit f863026
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
5 changes: 3 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def main():

output_beam_size = args.output_beam_size

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

logging.info(f'output_beam_size: {output_beam_size}')
Expand Down Expand Up @@ -285,7 +285,8 @@ def main():
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
vgg_frontend=True,
is_espnet_structure=True)
elif model_type == "contextnet":
model = ContextNet(
num_features=80,
Expand Down
5 changes: 3 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def run(rank, world_size, args):
fix_random_seed(42)
setup_dist(rank, world_size, args.master_port)

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
setup_logger(f'{exp_dir}/log/log-train-{rank}')
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
Expand Down Expand Up @@ -526,7 +526,8 @@ def run(rank, world_size, args):
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
vgg_frontend=True,
is_espnet_structure=True)
elif model_type == "contextnet":
model = ContextNet(
num_features=80,
Expand Down
30 changes: 23 additions & 7 deletions snowfall/models/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int
d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_decoder_layers: int = 6,
dropout: float = 0.1, cnn_module_kernel: int = 31,
normalize_before: bool = True, vgg_frontend: bool = False) -> None:
normalize_before: bool = True, vgg_frontend: bool = False,
is_espnet_structure: bool = False) -> None:
super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor,
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend)

self.encoder_pos = RelPositionalEncoding(d_model, dropout)

encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before)
encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before, is_espnet_structure)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)

def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]:
"""
Expand All @@ -66,6 +71,9 @@ def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor
mask = mask.to(x.device) if mask != None else None
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)

if self.normalize_before and self.is_espnet_structure:
x = self.after_norm(x)

return x, mask


Expand All @@ -90,9 +98,10 @@ class ConformerEncoderLayer(nn.Module):
"""

def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
cnn_module_kernel: int = 31, normalize_before: bool = True) -> None:
cnn_module_kernel: int = 31, normalize_before: bool = True,
is_espnet_structure: bool = False) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure)

self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Expand Down Expand Up @@ -319,7 +328,8 @@ class RelPositionMultiheadAttention(nn.Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""

def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.,
is_espnet_structure: bool = False) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
Expand All @@ -339,6 +349,8 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:

self._reset_parameters()

self.is_espnet_structure = is_espnet_structure

def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.)
Expand Down Expand Up @@ -538,7 +550,8 @@ def multi_head_attention_forward(self, query: Tensor,
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)

q = q * scaling
if not self.is_espnet_structure:
q = q * scaling

if attn_mask is not None:
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
Expand Down Expand Up @@ -596,7 +609,10 @@ def multi_head_attention_forward(self, query: Tensor,
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)

attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2)
if not self.is_espnet_structure:
attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2)
else:
attn_output_weights = (matrix_ac + matrix_bd) * scaling # (batch, head, time1, time2)

attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)

Expand Down

0 comments on commit f863026

Please sign in to comment.