diff --git a/src/cnlpt/CnlpModelForClassification.py b/src/cnlpt/CnlpModelForClassification.py index db58cc05..bb66adf9 100644 --- a/src/cnlpt/CnlpModelForClassification.py +++ b/src/cnlpt/CnlpModelForClassification.py @@ -25,6 +25,12 @@ def generalize_encoder_forward_kwargs(encoder, **kwargs: Any) -> Dict[str, Any]: + """ + Create a new input feature argument that preserves only the features that are valid for this encoder. + Warn if a feature is present but not valid for the encoder. + :param encoder: A HF encoder model + :return: Dictionary of valid arguments for this encoder + """ new_kwargs = dict() params = inspect.signature(encoder.forward).parameters for name, value in kwargs.items(): @@ -41,7 +47,13 @@ def generalize_encoder_forward_kwargs(encoder, **kwargs: Any) -> Dict[str, Any]: return new_kwargs -def freeze_encoder_weights(encoder, freeze): +def freeze_encoder_weights(encoder, freeze: float): + """ + Probabilistically freeze the weights of this HF encoder model according to the freeze parameter. + Values of freeze >=1 are treated as if every parameter should be frozen. + :param encoder: HF encoder model + :param freeze: Probability of freezing any given parameter (0-1) + """ for param in encoder.parameters(): if freeze >= 1.0: param.requires_grad = False @@ -243,7 +255,7 @@ def __init__( class CnlpModelForClassification(PreTrainedModel): """ The CNLP transformer model. - + :param config: The CnlpConfig object that configures this model :param class_weights: if provided, the weights to use for each task when computing the loss :param final_task_weight: the weight to use for the final task @@ -338,7 +350,16 @@ def __init__( # self.init_weights() - def predict_relations_with_previous_logits(self, features, logits): + def predict_relations_with_previous_logits( + self, features: torch.Tensor, logits: torch.Tensor + ) -> torch.Tensor: + """ + For the relation prediction task, use previous predictions of the tagging task as additional features in the + representation used for making the relation prediction. + :param features: The existing feature vector for the relations + :param logits: The predicted logits from the tagging task + :return: The augmented feature tensor + """ seq_len = features.shape[1] for prior_task_logits in logits: if len(features.shape) == 4: @@ -375,9 +396,9 @@ def compute_loss( task_logits: torch.FloatTensor, labels: torch.LongTensor, task_ind: int, - task_num_labels, - batch_size, - seq_len, + task_num_labels: int, + batch_size: int, + seq_len: int, state: dict, ) -> None: """ diff --git a/src/cnlpt/HierarchicalTransformer.py b/src/cnlpt/HierarchicalTransformer.py index 0b27a110..3a0c9beb 100644 --- a/src/cnlpt/HierarchicalTransformer.py +++ b/src/cnlpt/HierarchicalTransformer.py @@ -26,21 +26,6 @@ logger = logging.getLogger(__name__) -def set_seed(seed, n_gpu): - """ - Set the random seeds for ``random``, numpy, and pytorch to a specific value. - - Args: - seed: the seed to use - n_gpu: the number of GPUs being used - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if n_gpu > 0: - torch.cuda.manual_seed_all(seed) - - @dataclass class HierarchicalSequenceClassifierOutput(SequenceClassifierOutput): chunk_attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -296,7 +281,7 @@ def __init__( self.label_dictionary = config.label_dictionary self.set_class_weights(class_weights) - def remove_task_classifiers(self, tasks=None): + def remove_task_classifiers(self, tasks: List[str] = None): if tasks is None: self.classifiers = nn.ModuleDict() self.tasks = [] diff --git a/src/cnlpt/cnlp_data.py b/src/cnlpt/cnlp_data.py index 03f1122b..6a581927 100644 --- a/src/cnlpt/cnlp_data.py +++ b/src/cnlpt/cnlp_data.py @@ -26,10 +26,6 @@ logger = logging.getLogger(__name__) -def list_field(default=None, metadata=None): - return field(default_factory=lambda: default, metadata=metadata) - - class Split(Enum): """ Enum representing the three data splits for model development. @@ -106,11 +102,6 @@ def cnlp_convert_features_to_hierarchical( sep_id: int, pad_id: int, insert_empty_chunk_at_beginning: bool = False, - # cls_token_at_end=False, - # sequence_a_segment_id=0, - # cls_token_segment_id=0, - # pad_token_segment_id=0, - # use_special_token=True, ) -> BatchEncoding: """ Chunk an instance of InputFeatures into an instance of HierarchicalInputFeatures @@ -244,7 +235,7 @@ def create_pad_chunk(cls_type=cls_id, sep_type=sep_id, pad_type=pad_id): def cnlp_preprocess_data( - examples: List[InputExample], + examples: Dict[str, Union[List[str], List[int], List[float]]], tokenizer: PreTrainedTokenizer, max_length: Optional[int] = None, tasks: List[str] = None, @@ -258,18 +249,18 @@ def cnlp_preprocess_data( truncate_examples: bool = False, ) -> Union[List[InputFeatures], List[HierarchicalInputFeatures]]: """ - Processes the list of :class:`transformers.InputExample` generated by + Processes the dictionary of data inputs created by the processor defined in :data:`cnlpt.cnlp_processors.cnlp_processors` and converts the examples into a list of :class:`InputFeatures` or :class:`HierarchicalInputFeatures`, depending on the model. :param examples: - the list of examples to convert + the dictionary containing the input data to convert :param tokenizer: the tokenizer :param max_length: the maximum sequence length at which to truncate examples :param tasks: the task name(s) in a list, used to index the labels in the examples list. - :param label_list: a mapping from + :param label_lists: a mapping from tasks to the list of labels for each task. If not provided explicitly, it will be retrieved from the processor with :meth:`transformers.DataProcessor.get_labels`. :param output_modes: the output modes for this task. @@ -415,12 +406,12 @@ def cnlp_preprocess_data( def _build_pytorch_labels( result: BatchEncoding, tasks: List[str], - labels: List, + labels: list, output_modes: Dict[str, str], num_instances: int, max_length: int, label_lists: List[List[str]], -): +) -> list: """ _build_pytorch_labels: we do two things here: map from labels in input space to ints in a softmax, and in a data structure that can contain multiple task types such that the Trainer class will be happy with, and then that @@ -550,6 +541,15 @@ def _build_pytorch_labels( def _build_event_mask( result: BatchEncoding, num_insts: int, event_start_token_id, event_end_token_id ): + """ + Create arrays corresponding to input tokens where the events to be classified contain special mask tokens. + These are used if the --event flag is specified to classify event tokens rather than the [CLS] token. + :param result: The input encodings of the tokens + :param num_insts: The length of the input + :param event_start_token_id: The special token index used to indicate the start of an event. + :param event_end_token_id: The special token index used to indicate the end of an event. + :return: The list of lists of per-instance event mask values corresponding to the input tokens. + """ event_tokens = [] for i in range(num_insts): input_ids = result["input_ids"][i] @@ -597,13 +597,25 @@ def truncate_features(feature: Union[InputFeatures, HierarchicalInputFeatures]) ) -def summarize(li): +def summarize(li) -> str: + """ + Show a summarized version of a list. Used to reduce amount of text in logs for long input examples. + :param li: Input list + :return: Summary string + :meta private: + """ if li is None: return "None" return str(truncate_list_of_lists(li)).replace('"', "").replace("'", "") def truncate_list_of_lists(li: Union[list, str]) -> Union[list, str]: + """ + For a list with more than 3 items, give the first item, summarize the middle items, and final item. + If an element of the list is a list, it will recurse into that list and summarize that. + Primarily used by :func:`summarize` to limit the amount of output in log files for really long input texts. + :meta private: + """ if isinstance(li, str): return li if li: @@ -815,10 +827,11 @@ def __init__( self.num_train_instances += self.processed_dataset["train"].num_rows - def _reconcile_labels_lists(self, processor): + def _reconcile_labels_lists(self, processor: AutoProcessor): """ given a new data processor, which extracted a label list for every task it contained, we reconcile it with existing label list for the same task + :param processor: An AutoProcessor object that contains a processed dataset """ for task, labels in processor.get_labels().items(): if task in self.tasks_to_labels: @@ -842,7 +855,7 @@ def _reconcile_labels_lists(self, processor): else: self.tasks_to_labels[task] = labels - def _reconcile_output_modes(self, processor): + def _reconcile_output_modes(self, processor: AutoProcessor): """ given a new data processor, which inferred output modes for its tasks, make sure those output modes agree with existing inferred output modes for any @@ -886,7 +899,7 @@ def _reconcile_columns(self): if column not in tasks and column not in text_columns: dataset[split_name] = dataset[split_name].remove_columns(column) - def _concatenate_datasets(self): + def _concatenate_datasets(self) -> datasets.DatasetDict: """ We have multiple dataset dicts, we need to create a single dataset dict where we concatenate each of the splits first. diff --git a/src/cnlpt/cnlp_processors.py b/src/cnlpt/cnlp_processors.py index 4dccac1d..f446898e 100644 --- a/src/cnlpt/cnlp_processors.py +++ b/src/cnlpt/cnlp_processors.py @@ -13,6 +13,7 @@ from dataclasses import dataclass, field from typing import Callable, Dict, Optional, List, Union, Any, Set from transformers.data.processors.utils import DataProcessor, InputExample +import datasets from datasets import load_dataset import torch from torch.utils.data.dataset import Dataset @@ -29,7 +30,15 @@ relex = "relations" -def get_unique_labels(dataset, tasks, task_output_modes): +def get_unique_labels( + dataset, tasks: List[str], task_output_modes: Dict[str, str] +) -> Dict[str, List[str]]: + """ + Return the set of unique labels defined in a dataset by iterating through the dataset. + :param tasks: List of tasks that the caller cares about + :param task_output_modes: Dictionary mapping from task names to task output mode + :return: Dictionary from task names to a list of unique labels for that task + """ dataset_unique_labels = dict() for task_ind, task_name in enumerate(tasks): unique_labels = set() @@ -68,7 +77,12 @@ def get_unique_labels(dataset, tasks, task_output_modes): return dataset_unique_labels -def infer_output_modes(dataset): +def infer_output_modes(dataset: datasets.DatasetDict) -> Dict[str, str]: + """ + Figure out what output mode each task in the dataset requires by looking at the format of the labels. + :param dataset: HF datasets DatasetDict containing the loaded dataset + :return: Dictionary mapping from task names to output modes + """ task_output_modes = {} for task_ind, task_name in enumerate(dataset.tasks): output_mode = classification @@ -96,7 +110,12 @@ def infer_output_modes(dataset): return task_output_modes -def get_task_pruned_dataset(dataset, tasks, unique_labels): +def get_task_pruned_dataset( + dataset: datasets.DatasetDict, tasks: List[str], unique_labels: Dict[str, List[str]] +) -> datasets.DatasetDict: + """ + Remove tasks from the dataset that only have 1 unique label + """ tasks_to_remove = [] for task_ind, task_name in enumerate(tasks): if len(unique_labels[task_name]) == 1: diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index 1d9a0641..e25cc4c5 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -70,8 +70,12 @@ logger = logging.getLogger(__name__) -def is_hub_model(model_name): - # check if it's a model on the huggingface model hub: +def is_hub_model(model_name: str) -> bool: + """ + Check for whether a model specification string is on the huggingface model hub + :param model_name: the string to check + :return: whether the model is on the huggingface hub + """ try: url = hf_hub_url(model_name, CONFIG_NAME) r = requests.head(url) @@ -95,7 +99,13 @@ def is_cnlpt_model(model_path: str) -> bool: return encoder_config.model_type == "cnlpt" -def encoder_inferred(model_name_or_path: str) -> bool: +def is_external_encoder(model_name_or_path: str) -> bool: + """ + Check whether a specified model is not a cnlpt model -- an external model like a + huggingface hub model or a downloaded local directory. + :param model_name_or_path: specified model + :return: whether the encoder is an external (non-cnlpt) model + """ return is_hub_model(model_name_or_path) or not is_cnlpt_model(model_name_or_path) @@ -258,7 +268,7 @@ def main( if model_args.config_name else model_args.encoder_name ) - if encoder_inferred(encoder_name): + if is_external_encoder(encoder_name): config = CnlpConfig( encoder_name=encoder_name, finetuning_task=data_args.task_name @@ -341,7 +351,7 @@ def main( # TODO check when download any pretrained language model to local disk, if # the following condition "is_hub_model(encoder_name)" works or not. - if not encoder_inferred(encoder_name): + if not is_external_encoder(encoder_name): # we are loading one of our own trained models as a starting point. # # 1) if training_args.do_train is true: