Skip to content

Commit

Permalink
add demo code
Browse files Browse the repository at this point in the history
  • Loading branch information
Vegetebird committed Jun 5, 2022
1 parent 04e1f92 commit 557074a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def save_model(previous_name, save_dir, epoch, data_threshold, model, model_name
return previous_name


def save_model_epoch(save_dir, epoch, model, model_name):
torch.save(model.state_dict(), '%s/epoch_%s_%d.pth' % (save_dir, model_name, epoch))
def save_model_epoch(save_dir, epoch, model):
torch.save(model.state_dict(), '%s/epoch_%d.pth' % (save_dir, epoch))



Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def input_augmentation(input_2D, model_trans):
p1, p2 = val(opt, actions, test_dataloader, model)

if opt.train and not opt.refine:
save_model_epoch(opt.checkpoint, epoch, model['trans'], 'no_refine')
save_model_epoch(opt.checkpoint, epoch, model['trans'])

if opt.train and p1 < opt.previous_best_threshold:
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model['trans'], 'no_refine')
Expand Down

0 comments on commit 557074a

Please sign in to comment.