Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: BPE Training ctc loss and label smooth loss #219

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

glynpu
Copy link
Contributor

@glynpu glynpu commented Jun 24, 2021

As metioned in #217,currently bpe training with ctcLoss and labelSmoothLoss in snowfall obtain higher wer than that of espnet.

decoding algorithm training tool encoder + k2 ctc decode+no rescore
k2 ctc decode in #217 espnet 2.97
k2 ctc decode in #217 snowfall 3.97 (updated on June 24 with wrong data preparation)
k2 ctc decode in #217 snowfall 3.32 (updated on June 28 by fixing data preparation mistake)
k2 ctc decode in #217 snowfall 2.98 (updated on July 21 by adding feature batch_norm and tune training hyper-parameters)
avg (10epochs, 26-35 epochs)
INFO:root:[test-clean] %WER 3.33% [1749 / 52576, 171 ins, 140 del, 1438 sub ]
INFO:root:[test-other] %WER 8.06% [4218 / 52343, 397 ins, 403 del, 3418 sub ]
avg (10 epochs, 29 - 38 epochs)
INFO:root:[test-clean] %WER 3.32% [1744 / 52576, 184 ins, 136 del, 1424 sub ]
INFO:root:[test-other] %WER 7.96% [4167 / 52343, 402 ins, 367 del, 3398 sub ]

The PROBLEM I am facing is:
Wer of snowfall trained models is still a little higher than the model of espnet trained, by 3.32% > 2.97%.
(fixed by correcting datapreparation mistake)During espnet training: loss_att and loss_ctc always have the same order of magnitude, i.e. they decrease at the same pace.
However during snowfall training, loss_att decrease sharply to even below 1.0 while loss_ctc keeps more than [30 to 100] times larger than loss_att.

espnet training log file: https://github.com/glynpu/bpe_training_log_files/blob/master/espnet-egs2-librispeech-asr1-exp-asr_train_asr_conformer7_n_fft512_hop_length256_raw_en_bpe5000_sp-train.log
snowfall training log file of wer 3.97%: https://github.com/glynpu/bpe_training_log_files/blob/master/snowfall-egs-librispeech-asr-simple_v1-train_log.txt
snowfall training log file of wer 3.32% experiment:
https://github.com/glynpu/bpe_training_log_files/blob/master/wer_3.32_June_26_snowfall_egs_librispeech-asr-simple_v1-train_log.txt

What I have tried to make compariable between espnet and snowfall are:

  1. model structures: by loading espnet released models into snowfall successfully(using regular expression to chage key's name in state_dict, similar to this), I believe model structures are identical except parameter names.
  2. loss functions: both use torch.nn.CTCLoss and torch.nn.KLDivLoss
  3. normalization: both loss are normalized by batch_size in espnet and snowfall
  4. learning rate schedule: espnet use WarmupLR, while snowfall use Noam. WramupLR(optimizer.lr=0.0025, warmup_steps=40000) and Noam(model_size=512, factor=10.0, warm_step=40000) are quite similar, though not 100 percent identical.
  5. each batch contains utts in silimar duration: espnet use NumElementBatchSampler and snowfall use BucketingSampler
  6. token_ids: in total 5000 tokens. After spm tokenizer is trained, are removed while other 4997 tokens are kept. Then three tokens, "blank_id = 0; oov_id = 1; sos_eos_id=4999", are manully added(reference).

optimizer=optimizer,
)

total_objf += curr_batch_objf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be: total_objf += curr_batch_objf * curr_batch_num_utts, because you'll later be normalizing by epoch_num_utts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@glynpu
Copy link
Contributor Author

glynpu commented Jun 28, 2021

After fixing some bugs, wer on test-clean decrease from 3.97% to 3.32%, though it's still higher than espnet's 2.97%.

@danpovey
Copy link
Contributor

Great work!!

@sw005320
Copy link

Sorry for my late response.
I did not notice @glynpu's comment.

First, I just want to know whether the difference comes from the training part (probably so) or other parts.

  • I want to check the learning curve. Could you share it?
  • Did you compare the best validation accuracy (note that espnet uses the teacher forcing when computing the accuracy and it would be ~95.7% as written in the log)? We can compare them if we use the same (similar) BPE size.

Also, could you point out the main script to me?
I want to check the overall training and inference flows and several hyper-parameters.

@glynpu
Copy link
Contributor Author

glynpu commented Jul 16, 2021

Thanks for your kindly help! @sw005320

I want to check the learning curve. Could you share it?

image

tensorboard for above screenshot

Did you compare the best validation accuracy (note that espnet uses the teacher forcing when computing the accuracy and it would be ~95.7% as written in the log)? We can compare them if we use the same (similar) BPE size.

Not yet. I will compare the differences between espnet and snowfall.

Also, could you point out the main script to me?I want to check the overall training and inference flows and several hyper-parameters.

Currently this pr is about training part. and #227 is focusing on decoding part.

For training, this shell scripts is the entrance.

  export CUDA_VISIBLE_DEVICES="0, 1, 2, 3"
  python3 bpe_ctc_att_conformer_train.py \
    --bucketing-sampler True \
    --lr-factor 10.0 \
    --num-epochs 50 \
    --full-libri True \
    --max-duration 200 \
    --concatenate-cuts False \
    --world-size 4 \
    > train_log.txt

Some hyper-parameters is hard-coded in bpe_ctc_att_conformer_train.py or constructor of class Conformer. Some of them are listed below:

name value
warmup_step 40,000
att-rate 0.7, i.e. ctc_weight=0.3
lsm_weight 0.1
num_encoder_layerss 12
num_decoder_layers 6
nhead 8
atten_dim 512

For decoding part, the latest decoding implementations is #217, and I plan to port them to espnet after it's approved and finally merged into sowfall. The entrance of decoding is here

if [ $stage -le 3 ]; then
  export CUDA_VISIBLE_DEVICES=2
  python bpe_ctc_att_conformer_decode.py \
    --max-duration=20 \
    --generate-release-model=False \
    --decode_with_released_model=True \
    --num-paths-for-decoder-rescore=500
fi

A core function of decoding in bpe_ctc_att_conformer_decode.py is here

@sw005320
Copy link

Thanks!
Did you use model averaging?
If so, how do you pick up models (best loss?), and how many?

@sw005320
Copy link

I found it https://github.com/k2-fsa/snowfall/pull/217/files#diff-fd4e35e8e4b530ddf5ca285f24f2f92dfb6a0db691e75b4efe1dc59309654883R146-R151

Did you tune it?
It may not have a big difference, but we usually pick up the 10-best models based on the validation accuracy for averaging.

@glynpu
Copy link
Contributor Author

glynpu commented Jul 21, 2021

Latest results are:

  before current
Encoder + ctc 3.32 2.98( wer of espnet released model is 2.97/3.00)
Encoder + TLG + 4-gram lattice rescore + nbest rescore with transformer decoder with log_semering=False and remove repeated tokens 2.73 2.54

Result difference between current pr and espnet is sovled by tune training hyper-parameters with following modifications:

  feat-norm learning-factor warm-up steps epoch
before no 10 40,000 40 epoch (avg=10, with 26-35 epoch)
current yes 5 80,000(around 10 epochs) 50 epochs (avg=20 with 31-50 epochs)

Reason of previous modifications are:
I realized that in espnet, 1 epoch contains around 3000 batchs;
however, in my implementation, with max_duration=200, one epoch contains 6000 batchs.

As a matter of experience, smaller batch_size is compatible with smaller learning rate, so half the learning rate.
Since 1 epoch contains 6000 batchs ranther than 3000 batchs now, I doulbe warm-up steps.

The module feat_batch_norm also helps, resulting 3.32 --> 3.17.

As 35 epochs --> 50 eochs, I just set it arbitrarily to see what will happen with more epochs.

BTW, I failed to increase max_duration=200 because larger max_duration easily cause OOM. 200 seems the largest with my GPUs.

@sw005320
Copy link

I feel that 80,000 warm-up steps are too large. It requires larger epochs to make training converged. I think you can find some optimal points with fewer warm-up steps and comparable performance.

Also, how about using the 3000 batches?

@glynpu
Copy link
Contributor Author

glynpu commented Jul 21, 2021

I feel that 80,000 warm-up steps are too large. It requires larger epochs to make training converged. I think you can find some optimal points with fewer warm-up steps and comparable performance.

80,000 is calculated from:
around 10 epochs = 40,000/ 3000 (in espnet) = 80,000 / 6000(current pr)

Also, how about using the 3000 batches?

6000 batches --> 3000 batches means max_duration = 200 --> max_duration = 400;
which will cause OOM in some batches.
I am still analysising the reason.

@danpovey
Copy link
Contributor

As we mentioned in person, I believe a problem with the current setup is that the transformer loss is being normalized (divided by the minibatch size) twice, once in a library function and once in the training script, while the CTC loss is only normalized once.
If we had logged the 2 objectives separately, we likely would have noticed this.
I think that normalizing even once is not right, and that we should not normalize either of these objectives. The reason is that Librispeech has a wide range of durations, and the Lhotse sampler that we are using actually puts minibatches in bins where they have about the same duration (and approximately constant total duration in seconds), so in effect, right now we have a weight per frame that rises linearly with the sentence length. This will tend to cause convergence problems because longer sentences are harder to align. Removing the normalization (division by len(texts)) should not require changes to learning rates, because we are using Adam and there is no weight decay.
[However, as a separate issue, I think we should experiment with a very small weight decay, which will cause the system to train/converge faster.]

@sw005320
Copy link

FYI, espnet did not normalize the CTC and attention loss by the length.

ctc_loss = ctc_loss.sum() / bno

if att_rate != 0.0:
loss = ((1.0 - att_rate) * ctc_loss + att_rate * att_loss) * accum_grad
Copy link
Contributor Author

@glynpu glynpu Jul 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we mentioned in person, I believe a problem with the current setup is that the transformer loss is being normalized (divided by the minibatch size) twice, once in a library function and once in the training script, while the CTC loss is only normalized once.
If we had logged the 2 objectives separately, we likely would have noticed this.

@danpovey I don't think att_loss is normalized twice. After it is computed at line 85, there is no extra normalization for att_loss anymore.

BTW, the reduplicated normization for att_loss your mentioned maybe about another code(not this pr), which is

        if att_rate != 0.0:
            loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK.
In any case, IMO we shouldn't be normalizing even once. (so remove the "/ bno" above; and remove the same thing in the library function that computes the att_loss).

IMO, we should also remove the gradient-clipping step; my feeling is that if it was helping before, it was helping because it was compensating for the normalization that shouldn't have been happening. This setup is non-recurrent so gradient clipping should not be needed. (However, if we encounter instabilities we can revisit this).

Copy link
Contributor

@pkufool pkufool Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model that removing gradient-clipping and normalizing is training now.
The loss curve is as follows, in the beginning of each epoch, the loss value increases suddenly. Did it encounter instabilities?

total loss ctc loss att loss
image image image

The displaying loss value in tensorboard is normalized by num of utterances.

sys.exit(-1)

# TODO(Liyong Guo) make this configurable.
lang_dir = Path('data/en_token_list/bpe_unigram5000/')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you generate the directory data/en_token_list/bpe_unigram5000/?
I don't find any code responsible for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the code to generate bpe related files.
I will make a pr to your branch. @glynpu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, they are downloaded from the model zoo.

Please refer to the bpe_run.sh, which contains following downloading code:

    git clone https://huggingface.co/GuoLiyong/snowfall_bpe_model
    for sub_dir in data; do
      ln -sf snowfall_bpe_model/$sub_dir ./
    done

Actually, I deliberately don't sumit the code about training bpe model, because this pr is mainly about training pipeline.

@@ -87,6 +91,9 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int
else:
self.decoder_criterion = None

# Reference: https://github.com/espnet/espnet/blob/master/espnet2/asr/ctc.py#L37
self.ctc_loss_fn = torch.nn.CTCLoss(reduction='none')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not retain this pattern, we can make the CTC loss part of the training code. This may not carry over so easily to k2 code, and in any case is a little inflexible for our purposes.

@danpovey
Copy link
Contributor

Does anyone have any pointers to visualizations of the decoder attention in the application of transformers to ASR? I want to get a feel for how it works.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants