Skip to content

Commit

Permalink
FIX #123 trial 1: set_layer_trainable
Browse files Browse the repository at this point in the history
  • Loading branch information
jalew188 committed Dec 31, 2023
1 parent fa59f55 commit a4a98d2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
22 changes: 22 additions & 0 deletions peptdeep/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,28 @@ def build(self,
self._model_to_device()
self._init_for_training()

def set_bert_trainable(self,
trainable=False,
bert_layer_name="hidden_nn",
bert_layer_idxes=[1,2,3],
):
self.set_layer_trainable(
trainable=trainable,
layer_names=[
f"{bert_layer_name}.bert.layer.{layer}"
for layer in bert_layer_idxes
]
)

def set_layer_trainable(self,
trainable=False,
layer_names=[],
):
for layer in layer_names:
self.model.get_submodule(
layer
).requires_grad_(trainable)

def train_with_warmup(self,
precursor_df: pd.DataFrame,
*,
Expand Down
1 change: 0 additions & 1 deletion peptdeep/model/ms2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def forward(self,
NCEs:torch.Tensor,
instrument_indices,
):

in_x = self.dropout(self.input_nn(
aa_indices, mod_x
))
Expand Down

0 comments on commit a4a98d2

Please sign in to comment.