Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dvae.py and change param gamma to weight #733

Merged
merged 3 commits into from
Oct 15, 2024

Conversation

zly-idleness
Copy link
Contributor

@zly-idleness zly-idleness commented Aug 29, 2024

change param key gamma to weight
fixes #732
tranformer will change keys have 'gamma' in it to 'weight',due to compatibility issues
in dvae model this code is affected

class ConvNeXtBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        kernel: int,
        dilation: int,
        layer_scale_init_value: float = 1e-6,
    ):
        # ConvNeXt Block copied from Vocos.
        super().__init__()
        self.dwconv = nn.Conv1d(
            dim,
            dim,
            kernel_size=kernel,
            padding=dilation * (kernel // 2),
            dilation=dilation,
            groups=dim,
        )  # depthwise conv

        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(
            dim, intermediate_dim
        )  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(intermediate_dim, dim)
        self.gamma = (
            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
            if layer_scale_init_value > 0
            else None
        )
       
 @classmethod
    def _load_pretrained_model(...):
        ...
      def _fix_key(key):
            if "beta" in key:
                return key.replace("beta", "bias")
            if "gamma" in key:
                return key.replace("gamma", "weight")
            return key

        original_loaded_keys = loaded_keys
        loaded_keys = [_fix_key(key) for key in loaded_keys]

and bug output like :

(vqa-audio) (base) jeeves@notebook-5064-cadence:~/ChatTTS/rhapsodyaudio$ python tmp_save_pretrain.py 
bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████| 8/8 [00:16<00:00,  2.02s/it]
Some weights of Qwen2AudioForConditionalChatTTS were not initialized from the model checkpoint at /mnt/data/user/tc_agi/luoyuanZ/ChatTTS_default and are newly initialized: 

['tts.dvae.decoder.decoder_block.0.gamma', 'tts.dvae.decoder.decoder_block.1.gamma', 'tts.dvae.decoder.decoder_block.10.gamma', 'tts.dvae.decoder.decoder_block.11.gamma', 'tts.dvae.decoder.decoder_block.2.gamma', 'tts.dvae.decoder.decoder_block.3.gamma', 'tts.dvae.decoder.decoder_block.4.gamma', 'tts.dvae.decoder.decoder_block.5.gamma', 'tts.dvae.decoder.decoder_block.6.gamma', 'tts.dvae.decoder.decoder_block.7.gamma', 'tts.dvae.decoder.decoder_block.8.gamma', 'tts.dvae.decoder.decoder_block.9.gamma', 'tts.dvae.encoder.decoder_block.0.gamma', 'tts.dvae.encoder.decoder_block.1.gamma', 'tts.dvae.encoder.decoder_block.10.gamma', 'tts.dvae.encoder.decoder_block.11.gamma', 'tts.dvae.encoder.decoder_block.2.gamma', 'tts.dvae.encoder.decoder_block.3.gamma', 'tts.dvae.encoder.decoder_block.4.gamma', 'tts.dvae.encoder.decoder_block.5.gamma', 'tts.dvae.encoder.decoder_block.6.gamma', 'tts.dvae.encoder.decoder_block.7.gamma', 'tts.dvae.encoder.decoder_block.8.gamma', 'tts.dvae.encoder.decoder_block.9.gamma']

after change param:

(vqa-audio) (base) jeeves@notebook-5064-cadence:~/ChatTTS/rhapsodyaudio$ python tmp_save_pretrain.py 
bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████| 8/8 [00:14<00:00,  1.81s/it]

change gamma to weight
@github-actions github-actions bot changed the base branch from main to dev August 29, 2024 10:48
@fumiama fumiama added bug Something isn't working algorithm Algorithm improvements & issues labels Aug 30, 2024
@fumiama fumiama merged commit b3d511b into 2noise:dev Oct 15, 2024
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algorithm Algorithm improvements & issues bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants