From 21f75415fd42a6e3beb355f89115a0294f2eeed1 Mon Sep 17 00:00:00 2001 From: lcxrocks <1255538705@qq.com> Date: Mon, 2 Sep 2024 20:35:13 +0800 Subject: [PATCH 1/3] Add huggingface support. --- Trainer_finetune.py | 30 +++++++++++++++++++------ hf_demo_2x.py | 54 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 7 deletions(-) create mode 100644 hf_demo_2x.py diff --git a/Trainer_finetune.py b/Trainer_finetune.py index a03d898..5307106 100644 --- a/Trainer_finetune.py +++ b/Trainer_finetune.py @@ -5,7 +5,13 @@ from config import * - +def convert(param): + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k and 'attn_mask' not in k and 'HW' not in k + } + class Model: def __init__(self, local_rank): backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE'] @@ -30,18 +36,28 @@ def device(self): self.net.to(torch.device("cuda")) def load_model(self, name=None, rank=0, real=False): - def convert(param): - return { - k.replace("module.", ""): v - for k, v in param.items() - if "module." in k and 'attn_mask' not in k and 'HW' not in k - } if rank <= 0 : if name is None: name = self.name print(f"loading {name} ckpt") self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')), strict=True) + def from_pretrained(self, model_name): + try: + from huggingface_hub import hf_hub_download + + ckpt_path = hf_hub_download( + repo_id="MCG-NJU/VFIMamba", filename="ckpt/" + model_name + ".pkl" + ) + checkpoint = torch.load(ckpt_path) + except: + # In case the model is not hosted on huggingface + # or the user cannot import huggingface_hub correctly, model_name option: VFIMamba, VFIMamba_S + _VFIMAMBA_URL = f"https://huggingface.co/MCG-NJU/VFIMamba/resolve/main/ckpt/{model_name}.pkl" + checkpoint = torch.hub.load_state_dict_from_url(_VFIMAMBA_URL) + + self.net.load_state_dict(convert(checkpoint), strict=True) + @torch.no_grad() def hr_inference(self, img0, img1, local, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False): ''' diff --git a/hf_demo_2x.py b/hf_demo_2x.py new file mode 100644 index 0000000..60eca50 --- /dev/null +++ b/hf_demo_2x.py @@ -0,0 +1,54 @@ +import cv2 +import math +import sys +import torch +import numpy as np +import argparse +from imageio import mimsave + +'''==========import from our code==========''' +sys.path.append('.') +import config as cfg +from Trainer_finetune import Model +from benchmark.utils.padder import InputPadder + +parser = argparse.ArgumentParser() +parser.add_argument('--model', default='VFIMamba_S', type=str) +parser.add_argument('--scale', default=0, type=float) + +args = parser.parse_args() +assert args.model in ['VFIMamba_S', 'VFIMamba'], 'Model not exists!' + + +'''==========Model setting==========''' +TTA = False +if args.model == 'VFIMamba': + TTA = True + cfg.MODEL_CONFIG['LOGNAME'] = 'VFIMamba' + cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( + F = 32, + depth = [2, 2, 2, 3, 3] + ) +model = Model(-1) +model.from_pretrained(args.model) +model.eval() +model.device() + + +print(f'=========================Start Generating=========================') + +I0 = cv2.imread('example/im1.png') +I2 = cv2.imread('example/im2.png') + +I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) +I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) + +padder = InputPadder(I0_.shape, divisor=32) +I0_, I2_ = padder.pad(I0_, I2_) + +mid = (padder.unpad(model.inference(I0_, I2_, True, TTA=TTA, fast_TTA=TTA, scale=args.scale))[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8) +images = [I0[:, :, ::-1], mid[:, :, ::-1], I2[:, :, ::-1]] +mimsave('example/out_2x_hf.gif', images, fps=3) + + +print(f'=========================Done=========================') \ No newline at end of file From d967d957fc3f12ae9bb3fbb5e609b9830c95c574 Mon Sep 17 00:00:00 2001 From: lcxrocks <1255538705@qq.com> Date: Mon, 2 Sep 2024 21:15:47 +0800 Subject: [PATCH 2/3] Add huggingface support. --- Trainer_finetune.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Trainer_finetune.py b/Trainer_finetune.py index 5307106..421fcff 100644 --- a/Trainer_finetune.py +++ b/Trainer_finetune.py @@ -45,15 +45,16 @@ def load_model(self, name=None, rank=0, real=False): def from_pretrained(self, model_name): try: from huggingface_hub import hf_hub_download + assert model_name in ["VFIMamba", "VFIMamba_S"], "Please select a valid model name from ['VFIMamba', 'VFIMamba_S']" ckpt_path = hf_hub_download( - repo_id="MCG-NJU/VFIMamba", filename="ckpt/" + model_name + ".pkl" + repo_id=f"MCG-NJU/{model_name}", filename=model_name + ".pkl" ) checkpoint = torch.load(ckpt_path) except: # In case the model is not hosted on huggingface # or the user cannot import huggingface_hub correctly, model_name option: VFIMamba, VFIMamba_S - _VFIMAMBA_URL = f"https://huggingface.co/MCG-NJU/VFIMamba/resolve/main/ckpt/{model_name}.pkl" + _VFIMAMBA_URL = f"https://huggingface.co/MCG-NJU/{model_name}/resolve/main/{model_name}.pkl" checkpoint = torch.hub.load_state_dict_from_url(_VFIMAMBA_URL) self.net.load_state_dict(convert(checkpoint), strict=True) From 917b41fe4de1e50110b32c23edf7b979d28e9083 Mon Sep 17 00:00:00 2001 From: lcxrocks <1255538705@qq.com> Date: Tue, 3 Sep 2024 10:19:10 +0800 Subject: [PATCH 3/3] More unified integration for huggingface. --- Trainer_finetune.py | 33 ++++++++++++++++++++------------- hf_demo_2x.py | 13 ++----------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/Trainer_finetune.py b/Trainer_finetune.py index 421fcff..b291922 100644 --- a/Trainer_finetune.py +++ b/Trainer_finetune.py @@ -42,22 +42,29 @@ def load_model(self, name=None, rank=0, real=False): print(f"loading {name} ckpt") self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')), strict=True) - def from_pretrained(self, model_name): + @classmethod + def from_pretrained(cls, model_id, local_rank=-1): try: from huggingface_hub import hf_hub_download - assert model_name in ["VFIMamba", "VFIMamba_S"], "Please select a valid model name from ['VFIMamba', 'VFIMamba_S']" - - ckpt_path = hf_hub_download( - repo_id=f"MCG-NJU/{model_name}", filename=model_name + ".pkl" + except ImportError: + raise ImportError( + "Model is hosted on the Hugging Face Hub. " + "Please install huggingface_hub by running `pip install huggingface_hub` to load the weights correctly." ) - checkpoint = torch.load(ckpt_path) - except: - # In case the model is not hosted on huggingface - # or the user cannot import huggingface_hub correctly, model_name option: VFIMamba, VFIMamba_S - _VFIMAMBA_URL = f"https://huggingface.co/MCG-NJU/{model_name}/resolve/main/{model_name}.pkl" - checkpoint = torch.hub.load_state_dict_from_url(_VFIMAMBA_URL) - - self.net.load_state_dict(convert(checkpoint), strict=True) + if "/" not in model_id: + model_id = "MCG-NJU/" + model_id + ckpt_path = hf_hub_download(repo_id=model_id, filename="model.pkl") + print(f"loading {model_id} ckpt") + checkpoint = torch.load(ckpt_path) + from transformers import PretrainedConfig + cfg = PretrainedConfig.from_pretrained(model_id) + MODEL_CONFIG['MODEL_ARCH'] = init_model_config( + F=cfg.F, + depth=cfg.depth, + ) + model = cls(local_rank) + model.net.load_state_dict(convert(checkpoint), strict=True) + return model @torch.no_grad() def hr_inference(self, img0, img1, local, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False): diff --git a/hf_demo_2x.py b/hf_demo_2x.py index 60eca50..f19866d 100644 --- a/hf_demo_2x.py +++ b/hf_demo_2x.py @@ -8,12 +8,11 @@ '''==========import from our code==========''' sys.path.append('.') -import config as cfg from Trainer_finetune import Model from benchmark.utils.padder import InputPadder parser = argparse.ArgumentParser() -parser.add_argument('--model', default='VFIMamba_S', type=str) +parser.add_argument('--model', default='VFIMamba', type=str) parser.add_argument('--scale', default=0, type=float) args = parser.parse_args() @@ -22,15 +21,7 @@ '''==========Model setting==========''' TTA = False -if args.model == 'VFIMamba': - TTA = True - cfg.MODEL_CONFIG['LOGNAME'] = 'VFIMamba' - cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( - F = 32, - depth = [2, 2, 2, 3, 3] - ) -model = Model(-1) -model.from_pretrained(args.model) +model = Model.from_pretrained(args.model) model.eval() model.device()