Skip to content

Commit

Permalink
[ctc] KWS with CTCloss training and CTC prefix beam search detection. (
Browse files Browse the repository at this point in the history
…#135)

* add ctcloss training scripts.

* update compute_det_ctc

* fix typo.

* add fsmn model, can use pretrained kws model from modelscope.

* Add streaming detection of CTC model. Add CTC model onnx export. Add CTC model's result in README; For now CTC model runtime is not supported yet.

* QA run.sh, maxpooling training scripts is compatible. Ready to PR.

* Add a streaming kws demo, support fsmn online forward

* fix typo.

* Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward.

* fix repeat activation, add a interval restrict.

* fix timestamp when subsampling!=1.

* fix flake8, update training script and README, give pretrained ckpt.

* fix quickcheck and flake8

* Add realtime CTC-KWS demo in README.

---------

Co-authored-by: dujing <[email protected]>
  • Loading branch information
duj12 and dujing authored Aug 16, 2023
1 parent 85350c3 commit b233d46
Show file tree
Hide file tree
Showing 22 changed files with 3,328 additions and 19 deletions.
57 changes: 57 additions & 0 deletions examples/hi_xiaowen/s0/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Comparison among different backbones,
all models use Max-Pooling loss.
FRRs with FAR fixed at once per hour:

| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
Expand All @@ -8,3 +10,58 @@ FRRs with FAR fixed at once per hour:
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |

Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.

Since the FAR is pretty low when using CTC loss,
the follow results are FRRs with FAR fixed at once per 12 hours:

Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
FRRs with FAR fixed at once per 12 hours

| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|------------|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |


Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
FRRs with FAR fixed at once per 12 hours:

| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |

Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.

Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:

| model | stream | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 |
| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 |
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |

Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
we record the probability of a keyword in a decoding path once the keyword appears in this path.
Actually the probability will increase through the time, so we record a lower value of probability,
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.

On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

For realtime CTC-KWS, we should process wave input on streaming-fashion,
include feature extraction, keyword decoding and detection and some postprocessing.
Here is a [demo](https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/summary) in python,
the core code is in wekws/bin/stream_kws_ctc.py, you can refer it to implement the runtime code.
50 changes: 50 additions & 0 deletions examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256

model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

50 changes: 50 additions & 0 deletions examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 200

model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 50
log_interval: 100
criterion: ctc

64 changes: 64 additions & 0 deletions examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.
context_expansion: true
context_expansion_conf:
left: 2
right: 2
frame_skip: 3
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256

model:
input_dim: 400
preprocessing:
type: none
hidden_dim: 128
backbone:
type: fsmn
input_affine_dim: 140
num_layers: 4
linear_dim: 250
proj_dim: 128
left_order: 10
right_order: 2
left_stride: 1
right_stride: 1
output_affine_dim: 140
classifier:
type: identity
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

11 changes: 9 additions & 2 deletions examples/hi_xiaowen/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

. ./path.sh

stage=0
stop_stage=4
stage=$1
stop_stage=$2
num_keywords=2

config=conf/ds_tcn.yaml
Expand Down Expand Up @@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python wekws/bin/score.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
Expand All @@ -111,6 +112,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt
done

# plot det curve
python wekws/bin/plot_det_curve.py \
--keywords_dict dict/words.txt \
--stats_dir $result_dir \
--figure_file $result_dir/det.png
fi


Expand Down
Loading

0 comments on commit b233d46

Please sign in to comment.