Skip to content

Commit

Permalink
🐛 fix cosyvoice device #169
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Oct 12, 2024
1 parent b74aa7a commit bcf2e56
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions modules/core/models/tts/CosyVoiceFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torchaudio.compliance.kaldi as kaldi
import whisper

from modules.devices import devices


class CosyVoiceFrontEnd:

Expand All @@ -24,7 +26,7 @@ def __init__(
):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = devices.get_device_for("cosy-voice")
option = onnxruntime.SessionOptions()
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
Expand All @@ -45,7 +47,7 @@ def __init__(
],
)
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = torch.load(spk2info, map_location="cpu")
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
Expand Down
2 changes: 1 addition & 1 deletion modules/core/models/tts/FireRed/FireRedInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, config_path: str, pretrained_path: str, device: str = "cuda")
stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"],
)

sd = torch.load(self.gpt_path, map_location=device)["model"]
sd = torch.load(self.gpt_path, map_location="cpu")["model"]
self.gpt.load_state_dict(sd, strict=True)
self.gpt = self.gpt.to(device=device)
self.gpt.eval()
Expand Down

0 comments on commit bcf2e56

Please sign in to comment.