diff --git a/.gitignore b/.gitignore index 71a8771..890df1a 100644 --- a/.gitignore +++ b/.gitignore @@ -192,7 +192,7 @@ cython_debug/ .vscode tests/test_outputs -tests/test_inputs +!tests/test_inputs/**/*.wav /datasets/data_* diff --git a/modules/api/impl/openai_api.py b/modules/api/impl/openai_api.py index 700cda7..f02e222 100644 --- a/modules/api/impl/openai_api.py +++ b/modules/api/impl/openai_api.py @@ -1,8 +1,11 @@ +import io from typing import List, Optional +import numpy as np from fastapi import Body, File, Form, HTTPException, UploadFile from numpy import clip from pydantic import BaseModel, Field +from pydub import AudioSegment from modules.api import utils as api_utils from modules.api.Api import APIManager @@ -12,7 +15,9 @@ EncoderConfig, ) from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.datacls.stt_model import STTConfig, STTOutputFormat from modules.core.handler.datacls.tts_model import InferConfig, TTSConfig +from modules.core.handler.STTHandler import STTHandler from modules.core.handler.TTSHandler import TTSHandler from modules.core.spk.SpkMgr import spk_mgr from modules.core.spk.TTSSpeaker import TTSSpeaker @@ -153,6 +158,10 @@ class TranscriptionsVerboseResponse(BaseModel): segments: list[TranscribeSegment] +class TranscriptionsResponse(BaseModel): + text: str + + def setup(app: APIManager): app.post( "/v1/audio/speech", @@ -171,17 +180,52 @@ def setup(app: APIManager): @app.post( "/v1/audio/transcriptions", - response_model=TranscriptionsVerboseResponse, + # NOTE: 其实最好是不设置这个model...因为这个接口可以返回很多情况... + # response_model=TranscriptionsResponse, description="Transcribes audio into the input language.", ) async def transcribe( file: UploadFile = File(...), - model: str = Form(...), + model: str = Form("whisper.large"), language: Optional[str] = Form(None), prompt: Optional[str] = Form(None), - response_format: str = Form("json"), + # TODO 不支持 verbose_json + response_format: str = Form("txt"), temperature: float = Form(0), + # TODO 这个没实现,需要重写 whisper 的 transcribe 函数 timestamp_granularities: List[str] = Form(["segment"]), ): - # TODO: Implement transcribe - return api_utils.success_response("not implemented yet") + try: + response_format = STTOutputFormat(response_format) + except Exception: + raise HTTPException(status_code=400, detail="Invalid response format.") + + audio_bytes = await file.read() + audio_segment: AudioSegment = AudioSegment.from_file(io.BytesIO(audio_bytes)) + + sample_rate = audio_segment.frame_rate + samples = np.array(audio_segment.get_array_of_samples()) + + input_audio = (sample_rate, samples) + + sst_config = STTConfig( + mid=model, + prompt=prompt, + language=language, + tempperature=temperature if temperature > 0 else None, + format=response_format, + ) + + try: + handler = STTHandler(input_audio=input_audio, stt_config=sst_config) + + return {"text": handler.enqueue()} + except Exception as e: + import logging + + logging.exception(e) + + if isinstance(e, HTTPException): + raise e + else: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/core/handler/STTHandler.py b/modules/core/handler/STTHandler.py index 2918861..0006a90 100644 --- a/modules/core/handler/STTHandler.py +++ b/modules/core/handler/STTHandler.py @@ -1,6 +1,7 @@ from typing import Generator from modules.core.handler.datacls.stt_model import STTConfig +from modules.core.models.stt.Whisper import WhisperModel from modules.core.pipeline.processor import NP_AUDIO @@ -11,9 +12,18 @@ def __init__(self, input_audio: NP_AUDIO, stt_config: STTConfig) -> None: self.input_audio = input_audio self.stt_config = stt_config + self.model = self.get_model() + + def get_model(self): + model_id = self.stt_config.mid.lower() + if model_id.startswith("whisper"): + return WhisperModel(model_id=model_id) + + raise Exception(f"Model {model_id} is not supported") def enqueue(self) -> str: - raise NotImplementedError("Method 'enqueue' must be implemented by subclass") + result = self.model.transcribe(audio=self.input_audio, config=self.stt_config) + return result.text def enqueue_stream(self) -> Generator[str, None, None]: raise NotImplementedError( diff --git a/modules/core/handler/datacls/stt_model.py b/modules/core/handler/datacls/stt_model.py index 0692b97..610119b 100644 --- a/modules/core/handler/datacls/stt_model.py +++ b/modules/core/handler/datacls/stt_model.py @@ -20,14 +20,14 @@ class STTConfig(BaseModel): prefix: Optional[Union[str, List[int]]] = None language: Optional[str] = None - tempperature: float = 0.0 + tempperature: Optional[float] = None sample_len: Optional[int] = None best_of: Optional[int] = None beam_size: Optional[int] = None patience: Optional[int] = None length_penalty: Optional[float] = None - format: Optional[STTOutputFormat] = STTOutputFormat.json + format: Optional[STTOutputFormat] = STTOutputFormat.txt highlight_words: Optional[bool] = False max_line_count: Optional[int] = None diff --git a/modules/core/models/stt/Whisper.py b/modules/core/models/stt/Whisper.py index 414d2d9..6e9ef6d 100644 --- a/modules/core/models/stt/Whisper.py +++ b/modules/core/models/stt/Whisper.py @@ -1,8 +1,10 @@ +import logging import threading from pathlib import Path from typing import Optional import librosa +import numpy as np import torch from whisper import Whisper, audio, load_model @@ -38,9 +40,15 @@ class WhisperModel(STTModel): lock = threading.Lock() + logger = logging.getLogger(__name__) + + model: Optional[Whisper] = None + def __init__(self, model_id: str): # example: `whisper.large` or `whisper` or `whisper.small` - model_ver = model_id.split(".") + model_ver = model_id.lower().split(".") + + assert model_ver[0] == "whisper", f"Invalid model id: {model_id}" self.model_size = model_ver[1] if len(model_ver) > 1 else "large" self.model_dir = Path("./models/whisper") @@ -48,30 +56,38 @@ def __init__(self, model_id: str): self.device = devices.get_device_for("whisper") self.dtype = devices.dtype - self.model: Whisper = None - def load(self): - if self.model is None: + if WhisperModel.model is None: with self.lock: - self.model = load_model( + self.logger.info(f"Loading Whisper model [{self.model_size}]...") + WhisperModel.model = load_model( name=self.model_size, download_root=str(self.model_dir), device=self.device, ) - return self.model + self.logger.info("Whisper model loaded.") + return WhisperModel.model @devices.after_gc() def unload(self): - if self.model is None: + if WhisperModel.model is None: return - del self.model - self.model = None + with self.lock: + del self.model + self.model = None + del WhisperModel.model + WhisperModel.model = None def resample_audio(self, audio: NP_AUDIO): sr, data = audio + + if data.dtype == np.int16: + data = data.astype(np.float32) + data /= np.iinfo(np.int16).max + if sr == self.SAMPLE_RATE: return sr, data - data = librosa.resample(data, sr, self.SAMPLE_RATE) + data = librosa.resample(data, orig_sr=sr, target_sr=self.SAMPLE_RATE) return self.SAMPLE_RATE, data def transcribe(self, audio: NP_AUDIO, config: STTConfig) -> TranscribeResult: @@ -97,8 +113,14 @@ def transcribe(self, audio: NP_AUDIO, config: STTConfig) -> TranscribeResult: model = self.load() + _, audio_data = self.resample_audio(audio=audio) + + # ref https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-temperature + if tempperature is None or tempperature <= 0: + tempperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + result: WhisperTranscribeResult = model.transcribe( - audio, + audio=audio_data, language=language, prompt=prompt, prefix=prefix, @@ -127,7 +149,6 @@ def transcribe(self, audio: NP_AUDIO, config: STTConfig) -> TranscribeResult: if __name__ == "__main__": import json - import numpy as np from scipy.io import wavfile devices.reset_device() diff --git a/modules/core/spk/TTSSpeaker.py b/modules/core/spk/TTSSpeaker.py index 56f45c7..eeb27db 100644 --- a/modules/core/spk/TTSSpeaker.py +++ b/modules/core/spk/TTSSpeaker.py @@ -3,7 +3,8 @@ import dataclasses import json import uuid -from typing import Any, Callable, Optional +from tempfile import _TemporaryFileWrapper +from typing import Any, Callable, Optional, Union import numpy as np import torch @@ -17,6 +18,7 @@ DcSpkTrainInfo, DcSpkVoiceToken, ) +from modules.utils import audio_utils dclses = [ DcSpk, @@ -180,6 +182,18 @@ def get_ref( return found_ref return ref0 + def get_ref_wav( + self, get_func: Optional[Callable[[DcSpkReference], bool]] = None + ) -> Union[tuple[int, np.ndarray, str], tuple[None, None, None]]: + ref0 = self.get_ref(get_func) + if ref0 is None: + return None, None, None + sr = ref0.wav_sr + wav_bytes = ref0.wav + wav = audio_utils.bytes_to_librosa_array(audio_bytes=wav_bytes, sample_rate=sr) + text = ref0.text + return sr, wav, text + def get_recommend_config(self) -> Optional[DcSpkInferConfig]: if self._data.recommend_config: return self._data.recommend_config diff --git a/scripts/spk/get_wav.py b/scripts/spk/get_wav.py new file mode 100644 index 0000000..1b1fa6a --- /dev/null +++ b/scripts/spk/get_wav.py @@ -0,0 +1,41 @@ +import argparse + +import soundfile as sf + +from modules.core.spk.TTSSpeaker import TTSSpeaker + + +def parse_args(): + parser = argparse.ArgumentParser(description="Edit TTSSpeaker data") + parser.add_argument( + "--spk", + required=True, + help="Speaker file path", + ) + parser.add_argument( + "--out", + required=True, + help="Output file path", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + """ + 此脚本用于检查 spk 文件的 wav 音频信息 + + NOTE: 暂时只检查第一个,后续补充其他参数 + """ + args = parse_args() + + spk = TTSSpeaker.from_file(args.spk) + + sr, wav, text = spk.get_ref_wav() + + print(f"{wav.shape[0] / sr} seconds") + print(f"{wav.shape[0]} samples") + print(f"{sr} kz") + print(f"Text: {text}") + + sf.write(args.out, wav, sr * 2) diff --git a/tests/api/test_openai_stt.py b/tests/api/test_openai_stt.py new file mode 100644 index 0000000..1bc029d --- /dev/null +++ b/tests/api/test_openai_stt.py @@ -0,0 +1,27 @@ +import pytest +from fastapi.testclient import TestClient + + +@pytest.mark.openai_api_stt +def test_openai_speech_api_with_invalid_style(client: TestClient): + file_path = "./tests/test_inputs/cosyvoice_out1.wav" + + with open(file_path, "rb") as file: + response = client.post( + "/v1/audio/transcriptions", + files={"file": (file_path, file, "audio/wav")}, + data={ + "model": "whisper.large", + "language": "zh", + "prompt": "", + "response_format": "txt", + "temperature": 0, + "timestamp_granularities": "segment", + }, + ) + + assert response.status_code == 200 + response_data = response.json() + assert isinstance(response_data["text"], str) + expected_text = "我们走的每一步都是我们策略的一部分\n你看到的所有一切\n包括我此刻与你交谈\n所做的一切\n所说的每一句话\n都有深远的含义\n" + assert response_data["text"] == expected_text diff --git a/tests/test_inputs/cosyvoice_out1.wav b/tests/test_inputs/cosyvoice_out1.wav new file mode 100644 index 0000000..f778558 Binary files /dev/null and b/tests/test_inputs/cosyvoice_out1.wav differ