-
Notifications
You must be signed in to change notification settings - Fork 239
Core Concepts
-
Model_T: an instance of
torch.nn.Module
, the teacher model, which usually has more parameters than the student model. -
Model_S: an instance of
torch.nn.Module
, the student model, usually smaller than the teacher model for the purpose of model compression and faster inference speed. -
optimizer: instance of
torch.optim.Optimizer
. -
scheduler: instance of
torch.optim.lr_scheduler
, allows flexible adjustment of learning rate. -
dataloader: data iterator, used to generate data batches. A batch can be a tuple or a dict. t
for batch in dataloader:
# if batch_postprocessor is not None:
batch = batch_postprocessor(batch)
# check batch datatype
# passes batch to the model and adaptors
Note:
- During training, the distiller will check if the batch is a dict, if so the model will be called as model(**batch, **args), otherwise the model is called as model(*batch, **args). Hence if the batch is not a dict, users should make sure that the order of each element in the batch is the same as the order of the arguments of model.forward.
args
is used for passing additional parameters. - Users can define a
batch_postprocessor
function to post-process batches if needed.batch_postprocessor
should take a batch and return a batch. See the explanation ontrain
method of Distillers for more details.
- TrainingConfig: configuration related to general deep learning model training.
- DistillationConfig: configuration related to distillation methods.
Distillers are in charge of conducting the actual experiments. The following distillers are available:
-
BasicDistiller
: single-teacher single-task distillation, provides basic distillation strategies. -
GeneralDistiller
(Recommended): single-teacher single-task distillation, supports intermediate features matching. Recommended most of the time. -
MultiTeacherDistiller
: multi-teacher distillation, which distills multiple teacher models ( of the same task) into a single student. This class doesn't support Intermediate features matching. -
MultiTaskDistiller
: multi-task distillation, which distills multiple teacher models (of different tasks) into a single student. This class doesn't support Intermediate features matching. -
BasicTrainer
: Supervised training a single model on a labeled dataset, not for distillation. It can be used to train a teacher model.
In TextBrewer, there are two functions that should be implemented by users: callback and adaptor.
Optional, can be None. At each checkpoint, after saving the model, the distiller calls the callback function with arguments model=model_S, step=global_step
. Callback can be used to evaluate the performance of the student model at each checkpoint.
If users want to do an evaluation in the callback, remember to add model.eval() in the callback.
The signature is
callback(model: torch.nn.Module, step: int) -> Any
It converts the model inputs and outputs to the specified format so that it could be recognized by the distiller, and distillation loss can be computed. At each training step, batch and model outputs will be passed to the adaptor; adaptor reorganize the data and returns a dict.
adaptor(batch: Union[Dict,Tuple], model_outputs: Tuple) -> Dict
The functionality of the adaptor is shown in the figure below:
The available keys and their values of the returned dict are:
-
'logits' :
List[torch.Tensor]
ortorch.Tensor
:The inputs to the final softmax. Each tensor should have the shape (batch_size, num_labels) or (batch_size, length, num_labels).
-
'logits_mask':
List[torch.Tensor]
ortorch.Tensor
:0/1 matrix, which masks logits at specified positions. The positions where mask==0 won't be included in the calculation of loss on logits. Each tensor should have the shape (batch_size, length).
-
'labels':
List[torch.Tensor]
ortorch.Tensor
:Ground-truth labels of the examples. Each tensor should have the shape (batch_size,) or (batch_size, length).
Note:
-
logits_mask only works for logits with shape (batch_size, length, num_labels). It's used to mask in the length dimension, commonly used in sequence labeling tasks.
-
logits, logits_mask and labels should either all be lists of tensors, or all be tensors.
-
-
'losses' :
List[torch.Tensor]
:It stores pre-computed losses, for example, the cross-entropy between logits and ground-truth labels. All the losses stored here would be summed and weighted by
hard_label_weight
and added to the total loss. Each tensor in the list should be a scalar, i.e., shape []. -
'attention':
List[torch.Tensor]
:List of attention matrices, used to compute intermediate feature matching. Each tensor should have the shape (batch_size, num_heads, length, length) or (batch_size, length, length), depending on what attention loss is used. Details about various loss functions can be found at Intermediate Loss.
-
'hidden':
List[torch.Tensor]
:List of hidden states used to compute intermediate feature matching. Each tensor should have the shape (batch_size, length, hidden_dim).
-
'inputs_mask' :
torch.Tensor
:0/1 matrix, performs masking on 'attention' and 'hidden', should have the shape (batch_size, length).
These keys are all optional:
- If there is no 'inputs_mask' or 'logits_mask', then it's considered as no masking, or equivalent to using a mask with all elements equal to 1.
- If not using intermediate feature matching, you can ignore 'attention' and 'hidden'.
- If you don't want to add loss of the original hard labels, you can set
hard_label_weight=0
, and ignore 'losses'. - If 'logits' is not provided, the KD loss of the logits will be omitted.
-
'labels' is required if and only if
probability_shift==True
. - You shouldn't ignore all the keys, otherwise the training won't start :)
Usually 'logits' should be provided, unless you are doing multi-stage training.