From 557074a9a6ba7365dc9736ae8cb101abdd5e2ade Mon Sep 17 00:00:00 2001 From: Vegetebird <781256005@qq.com> Date: Sun, 5 Jun 2022 11:11:40 +0800 Subject: [PATCH] add demo code --- common/utils.py | 4 ++-- main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/utils.py b/common/utils.py index e287cbf5..22a5305e 100644 --- a/common/utils.py +++ b/common/utils.py @@ -207,8 +207,8 @@ def save_model(previous_name, save_dir, epoch, data_threshold, model, model_name return previous_name -def save_model_epoch(save_dir, epoch, model, model_name): - torch.save(model.state_dict(), '%s/epoch_%s_%d.pth' % (save_dir, model_name, epoch)) +def save_model_epoch(save_dir, epoch, model): + torch.save(model.state_dict(), '%s/epoch_%d.pth' % (save_dir, epoch)) diff --git a/main.py b/main.py index c3fb181d..bc08747e 100644 --- a/main.py +++ b/main.py @@ -197,7 +197,7 @@ def input_augmentation(input_2D, model_trans): p1, p2 = val(opt, actions, test_dataloader, model) if opt.train and not opt.refine: - save_model_epoch(opt.checkpoint, epoch, model['trans'], 'no_refine') + save_model_epoch(opt.checkpoint, epoch, model['trans']) if opt.train and p1 < opt.previous_best_threshold: opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model['trans'], 'no_refine')