Skip to content

Commit

Permalink
Merge pull request #129 from toverainc/eric/gpu-batching
Browse files Browse the repository at this point in the history
Load features onto the GPU in batches to support arbitrarily long audio
  • Loading branch information
kristiankielhofner committed Oct 24, 2023
2 parents 75e6c04 + 50fa0a7 commit dac07ff
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 18 deletions.
64 changes: 49 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@
torchaudio.set_audio_backend('soundfile')


# you can chunkit
def chunkit(lst, num):
"""Yield successive num-sized chunks from list lst."""
for i in range(0, len(lst), num):
yield lst[i:i + num]


# Function to create a wav file from stream data
def write_stream_wav(data, rate, bits, ch):
file = io.BytesIO()
Expand Down Expand Up @@ -228,6 +235,8 @@ async def create_datagram_endpoint(self, protocol_factory, local_addr: Tuple[str
# Chatbot model max length
chatbot_max_new_tokens = settings.chatbot_max_new_tokens

concurrent_gpu_chunks = settings.concurrent_gpu_chunks

# Try CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -254,9 +263,10 @@ async def create_datagram_endpoint(self, protocol_factory, local_addr: Tuple[str
logger.info(f'CUDA: Device {cuda_dev_num} total memory: {cuda_total_memory} bytes')
logger.info(f'CUDA: Device {cuda_dev_num} free memory: {cuda_free_memory} bytes')

# Disable chunking if card has less than 10GB VRAM (complete guess)
# Disable chunking if card has too little VRAM
# This can still encounter out of memory errors depending on audio length
if cuda_free_memory <= 10000000000:
# XXX: we don't really need this anymore, but leaving it in to not affect short audio on small cards
if cuda_free_memory <= settings.chunking_memory_threshold:
logger.warning(f'CUDA: Device {cuda_dev_num} has low memory, disabling chunking support')
support_chunking = False

Expand Down Expand Up @@ -468,10 +478,10 @@ def do_chatbot(text, max_new_tokens=chatbot_max_new_tokens, temperature=chatbot_
return output


def do_translate(whisper_model, features, batch_size, language, beam_size):
def do_translate(whisper_model, features, total_chunk_count, language, beam_size):
# Set task in token format for processor
task = 'translate'
logger.debug(f'WHISPER: Doing translation with {language} beam size {beam_size} and batch size {batch_size}')
logger.debug(f'WHISPER: Doing translation with {language}, beam size {beam_size}, chunk count {total_chunk_count}')
processor_task = f'<|{task}|>'

# Describe the task in the prompt.
Expand All @@ -487,7 +497,7 @@ def do_translate(whisper_model, features, batch_size, language, beam_size):

# Run generation for the 30-second window.
time_start = datetime.datetime.now()
results = whisper_model.generate(features, [prompt]*batch_size, beam_size=beam_size)
results = whisper_model.generate(features, [prompt]*total_chunk_count, beam_size=beam_size)
time_end = datetime.datetime.now()
infer_time = time_end - time_start
infer_time_milliseconds = infer_time.total_seconds() * 1000
Expand Down Expand Up @@ -548,14 +558,13 @@ def do_whisper(audio_file, model: str, beam_size: int = beam_size, task: str = "
chunks.append(log_mel_spectrogram(chunk).numpy())
strides.append(stride)
mel_features = np.stack(chunks)
batch_size = len(chunks)
total_chunk_count = len(chunks)
else:
mel_audio = pad_or_trim(audio)
mel_features = log_mel_spectrogram(mel_audio).numpy()
# Ref Whisper returns shape (80, 3000) but model expects (1, 80, 3000)
mel_features = np.expand_dims(mel_features, axis=0)
batch_size = 1
features = ctranslate2.StorageView.from_array(mel_features)
total_chunk_count = 1

time_end = datetime.datetime.now()
infer_time = time_end - time_start
Expand All @@ -570,7 +579,12 @@ def do_whisper(audio_file, model: str, beam_size: int = beam_size, task: str = "
processor_language = f'<|{language}|>'

if detect_language and not force_language:
results = whisper_model.detect_language(features)
# load the first mel_features batch into the GPU
# just for language detection
# important - this is named gpu_features so it will be unloaded during our batch processing later
first_mel_features = mel_features[0:1, :, :]
gpu_features = ctranslate2.StorageView.from_array(first_mel_features)
results = whisper_model.detect_language(gpu_features)
language, probability = results[0][0]
processor_language = language
logger.debug(f"WHISPER: Detected language {language} with probability {probability}")
Expand Down Expand Up @@ -602,7 +616,21 @@ def do_whisper(audio_file, model: str, beam_size: int = beam_size, task: str = "
# Whisper STEP 3 - run model
time_start = datetime.datetime.now()
logger.debug(f'WHISPER: Using model {model} with beam size {beam_size}')
results = whisper_model.generate(features, [prompt]*batch_size, beam_size=beam_size, return_scores=False)

results = []
for i, mel_features_batch in enumerate(
chunkit(mel_features, concurrent_gpu_chunks)
):
logger.debug("Processing GPU batch %s of expected %s", i+1, len(mel_features) // concurrent_gpu_chunks + 1)
gpu_features = ctranslate2.StorageView.from_array(mel_features_batch)
results.extend(whisper_model.generate(
gpu_features,
[prompt]*len(mel_features_batch),
beam_size=beam_size,
return_scores=False,
))
assert len(results) == total_chunk_count, "Result length doesn't match expected total_chunk_count"

time_end = datetime.datetime.now()
infer_time = time_end - time_start
infer_time_milliseconds = infer_time.total_seconds() * 1000
Expand All @@ -611,7 +639,7 @@ def do_whisper(audio_file, model: str, beam_size: int = beam_size, task: str = "
time_start = datetime.datetime.now()
if use_chunking:
assert strides, 'strides needed to compute final tokens when chunking'
tokens = [(results[i].sequences_ids[0], strides[i]) for i in range(batch_size)]
tokens = [(results[i].sequences_ids[0], strides[i]) for i in range(total_chunk_count)]
tokens = find_longest_common_sequence(tokens, models.whisper_processor.tokenizer)
else:
tokens = results[0].sequences_ids[0]
Expand All @@ -626,9 +654,14 @@ def do_whisper(audio_file, model: str, beam_size: int = beam_size, task: str = "
pattern = re.compile("[A-Za-z0-9]+", )
language = pattern.findall(language)[0]

if translate:
# the gpu_features were loaded above when we ran the initial whisper model
# so we don't need to reload them to the GPU here
if translate and len(total_chunk_count) > concurrent_gpu_chunks:
logger.warning("Cannot translate because too much audio for the GPU memory")
translation = None
elif translate:
logger.debug(f'WHISPER: Detected non-preferred language {language}, translating')
translation = do_translate(whisper_model, features, batch_size, language, beam_size=beam_size)
translation = do_translate(whisper_model, gpu_features, total_chunk_count, language, beam_size=beam_size)
# Strip tokens from translation output - brittle but works right now
translation = translation.split('>')[2]
translation = translation.strip()
Expand Down Expand Up @@ -1160,7 +1193,6 @@ async def willow(request: Request, response: Response, model: Optional[str] = wh
channel = "1"
codec = "pcm"

body = b''
sample_rate = request.headers.get('x-audio-sample-rate', '').lower()
bits = request.headers.get('x-audio-bits', '').lower()
channel = request.headers.get('x-audio-channel', '').lower()
Expand All @@ -1173,8 +1205,10 @@ async def willow(request: Request, response: Response, model: Optional[str] = wh
if willow_id:
logger.debug(f"WILLOW: Got Willow ID {willow_id}")

body = []
async for chunk in request.stream():
body += chunk
body.append(chunk)
body = b''.join(body)

try:
if codec == "pcm":
Expand Down
7 changes: 5 additions & 2 deletions nginx/nginx.conf
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ http {
# Websocket support
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
# Support very long sessions for GPU batching of large files
proxy_read_timeout 1800;

# Use HTTP 1.1 keepalives to backend gunicorn
upstream keepalive-wis {
Expand All @@ -44,8 +46,9 @@ http {
keepalive_timeout 3600s;
}

# Increase max client body size for ASR file uploads, etc. 100MB matches Cloudflare
client_max_body_size 100M;
# Increase max client body size for ASR file uploads, etc.
# Default to very large to support GPU batching of long audio files.
client_max_body_size 2G;

server {
listen 19001;
Expand Down
12 changes: 11 additions & 1 deletion settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class APISettings(BaseSettings):
# Project metadata
name: str= "Willow Inference Server"
name: str = "Willow Inference Server"
description: str = "High Performance Language Inference API"
version: str = "1.0"

Expand Down Expand Up @@ -33,6 +33,16 @@ class APISettings(BaseSettings):
# Enable chunking support
support_chunking: bool = True

# There is really no reason to disable chunking anymore
# But if you still want to, you can set this threshold higher
# current value is equivalent of 4GB GPUs
chunking_memory_threshold: int = 3798205849

# Maximum number of chunks that are loaded into the GPU at once
# This will need to be tweaked based on GPU ram and model used.
# 8GB GPUs should support at least 2 chunks so starting with that
concurrent_gpu_chunks: int = 2

# Enable TTS
support_tts: bool = True

Expand Down

0 comments on commit dac07ff

Please sign in to comment.