diff --git a/peptdeep/model/model_interface.py b/peptdeep/model/model_interface.py index 327fde07..bac4ccc5 100644 --- a/peptdeep/model/model_interface.py +++ b/peptdeep/model/model_interface.py @@ -256,15 +256,13 @@ def set_bert_trainable(self, for layer in bert_layer_idxes ] ) - + def set_layer_trainable(self, trainable=False, layer_names=[], ): for layer in layer_names: - self.model.get_submodule( - layer - ).requires_grad_(trainable) + self.model.get_submodule(layer).requires_grad_(trainable) def train_with_warmup(self, precursor_df: pd.DataFrame,