From 4120f57030b8ee19cafa567722a5722fa4afa624 Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 29 Sep 2024 17:21:11 +0100 Subject: [PATCH] Change from no_grad to inference_mode (#57) --- amt/inference/transcribe.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 996e469..85b32d3 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -151,7 +151,7 @@ def wrapper(*args, **kwargs): return wrapper -@torch.no_grad() +@torch.inference_mode() def decode_token( model: AmtEncoderDecoder, x: torch.Tensor, @@ -170,7 +170,7 @@ def decode_token( return logits, next_tok_ids -@torch.no_grad() +@torch.inference_mode() def prefill( model: AmtEncoderDecoder, x: torch.Tensor, @@ -211,7 +211,7 @@ def calculate_input_pos(prefix_lens: torch.Tensor, prefill: bool): @optional_bf16_autocast -@torch.no_grad() +@torch.inference_mode() def process_segments( tasks: List, model: AmtEncoderDecoder, @@ -304,6 +304,7 @@ def process_segments( return results +@torch.inference_mode() def gpu_manager( gpu_batch_queue: Queue, gpu_waiting_dict: dict, @@ -1024,7 +1025,9 @@ def batch_transcribe( min(batch_size, multiprocessing.cpu_count() - num_gpus), file_queue.qsize(), ) - num_processes_per_worker = min(10, file_queue.qsize() // num_workers) + num_processes_per_worker = min( + 3 * (batch_size // num_workers), file_queue.qsize() // num_workers + ) mp_manager = Manager() gpu_waiting_dict = mp_manager.dict()