From 94e91ee051d75cdf5592e541dade2c5637fee709 Mon Sep 17 00:00:00 2001 From: Chengdong Liang <1404056823@qq.com> Date: Mon, 4 Mar 2024 19:54:40 +0800 Subject: [PATCH] [ckpt] print more info for debug when loading state_dict (#282) --- wespeaker/utils/checkpoint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/wespeaker/utils/checkpoint.py b/wespeaker/utils/checkpoint.py index 4d05e27e..78892807 100644 --- a/wespeaker/utils/checkpoint.py +++ b/wespeaker/utils/checkpoint.py @@ -14,11 +14,17 @@ # limitations under the License. import torch +import logging def load_checkpoint(model: torch.nn.Module, path: str): checkpoint = torch.load(path, map_location='cpu') - model.load_state_dict(checkpoint, strict=False) + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, + strict=False) + for key in missing_keys: + logging.warning('missing tensor: {}'.format(key)) + for key in unexpected_keys: + logging.warning('unexpected tensor: {}'.format(key)) def save_checkpoint(model: torch.nn.Module, path: str):