Skip to content

Commit

Permalink
release v0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dayyass committed Jun 22, 2022
1 parent 2e35f3d commit ac8156c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# QaNER: Prompting Question Answering Models for Few-shot Named Entity Recognition
Unofficial implementation of [QaNER](https://arxiv.org/abs/2203.01543).

You can adopt this pipeline for arbitrary [CoNLL-2003-like format](https://github.com/dayyass/QaNER/tree/main/data/conll2003) data.

### CoNLL-2003
Pipeline results on CoNLL-2003 dataset:
- [Metrics](https://tensorboard.dev/experiment/FEsbNJdmSd2LGVhga8Ku0Q/)
- [Trained Hugging Face model](https://huggingface.co/dayyass/qaner-conll-bert-base-uncased)

## How to use
### Training
Script for training QaNER model:
Expand All @@ -21,7 +28,7 @@ python qaner/train.py \
Required arguments:
- **--bert_model_name** - base bert model for QaNER fine-tuning
- **--path_to_prompt_mapper** - path to prompt mapper json file
- **--path_to_train_data** - path to train data ([CoNLL-2003 like format](https://github.com/dayyass/QaNER/tree/main/data/conll2003))
- **--path_to_train_data** - path to train data ([CoNLL-2003-like format](https://github.com/dayyass/QaNER/tree/main/data/conll2003))
- **--path_to_test_data** - path to test data ([CoNLL-2003-like format](https://github.com/dayyass/QaNER/tree/main/data/conll2003))
- **--path_to_save_model** - path to save trained QaNER model
- **--n_epochs** - number of epochs to fine-tune
Expand All @@ -40,11 +47,20 @@ python qaner/inference.py \
--question 'What is the organization?' \
--path_to_prompt_mapper 'prompt_mapper.json' \
--path_to_trained_model 'dayyass/qaner-conll-bert-base-uncased' \
--n_best_size 5 \
--n_best_size 1 \
--max_answer_length 100 \
--seed 42
```

Result:
```
question: What is the organization?
context: EU rejects German call to boycott British lamb .
answer: [Span(token='EU', label='ORG', start_context_char_pos=0, end_context_char_pos=2)]
```

Required arguments:
- **--context** - sentence to extract entities from
- **--question** - question prompt with entity name to extract (examples below)
Expand All @@ -67,7 +83,7 @@ Python >= 3.7

### Citation
```bibtex
@misc{liu2021qaner,
@misc{liu2022qaner,
title = {QaNER: Prompting Question Answering Models for Few-shot Named Entity Recognition},
author = {Andy T. Liu and Wei Xiao and Henghui Zhu and Dejiao Zhang and Shang-Wen Li and Andrew Arnold},
year = {2022},
Expand Down
30 changes: 10 additions & 20 deletions qaner/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from arg_parse import get_inference_args
from data_utils import Instance, Span
from data_utils import Instance
from inference_utils import get_top_valid_spans
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from utils import set_global_seed
Expand Down Expand Up @@ -43,7 +43,7 @@ def predict(
with torch.no_grad():
outputs = model(**inputs)

spans_pred_batch_top_1 = get_top_valid_spans(
spans_pred_batch_top = get_top_valid_spans(
context_list=[context],
question_list=[question],
prompt_mapper=prompt_mapper,
Expand All @@ -52,30 +52,20 @@ def predict(
offset_mapping_batch=offset_mapping_batch,
n_best_size=n_best_size,
max_answer_length=max_answer_length,
)

# TODO: maybe move into get_top_valid_spans
# TODO: maybe remove it
for i in range(len(spans_pred_batch_top_1)):
if not spans_pred_batch_top_1[i]:
empty_span = Span(
token="",
label="O", # TODO: maybe not "O" label
start_context_char_pos=0,
end_context_char_pos=0,
)
spans_pred_batch_top_1[i] = [empty_span]
)[0]

predicted_answer_span = spans_pred_batch_top_1[0][0] # TODO: remove hardcode
# TODO: validate it
spans_pred_batch_top = [span for span in spans_pred_batch_top if span]

start_pos = predicted_answer_span.start_context_char_pos
end_pos = predicted_answer_span.end_context_char_pos
assert predicted_answer_span.token == context[start_pos:end_pos]
for predicted_answer_span in spans_pred_batch_top:
start_pos = predicted_answer_span.start_context_char_pos
end_pos = predicted_answer_span.end_context_char_pos
assert predicted_answer_span.token == context[start_pos:end_pos]

prediction = Instance(
context=context,
question=question,
answer=predicted_answer_span,
answer=spans_pred_batch_top,
)

return prediction
Expand Down
2 changes: 1 addition & 1 deletion qaner/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_top_valid_spans(
]
span = Span(
token=context[start_context_char_char:end_context_char_char],
label=inv_prompt_mapper[
label=inv_prompt_mapper[ # TODO: add inference exception
question_list[i].lstrip("What is the ").rstrip("?")
],
start_context_char_pos=start_context_char_char,
Expand Down

0 comments on commit ac8156c

Please sign in to comment.