-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve HF integration #1
Conversation
Hi @NielsRogge! Thank you very much for your suggestions for our work! But one of our model's components, namely class MultiScaleFlow(nn.Module, PyTorchModelHubMixin):
def __init__(self, backbone, **kargs):
super(MultiScaleFlow, self).__init__()
self.feature_bone = backbone Because of this, |
Thanks for the context @lcxrocks. Yes if models take So some suggestions here:
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin):
def __init__(self, **kwargs):
super(MultiScaleFlow, self).__init__()
backbone_kwargs = kwargs.pop("backbone_kwargs")
self.feature_bone = MambaFeature(**backbone_kwargs) where you store each of the parameters regarding the backbone in a dedicated
|
Hey there 👋 Agree with @NielsRogge suggestions above :) I would lean toward the first option (using some def backbone_to_dict(backbone: BackboneClass) -> Dict:
return ... # return a jsonable dictionary with backbone config
def dict_to_backbone(config) -> BackboneClass:
return BackboneClass(config) # instantiate backbone based on config
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin,
coders={
BackboneClass : (
backbone_to_dict, # Encoder: how to convert a `BackboneClass` to a valid jsonable value?
dict_to_backbone, # Decoder: how to reconstruct a `BackboneClass` from a dictionary?
)
}
):
def __init__(self, backbone: BackboneClass, **kwargs):
super(MultiScaleFlow, self).__init__()
self.feature_bone = backbone With this solution, the backbone configuration would be saved in the same |
@@ -76,7 +78,7 @@ def forward(self, x, flow): | |||
mask = tmp[:, 4:5] | |||
return flow, mask | |||
|
|||
class MultiScaleFlow(nn.Module): | |||
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/VFIMamba", pipeline_tag="video-frame-interpolation", license="apache-2.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/VFIMamba", pipeline_tag="video-frame-interpolation", license="apache-2.0"): | |
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/VFIMamba", library_name="vfi-mamba", pipeline_tag="video-frame-interpolation", license="apache-2.0"): |
This change would add a library_name
tag in the model card metadata to all uploaded models (even the ones finetuned and uploaded by other users). You will be able to list them using https://huggingface.co/models?other=vfi-mamba to track usage of your library.
@NielsRogge @Wauplin Thank you both so much for your valuable suggestions and support! To ensure we maintain the current code readability, we’ve decided to prioritize the second approach suggested by @NielsRogge (#2). However, we truly appreciate the other suggestions and will revisit and consider them in the future. |
Thanks a lot for your feedback! I'll close this one. |
Makes perfect sense! Thanks for the feedback. Let us know if you have any question building the integration. This guide should be helpful to do it :) https://huggingface.co/docs/huggingface_hub/guides/integrations#a-flexible-approach-helpers |
Hi @GuozhenZhang1999 and team,
Thanks for this nice work! I see the checkpoints are already available on the hub which is nice to see. This PR aims to improve the HF integration, as we usually recommend to store each checkpoint in a separate, dedicated model repo. This way, downloads also work for your models.
I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can
from_pretrained
(and push it usingpush_to_hub
)safetensors
for the weights in favor of pickle.It leverages the PyTorchModelHubMixin class which allows to inherits these methods.
Usage is as follows:
This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub. Checkpoints could be pushed to https://huggingface.co/MCG-NJU.
Would you be interested in this integration?
Kind regards,
Niels
Note
Please don't merge this PR before pushing the model to the hub :)