diff --git a/Trainer_finetune.py b/Trainer_finetune.py index 421fcff..b291922 100644 --- a/Trainer_finetune.py +++ b/Trainer_finetune.py @@ -42,22 +42,29 @@ def load_model(self, name=None, rank=0, real=False): print(f"loading {name} ckpt") self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')), strict=True) - def from_pretrained(self, model_name): + @classmethod + def from_pretrained(cls, model_id, local_rank=-1): try: from huggingface_hub import hf_hub_download - assert model_name in ["VFIMamba", "VFIMamba_S"], "Please select a valid model name from ['VFIMamba', 'VFIMamba_S']" - - ckpt_path = hf_hub_download( - repo_id=f"MCG-NJU/{model_name}", filename=model_name + ".pkl" + except ImportError: + raise ImportError( + "Model is hosted on the Hugging Face Hub. " + "Please install huggingface_hub by running `pip install huggingface_hub` to load the weights correctly." ) - checkpoint = torch.load(ckpt_path) - except: - # In case the model is not hosted on huggingface - # or the user cannot import huggingface_hub correctly, model_name option: VFIMamba, VFIMamba_S - _VFIMAMBA_URL = f"https://huggingface.co/MCG-NJU/{model_name}/resolve/main/{model_name}.pkl" - checkpoint = torch.hub.load_state_dict_from_url(_VFIMAMBA_URL) - - self.net.load_state_dict(convert(checkpoint), strict=True) + if "/" not in model_id: + model_id = "MCG-NJU/" + model_id + ckpt_path = hf_hub_download(repo_id=model_id, filename="model.pkl") + print(f"loading {model_id} ckpt") + checkpoint = torch.load(ckpt_path) + from transformers import PretrainedConfig + cfg = PretrainedConfig.from_pretrained(model_id) + MODEL_CONFIG['MODEL_ARCH'] = init_model_config( + F=cfg.F, + depth=cfg.depth, + ) + model = cls(local_rank) + model.net.load_state_dict(convert(checkpoint), strict=True) + return model @torch.no_grad() def hr_inference(self, img0, img1, local, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False): diff --git a/hf_demo_2x.py b/hf_demo_2x.py index 60eca50..f19866d 100644 --- a/hf_demo_2x.py +++ b/hf_demo_2x.py @@ -8,12 +8,11 @@ '''==========import from our code==========''' sys.path.append('.') -import config as cfg from Trainer_finetune import Model from benchmark.utils.padder import InputPadder parser = argparse.ArgumentParser() -parser.add_argument('--model', default='VFIMamba_S', type=str) +parser.add_argument('--model', default='VFIMamba', type=str) parser.add_argument('--scale', default=0, type=float) args = parser.parse_args() @@ -22,15 +21,7 @@ '''==========Model setting==========''' TTA = False -if args.model == 'VFIMamba': - TTA = True - cfg.MODEL_CONFIG['LOGNAME'] = 'VFIMamba' - cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( - F = 32, - depth = [2, 2, 2, 3, 3] - ) -model = Model(-1) -model.from_pretrained(args.model) +model = Model.from_pretrained(args.model) model.eval() model.device()