Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

n-best rescore with transformer lm #201

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

glynpu
Copy link
Contributor

@glynpu glynpu commented May 24, 2021

Wer results of this pr (by loaded models from espnet model zoo):

test-clean 2.43% 
test-other 5.79%

image

image

This pr implements following procedure with models from espnet model zoo:
image

Added benefits by loading espnet trained conformer encoder model with equivalent snowfall model definition:

  1. identify differences of conformer implementation between espnet and snowfall. As shown in snowfall/models/conformer.py, snowfall only scaling q; while espnet scale attn_outout_weights.
  2. espnet conformer has an extra layer_norm after encoder

Also, the loaded espnet transformer lm could be used as a baseline for snowfall lm training tasks.

@danpovey
Copy link
Contributor

Great!!
I assume the modeling units are BPE pieces? I think a good step towards resolving the difference would be to train
(i) a CTC model
(ii) a LF-MMI model
using those same BPE pieces.

@glynpu
Copy link
Contributor Author

glynpu commented May 25, 2021

Great!!
I assume the modeling units are BPE pieces? I think a good step towards resolving the difference would be to train
(i) a CTC model
(ii) a LF-MMI model
using those same BPE pieces.

Yes, the modeling units are 5000 tokens including "<blank>".
I will do the suggested experiments.

@danpovey
Copy link
Contributor

danpovey commented May 25, 2021 via email

b_to_a_map=b_to_a_map,
sorted_match_a=True)
lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device)
lm_scores = lm_path_lats.get_tot_scores(True, True)
Copy link
Contributor

@danpovey danpovey Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2nd arg to get_tot_scores() here, representing log_semiring, should be false, because ARPA-type language models are constructed in such a way that the backoff prob is included in the direct arc. I.e. we would be double-counting if we were to sum the probabilities of the non-backoff and backoff arcs.

Copy link
Collaborator

@csukuangfj csukuangfj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more documentation to your code.

x -= self.mean

if norm_vars:
x /= self.std
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_means uses a guard requires_grad to choose whether to perform an in-place update. Is there a reason not to do the same here?

The original implementation
https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet/nets/pytorch_backend/frontends/feature_transform.py#L135
uses self.scale to do a multiplication, which is more efficient than dividing by self.std.

def encode(
self, speech: torch.Tensor,
speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding doc describing the shape of various tensors?

return nnet_output

@classmethod
def build_model(cls, asr_train_config, asr_model_file, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls is never used.
I would suggest changing @classmethod to @staticmethod and removing cls.

"""
model = TransformerLM(**config)

assert model_file is not None, f"model file doesn't exist"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f"{model_file} doesn't exist"

if model_type == 'espnet':
return load_espnet_model(config, model_file)
elif model_type == 'snowfall':
raise NotImplementedError(f'Snowfall model to be suppported')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use f-string here.

self.unk_idx = self.token2idx['<unk>']


@dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to use dataclass here?

Also, could you remove the class NumericalizerMixin?
The extra level of inheritance makes the code hard to read.

# The original link of these models is:
# https://zenodo.org/record/4604066#.YKtNrqgzZPY
# which is accessible by espnet utils
# The are ported to following link for users who don't have espnet dependencies.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The -> They

# The are ported to following link for users who don't have espnet dependencies.
if [ ! -d snowfall_model_zoo ]; then
echo "About to download pretrained models."
git clone https://huggingface.co/GuoLiyong/snowfall_model_zoo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using git clone --depth 1. It improves the clone speed.

blank_bias = -1.0
nnet_output[:, :, 0] += blank_bias

supervision_segments = torch.tensor([[0, 0, nnet_output.shape[1]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the batch size always 1? A larger batch size can improve decoding speed.


ref = batch['supervisions']['text']
for i in range(len(ref)):
hyp_words = text.split(' ')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the format of text?
Does text depend on i? If not, you can split it outside of the for loop.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants