Skip to content

Commit

Permalink
🐛 fix model loader
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Aug 1, 2024
1 parent 928a9b4 commit d904679
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
23 changes: 17 additions & 6 deletions modules/core/models/tts/CosyVoiceModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
from functools import partial
from pathlib import Path
from typing import Generator
from typing import Generator, Optional

import librosa
import numpy as np
Expand Down Expand Up @@ -43,6 +43,9 @@ class CosyVoiceTTSModel(TTSModel):

load_lock = threading.Lock()

model: Optional[CosyVoiceModel] = None
frontend: Optional[CosyVoiceFrontEnd] = None

def __init__(self) -> None:
super().__init__("cosy-voice")

Expand All @@ -59,21 +62,22 @@ def __init__(self) -> None:
self.logger.info(f"Found CosyVoice model: {paths}")

self.model_dir = paths[0]
self.model: CosyVoiceModel = None
self.frontend: CosyVoiceFrontEnd = None

self.device = devices.get_device_for(self.model_id)
self.dtype = devices.dtype

self.model = CosyVoiceTTSModel.model
self.frontend = CosyVoiceTTSModel.frontend

def reset(self) -> None:
return super().reset()

def load(
self, context: TTSPipelineContext = None
) -> tuple[CosyVoiceModel, CosyVoiceFrontEnd]:
with self.load_lock:
if self.model is not None:
return self.model, self.frontend
if CosyVoiceTTSModel.model is not None:
return CosyVoiceTTSModel.model, CosyVoiceTTSModel.frontend
self.logger.info("Loading CosyVoice model...")

device = self.device
Expand Down Expand Up @@ -111,16 +115,23 @@ def load(
devices.torch_gc()
self.logger.info("CosyVoice model loaded.")

CosyVoiceTTSModel.model = model
CosyVoiceTTSModel.frontend = frontend

return model, frontend

def unload(self, context: TTSPipelineContext = None) -> None:
with self.load_lock:
if self.model is None:
if CosyVoiceTTSModel.model is None:
return
del self.model
del self.frontend
self.model = None
self.frontend = None
del CosyVoiceTTSModel.model
del CosyVoiceTTSModel.frontend
CosyVoiceTTSModel.model = None
CosyVoiceTTSModel.frontend = None
devices.torch_gc()

def encode(self, text: str) -> list[int]:
Expand Down
21 changes: 15 additions & 6 deletions modules/core/models/tts/FishSpeechModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
class FishSpeechModel(TTSModel):
lock = threading.Lock()

model: FISH_SPEECH_LLAMA = None
vqgan: FireflyArchitecture = None

def __init__(self) -> None:
super().__init__("fish-speech")

self.model: FISH_SPEECH_LLAMA = None
self.model: FISH_SPEECH_LLAMA = FishSpeechModel.model
self.vqgan: FireflyArchitecture = FishSpeechModel.vqgan
self.token_decoder: callable = None
self.vqgan: FireflyArchitecture = None

self.device = devices.get_device_for("fish-speech")
self.dtype = devices.dtype
Expand All @@ -54,8 +57,8 @@ def load(
return llama, vqgan

def load_llama(self) -> FISH_SPEECH_LLAMA:
if self.model:
return self.model
if FishSpeechModel.model:
return FishSpeechModel.model

with self.lock:
logger.info(
Expand All @@ -73,12 +76,13 @@ def load_llama(self) -> FISH_SPEECH_LLAMA:

self.model = model
self.token_decoder = token_decoder
FishSpeechModel.model = model
devices.torch_gc()
return model

def load_vqgan(self) -> FireflyArchitecture:
if self.vqgan:
return self.vqgan
if FishSpeechModel.vqgan:
return FishSpeechModel.vqgan

with self.lock:
logger.info(
Expand All @@ -96,6 +100,7 @@ def load_vqgan(self) -> FireflyArchitecture:
model = model.to(device=self.device, dtype=self.dtype)

self.vqgan = model
FishSpeechModel.vqgan = model
return model

def unload(self, context: TTSPipelineContext = None) -> None:
Expand All @@ -106,6 +111,10 @@ def unload(self, context: TTSPipelineContext = None) -> None:
self.model = None
self.token_decoder = None
self.vqgan = None
del FishSpeechModel.vqgan
del FishSpeechModel.model
FishSpeechModel.model = None
FishSpeechModel.vqgan = None
devices.torch_gc()

def encode(self, text: str) -> list[int]:
Expand Down

0 comments on commit d904679

Please sign in to comment.