Skip to content

EMNLP 23 - Integrating Whisper Encoder to LLaMA Decoder for Generative ASR Error Correction

License

Notifications You must be signed in to change notification settings

Srijith-rkr/Whispering-LLaMA

Repository files navigation

Whispering-LLaMA: Integrate Whisper Encoder to LLaMA Decoder

  • Accepted at EMNLP 2023 (Main Track) | Paper | Slides | PyTorch | HuggingFace CKPT | N-best Dataset
  • ASR Generative Error Correction by leveraging foundational Audio (Whisper) and Language (LLaMA) models.
  • Fusing Whisper Encoder and LLaMA decoder

Introduction

We introduce a novel cross-modal fusion technique designed for generative error correction for Automatic Speech Recognition. In an oversimplified sense, We leverage In-Context learning to feed the n-best hypothesis produced by an Acoustic model into a Large Language model and prompt it to predict the most accurate sentence, as shown below.

We propose a novel mechanism to fuse the acoustic features from the audio input into the LLM to significantly enhance the performance (28.83% -> 37.66% WERR) by leveraging an Audio Foundational model as a feature extractor. We further design our system in a parameter-efficient manner with only 7.97M trainable parameters as shown below. Please refer to the paper [YET] for further information.

Setup

Clone the repo

git clone https://github.com/Srijith-rkr/Whispering-LLaMA
cd WHISPERing-LLaMA

And use the environment.yml file to install dependencies with Anaconda.

conda env create -f environment.yml

Or you can also use the requirements.txt as

pip install -r requirements.txt
  • To obtain the pre-trained Alpaca weights, please refer here. You can then use convert_hf_checkpoint.py to rename the state_dict the lit-llama implementation
  • Or you can use the Alpaca weights hosted in HuggingFace Huggin Face/Whispering-LLaMA. Refer to demo.py on how to use them.

You are all set! 🎉

 

Dataset

We have uploaded our N-best Hypotheses dataset generated using Whisper-Tiny on Hugging Face PeacefulData. The hypotheses were generated using the Hugging Face GigaSpeech dataset M subset. You will be able to map the hypothesis on our dataset with the audio clips from the Gigaspeeh dataset using the 'ID' tag.

Model Weights

The model and tokenizer weights are hosted in Huggin Face/Whispering-LLaMA for easier setup. You can refer to demo.py on how to use them.

Training & Inference

Please refer to :

  • data_preparation to generate your custom n-best hypothesis dataset

  • training/WL-M.py to train the best our best model on your dataset

  • Inference/WL-M.py to run inference

  • Once you setup your dataset, You can train your models as

python training/WL-S.py --lr 1e-3 --d 1 --pretrained_path 'weights/alpaca.pth' --tokenizer_path 'weights/tokenizer.model' --data 'path to your dataset'

You can configure the following flags.

--lr: learning rate (1e-3 is recommended)
--d: Number of GPUs you are using to run the DDP strategy (You can uncomment lines in the code to switch to DeepSpeed)
--pretrained_path: Path to the Alpaca model weights
--tokenizer_path: Path to the LLaMA tokenizer
--data: Path to your dataset

Acknowledgements

This implementation builds on

  • lit-llama for the Training pipeline.

  • stanford_alpaca for the pre-trained instruction following Language model.

  • Whisper to obtain acoustic embeddings.

  • Reference

If you consider this work would be related or useful for your research, please consider to cite this paper. Thank you!

@inproceedings{radhakrishnan2023whispering,
  title={Whispering LLaMA: A Cross-Modal Generative Error Correction Framework for Speech Recognition},
  author={Srijith Radhakrishnan, Chao-Han Huck Yang, Sumeer Ahmad Khan, Rohit Kumar, Narsis A. Kiani, David Gomez-Cabrero, Jesper N. Tegner},
  booktitle={Proc. of EMNLP},
  year={2023}
}

About

EMNLP 23 - Integrating Whisper Encoder to LLaMA Decoder for Generative ASR Error Correction

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published