Skip to content

Commit

Permalink
Move type hints from docstrings into annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
angus-lherrou committed Jul 6, 2023
1 parent 32f1b5f commit 8d24c1c
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 90 deletions.
66 changes: 38 additions & 28 deletions src/cnlpt/CnlpModelForClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# from transformers.models.auto import AutoModel, AutoConfig
import copy
import inspect
from typing import Optional, List, Any, Dict
from os import PathLike
from typing import Optional, List, Any, Dict, Union

from transformers import AutoModel, AutoConfig
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -72,7 +73,16 @@ class RepresentationProjectionLayer(nn.Module):
:param num_attention_heads - For relations, how many "features" to use
:param head_size - For relations, how big each head should be
"""
def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, num_attention_heads=-1, head_size=64):
def __init__(
self,
config: 'CnlpConfig',
layer: int = 10,
tokens: bool = False,
tagger: bool = False,
relations: bool = False,
num_attention_heads: int = -1,
head_size: int = 64,
):
super().__init__()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if relations:
Expand Down Expand Up @@ -143,33 +153,33 @@ class CnlpConfig(PretrainedConfig):
The config class for :class:`CnlpModelForClassification`.
:param encoder_name: the encoder name to use with :meth:`transformers.AutoConfig.from_pretrained`
:param typing.Optional[str] finetuning_task: the tasks for which this model is fine-tuned
:param int layer: the index of the encoder layer to extract features from
:param bool tokens: if true, sentence-level classification is done based on averaged token embeddings for token(s) surrounded by <e> </e> special tokens
:param int num_rel_attention_heads: the number of features/attention heads to use in the NxN relation classifier
:param int rel_attention_head_dims: the number of parameters in each attention head in the NxN relation classifier
:param typing.Dict[str,bool] tagger: for each task, whether the task is a sequence tagging task
:param typing.Dict[str,bool] relations: for each task, whether the task is a relation extraction task
:param bool use_prior_tasks: whether to use the outputs from the previous tasks as additional inputs for subsequent tasks
:param typing.Dict[] hier_head_config: If this is a hierarchical model, this is where the config parameters go
:param typing.Dict[str, typing.List[str]] label_dictionary: A mapping from task names to label sets
:param finetuning_task: the tasks for which this model is fine-tuned
:param layer: the index of the encoder layer to extract features from
:param tokens: if true, sentence-level classification is done based on averaged token embeddings for token(s) surrounded by <e> </e> special tokens
:param num_rel_attention_heads: the number of features/attention heads to use in the NxN relation classifier
:param rel_attention_head_dims: the number of parameters in each attention head in the NxN relation classifier
:param tagger: for each task, whether the task is a sequence tagging task
:param relations: for each task, whether the task is a relation extraction task
:param use_prior_tasks: whether to use the outputs from the previous tasks as additional inputs for subsequent tasks
:param hier_head_config: If this is a hierarchical model, this is where the config parameters go
:param label_dictionary: A mapping from task names to label sets
:param \**kwargs: arguments for :class:`transformers.PretrainedConfig`
"""
model_type='cnlpt'

def __init__(
self,
encoder_name='roberta-base',
finetuning_task=None,
layer=-1,
tokens=False,
num_rel_attention_heads=12,
rel_attention_head_dims=64,
tagger = {},
relations = {},
use_prior_tasks=False,
hier_head_config=None,
label_dictionary = None,
encoder_name: Union[str, PathLike] = 'roberta-base',
finetuning_task: Optional[List[str]] = None,
layer: int = -1,
tokens: bool = False,
num_rel_attention_heads: int = 12,
rel_attention_head_dims: int = 64,
tagger: Dict[str, bool] = {},
relations: Dict[str, bool] = {},
use_prior_tasks: bool = False,
hier_head_config: Dict[str, Any] = None,
label_dictionary: Dict[str, List[str]] = None,
**kwargs
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -206,12 +216,12 @@ class CnlpModelForClassification(PreTrainedModel):
"""
The CNLP transformer model.
:param typing.Optional[typing.List[float]] class_weights: if provided,
:param class_weights: if provided,
the weights to use for each task when computing the loss
:param float final_task_weight: the weight to use for the final task
:param final_task_weight: the weight to use for the final task
when computing the loss; default 1.0.
:param bool freeze: whether to freeze the weights of the encoder
:param bool bias_fit: whether to fine-tune only the bias of the encoder
:param freeze: what proportion of encoder weights to freeze (-1 for none)
:param bias_fit: whether to fine-tune only the bias of the encoder
"""
base_model_prefix = 'cnlpt'
config_class = CnlpConfig
Expand All @@ -222,7 +232,7 @@ def __init__(self,
class_weights: Optional[Dict[str, float]] = None,
final_task_weight: float = 1.0,
freeze: float = -1.0,
bias_fit=False,
bias_fit: bool = False,
):

super().__init__(config)
Expand Down
73 changes: 33 additions & 40 deletions src/cnlpt/cnlp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from transformers import BatchEncoding
from transformers import BatchEncoding, InputExample
# from transformers.data.processors.utils import DataProcessor, InputExample
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Features
Expand Down Expand Up @@ -113,16 +113,15 @@ def cnlp_convert_features_to_hierarchical(
Chunk an instance of InputFeatures into an instance of HierarchicalInputFeatures
for the hierarchical model.
:param BatchEncoding features: the dictionary containing mappings from properties to lists of values for each instance for each of those properties
:param int chunk_len: the maximum length of a chunk
:param int num_chunks: the maximum number of chunks in the instance
:param int cls_id: the tokenizer's ID representing the CLS token
:param int sep_id: the tokenizer's ID representing the SEP token
:param int pad_id: the tokenizer's ID representing the PAD token
:param bool insert_empty_chunk_at_beginning: whether to insert an
:param features: the dictionary containing mappings from properties to lists of values for each instance for each of those properties
:param chunk_len: the maximum length of a chunk
:param num_chunks: the maximum number of chunks in the instance
:param cls_id: the tokenizer's ID representing the CLS token
:param sep_id: the tokenizer's ID representing the SEP token
:param pad_id: the tokenizer's ID representing the PAD token
:param insert_empty_chunk_at_beginning: whether to insert an
empty chunk at the beginning of the instance
:rtype: HierarchicalInputFeatures
:return: an instance of `HierarchicalInputFeatures` containing the chunked instance
:return: an instance of `BatchEncoding` containing the chunked instance
"""

for ind in range(len(features['input_ids'])):
Expand Down Expand Up @@ -234,7 +233,7 @@ def create_pad_chunk(cls_type=cls_id, sep_type=sep_id, pad_type=pad_id):


def cnlp_preprocess_data(
examples,
examples: List[InputExample],
tokenizer: PreTrainedTokenizer,
max_length: Optional[int] = None,
tasks: List[str] = None,
Expand All @@ -253,30 +252,29 @@ def cnlp_preprocess_data(
and converts the examples into a list of :class:`InputFeatures` or
:class:`HierarchicalInputFeatures`, depending on the model.
:param typing.List[transformers.data.processors.utils.InputExample] examples:
:param examples:
the list of examples to convert
:param transformers.tokenization_utils.PreTrainedTokenizer tokenizer: the tokenizer
:param typing.Optional[int] max_length: the maximum sequence length
:param tokenizer: the tokenizer
:param max_length: the maximum sequence length
at which to truncate examples
:param List[str] tasks: the task name(s) in a list, used to index the labels in the examples list.
:param typing.Optional[typing.Dict[str,List[str]]] label_list: a mapping from
:param tasks: the task name(s) in a list, used to index the labels in the examples list.
:param label_list: a mapping from
tasks to the list of labels for each task. If not provided explicitly, it will be retrieved from
the processor with :meth:`transformers.DataProcessor.get_labels`.
:param typing.Optional[Dict[str,str]] output_modes: the output modes for this task.
:param output_modes: the output modes for this task.
If not provided explicitly, it will be retrieved from
:data:`cnlpt.cnlp_processors.cnlp_output_modes`.
:param bool inference: whether we're doing training or inference only -- if inference mode the labels associated with examples can't be trusted.
:param bool hierarchical: whether to structure the data for the hierarchical
:param inference: whether we're doing training or inference only -- if inference mode the labels associated with examples can't be trusted.
:param hierarchical: whether to structure the data for the hierarchical
model (:class:`cnlpt.HierarchicalTransformer.HierarchicalModel`)
:param int chunk_len: for the hierarchical model, the length of each
:param chunk_len: for the hierarchical model, the length of each
chunk in tokens
:param int num_chunks: for the hierarchical model, the number of chunks
:param bool insert_empty_chunk_at_beginning: for the hierarchical model,
:param num_chunks: for the hierarchical model, the number of chunks
:param insert_empty_chunk_at_beginning: for the hierarchical model,
whether to insert an empty chunk at the beginning of the list of chunks
(equivalent in theory to a CLS chunk).
:param bool truncate_examples: whether to truncate the string representation
:param truncate_examples: whether to truncate the string representation
of the example instances printed to the log
:rtype: typing.Union[typing.List[InputFeatures], typing.List[HierarchicalInputFeatures]]
:return: the list of converted input features
"""

Expand Down Expand Up @@ -526,13 +524,11 @@ def _build_event_mask(result:BatchEncoding, num_insts:int, event_start_token_id,

return event_tokens

def truncate_features(feature: Union[InputFeatures, HierarchicalInputFeatures]):
def truncate_features(feature: Union[InputFeatures, HierarchicalInputFeatures]) -> str:
"""
Method to produce a truncated string representation of a feature.
:param typing.Union[InputFeatures, HierarchicalInputFeatures] feature:
the feature to represent
:rtype: str
:param feature: the feature to represent
:return: the truncated representation of the feature
:meta private:
"""
Expand Down Expand Up @@ -634,13 +630,13 @@ class ClinicalNlpDataset(Dataset):
Copy-pasted from GlueDataset with glue task-specific code changed;
moved into here to be self-contained.
:param DataTrainingArguments args: the data training args for this experiment
:param transformers.tokenization_utils.PreTrainedTokenizer tokenizer: the tokenizer
:param typing.Optional[int] limit_length: if provided, the number of
:param args: the data training args for this experiment
:param tokenizer: the tokenizer
:param limit_length: if provided, the number of
examples to include in the dataset
:param typing.Optional[str] cache_dir: if provided, the directory to save/load a cache
:param cache_dir: if provided, the directory to save/load a cache
of this dataset
:param bool hierarchical: whether to structure the data for the hierarchical
:param hierarchical: whether to structure the data for the hierarchical
model (:class:`cnlpt.HierarchicalTransformer.HierarchicalModel`)
"""
args: DataTrainingArguments
Expand Down Expand Up @@ -793,10 +789,10 @@ def _reconcile_columns(self):
dataset[split_name] = dataset[split_name].remove_columns(column)

def _concatenate_datasets(self):
'''
"""
We have multiple dataset dicts, we need to create a single dataset dict
where we concatenate each of the splits first.
'''
"""
datasets_by_split = {}
for dataset in self.datasets:
for split in dataset:
Expand All @@ -813,26 +809,23 @@ def __len__(self) -> int:
"""
Length method for this class.
:rtype: int
:return: the number of datasets included in this dataset
"""
return len(self.datasets)

def __getitem__(self, i):
def __getitem__(self, i) -> Union[InputFeatures, HierarchicalInputFeatures]:
"""
Getitem method for this class.
:param i: the index of the example to retrieve
:rtype: typing.Union[InputFeatures, HierarchicalInputFeatures]
:return: the example at index `i`
"""
return self.features[i]

def get_labels(self):
def get_labels(self) -> Dict[str, List[str]]:
"""
Retrieve the label lists for all the tasks for the dataset.
:rtype: typing.Dict[str,typing.List[str]]
:return: the dictionary of label lists indexed by task name
"""
return self.tasks_to_labels
54 changes: 36 additions & 18 deletions src/cnlpt/cnlp_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
from typing import Set, Any, Dict

import numpy as np
from sklearn.metrics import matthews_corrcoef, f1_score, recall_score, precision_score, classification_report, accuracy_score
from seqeval.metrics import f1_score as seq_f1, classification_report as seq_cls
Expand All @@ -20,7 +22,12 @@ def fix_np_types(input_variable):

return input_variable

def tagging_metrics(label_set, preds, labels, task_name):
def tagging_metrics(
label_set: Set[str],
preds: np.ndarray,
labels: np.ndarray,
task_name: str,
) -> Dict[str, Any]:
"""
One of the metrics functions for use in :func:`cnlp_compute_metrics`.
Expand All @@ -38,9 +45,9 @@ def tagging_metrics(label_set, preds, labels, task_name):
}
:param label_set: The set of labels for this task
:param numpy.ndarray preds: the predicted labels from the model
:param numpy.ndarray labels: the true labels
:rtype: typing.Dict[str, typing.Any]
:param preds: the predicted labels from the model
:param labels: the true labels
:param task_name: the name of the relevant task (unused)
:return: a dictionary containing evaluation metrics
"""
preds = preds.flatten()
Expand All @@ -60,7 +67,12 @@ def tagging_metrics(label_set, preds, labels, task_name):

return {'acc': acc, 'token_f1': fix_np_types(f1), 'f1': fix_np_types(seq_f1([label_seq], [pred_seq])), 'report':'\n'+seq_cls([label_seq], [pred_seq])}

def relation_metrics(label_set, preds, labels, task_name):
def relation_metrics(
label_set: Set[str],
preds: np.ndarray,
labels: np.ndarray,
task_name: str,
) -> Dict[str, Any]:
"""
One of the metrics functions for use in :func:`cnlp_compute_metrics`.
Expand All @@ -77,10 +89,9 @@ def relation_metrics(label_set, preds, labels, task_name):
'precision': precision
}
:param label_set: the set of labels for this task
:param numpy.ndarray preds: the predicted labels from the model
:param numpy.ndarray labels: the true labels
:rtype: typing.Dict[str, typing.Any]
:param label_set: the set of labels for this task
:param preds: the predicted labels from the model
:param labels: the true labels
:return: a dictionary containing evaluation metrics
"""

Expand All @@ -103,7 +114,7 @@ def relation_metrics(label_set, preds, labels, task_name):

return {'f1': f1_scores, 'acc': acc, 'recall':fix_np_types(recall), 'precision':fix_np_types(precision), 'report_dict':report_dict, 'report_str':report_str }

def acc_and_f1(preds, labels):
def acc_and_f1(preds: np.ndarray, labels: np.ndarray) -> Dict[str, Any]:
"""
One of the metrics functions for use in :func:`cnlp_compute_metrics`.
Expand All @@ -119,9 +130,8 @@ def acc_and_f1(preds, labels):
'precision': precision
}
:param numpy.ndarray preds: the predicted labels from the model
:param numpy.ndarray labels: the true labels
:rtype: typing.Dict[str, typing.Any]
:param preds: the predicted labels from the model
:param labels: the true labels
:return: a dictionary containing evaluation metrics
"""
acc = accuracy_score(y_pred=preds, y_true=labels)
Expand All @@ -137,7 +147,14 @@ def acc_and_f1(preds, labels):
"precision": fix_np_types(precision)
}

def cnlp_compute_metrics(task_name, preds, labels, output_mode, label_set):

def cnlp_compute_metrics(
task_name: str,
preds: np.ndarray,
labels: np.ndarray,
output_mode: str,
label_set: Set[str],
) -> Dict[str, Any]:
"""
Function that defines and computes the metrics used for each task.
Expand All @@ -146,10 +163,11 @@ def cnlp_compute_metrics(task_name, preds, labels, output_mode, label_set):
If the new task is a simple classification task, a sensible default
is defined; falling back on this will trigger a warning.
:param str task_name: the task name used to index into cnlp_processors
:param numpy.ndarray preds: the predicted labels from the model
:param numpy.ndarray labels: the true labels
:rtype: typing.Dict[str, typing.Any]
:param task_name: the task name used to index into cnlp_processors
:param preds: the predicted labels from the model
:param labels: the true labels
:param output_mode: the output mode of the classifier
:param label_set: the set of output label names for the classifier
:return: a dictionary containing evaluation metrics
"""

Expand Down
Loading

0 comments on commit 8d24c1c

Please sign in to comment.