diff --git a/peptdeep/model/model_interface.py b/peptdeep/model/model_interface.py index 030f03a9..327fde07 100644 --- a/peptdeep/model/model_interface.py +++ b/peptdeep/model/model_interface.py @@ -244,6 +244,28 @@ def build(self, self._model_to_device() self._init_for_training() + def set_bert_trainable(self, + trainable=False, + bert_layer_name="hidden_nn", + bert_layer_idxes=[1,2,3], + ): + self.set_layer_trainable( + trainable=trainable, + layer_names=[ + f"{bert_layer_name}.bert.layer.{layer}" + for layer in bert_layer_idxes + ] + ) + + def set_layer_trainable(self, + trainable=False, + layer_names=[], + ): + for layer in layer_names: + self.model.get_submodule( + layer + ).requires_grad_(trainable) + def train_with_warmup(self, precursor_df: pd.DataFrame, *, diff --git a/peptdeep/model/ms2.py b/peptdeep/model/ms2.py index b53364bb..f13a4fdd 100644 --- a/peptdeep/model/ms2.py +++ b/peptdeep/model/ms2.py @@ -209,7 +209,6 @@ def forward(self, NCEs:torch.Tensor, instrument_indices, ): - in_x = self.dropout(self.input_nn( aa_indices, mod_x ))