Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

n-best rescore with transformer lm #201

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
109 changes: 109 additions & 0 deletions egs/librispeech/asr/simple_v1/espnet_utils/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python3

# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0

import argparse
import logging
from typing import Tuple

import numpy as np
import torch

from espnet_utils.common import load_espnet_model_config
from espnet_utils.common import rename_state_dict, combine_qkv
from espnet_utils.frontened import Fbank
from espnet_utils.frontened import GlobalMVN
from espnet_utils.numericalizer import SpmNumericalizer
from snowfall.models.conformer import Conformer

_ESPNET_ENCODER_KEY_TO_SNOWFALL_KEY = [
('frontend.logmel.melmat', 'frontend.melmat'),
('encoder.embed.out.0.weight', 'encoder.embed.out.weight'),
('encoder.embed.out.0.bias', 'encoder.embed.out.bias'),
(r'(encoder.encoders.)(\d+)(.self_attn.)linear_out([\s\S*])',
r'\1\2\3out_proj\4'),
(r'(encoder.encoders.)(\d+)', r'\1layers.\2'),
(r'(encoder.encoders.layers.)(\d+)(.feed_forward.)(w_1)',
r'\1\2.feed_forward.0'),
(r'(encoder.encoders.layers.)(\d+)(.feed_forward.)(w_2)',
r'\1\2.feed_forward.3'),
(r'(encoder.encoders.layers.)(\d+)(.feed_forward_macaron.)(w_1)',
r'\1\2.feed_forward_macaron.0'),
(r'(encoder.encoders.layers.)(\d+)(.feed_forward_macaron.)(w_2)',
r'\1\2.feed_forward_macaron.3'),
(r'(encoder.embed.)([\s\S*])', r'encoder.encoder_embed.\2'),
(r'(encoder.encoders.)([\s\S*])', r'encoder.encoder.\2'),
(r'(ctc.ctc_lo.)([\s\S*])', r'encoder.encoder_output_layer.1.\2'),
]


class ESPnetASRModel(torch.nn.Module):

def __init__(
self,
frontend: None,
normalize: None,
encoder: None,
):

super().__init__()
self.frontend = frontend
self.normalize = normalize
self.encoder = encoder

def encode(
self, speech: torch.Tensor,
speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding doc describing the shape of various tensors?

feats, feats_lengths = self.frontend(speech, speech_lengths)

feats, feats_lengths = self.normalize(feats, feats_lengths)

feats = feats.permute(0, 2, 1)

nnet_output, _, _ = self.encoder(feats)
nnet_output = nnet_output.permute(2, 0, 1)
return nnet_output

@classmethod
def build_model(cls, asr_train_config, asr_model_file, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls is never used.
I would suggest changing @classmethod to @staticmethod and removing cls.

args = load_espnet_model_config(asr_train_config)
# {'fs': '16k', 'hop_length': 256, 'n_fft': 512}
frontend = Fbank(**args.frontend_conf)
normalize = GlobalMVN(**args.normalize_conf)
encoder = Conformer(num_features=80,
num_classes=len(args.token_list),
subsampling_factor=4,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
cnn_module_kernel=31,
num_decoder_layers=0,
is_espnet_structure=True)

model = ESPnetASRModel(
frontend=frontend,
normalize=normalize,
encoder=encoder,
)

state_dict = torch.load(asr_model_file, map_location=device)

state_dict = {
k: v for k, v in state_dict.items() if not k.startswith('decoder')
}

combine_qkv(state_dict, num_encoder_layers=11)
rename_state_dict(rename_patterns=_ESPNET_ENCODER_KEY_TO_SNOWFALL_KEY,
state_dict=state_dict)

model.load_state_dict(state_dict, strict=False)
model = model.to(torch.device(device))

numericalizer = SpmNumericalizer(tokenizer_type='spm',
tokenizer_file=args.bpemodel,
token_list=args.token_list,
unk_symbol='<unk>')
return model, numericalizer
56 changes: 56 additions & 0 deletions egs/librispeech/asr/simple_v1/espnet_utils/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3

# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0

import argparse
import re
import yaml

from typing import List, Tuple, Dict
from pathlib import Path

import torch


def load_espnet_model_config(config_file):
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
return argparse.Namespace(**args)


def rename_state_dict(rename_patterns: List[Tuple[str, str]],
state_dict: Dict[str, torch.Tensor]):
# Rename state dict to load espent model
if rename_patterns is not None:
for old_pattern, new_pattern in rename_patterns:
old_keys = [
k for k in state_dict if re.match(old_pattern, k) is not None
]
for k in old_keys:
v = state_dict.pop(k)
new_k = re.sub(old_pattern, new_pattern, k)
state_dict[new_k] = v


def combine_qkv(state_dict: Dict[str, torch.Tensor], num_encoder_layers=11):
for layer in range(num_encoder_layers + 1):
q_w = state_dict[f'encoder.encoders.{layer}.self_attn.linear_q.weight']
k_w = state_dict[f'encoder.encoders.{layer}.self_attn.linear_k.weight']
v_w = state_dict[f'encoder.encoders.{layer}.self_attn.linear_v.weight']
q_b = state_dict[f'encoder.encoders.{layer}.self_attn.linear_q.bias']
k_b = state_dict[f'encoder.encoders.{layer}.self_attn.linear_k.bias']
v_b = state_dict[f'encoder.encoders.{layer}.self_attn.linear_v.bias']

for param_type in ['weight', 'bias']:
for layer_type in ['q', 'k', 'v']:
key_to_remove = f'encoder.encoders.{layer}.self_attn.linear_{layer_type}.{param_type}'
state_dict.pop(key_to_remove)

in_proj_weight = torch.cat([q_w, k_w, v_w], dim=0)
in_proj_bias = torch.cat([q_b, k_b, v_b], dim=0)
key_weight = f'encoder.encoders.{layer}.self_attn.in_proj.weight'
state_dict[key_weight] = in_proj_weight
key_bias = f'encoder.encoders.{layer}.self_attn.in_proj.bias'
state_dict[key_bias] = in_proj_bias
229 changes: 229 additions & 0 deletions egs/librispeech/asr/simple_v1/espnet_utils/frontened.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#!/usr/bin/env python3

# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0

import humanfriendly
import librosa
import numpy as np
import torch

from pathlib import Path
from typeguard import check_argument_types
from typing import Optional, Tuple, Union


# Modified from:
# https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet/nets/pytorch_backend/frontends/feature_transform.py#L135
class GlobalMVN(torch.nn.Module):
"""Apply global mean and variance normalization

TODO(kamo): Make this class portable somehow

Args:
stats_file: npy file
norm_means: Apply mean normalization
norm_vars: Apply var normalization
eps:
"""

def __init__(
self,
stats_file: Union[Path, str],
norm_means: bool = True,
norm_vars: bool = True,
eps: float = 1.0e-20,
):
assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
stats_file = Path(stats_file)

self.stats_file = stats_file
stats = np.load(stats_file)
if isinstance(stats, np.ndarray):
# Kaldi like stats
count = stats[0].flatten()[-1]
mean = stats[0, :-1] / count
var = stats[1, :-1] / count - mean * mean
else:
# New style: Npz file
count = stats["count"]
sum_v = stats["sum"]
sum_square_v = stats["sum_square"]
mean = sum_v / count
var = sum_square_v / count - mean * mean
std = np.sqrt(np.maximum(var, eps))

self.register_buffer("mean", torch.from_numpy(mean))
self.register_buffer("std", torch.from_numpy(std))

def forward(
self,
x: torch.Tensor,
ilens: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function

Args:
x: (B, L, ...)
ilens: (B,)
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
norm_means = self.norm_means
norm_vars = self.norm_vars
self.mean = self.mean.to(x.device, x.dtype)
self.std = self.std.to(x.device, x.dtype)

# feat: (B, T, D)
if norm_means:
if x.requires_grad:
x = x - self.mean
else:
x -= self.mean

if norm_vars:
x /= self.std
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_means uses a guard requires_grad to choose whether to perform an in-place update. Is there a reason not to do the same here?

The original implementation
https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet/nets/pytorch_backend/frontends/feature_transform.py#L135
uses self.scale to do a multiplication, which is more efficient than dividing by self.std.


return x, ilens


# Modified from:
# https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet2/layers/stft.py#L14:7
class Stft(torch.nn.Module):

def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window

def forward(
self,
input: torch.Tensor,
ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.

Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)

"""
bs = input.size(0)

if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length,
dtype=input.dtype,
device=input.device)
else:
window = None
output = torch.stft(
input,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
)
output = output.transpose(1, 2)

if self.center:
pad = self.win_length // 2
ilens = ilens + 2 * pad

olens = (ilens - self.win_length) // self.hop_length + 1

return output, olens


# Modified from:
# https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet2/asr/frontend/default.py#L19
class Fbank(torch.nn.Module):
"""

Stft -> Power-spec -> Mel-Fbank
"""

def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
):
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)

self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)

fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
)

# _mel_options = {'sr': 16000, 'n_fft': 512, 'n_mels': 80, 'fmin': 0, 'fmax': 8000.0, 'htk': False}
melmat = librosa.filters.mel(**_mel_options)

self.register_buffer("melmat", torch.from_numpy(melmat.T).float())

def forward(
self, input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
input_stft, feats_lens = self.stft(input, input_lengths)
input_stft = torch.complex(input_stft[..., 0], input_stft[..., 1])

input_power = input_stft.real**2 + input_stft.imag**2

mel_feat = torch.matmul(input_power, self.melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)

input_feats = mel_feat.log()

return input_feats, feats_lens
Loading