Skip to content

Commit

Permalink
Merge pull request #2 from lcxrocks/main
Browse files Browse the repository at this point in the history
🤗 Add support for downloading Huggingface weights.
  • Loading branch information
GuozhenZhang1999 committed Sep 3, 2024
2 parents 152cff3 + 917b41f commit eb1d6dd
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
38 changes: 31 additions & 7 deletions Trainer_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from config import *


def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k and 'attn_mask' not in k and 'HW' not in k
}

class Model:
def __init__(self, local_rank):
backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
Expand All @@ -30,18 +36,36 @@ def device(self):
self.net.to(torch.device("cuda"))

def load_model(self, name=None, rank=0, real=False):
def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k and 'attn_mask' not in k and 'HW' not in k
}
if rank <= 0 :
if name is None:
name = self.name
print(f"loading {name} ckpt")
self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')), strict=True)

@classmethod
def from_pretrained(cls, model_id, local_rank=-1):
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(
"Model is hosted on the Hugging Face Hub. "
"Please install huggingface_hub by running `pip install huggingface_hub` to load the weights correctly."
)
if "/" not in model_id:
model_id = "MCG-NJU/" + model_id
ckpt_path = hf_hub_download(repo_id=model_id, filename="model.pkl")
print(f"loading {model_id} ckpt")
checkpoint = torch.load(ckpt_path)
from transformers import PretrainedConfig
cfg = PretrainedConfig.from_pretrained(model_id)
MODEL_CONFIG['MODEL_ARCH'] = init_model_config(
F=cfg.F,
depth=cfg.depth,
)
model = cls(local_rank)
model.net.load_state_dict(convert(checkpoint), strict=True)
return model

@torch.no_grad()
def hr_inference(self, img0, img1, local, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
'''
Expand Down
45 changes: 45 additions & 0 deletions hf_demo_2x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import cv2
import math
import sys
import torch
import numpy as np
import argparse
from imageio import mimsave

'''==========import from our code=========='''
sys.path.append('.')
from Trainer_finetune import Model
from benchmark.utils.padder import InputPadder

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='VFIMamba', type=str)
parser.add_argument('--scale', default=0, type=float)

args = parser.parse_args()
assert args.model in ['VFIMamba_S', 'VFIMamba'], 'Model not exists!'


'''==========Model setting=========='''
TTA = False
model = Model.from_pretrained(args.model)
model.eval()
model.device()


print(f'=========================Start Generating=========================')

I0 = cv2.imread('example/im1.png')
I2 = cv2.imread('example/im2.png')

I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0)
I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0)

padder = InputPadder(I0_.shape, divisor=32)
I0_, I2_ = padder.pad(I0_, I2_)

mid = (padder.unpad(model.inference(I0_, I2_, True, TTA=TTA, fast_TTA=TTA, scale=args.scale))[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
images = [I0[:, :, ::-1], mid[:, :, ::-1], I2[:, :, ::-1]]
mimsave('example/out_2x_hf.gif', images, fps=3)


print(f'=========================Done=========================')

0 comments on commit eb1d6dd

Please sign in to comment.