Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve HF integration #1

Closed
wants to merge 1 commit into from
Closed

Conversation

NielsRogge
Copy link

@NielsRogge NielsRogge commented Sep 1, 2024

Hi @GuozhenZhang1999 and team,

Thanks for this nice work! I see the checkpoints are already available on the hub which is nice to see. This PR aims to improve the HF integration, as we usually recommend to store each checkpoint in a separate, dedicated model repo. This way, downloads also work for your models.

I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can

  • automatically load the model using from_pretrained (and push it using push_to_hub)
  • track download numbers for your models (similar to models in the Transformers library)
  • have nice model cards on a per-model basis along with tags (so that people find them when filtering https://hf.co/models) => we could collaborate on adding a dedicated "video-frame-interpolation" tag
  • perhaps most importantly, leverage safetensors for the weights in favor of pickle.

It leverages the PyTorchModelHubMixin class which allows to inherits these methods.

Usage is as follows:

from model.flow_estimation import MultiScaleFlow

# instantiate model
model = MultiScaleFlow(...)

# equip model with weights
model.load_state_dict(...)

# push to hub
model.push_to_hub("MCG-NJU/vfi-mamba")

# reload
model = MultiScaleFlow.from_pretrained("MCG-NJU/vfi-mamba")

This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub. Checkpoints could be pushed to https://huggingface.co/MCG-NJU.

Would you be interested in this integration?

Kind regards,

Niels

Note

Please don't merge this PR before pushing the model to the hub :)

@lcxrocks
Copy link
Contributor

lcxrocks commented Sep 2, 2024

Hi @NielsRogge! Thank you very much for your suggestions for our work! But one of our model's components, namely MambaFeature, is passed to MultiScaleFlow as an initialization parameter:

class MultiScaleFlow(nn.Module, PyTorchModelHubMixin):
    def __init__(self, backbone, **kargs):
        super(MultiScaleFlow, self).__init__()
        self.feature_bone = backbone

Because of this, MultiScaleFlow.from_pretrained("MCG-NJU/vfi-mamba") seems unable to correctly initialize self.feature_bone and load its weights. I’m not sure how to integrate your pull request while preserving our current implementation structure. Any ideas?

@NielsRogge
Copy link
Author

Thanks for the context @lcxrocks.

Yes if models take nn.Module's in their init, then they can't be JSON serialized into a config.json (cc @Wauplin).

So some suggestions here:

  • either one can define the MambaFeature class within the MultiScaleFlow class, taking JSON serializable arguments as input:
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin):
    def __init__(self, **kwargs):
        super(MultiScaleFlow, self).__init__()
        backbone_kwargs = kwargs.pop("backbone_kwargs")
        self.feature_bone = MambaFeature(**backbone_kwargs)

where you store each of the parameters regarding the backbone in a dedicated backbone_kwargs field in the config.json.

  • if you prefer to keep the same structure as-is, then we can remove the PyTorchModelHubMixin class and instead define a from_pretrained method ourselves, which uses the hf_hub_download method behind the scenes. See here for an example.

@Wauplin
Copy link

Wauplin commented Sep 2, 2024

Hey there 👋 Agree with @NielsRogge suggestions above :) I would lean toward the first option (using some backbone_kwargs) if that's possible. Otherwise, another solution would be to add a type annotation to backbone argument and then let the model know how to encode/decode this specific attribute. See the VoiceCraft example in this guide. This would look like this:

def backbone_to_dict(backbone: BackboneClass) -> Dict:
    return ... # return a jsonable dictionary with backbone config

def dict_to_backbone(config) -> BackboneClass:
    return BackboneClass(config) # instantiate backbone based on config 

class MultiScaleFlow(nn.Module, PyTorchModelHubMixin,
   coders={
      BackboneClass : (
         backbone_to_dict,  # Encoder: how to convert a `BackboneClass` to a valid jsonable value?
         dict_to_backbone,  # Decoder: how to reconstruct a `BackboneClass` from a dictionary?
      )
   }
):
    def __init__(self, backbone: BackboneClass, **kwargs):
        super(MultiScaleFlow, self).__init__()
        self.feature_bone = backbone

With this solution, the backbone configuration would be saved in the same config.json file with the other kwargs parameter. You should be able to define any logic you want (even downloading files from the Hub), as long as the config is json-serializable.

@@ -76,7 +78,7 @@ def forward(self, x, flow):
mask = tmp[:, 4:5]
return flow, mask

class MultiScaleFlow(nn.Module):
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/VFIMamba", pipeline_tag="video-frame-interpolation", license="apache-2.0"):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/VFIMamba", pipeline_tag="video-frame-interpolation", license="apache-2.0"):
class MultiScaleFlow(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/VFIMamba", library_name="vfi-mamba", pipeline_tag="video-frame-interpolation", license="apache-2.0"):

This change would add a library_name tag in the model card metadata to all uploaded models (even the ones finetuned and uploaded by other users). You will be able to list them using https://huggingface.co/models?other=vfi-mamba to track usage of your library.

@lcxrocks
Copy link
Contributor

lcxrocks commented Sep 2, 2024

@NielsRogge @Wauplin Thank you both so much for your valuable suggestions and support! To ensure we maintain the current code readability, we’ve decided to prioritize the second approach suggested by @NielsRogge (#2). However, we truly appreciate the other suggestions and will revisit and consider them in the future.

@NielsRogge
Copy link
Author

Thanks a lot for your feedback! I'll close this one.

@NielsRogge NielsRogge closed this Sep 2, 2024
@Wauplin
Copy link

Wauplin commented Sep 2, 2024

Makes perfect sense! Thanks for the feedback. Let us know if you have any question building the integration. This guide should be helpful to do it :) https://huggingface.co/docs/huggingface_hub/guides/integrations#a-flexible-approach-helpers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants