Skip to content

Commit

Permalink
Cleanup of several functions, adding doc strings and type annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmills committed Aug 24, 2023
1 parent 7cd79ea commit 6fdd718
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 49 deletions.
33 changes: 27 additions & 6 deletions src/cnlpt/CnlpModelForClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@


def generalize_encoder_forward_kwargs(encoder, **kwargs: Any) -> Dict[str, Any]:
"""
Create a new input feature argument that preserves only the features that are valid for this encoder.
Warn if a feature is present but not valid for the encoder.
:param encoder: A HF encoder model
:return: Dictionary of valid arguments for this encoder
"""
new_kwargs = dict()
params = inspect.signature(encoder.forward).parameters
for name, value in kwargs.items():
Expand All @@ -41,7 +47,13 @@ def generalize_encoder_forward_kwargs(encoder, **kwargs: Any) -> Dict[str, Any]:
return new_kwargs


def freeze_encoder_weights(encoder, freeze):
def freeze_encoder_weights(encoder, freeze: float):
"""
Probabilistically freeze the weights of this HF encoder model according to the freeze parameter.
Values of freeze >=1 are treated as if every parameter should be frozen.
:param encoder: HF encoder model
:param freeze: Probability of freezing any given parameter (0-1)
"""
for param in encoder.parameters():
if freeze >= 1.0:
param.requires_grad = False
Expand Down Expand Up @@ -243,7 +255,7 @@ def __init__(
class CnlpModelForClassification(PreTrainedModel):
"""
The CNLP transformer model.
:param config: The CnlpConfig object that configures this model
:param class_weights: if provided,
the weights to use for each task when computing the loss
:param final_task_weight: the weight to use for the final task
Expand Down Expand Up @@ -338,7 +350,16 @@ def __init__(

# self.init_weights()

def predict_relations_with_previous_logits(self, features, logits):
def predict_relations_with_previous_logits(
self, features: torch.Tensor, logits: torch.Tensor
) -> torch.Tensor:
"""
For the relation prediction task, use previous predictions of the tagging task as additional features in the
representation used for making the relation prediction.
:param features: The existing feature vector for the relations
:param logits: The predicted logits from the tagging task
:return: The augmented feature tensor
"""
seq_len = features.shape[1]
for prior_task_logits in logits:
if len(features.shape) == 4:
Expand Down Expand Up @@ -375,9 +396,9 @@ def compute_loss(
task_logits: torch.FloatTensor,
labels: torch.LongTensor,
task_ind: int,
task_num_labels,
batch_size,
seq_len,
task_num_labels: int,
batch_size: int,
seq_len: int,
state: dict,
) -> None:
"""
Expand Down
17 changes: 1 addition & 16 deletions src/cnlpt/HierarchicalTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,6 @@
logger = logging.getLogger(__name__)


def set_seed(seed, n_gpu):
"""
Set the random seeds for ``random``, numpy, and pytorch to a specific value.
Args:
seed: the seed to use
n_gpu: the number of GPUs being used
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)


@dataclass
class HierarchicalSequenceClassifierOutput(SequenceClassifierOutput):
chunk_attentions: Optional[Tuple[torch.FloatTensor]] = None
Expand Down Expand Up @@ -296,7 +281,7 @@ def __init__(
self.label_dictionary = config.label_dictionary
self.set_class_weights(class_weights)

def remove_task_classifiers(self, tasks=None):
def remove_task_classifiers(self, tasks: List[str] = None):
if tasks is None:
self.classifiers = nn.ModuleDict()
self.tasks = []
Expand Down
51 changes: 32 additions & 19 deletions src/cnlpt/cnlp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
logger = logging.getLogger(__name__)


def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)


class Split(Enum):
"""
Enum representing the three data splits for model development.
Expand Down Expand Up @@ -106,11 +102,6 @@ def cnlp_convert_features_to_hierarchical(
sep_id: int,
pad_id: int,
insert_empty_chunk_at_beginning: bool = False,
# cls_token_at_end=False,
# sequence_a_segment_id=0,
# cls_token_segment_id=0,
# pad_token_segment_id=0,
# use_special_token=True,
) -> BatchEncoding:
"""
Chunk an instance of InputFeatures into an instance of HierarchicalInputFeatures
Expand Down Expand Up @@ -244,7 +235,7 @@ def create_pad_chunk(cls_type=cls_id, sep_type=sep_id, pad_type=pad_id):


def cnlp_preprocess_data(
examples: List[InputExample],
examples: Dict[str, Union[List[str], List[int], List[float]]],
tokenizer: PreTrainedTokenizer,
max_length: Optional[int] = None,
tasks: List[str] = None,
Expand All @@ -258,18 +249,18 @@ def cnlp_preprocess_data(
truncate_examples: bool = False,
) -> Union[List[InputFeatures], List[HierarchicalInputFeatures]]:
"""
Processes the list of :class:`transformers.InputExample` generated by
Processes the dictionary of data inputs created by
the processor defined in :data:`cnlpt.cnlp_processors.cnlp_processors`
and converts the examples into a list of :class:`InputFeatures` or
:class:`HierarchicalInputFeatures`, depending on the model.
:param examples:
the list of examples to convert
the dictionary containing the input data to convert
:param tokenizer: the tokenizer
:param max_length: the maximum sequence length
at which to truncate examples
:param tasks: the task name(s) in a list, used to index the labels in the examples list.
:param label_list: a mapping from
:param label_lists: a mapping from
tasks to the list of labels for each task. If not provided explicitly, it will be retrieved from
the processor with :meth:`transformers.DataProcessor.get_labels`.
:param output_modes: the output modes for this task.
Expand Down Expand Up @@ -415,12 +406,12 @@ def cnlp_preprocess_data(
def _build_pytorch_labels(
result: BatchEncoding,
tasks: List[str],
labels: List,
labels: list,
output_modes: Dict[str, str],
num_instances: int,
max_length: int,
label_lists: List[List[str]],
):
) -> list:
"""
_build_pytorch_labels: we do two things here: map from labels in input space to ints in a softmax, and in a data
structure that can contain multiple task types such that the Trainer class will be happy with, and then that
Expand Down Expand Up @@ -550,6 +541,15 @@ def _build_pytorch_labels(
def _build_event_mask(
result: BatchEncoding, num_insts: int, event_start_token_id, event_end_token_id
):
"""
Create arrays corresponding to input tokens where the events to be classified contain special mask tokens.
These are used if the --event flag is specified to classify event tokens rather than the [CLS] token.
:param result: The input encodings of the tokens
:param num_insts: The length of the input
:param event_start_token_id: The special token index used to indicate the start of an event.
:param event_end_token_id: The special token index used to indicate the end of an event.
:return: The list of lists of per-instance event mask values corresponding to the input tokens.
"""
event_tokens = []
for i in range(num_insts):
input_ids = result["input_ids"][i]
Expand Down Expand Up @@ -597,13 +597,25 @@ def truncate_features(feature: Union[InputFeatures, HierarchicalInputFeatures])
)


def summarize(li):
def summarize(li) -> str:
"""
Show a summarized version of a list. Used to reduce amount of text in logs for long input examples.
:param li: Input list
:return: Summary string
:meta private:
"""
if li is None:
return "None"
return str(truncate_list_of_lists(li)).replace('"', "").replace("'", "")


def truncate_list_of_lists(li: Union[list, str]) -> Union[list, str]:
"""
For a list with more than 3 items, give the first item, summarize the middle items, and final item.
If an element of the list is a list, it will recurse into that list and summarize that.
Primarily used by :func:`summarize` to limit the amount of output in log files for really long input texts.
:meta private:
"""
if isinstance(li, str):
return li
if li:
Expand Down Expand Up @@ -815,10 +827,11 @@ def __init__(

self.num_train_instances += self.processed_dataset["train"].num_rows

def _reconcile_labels_lists(self, processor):
def _reconcile_labels_lists(self, processor: AutoProcessor):
"""
given a new data processor, which extracted a label list for every task it contained,
we reconcile it with existing label list for the same task
:param processor: An AutoProcessor object that contains a processed dataset
"""
for task, labels in processor.get_labels().items():
if task in self.tasks_to_labels:
Expand All @@ -842,7 +855,7 @@ def _reconcile_labels_lists(self, processor):
else:
self.tasks_to_labels[task] = labels

def _reconcile_output_modes(self, processor):
def _reconcile_output_modes(self, processor: AutoProcessor):
"""
given a new data processor, which inferred output modes for its tasks, make
sure those output modes agree with existing inferred output modes for any
Expand Down Expand Up @@ -886,7 +899,7 @@ def _reconcile_columns(self):
if column not in tasks and column not in text_columns:
dataset[split_name] = dataset[split_name].remove_columns(column)

def _concatenate_datasets(self):
def _concatenate_datasets(self) -> datasets.DatasetDict:
"""
We have multiple dataset dicts, we need to create a single dataset dict
where we concatenate each of the splits first.
Expand Down
25 changes: 22 additions & 3 deletions src/cnlpt/cnlp_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional, List, Union, Any, Set
from transformers.data.processors.utils import DataProcessor, InputExample
import datasets
from datasets import load_dataset
import torch
from torch.utils.data.dataset import Dataset
Expand All @@ -29,7 +30,15 @@
relex = "relations"


def get_unique_labels(dataset, tasks, task_output_modes):
def get_unique_labels(
dataset, tasks: List[str], task_output_modes: Dict[str, str]
) -> Dict[str, List[str]]:
"""
Return the set of unique labels defined in a dataset by iterating through the dataset.
:param tasks: List of tasks that the caller cares about
:param task_output_modes: Dictionary mapping from task names to task output mode
:return: Dictionary from task names to a list of unique labels for that task
"""
dataset_unique_labels = dict()
for task_ind, task_name in enumerate(tasks):
unique_labels = set()
Expand Down Expand Up @@ -68,7 +77,12 @@ def get_unique_labels(dataset, tasks, task_output_modes):
return dataset_unique_labels


def infer_output_modes(dataset):
def infer_output_modes(dataset: datasets.DatasetDict) -> Dict[str, str]:
"""
Figure out what output mode each task in the dataset requires by looking at the format of the labels.
:param dataset: HF datasets DatasetDict containing the loaded dataset
:return: Dictionary mapping from task names to output modes
"""
task_output_modes = {}
for task_ind, task_name in enumerate(dataset.tasks):
output_mode = classification
Expand Down Expand Up @@ -96,7 +110,12 @@ def infer_output_modes(dataset):
return task_output_modes


def get_task_pruned_dataset(dataset, tasks, unique_labels):
def get_task_pruned_dataset(
dataset: datasets.DatasetDict, tasks: List[str], unique_labels: Dict[str, List[str]]
) -> datasets.DatasetDict:
"""
Remove tasks from the dataset that only have 1 unique label
"""
tasks_to_remove = []
for task_ind, task_name in enumerate(tasks):
if len(unique_labels[task_name]) == 1:
Expand Down
20 changes: 15 additions & 5 deletions src/cnlpt/train_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@
logger = logging.getLogger(__name__)


def is_hub_model(model_name):
# check if it's a model on the huggingface model hub:
def is_hub_model(model_name: str) -> bool:
"""
Check for whether a model specification string is on the huggingface model hub
:param model_name: the string to check
:return: whether the model is on the huggingface hub
"""
try:
url = hf_hub_url(model_name, CONFIG_NAME)
r = requests.head(url)
Expand All @@ -95,7 +99,13 @@ def is_cnlpt_model(model_path: str) -> bool:
return encoder_config.model_type == "cnlpt"


def encoder_inferred(model_name_or_path: str) -> bool:
def is_external_encoder(model_name_or_path: str) -> bool:
"""
Check whether a specified model is not a cnlpt model -- an external model like a
huggingface hub model or a downloaded local directory.
:param model_name_or_path: specified model
:return: whether the encoder is an external (non-cnlpt) model
"""
return is_hub_model(model_name_or_path) or not is_cnlpt_model(model_name_or_path)


Expand Down Expand Up @@ -258,7 +268,7 @@ def main(
if model_args.config_name
else model_args.encoder_name
)
if encoder_inferred(encoder_name):
if is_external_encoder(encoder_name):
config = CnlpConfig(
encoder_name=encoder_name,
finetuning_task=data_args.task_name
Expand Down Expand Up @@ -341,7 +351,7 @@ def main(

# TODO check when download any pretrained language model to local disk, if
# the following condition "is_hub_model(encoder_name)" works or not.
if not encoder_inferred(encoder_name):
if not is_external_encoder(encoder_name):
# we are loading one of our own trained models as a starting point.
#
# 1) if training_args.do_train is true:
Expand Down

0 comments on commit 6fdd718

Please sign in to comment.