Skip to content

Commit

Permalink
More unified integration for huggingface.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcxrocks committed Sep 3, 2024
1 parent d967d95 commit 917b41f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
33 changes: 20 additions & 13 deletions Trainer_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,29 @@ def load_model(self, name=None, rank=0, real=False):
print(f"loading {name} ckpt")
self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')), strict=True)

def from_pretrained(self, model_name):
@classmethod
def from_pretrained(cls, model_id, local_rank=-1):
try:
from huggingface_hub import hf_hub_download
assert model_name in ["VFIMamba", "VFIMamba_S"], "Please select a valid model name from ['VFIMamba', 'VFIMamba_S']"

ckpt_path = hf_hub_download(
repo_id=f"MCG-NJU/{model_name}", filename=model_name + ".pkl"
except ImportError:
raise ImportError(
"Model is hosted on the Hugging Face Hub. "
"Please install huggingface_hub by running `pip install huggingface_hub` to load the weights correctly."
)
checkpoint = torch.load(ckpt_path)
except:
# In case the model is not hosted on huggingface
# or the user cannot import huggingface_hub correctly, model_name option: VFIMamba, VFIMamba_S
_VFIMAMBA_URL = f"https://huggingface.co/MCG-NJU/{model_name}/resolve/main/{model_name}.pkl"
checkpoint = torch.hub.load_state_dict_from_url(_VFIMAMBA_URL)

self.net.load_state_dict(convert(checkpoint), strict=True)
if "/" not in model_id:
model_id = "MCG-NJU/" + model_id
ckpt_path = hf_hub_download(repo_id=model_id, filename="model.pkl")
print(f"loading {model_id} ckpt")
checkpoint = torch.load(ckpt_path)
from transformers import PretrainedConfig
cfg = PretrainedConfig.from_pretrained(model_id)
MODEL_CONFIG['MODEL_ARCH'] = init_model_config(
F=cfg.F,
depth=cfg.depth,
)
model = cls(local_rank)
model.net.load_state_dict(convert(checkpoint), strict=True)
return model

@torch.no_grad()
def hr_inference(self, img0, img1, local, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
Expand Down
13 changes: 2 additions & 11 deletions hf_demo_2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

'''==========import from our code=========='''
sys.path.append('.')
import config as cfg
from Trainer_finetune import Model
from benchmark.utils.padder import InputPadder

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='VFIMamba_S', type=str)
parser.add_argument('--model', default='VFIMamba', type=str)
parser.add_argument('--scale', default=0, type=float)

args = parser.parse_args()
Expand All @@ -22,15 +21,7 @@

'''==========Model setting=========='''
TTA = False
if args.model == 'VFIMamba':
TTA = True
cfg.MODEL_CONFIG['LOGNAME'] = 'VFIMamba'
cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(
F = 32,
depth = [2, 2, 2, 3, 3]
)
model = Model(-1)
model.from_pretrained(args.model)
model = Model.from_pretrained(args.model)
model.eval()
model.device()

Expand Down

0 comments on commit 917b41f

Please sign in to comment.