Skip to content

Commit

Permalink
support export after save model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Sep 9, 2024
1 parent 1e20933 commit 19855ad
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
19 changes: 14 additions & 5 deletions ppdet/engine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import paddle
import paddle.distributed as dist

from ppdet.utils.checkpoint import save_model, save_semi_model
from ppdet.utils.checkpoint import save_model, save_semi_model, save_model_info
from ppdet.metrics import get_infer_results

from ppdet.utils.logger import setup_logger
Expand Down Expand Up @@ -178,11 +178,12 @@ def __init__(self, model):
super(Checkpointer, self).__init__(model)
self.best_ap = -1000.
self.save_dir = self.model.cfg.save_dir
self.model_export_enabled = self.model.cfg.get("model_export_enabled", False)
if hasattr(self.model.model, 'student_model'):
self.weight = self.model.model.student_model
else:
self.weight = self.model.model

def on_epoch_end(self, status):
# Checkpointer only performed during training
mode = status['mode']
Expand Down Expand Up @@ -226,8 +227,10 @@ def on_epoch_end(self, status):
'metric': abs(epoch_ap),
'epoch': epoch_id + 1
}
save_path = os.path.join(self.save_dir, f"{save_name}.pdstates")
save_path = os.path.join(os.path.join(self.save_dir, save_name) if self.model_export_enabled else self.save_dir, f"{save_name}.pdstates")
paddle.save(epoch_metric, save_path)
if self.model_export_enabled:
save_model_info(epoch_metric, self.save_dir, save_name)
if 'save_best_model' in status and status['save_best_model']:
if epoch_ap >= self.best_ap:
self.best_ap = epoch_ap
Expand All @@ -239,6 +242,8 @@ def on_epoch_end(self, status):
}
save_path = os.path.join(self.save_dir, "best_model.pdstates")
paddle.save(best_metric, save_path)
if self.model_export_enabled:
save_model_info(best_metric, self.save_dir, save_name)
logger.info("Best test {} {} is {:0.3f}.".format(
key, eval_func, abs(self.best_ap)))
if weight:
Expand All @@ -250,10 +255,12 @@ def on_epoch_end(self, status):
save_model(
status['weight'],
self.model.optimizer,
self.save_dir,
os.path.join(self.save_dir, save_name) if self.model_export_enabled else self.save_dir,
save_name,
epoch_id + 1,
ema_model=weight)
if self.model_export_enabled:
self.model.export(output_dir=os.path.join(self.save_dir, save_name, "inference"), for_fd=True)
else:
# save model(student model) and ema_model(teacher model)
# in DenseTeacher SSOD, the teacher model will be higher,
Expand All @@ -270,8 +277,10 @@ def on_epoch_end(self, status):
del teacher_model
del student_model
else:
save_model(weight, self.model.optimizer, self.save_dir,
save_model(weight, self.model.optimizer, os.path.join(self.save_dir, save_name) if self.model_export_enabled else self.save_dir,
save_name, epoch_id + 1)
if self.model_export_enabled:
self.model.export(output_dir=os.path.join(self.save_dir, save_name, "inference"), for_fd=True)


class WiferFaceEval(Callback):
Expand Down
3 changes: 2 additions & 1 deletion ppdet/engine/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def _dump_infer_config(config, path, image_shape, model):
else:
reader_cfg = config['TestReader']
dataset_cfg = config['TestDataset']

# print(dataset_cfg)
# exit()
infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])

Expand Down
24 changes: 22 additions & 2 deletions ppdet/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, cfg, mode='train'):
self.custom_white_list = self.cfg.get('custom_white_list', None)
self.custom_black_list = self.cfg.get('custom_black_list', None)
self.use_master_grad = self.cfg.get('master_grad', False)
if 'slim' in cfg and cfg['slim_type'] == 'PTQ':
if ('slim' in cfg and cfg['slim_type'] == 'PTQ') or 'model_export_enabled' in cfg:
self.cfg['TestDataset'] = create('TestDataset')()
log_ranks = cfg.get('log_ranks', '0')
if isinstance(log_ranks, str):
Expand Down Expand Up @@ -1212,8 +1212,13 @@ def _get_infer_cfg_and_input_spec(self,
"img_name": str,
})
if prune_input:
if self.cfg.get("model_export_enabled", False):
model = ExportModel(self.model)
model.eval()
else:
model = self.model
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec, full_graph=True)
model, input_spec=input_spec, full_graph=True)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = _prune_input_spec(
Expand Down Expand Up @@ -1490,3 +1495,18 @@ def reset_norm_param_attr(self, layer, **kwargs):
setattr(layer, name, new_sublayer)

return layer

class ExportModel(nn.Layer):
def __init__(self, model):
super().__init__()
self.base_model = model

def eval(self):
self.training = False
for layer in self.sublayers():
layer.training = False
layer.eval()

def forward(self, x):
x = self.base_model(x)
return x
12 changes: 12 additions & 0 deletions ppdet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import unicode_literals

import os
import json
import numpy as np
import paddle
import paddle.nn as nn
Expand Down Expand Up @@ -375,3 +376,14 @@ def save_semi_model(teacher_model, student_model, optimizer, save_dir,
state_dict['last_iter'] = last_iter
paddle.save(state_dict, save_path + str(last_epoch) + "epoch.pdopt")
logger.info("Save checkpoint: {}".format(save_dir))

def save_model_info(model_info, save_path, prefix):
"""
save model info to the target path
"""
save_path = os.path.join(save_path, prefix)
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, f'{prefix}.info.json'), 'w') as f:
json.dump(model_info, f)
logger.info("Already save model info in {}".format(save_path))

0 comments on commit 19855ad

Please sign in to comment.