A Tensorflow 2/Keras implementation of POS tagging task using Bidirectional Long Short Term Memory (denoted as BiLSTM) with Conditional Random Field on top of that BiLSTM layer (at the inference layer) to predict the most relevant POS tags. My work is not the first to apply a BI-LSTM-CRF model to NLP sequence tagging benchmark datasets but it might achieve State-Of-The-Art (or nearly) results on POS tagging, even with NER. Experimental results on the POS tagging corpus Penn Treebank (approximately 1 million tokens for Wall Street Journal) show that my model might achieve SOTA (reaching 98.93% accuracy at word level).
- batch size: 256
- learning rate: 0.001
- number of epochs: 3
- max length (the number of timesteps): 141
- embedding size: 100
- number of tags: 45
- hidden BiLSTM layer: 1
- hidden BiLSTM units: 100
- recurrent dropout: 0.01
I test BI-LSTM-CRF networks on the Penn Treebank (POS tagging task), the table below shows the size of sentences, tokens and labels for training, validation and test sets respectively.
PTB POS | ||
---|---|---|
training | sentence #, token # | 33458, 798284 |
validation | sentence #, token # | 6374, 151744 |
test | sentence #, token # | 1346, 32853 |
label # | 45 |
For Word Representation, i used pretrained word embedding Glove which each word corresponds to a 100-dimentional embedding vector.
First, i set batch size to 64, the model was overfitting at epoch 2, then i changed batch size to 128, it was at epoch 3. Eventually, i set batch size to 256 and it reached the highest accuarcy (at word level): 98.93%.
My implementation is based on the following paper:
Huang, et al. "Bidirectional LSTM-CRF Models for Sequence Tagging" arXiv preprint arXiv:1508.01991 (2015).
-
Tensorflow 2/Keras
-
Numpy
-
JSON
-
NLTK
-
argparse
$ pip install -r requirements.txt
$ python train.py
Output:
$ Viterbi accuracy: 98.93%
Accuracy
Loss
$ python test.py --sent "My heart is always breaking for the ghosts that haunt this room."
Output:
$ [('My', 'prp$'), ('heart', 'nn'), ('is', 'vbz'),
('always', 'rb'), ('breaking', 'vbg'), ('for', 'in'),
('the', 'dt'), ('ghosts', 'nns'), ('that', 'wdt'),
('haunt', 'vbp'), ('this', 'dt'),
('room', 'nn'), ('.', '.')]
Note: A standard dataset for POS tagging is the Wall Street Journal (WSJ) portion of the Penn Treebank, containing 45 different POS tags. Sections 0-18 are used for training, sections 19-21 for development, and sections 22-24 for testing. Models are evaluated based on accuracy. But I just own sections 2-21 (for training), i took 16% from it for development, and section 24 for testing. There's a little bit difference here but (I think) with this model, it would outperform SOTA results for POS tagging task (or nearly). The dataset is not public, contact me via email in case you want to use it.