Skip to content

Commit

Permalink
Merge pull request #13 from dayyass/release_v0.1.0
Browse files Browse the repository at this point in the history
release v0.1.1
  • Loading branch information
dayyass authored Jul 18, 2022
2 parents 327f162 + 6ffa055 commit 1a47868
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 39 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This workflow will install Python dependencies and run codecov
# https://github.com/codecov/codecov-action#example-workflowyml-with-codecov-action

name: codecov

on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
steps:
- uses: actions/checkout@master
- name: Set up Python
uses: actions/setup-python@master
with:
python-version: 3.7
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov
- name: Generate coverage report
run: pytest --cov=./ --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
flags: unittests
env_vars: OS,PYTHON
fail_ci_if_error: true
verbose: true
37 changes: 37 additions & 0 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This workflow will install Python dependencies and run linter
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

# TODO: update linters

name: linter

on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install dependencies
run: |
pip install --upgrade pip
pip install isort black flake8 mypy
- name: Code format check with isort
run: isort --check-only --profile black .
- name: Code format check with black
run: black --check .
- name: Code format check with flake8
run: flake8 --ignore E501,E203,W503 .
- name: Type check with mypy
run: mypy --ignore-missing-imports .
30 changes: 30 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This workflow will install Python dependencies and run tests with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: tests

on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9']
os: [ubuntu-latest, macOS-latest, add windows-latest]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -r requirements.txt
- name: Unittests
run: python -m unittest discover
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ venv

runs
dayyass

qaner.egg-info
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ repos:
rev: v0.961
hooks:
- id: mypy
- repo: local
hooks:
- id: unittest
name: unittest
entry: venv/bin/python -m unittest discover
language: python
always_run: true
pass_filenames: false
# - repo: local
# hooks:
# - id: unittest
# name: unittest
# entry: venv/bin/python -m unittest discover
# language: python
# always_run: true
# pass_filenames: false
29 changes: 24 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
[![tests](https://github.com/dayyass/qaner/actions/workflows/tests.yml/badge.svg)](https://github.com/dayyass/qaner/actions/workflows/tests.yml)
[![linter](https://github.com/dayyass/qaner/actions/workflows/linter.yml/badge.svg)](https://github.com/dayyass/qaner/actions/workflows/linter.yml)
<!-- [![codecov](https://codecov.io/gh/dayyass/qaner/branch/main/graph/badge.svg?token=S3UKX8BFP3)](https://codecov.io/gh/dayyass/qaner) -->

[![python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://github.com/dayyass/qaner#requirements)
[![release (latest by date)](https://img.shields.io/github/v/release/dayyass/qaner)](https://github.com/dayyass/qaner/releases/latest)
[![license](https://img.shields.io/github/license/dayyass/qaner?color=blue)](https://github.com/dayyass/qaner/blob/main/LICENSE)

[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-black)](https://github.com/dayyass/qaner/blob/main/.pre-commit-config.yaml)
[![code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

[![pypi version](https://img.shields.io/pypi/v/qaner)](https://pypi.org/project/qaner)
[![pypi downloads](https://img.shields.io/pypi/dm/qaner)](https://pypi.org/project/qaner)

# QaNER
Unofficial implementation of [*QaNER: Prompting Question Answering Models for Few-shot Named Entity Recognition*](https://arxiv.org/abs/2203.01543).

You can adopt this pipeline for arbitrary [BIO-markup](https://github.com/dayyass/QaNER/tree/main/data/conll2003) data.

### CoNLL-2003
## Installation
```
pip install qaner
```

## CoNLL-2003
Pipeline results on CoNLL-2003 dataset:
- [Metrics](https://tensorboard.dev/experiment/FEsbNJdmSd2LGVhga8Ku0Q/)
- [Trained Hugging Face model](https://huggingface.co/dayyass/qaner-conll-bert-base-uncased)
Expand All @@ -12,7 +31,7 @@ Pipeline results on CoNLL-2003 dataset:
### Training
Script for training QaNER model:
```
python qaner/train.py \
qaner-train \
--bert_model_name 'bert-base-uncased' \
--path_to_prompt_mapper 'data/conll2003/prompt_mapper.json' \
--path_to_train_data 'data/conll2003/train.bio' \
Expand Down Expand Up @@ -42,7 +61,7 @@ Optional arguments:
### Infrerence
Script for inference trained QaNER model:
```
python qaner/inference.py \
qaner-inference \
--context 'EU rejects German call to boycott British lamb .' \
--question 'What is the organization?' \
--path_to_prompt_mapper 'data/conll2003/prompt_mapper.json' \
Expand Down Expand Up @@ -78,10 +97,10 @@ Possible inference questions for CoNLL-2003:
- What is the organization? (ORG)
- What is the miscellaneous entity? (MISC)

### Requirements
## Requirements
Python >= 3.7

### Citation
## Citation
```bibtex
@misc{liu2022qaner,
title = {QaNER: Prompting Question Answering Models for Few-shot Named Entity Recognition},
Expand Down
2 changes: 1 addition & 1 deletion qaner/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# TODO
__version__ = "0.1.1"
3 changes: 2 additions & 1 deletion qaner/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import torch
import transformers
from data_utils import Instance, Span
from tqdm import tqdm

from qaner.data_utils import Instance, Span


# TODO: add documentation
class Dataset(torch.utils.data.Dataset):
Expand Down
23 changes: 18 additions & 5 deletions qaner/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Any, Dict

import torch
from arg_parse import get_inference_args
from data_utils import Instance
from inference_utils import get_top_valid_spans
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from utils import set_global_seed

from qaner.arg_parse import get_inference_args
from qaner.data_utils import Instance
from qaner.inference_utils import get_top_valid_spans
from qaner.utils import set_global_seed


# TODO: add batch inference
Expand Down Expand Up @@ -71,7 +72,13 @@ def predict(
return prediction


if __name__ == "__main__":
def main() -> int:
"""
Main inference function.
Returns:
int: exit code.
"""

# argparse
args = get_inference_args()
Expand Down Expand Up @@ -113,3 +120,9 @@ def predict(
print(f"\nquestion: {prediction.question}\n")
print(f"context: {prediction.context}")
print(f"\nanswer: {prediction.answer}\n")

return 0


if __name__ == "__main__":
main()
5 changes: 3 additions & 2 deletions qaner/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import torch
import transformers
from data_utils import Span

from qaner.data_utils import Span


def get_top_valid_spans(
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_top_valid_spans(
span = Span(
token=context[start_context_char_char:end_context_char_char],
label=inv_prompt_mapper[ # TODO: add inference exception
question_list[i].lstrip("What is the ").rstrip("?")
question_list[i].split(r"What is the ")[-1].rstrip(r"?")
],
start_context_char_pos=start_context_char_char,
end_context_char_pos=end_context_char_char,
Expand Down
5 changes: 3 additions & 2 deletions qaner/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Dict, List

import numpy as np
from data_utils import Span

from qaner.data_utils import Span


# TODO: add metrics over label types
Expand Down Expand Up @@ -33,7 +34,7 @@ def compute_metrics(
confusion_matrix_pred_denominator = np.zeros(len(entity_mapper))

for span_true, span_pred in zip(spans_true_batch, spans_pred_batch_top_1):
span_pred = span_pred[0]
span_pred = span_pred[0] # type: ignore

i = entity_mapper[span_true.label]
j = entity_mapper[span_pred.label] # type: ignore
Expand Down
26 changes: 20 additions & 6 deletions qaner/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import json

import torch
from arg_parse import get_train_args
from data_utils import prepare_sentences_and_spans, read_bio_markup
from dataset import Collator, Dataset
from torch.utils.tensorboard import SummaryWriter
from train_utils import train
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from utils import set_global_seed

if __name__ == "__main__":
from qaner.arg_parse import get_train_args
from qaner.data_utils import prepare_sentences_and_spans, read_bio_markup
from qaner.dataset import Collator, Dataset
from qaner.train_utils import train
from qaner.utils import set_global_seed


def main() -> int:
"""
Main train function.
Returns:
int: exit code.
"""

# argparse
args = get_train_args()
Expand Down Expand Up @@ -105,3 +113,9 @@

model.save_pretrained(args.path_to_save_model)
tokenizer.save_pretrained(args.path_to_save_model)

return 0


if __name__ == "__main__":
main()
7 changes: 4 additions & 3 deletions qaner/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import numpy as np
import torch
from data_utils import Span
from inference_utils import get_top_valid_spans
from metrics import compute_metrics
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForQuestionAnswering

from qaner.data_utils import Span
from qaner.inference_utils import get_top_valid_spans
from qaner.metrics import compute_metrics


# TODO: add metrics calculation
def train(
Expand Down
33 changes: 33 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from setuptools import setup

from qaner import __version__

with open("README.md", mode="r", encoding="utf-8") as fp:
long_description = fp.read()


setup(
name="qaner",
version=__version__,
description="Unofficial implementation of QaNER: Prompting Question Answering Models for Few-shot Named Entity Recognition.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Dani El-Ayyass",
author_email="[email protected]",
license_files=["LICENSE"],
url="https://github.com/dayyass/qaner",
packages=["qaner"],
entry_points={
"console_scripts": [
"qaner-train = qaner.train:main",
"qaner-inference = qaner.inference:main",
],
},
install_requires=[
"numpy==1.21.6",
"tensorboard==2.9.0",
"torch==1.8.1",
"tqdm==4.64.0",
"transformers==4.19.2",
],
)
Loading

0 comments on commit 1a47868

Please sign in to comment.