From bcf2e561cb6d324d58acecb908d80c8c97695f57 Mon Sep 17 00:00:00 2001 From: zhzluke96 Date: Sat, 12 Oct 2024 17:05:14 +0800 Subject: [PATCH] :bug: fix cosyvoice device #169 --- modules/core/models/tts/CosyVoiceFE.py | 6 ++++-- modules/core/models/tts/FireRed/FireRedInfer.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/core/models/tts/CosyVoiceFE.py b/modules/core/models/tts/CosyVoiceFE.py index 5db3de9..83b665c 100644 --- a/modules/core/models/tts/CosyVoiceFE.py +++ b/modules/core/models/tts/CosyVoiceFE.py @@ -9,6 +9,8 @@ import torchaudio.compliance.kaldi as kaldi import whisper +from modules.devices import devices + class CosyVoiceFrontEnd: @@ -24,7 +26,7 @@ def __init__( ): self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = devices.get_device_for("cosy-voice") option = onnxruntime.SessionOptions() option.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL @@ -45,7 +47,7 @@ def __init__( ], ) if os.path.exists(spk2info): - self.spk2info = torch.load(spk2info, map_location=self.device) + self.spk2info = torch.load(spk2info, map_location="cpu") self.instruct = instruct self.allowed_special = allowed_special self.inflect_parser = inflect.engine() diff --git a/modules/core/models/tts/FireRed/FireRedInfer.py b/modules/core/models/tts/FireRed/FireRedInfer.py index c62132b..150825d 100644 --- a/modules/core/models/tts/FireRed/FireRedInfer.py +++ b/modules/core/models/tts/FireRed/FireRedInfer.py @@ -75,7 +75,7 @@ def __init__(self, config_path: str, pretrained_path: str, device: str = "cuda") stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"], ) - sd = torch.load(self.gpt_path, map_location=device)["model"] + sd = torch.load(self.gpt_path, map_location="cpu")["model"] self.gpt.load_state_dict(sd, strict=True) self.gpt = self.gpt.to(device=device) self.gpt.eval()