Skip to content

Commit

Permalink
✨ add stt api #92
Browse files Browse the repository at this point in the history
- add `/v1/audio/transcriptions` api
- support whisper model
- add stt tests
  • Loading branch information
zhzLuke96 committed Aug 1, 2024
1 parent d904679 commit 92b992f
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ cython_debug/
.vscode

tests/test_outputs
tests/test_inputs
!tests/test_inputs/**/*.wav

/datasets/data_*

Expand Down
54 changes: 49 additions & 5 deletions modules/api/impl/openai_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import io
from typing import List, Optional

import numpy as np
from fastapi import Body, File, Form, HTTPException, UploadFile
from numpy import clip
from pydantic import BaseModel, Field
from pydub import AudioSegment

from modules.api import utils as api_utils
from modules.api.Api import APIManager
Expand All @@ -12,7 +15,9 @@
EncoderConfig,
)
from modules.core.handler.datacls.enhancer_model import EnhancerConfig
from modules.core.handler.datacls.stt_model import STTConfig, STTOutputFormat
from modules.core.handler.datacls.tts_model import InferConfig, TTSConfig
from modules.core.handler.STTHandler import STTHandler
from modules.core.handler.TTSHandler import TTSHandler
from modules.core.spk.SpkMgr import spk_mgr
from modules.core.spk.TTSSpeaker import TTSSpeaker
Expand Down Expand Up @@ -153,6 +158,10 @@ class TranscriptionsVerboseResponse(BaseModel):
segments: list[TranscribeSegment]


class TranscriptionsResponse(BaseModel):
text: str


def setup(app: APIManager):
app.post(
"/v1/audio/speech",
Expand All @@ -171,17 +180,52 @@ def setup(app: APIManager):

@app.post(
"/v1/audio/transcriptions",
response_model=TranscriptionsVerboseResponse,
# NOTE: 其实最好是不设置这个model...因为这个接口可以返回很多情况...
# response_model=TranscriptionsResponse,
description="Transcribes audio into the input language.",
)
async def transcribe(
file: UploadFile = File(...),
model: str = Form(...),
model: str = Form("whisper.large"),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: str = Form("json"),
# TODO 不支持 verbose_json
response_format: str = Form("txt"),
temperature: float = Form(0),
# TODO 这个没实现,需要重写 whisper 的 transcribe 函数
timestamp_granularities: List[str] = Form(["segment"]),
):
# TODO: Implement transcribe
return api_utils.success_response("not implemented yet")
try:
response_format = STTOutputFormat(response_format)
except Exception:
raise HTTPException(status_code=400, detail="Invalid response format.")

audio_bytes = await file.read()
audio_segment: AudioSegment = AudioSegment.from_file(io.BytesIO(audio_bytes))

sample_rate = audio_segment.frame_rate
samples = np.array(audio_segment.get_array_of_samples())

input_audio = (sample_rate, samples)

sst_config = STTConfig(
mid=model,
prompt=prompt,
language=language,
tempperature=temperature if temperature > 0 else None,
format=response_format,
)

try:
handler = STTHandler(input_audio=input_audio, stt_config=sst_config)

return {"text": handler.enqueue()}
except Exception as e:
import logging

logging.exception(e)

if isinstance(e, HTTPException):
raise e
else:
raise HTTPException(status_code=500, detail=str(e))
12 changes: 11 additions & 1 deletion modules/core/handler/STTHandler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Generator

from modules.core.handler.datacls.stt_model import STTConfig
from modules.core.models.stt.Whisper import WhisperModel
from modules.core.pipeline.processor import NP_AUDIO


Expand All @@ -11,9 +12,18 @@ def __init__(self, input_audio: NP_AUDIO, stt_config: STTConfig) -> None:

self.input_audio = input_audio
self.stt_config = stt_config
self.model = self.get_model()

def get_model(self):
model_id = self.stt_config.mid.lower()
if model_id.startswith("whisper"):
return WhisperModel(model_id=model_id)

raise Exception(f"Model {model_id} is not supported")

def enqueue(self) -> str:
raise NotImplementedError("Method 'enqueue' must be implemented by subclass")
result = self.model.transcribe(audio=self.input_audio, config=self.stt_config)
return result.text

def enqueue_stream(self) -> Generator[str, None, None]:
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions modules/core/handler/datacls/stt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ class STTConfig(BaseModel):
prefix: Optional[Union[str, List[int]]] = None

language: Optional[str] = None
tempperature: float = 0.0
tempperature: Optional[float] = None
sample_len: Optional[int] = None
best_of: Optional[int] = None
beam_size: Optional[int] = None
patience: Optional[int] = None
length_penalty: Optional[float] = None

format: Optional[STTOutputFormat] = STTOutputFormat.json
format: Optional[STTOutputFormat] = STTOutputFormat.txt

highlight_words: Optional[bool] = False
max_line_count: Optional[int] = None
Expand Down
45 changes: 33 additions & 12 deletions modules/core/models/stt/Whisper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import threading
from pathlib import Path
from typing import Optional

import librosa
import numpy as np
import torch
from whisper import Whisper, audio, load_model

Expand Down Expand Up @@ -38,40 +40,54 @@ class WhisperModel(STTModel):

lock = threading.Lock()

logger = logging.getLogger(__name__)

model: Optional[Whisper] = None

def __init__(self, model_id: str):
# example: `whisper.large` or `whisper` or `whisper.small`
model_ver = model_id.split(".")
model_ver = model_id.lower().split(".")

assert model_ver[0] == "whisper", f"Invalid model id: {model_id}"

self.model_size = model_ver[1] if len(model_ver) > 1 else "large"
self.model_dir = Path("./models/whisper")

self.device = devices.get_device_for("whisper")
self.dtype = devices.dtype

self.model: Whisper = None

def load(self):
if self.model is None:
if WhisperModel.model is None:
with self.lock:
self.model = load_model(
self.logger.info(f"Loading Whisper model [{self.model_size}]...")
WhisperModel.model = load_model(
name=self.model_size,
download_root=str(self.model_dir),
device=self.device,
)
return self.model
self.logger.info("Whisper model loaded.")
return WhisperModel.model

@devices.after_gc()
def unload(self):
if self.model is None:
if WhisperModel.model is None:
return
del self.model
self.model = None
with self.lock:
del self.model
self.model = None
del WhisperModel.model
WhisperModel.model = None

def resample_audio(self, audio: NP_AUDIO):
sr, data = audio

if data.dtype == np.int16:
data = data.astype(np.float32)
data /= np.iinfo(np.int16).max

if sr == self.SAMPLE_RATE:
return sr, data
data = librosa.resample(data, sr, self.SAMPLE_RATE)
data = librosa.resample(data, orig_sr=sr, target_sr=self.SAMPLE_RATE)
return self.SAMPLE_RATE, data

def transcribe(self, audio: NP_AUDIO, config: STTConfig) -> TranscribeResult:
Expand All @@ -97,8 +113,14 @@ def transcribe(self, audio: NP_AUDIO, config: STTConfig) -> TranscribeResult:

model = self.load()

_, audio_data = self.resample_audio(audio=audio)

# ref https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-temperature
if tempperature is None or tempperature <= 0:
tempperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)

result: WhisperTranscribeResult = model.transcribe(
audio,
audio=audio_data,
language=language,
prompt=prompt,
prefix=prefix,
Expand Down Expand Up @@ -127,7 +149,6 @@ def transcribe(self, audio: NP_AUDIO, config: STTConfig) -> TranscribeResult:
if __name__ == "__main__":
import json

import numpy as np
from scipy.io import wavfile

devices.reset_device()
Expand Down
16 changes: 15 additions & 1 deletion modules/core/spk/TTSSpeaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import dataclasses
import json
import uuid
from typing import Any, Callable, Optional
from tempfile import _TemporaryFileWrapper
from typing import Any, Callable, Optional, Union

import numpy as np
import torch
Expand All @@ -17,6 +18,7 @@
DcSpkTrainInfo,
DcSpkVoiceToken,
)
from modules.utils import audio_utils

dclses = [
DcSpk,
Expand Down Expand Up @@ -180,6 +182,18 @@ def get_ref(
return found_ref
return ref0

def get_ref_wav(
self, get_func: Optional[Callable[[DcSpkReference], bool]] = None
) -> Union[tuple[int, np.ndarray, str], tuple[None, None, None]]:
ref0 = self.get_ref(get_func)
if ref0 is None:
return None, None, None
sr = ref0.wav_sr
wav_bytes = ref0.wav
wav = audio_utils.bytes_to_librosa_array(audio_bytes=wav_bytes, sample_rate=sr)
text = ref0.text
return sr, wav, text

def get_recommend_config(self) -> Optional[DcSpkInferConfig]:
if self._data.recommend_config:
return self._data.recommend_config
Expand Down
41 changes: 41 additions & 0 deletions scripts/spk/get_wav.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse

import soundfile as sf

from modules.core.spk.TTSSpeaker import TTSSpeaker


def parse_args():
parser = argparse.ArgumentParser(description="Edit TTSSpeaker data")
parser.add_argument(
"--spk",
required=True,
help="Speaker file path",
)
parser.add_argument(
"--out",
required=True,
help="Output file path",
)
args = parser.parse_args()
return args


if __name__ == "__main__":
"""
此脚本用于检查 spk 文件的 wav 音频信息
NOTE: 暂时只检查第一个,后续补充其他参数
"""
args = parse_args()

spk = TTSSpeaker.from_file(args.spk)

sr, wav, text = spk.get_ref_wav()

print(f"{wav.shape[0] / sr} seconds")
print(f"{wav.shape[0]} samples")
print(f"{sr} kz")
print(f"Text: {text}")

sf.write(args.out, wav, sr * 2)
27 changes: 27 additions & 0 deletions tests/api/test_openai_stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from fastapi.testclient import TestClient


@pytest.mark.openai_api_stt
def test_openai_speech_api_with_invalid_style(client: TestClient):
file_path = "./tests/test_inputs/cosyvoice_out1.wav"

with open(file_path, "rb") as file:
response = client.post(
"/v1/audio/transcriptions",
files={"file": (file_path, file, "audio/wav")},
data={
"model": "whisper.large",
"language": "zh",
"prompt": "",
"response_format": "txt",
"temperature": 0,
"timestamp_granularities": "segment",
},
)

assert response.status_code == 200
response_data = response.json()
assert isinstance(response_data["text"], str)
expected_text = "我们走的每一步都是我们策略的一部分\n你看到的所有一切\n包括我此刻与你交谈\n所做的一切\n所说的每一句话\n都有深远的含义\n"
assert response_data["text"] == expected_text
Binary file added tests/test_inputs/cosyvoice_out1.wav
Binary file not shown.

0 comments on commit 92b992f

Please sign in to comment.