From d481680d72d8066cc46e43051a3d63ba75d4bdf9 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Jul 2024 21:40:03 +0100 Subject: [PATCH] Improve inference batch manager (#47) * fix prefill CUDA mem leakage * update batch manager * fix logic * switch back to spawn and add cli arg --- amt/inference/transcribe.py | 230 +++++++++++++++++++++++++----------- amt/run.py | 9 ++ amt/tokenizer.py | 7 +- 3 files changed, 175 insertions(+), 71 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index f18705f..7913998 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -13,8 +13,12 @@ import numpy as np import concurrent -from torch.multiprocessing import Queue +from multiprocessing import Queue, Manager +from multiprocessing.synchronize import Lock as LockType +from queue import Empty +from collections import deque from concurrent.futures import ThreadPoolExecutor +from typing import Tuple, List, Deque from tqdm import tqdm from functools import wraps from torch.cuda import is_bf16_supported @@ -185,7 +189,7 @@ def prefill( @optional_bf16_autocast @torch.no_grad() def process_segments( - tasks: list, + tasks: List, model: AmtEncoderDecoder, audio_transform: AudioTransform, tokenizer: AmtTokenizer, @@ -278,22 +282,25 @@ def process_segments( def gpu_manager( gpu_batch_queue: Queue, + gpu_waiting_dict: dict, + gpu_waiting_dict_lock: LockType, result_queue: Queue, model: AmtEncoderDecoder, batch_size: int, compile_mode: str | bool = False, gpu_id: int | None = None, ): + if gpu_id is not None: + torch.cuda.set_device(gpu_id) + if gpu_id is not None: logger = _setup_logger(name=f"GPU-{gpu_id}") else: logger = _setup_logger(name=f"GPU") + gpu_id = 0 logger.info("Started GPU manager") - if gpu_id is not None: - torch.cuda.set_device(gpu_id) - model.decoder.setup_cache( batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN, @@ -315,10 +322,15 @@ def gpu_manager( try: while True: try: + with gpu_waiting_dict_lock: + gpu_waiting_dict[gpu_id] = time.time() batch = gpu_batch_queue.get(timeout=60) - except Exception as e: - logger.info(f"GPU timed out waiting for batch") - break + with gpu_waiting_dict_lock: + del gpu_waiting_dict[gpu_id] + except Empty as e: + with gpu_waiting_dict_lock: + del gpu_waiting_dict[gpu_id] + raise e else: try: results = process_segments( @@ -342,10 +354,11 @@ def gpu_manager( except Exception as e: logger.error(f"GPU manager failed with exception: {e}") finally: + del gpu_waiting_dict[gpu_id] logger.info(f"GPU manager terminated") -def _find_min_diff_batch(tasks: list, batch_size: int): +def _find_min_diff_batch(tasks: List, batch_size: int): prefix_lens = [ (len(prefix), idx) for idx, ((audio_seg, prefix), _) in enumerate(tasks) ] @@ -371,58 +384,108 @@ def _find_min_diff_batch(tasks: list, batch_size: int): ] +# NOTE: +# - For some reason copying gpu_waiting_dict is not working properly and is +# leading to race conditions. I've implemented a lock to stop it. +# - The size of gpu_batch_queue decreases before the code for deleting the +# corresponding entry in gpu_waiting_dict get processed. Adding a short sleep +# is a workaround def gpu_batch_manager( gpu_task_queue: Queue, gpu_batch_queue: Queue, + gpu_waiting_dict: dict, + gpu_waiting_dict_lock: LockType, batch_size: int, + max_wait_time: float = 0.25, + min_batch_size: int = 1, ): logger = _setup_logger(name="B") logger.info("Started batch manager") + + tasks: Deque[Tuple[object, int]] = deque() + gpu_wait_time = 0 + try: - tasks = [] while True: try: - task, pid = gpu_task_queue.get(timeout=0.1) - except Exception as e: + while not gpu_task_queue.empty(): + task, pid = gpu_task_queue.get_nowait() + tasks.append((task, pid)) + except Empty: pass - else: - tasks.append((task, pid)) - if gpu_batch_queue.empty() is False: - continue - # No tasks in queue -> check gpu batch queue - if gpu_batch_queue.empty() is False: - continue - elif len(tasks) == 0: - continue + with gpu_waiting_dict_lock: + curr_time = time.time() + num_tasks_in_batch_queue = gpu_batch_queue.qsize() + num_gpus_waiting = len(gpu_waiting_dict) + gpu_wait_time = ( + max( + [ + curr_time - wait_time_abs + for gpu_id, wait_time_abs in gpu_waiting_dict.items() + ] + ) + if gpu_waiting_dict + else 0.0 + ) - # Get new batch and add to batch queue - if len(tasks) < batch_size: - logger.warning( - f"Not enough tasks ({len(tasks)}) - padding batch" + should_create_batch = ( + len(tasks) >= 4 * batch_size + or ( + num_gpus_waiting > num_tasks_in_batch_queue + and len(tasks) >= batch_size + ) + or ( + num_gpus_waiting > num_tasks_in_batch_queue + and len(tasks) >= min_batch_size + and gpu_wait_time > max_wait_time ) - while len(tasks) < batch_size: - _pad_task, _pid = tasks[0] - tasks.append((_pad_task, -1)) - - assert len(tasks) >= batch_size, "batch error" - new_batch_idxs = _find_min_diff_batch( - tasks, - batch_size=batch_size, ) - gpu_batch_queue.put([tasks[_idx] for _idx in new_batch_idxs]) - tasks = [ - task - for _idx, task in enumerate(tasks) - if _idx not in new_batch_idxs - ] + + if should_create_batch: + logger.debug( + f"Creating batch: " + f"num_gpus_waiting={num_gpus_waiting}, " + f"gpu_wait_time={round(gpu_wait_time, 4)}s, " + f"num_tasks_ready={len(tasks)}, " + f"num_batches_ready={num_tasks_in_batch_queue}" + ) + batch = create_batch(tasks, batch_size, min_batch_size, logger) + gpu_batch_queue.put(batch) + time.sleep(0.025) + except Exception as e: logger.error(f"GPU batch manager failed with exception: {e}") finally: - logger.info(f"GPU batch manager terminated") + logger.info("GPU batch manager terminated") + +def create_batch( + tasks: Deque[Tuple[object, int]], + batch_size: int, + min_batch_size: int, + logger: logging.Logger, +): + assert len(tasks) >= min_batch_size, "Insufficient number of tasks" -def _shift_onset(seq: list, shift_ms: int): + if len(tasks) < batch_size: + logger.info(f"Creating a partial padded batch with {len(tasks)} tasks") + batch_idxs = list(range(len(tasks))) + batch = [tasks.popleft() for _ in batch_idxs] + + while len(batch) < batch_size: + pad_task, _ = batch[0] + batch.append((pad_task, -1)) + else: + batch_idxs = _find_min_diff_batch(list(tasks), batch_size) + batch = [tasks[idx] for idx in batch_idxs] + for idx in sorted(batch_idxs, reverse=True): + del tasks[idx] + + return batch + + +def _shift_onset(seq: List, shift_ms: int): res = [] for tok in seq: if type(tok) is tuple and tok[0] == "onset": @@ -434,7 +497,7 @@ def _shift_onset(seq: list, shift_ms: int): def _truncate_seq( - seq: list, + seq: List, start_ms: int, end_ms: int, tokenizer: AmtTokenizer = AmtTokenizer(), @@ -460,8 +523,10 @@ def _truncate_seq( # TODO: Add detection for pedal messages which occur before notes are played -def process_silent_intervals( - seq: list, intervals: list, tokenizer: AmtTokenizer +def _process_silent_intervals( + seq: List, + intervals: List, + tokenizer: AmtTokenizer, ): def adjust_onset(_onset: int): # Adjusts the onset according to the silence intervals @@ -552,7 +617,7 @@ def adjust_onset(_onset: int): return res -def get_silent_intervals(wav: torch.Tensor): +def _get_silent_intervals(wav: torch.Tensor): FRAME_LEN = 2048 HOP_LEN = 512 MIN_WINDOW_S = 5 @@ -614,12 +679,12 @@ def transcribe_file( # Add to gpu queue and wait for results curr_audio_segment = audio_segments.pop(0) - silent_intervals = get_silent_intervals(curr_audio_segment) + silent_intervals = _get_silent_intervals(curr_audio_segment) input_seq = copy.deepcopy(seq) gpu_task_queue.put(((curr_audio_segment, seq), pid)) while True: try: - gpu_result = result_queue.get(timeout=0.1) + gpu_result = result_queue.get(timeout=0.01) except Exception as e: pass else: @@ -634,7 +699,7 @@ def transcribe_file( f"Seen silent intervals in segment {idx}: {silent_intervals}" ) - seq_adj = process_silent_intervals( + seq_adj = _process_silent_intervals( seq, intervals=silent_intervals, tokenizer=tokenizer ) @@ -724,7 +789,7 @@ def process_file( input_dir: str, logger: logging.Logger, ): - def _save_seq(_seq: list, _save_path: str): + def _save_seq(_seq: List, _save_path: str): if os.path.exists(_save_path): logger.info(f"Already exists {_save_path} - overwriting") @@ -787,7 +852,7 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): logger.info(f"{file_queue.qsize()} file(s) remaining in queue") -def watchdog(main_pids: list, child_pids: list): +def watchdog(main_pids: List, child_pids: List): while True: if not all(os.path.exists(f"/proc/{pid}") for pid in main_pids): print("Cleaning up children...") @@ -797,6 +862,7 @@ def watchdog(main_pids: list, child_pids: list): except ProcessLookupError: pass + print("Finished cleaning up children") return time.sleep(1) @@ -816,12 +882,14 @@ def worker( def process_file_wrapper(): while True: try: - file_path = file_queue.get(timeout=5) - except Exception as e: + file_path = file_queue.get(timeout=15) + except Empty as e: if file_queue.empty(): logger.info("File queue empty") break else: + # I'm pretty sure empty is thrown due to timeout too + logger.info("Processes timed out waiting for file queue") continue process_file( @@ -835,6 +903,9 @@ def process_file_wrapper(): logger, ) + if file_queue.empty(): + return + try: with ThreadPoolExecutor(max_workers=tasks_per_worker) as executor: futures = [ @@ -849,12 +920,13 @@ def process_file_wrapper(): def batch_transcribe( - file_paths: list, + file_paths: List, model: AmtEncoderDecoder, save_dir: str, - batch_size: int = 16, + batch_size: int = 8, input_dir: str | None = None, gpu_ids: int | None = None, + num_workers: int | None = None, quantize: bool = False, compile_mode: str | bool = False, ): @@ -889,17 +961,24 @@ def batch_transcribe( logger.info(f"Files to process: {file_queue.qsize()}/{len(file_paths)}") - num_workers = min( - min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus), - file_queue.qsize(), - ) + if num_workers is None: + num_workers = min( + min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus), + file_queue.qsize(), + ) + num_processes_per_worker = min(5, file_queue.qsize() // num_workers) - gpu_task_queue = Queue() + mp_manager = Manager() + gpu_waiting_dict = mp_manager.dict() + gpu_waiting_dict_lock = mp_manager.Lock() gpu_batch_queue = Queue() + gpu_task_queue = Queue() result_queue = Queue() child_pids = [] - logger.info(f"Creating {num_workers} file worker(s)") + logger.info( + f"Creating {num_workers} file worker(s) with {num_processes_per_worker} sub-processes" + ) worker_processes = [ multiprocessing.Process( target=worker, @@ -909,7 +988,7 @@ def batch_transcribe( result_queue, save_dir, input_dir, - 5, + num_processes_per_worker, ), ) for _ in range(num_workers) @@ -921,20 +1000,26 @@ def batch_transcribe( gpu_batch_manager_process = multiprocessing.Process( target=gpu_batch_manager, - args=(gpu_task_queue, gpu_batch_queue, batch_size), + args=( + gpu_task_queue, + gpu_batch_queue, + gpu_waiting_dict, + gpu_waiting_dict_lock, + batch_size, + ), ) gpu_batch_manager_process.start() child_pids.append(gpu_batch_manager_process.pid) - time.sleep(5) start_time = time.time() - if num_gpus > 1: gpu_manager_processes = [ multiprocessing.Process( target=gpu_manager, args=( gpu_batch_queue, + gpu_waiting_dict, + gpu_waiting_dict_lock, result_queue, model, batch_size, @@ -945,10 +1030,19 @@ def batch_transcribe( for gpu_id in range(len(gpu_ids)) ] for p in gpu_manager_processes: - child_pids.append(p.pid) p.start() + child_pids.append(p.pid) + watchdog_process = multiprocessing.Process( - target=watchdog, args=(os.getpid(), child_pids) + target=watchdog, + args=( + [ + os.getpid(), + gpu_batch_manager_process.pid, + ] + + [p.pid for p in gpu_manager_processes], + child_pids, + ), ) watchdog_process.start() else: @@ -956,14 +1050,16 @@ def batch_transcribe( target=gpu_manager, args=( gpu_batch_queue, + gpu_waiting_dict, + gpu_waiting_dict_lock, result_queue, model, batch_size, compile_mode, ), ) - child_pids.append(_gpu_manager_process.pid) _gpu_manager_process.start() + child_pids.append(_gpu_manager_process.pid) gpu_manager_processes = [_gpu_manager_process] watchdog_process = multiprocessing.Process( diff --git a/amt/run.py b/amt/run.py index d22c319..b57d670 100644 --- a/amt/run.py +++ b/amt/run.py @@ -95,6 +95,11 @@ def _add_transcribe_args(subparser): action="store_true", default=False, ) + subparser.add_argument( + "-num_workers", + help="numer of file worker processes", + type=int, + ) subparser.add_argument("-bs", help="batch size", type=int, default=16) @@ -355,6 +360,7 @@ def transcribe( maestro: bool = False, batch_size: int = 8, multi_gpu: bool = False, + num_workers: int | None = None, quantize: bool = False, compile_mode: str | bool = False, ): @@ -454,6 +460,7 @@ def transcribe( batch_size=batch_size, input_dir=load_dir, gpu_ids=gpu_ids, + num_workers=num_workers, quantize=quantize, compile_mode=compile_mode, ) @@ -465,6 +472,7 @@ def transcribe( save_dir=save_dir, batch_size=batch_size, input_dir=load_dir, + num_workers=num_workers, quantize=quantize, compile_mode=compile_mode, ) @@ -534,6 +542,7 @@ def main(): batch_size=args.bs, multi_gpu=args.multi_gpu, quantize=args.q8, + num_workers=args.num_workers, compile_mode=( "max-autotune" if args.compile and args.max_autotune diff --git a/amt/tokenizer.py b/amt/tokenizer.py index dd34d0b..3146e32 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -343,11 +343,10 @@ def _detokenize_midi_dict( # Process note and add to note msgs note_to_close = notes_to_close.pop(tok_1_data, None) if note_to_close is None: - print( - f"No 'on' token corresponding to 'off' token: {tok_1, tok_2}" - ) if DEBUG: - raise Exception + print( + f"No 'on' token corresponding to 'off' token: {tok_1, tok_2}" + ) continue else: _pitch = tok_1_data