diff --git a/tests/benchmark/benchmark_gptq.py b/tests/benchmark/benchmark_gptq.py index 5cd26e4cc6..0a565f01df 100644 --- a/tests/benchmark/benchmark_gptq.py +++ b/tests/benchmark/benchmark_gptq.py @@ -1,13 +1,19 @@ import argparse - +import time import torch from tqdm import tqdm from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig +import numpy as np from optimum.exporters import TasksManager from optimum.gptq import load_quantized_model from accelerate import init_empty_weights +import json +from memory_tracker import MemoryTracker +import os +import gc + def get_parser(): parser = argparse.ArgumentParser() @@ -54,40 +60,84 @@ def get_parser(): ) parser.add_argument( "--gptq", - action='store_true', + action="store_true", help="Indicate that the model to benchmark is a GPTQ model.", ) parser.add_argument( "--sweep", - action='store_true', + action="store_true", help="Use the parameter ranges for (batch_size, prompt_length, new_tokens) defined in the .py file instead of the CLI ones.", ) + parser.add_argument( + "--disable-exllama", + action="store_true", + help="Disable Exllama kernel, to rather use the AutoGPTQ CUDA (act-order case) or CUDA-old (no act-order case) kernels.", + ) return parser -def timing_cuda(model, num_batches: int, input_ids: torch.Tensor, masks: torch.Tensor, is_decoder: bool, generation_config=None): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) +def timing_cuda( + model, num_batches: int, input_ids: torch.Tensor, masks: torch.Tensor, is_decoder: bool, generation_config=None +): + assert generation_config.min_new_tokens == generation_config.max_new_tokens - torch.cuda.reset_peak_memory_stats(device) - torch.cuda.empty_cache() torch.cuda.synchronize() - start_event.record() + # We need NOT call torch.cuda.empty_cache() here as it appears to negate the warmup. + + latencies = [] for _ in tqdm(range(num_batches)): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + if is_decoder: _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config) else: _ = model(input_ids, masks) - end_event.record() - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated(device) + end_event.record() + torch.cuda.synchronize() + + latency_ms = start_event.elapsed_time(end_event) + print(f"\nLatency per token: {latency_ms / generation_config.min_new_tokens:.3f} ms") + latencies.append(latency_ms) - return start_event.elapsed_time(end_event) / num_batches, max_memory + return np.mean(latencies) -def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: int, is_decoder: bool, new_tokens: int, pad_token_id: int): - # Warmup +def memory_cuda( + model, + input_ids: torch.Tensor, + masks: torch.Tensor, + is_decoder: bool, + memory_tracker: MemoryTracker, + generation_config=None, +): + with memory_tracker.track(): + if is_decoder: + _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config) + else: + _ = model(input_ids, masks) + + return memory_tracker.peak_memory + + +def benchmark( + model, + input_ids: torch.Tensor, + masks: torch.Tensor, + num_batches: int, + is_decoder: bool, + new_tokens: int, + pad_token_id: int, + memory_tracker: MemoryTracker, +): + torch.cuda.empty_cache() + gc.collect() + + # It appears running the warmup only once is not enough to get low variance on the latency in later runs. Hence the `for i in range(2):` below. + print("Warmup...") if is_decoder: gen_config = GenerationConfig( max_new_tokens=new_tokens, @@ -96,18 +146,21 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: pad_token_id=pad_token_id, num_beams=1, do_sample=False, + eos_token_id=None, # This is required for min_new_tokens to actually have an effect. ) - _ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config) - torch.cuda.synchronize() + model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect. + res = model.generate(input_ids, attention_mask=masks, generation_config=gen_config) + assert res.shape[1] == new_tokens + input_ids.shape[1] + del res else: + gen_config = None _ = model(input_ids, masks) - torch.cuda.synchronize() + torch.cuda.synchronize() - # Benchmark - if is_decoder: - total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config) - else: - total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder) + print("Measuring latency...") + total_time = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config) + print("Measuring peak memory...") + max_mem = memory_cuda(model, input_ids, masks, is_decoder, memory_tracker, gen_config) return total_time, max_mem @@ -116,7 +169,7 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: args = parser.parse_args() if args.sweep: - batch_sizes = [1, 4, 16, 32] + batch_sizes = [1, 4, 8, 16] prompt_lengths = [512] new_tokens = [512] else: @@ -127,8 +180,14 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: if not torch.cuda.is_available(): raise ValueError("A cuda device is necessary to benchmark GPTQ.") +if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) != 1: + raise ValueError( + "Please set CUDA_VISIBLE_DEVICES variable to a single device index. This benchmark code is not tested for multi-device setup." + ) device = torch.device("cuda:0") +memory_tracker = MemoryTracker() + tokenizer = AutoTokenizer.from_pretrained(args.model) if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: @@ -151,20 +210,55 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: else: is_decoder = False +act_order = None +bits = None +group_size = None +kernel = None if args.gptq: if not args.gptq_model: raise ValueError("The argument --gptq-model needs to be provided when benchmarking GPTQ.") - + + with open(os.path.join(args.gptq_model, "quantization_config.json"), "r", encoding="utf-8") as f: + quantize_config_dict = json.load(f) + + act_order = quantize_config_dict["desc_act"] + bits = quantize_config_dict["bits"] + group_size = quantize_config_dict["group_size"] + + if not args.disable_exllama: + kernel = "exllama" + elif act_order: + kernel = "autotogptq-cuda" + else: + kernel = "autogptq-cuda-old" + +load_start = time.time_ns() +if args.gptq: with init_empty_weights(): empty_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16) empty_model.tie_weights() - model = load_quantized_model(empty_model, save_folder=args.gptq_model, state_dict_name="model.safetensors", device_map="auto") + model = load_quantized_model( + empty_model, + save_folder=args.gptq_model, + state_dict_name="model.safetensors", + device_map="auto", + disable_exllama=args.disable_exllama, + ) else: with device: model = autoclass.from_pretrained(args.model, torch_dtype=torch.float16) - +torch.cuda.synchronize() +load_end = time.time_ns() + +load_time = (load_end - load_start) * 1e-9 + uses_gptq = args.gptq print(f"Model uses GPTQ: {uses_gptq}") +print(f"Using accelerate hooks: {hasattr(model, '_hf_hook')}") +print(f"Bits: {bits}") +print(f"group_size: {group_size}") +print(f"act_order: {act_order}") +print(f"kernel: {kernel}") model = model.eval() @@ -177,7 +271,7 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: output_file = open(file_name, "w") output_file.write( - "gptq, num_batches, batch_size, prompt_length, new_tokens, Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n" + "gptq, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n" ) latencies = {} @@ -188,7 +282,7 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: for prompt_length in tqdm(prompt_lengths): for new_token in tqdm(new_tokens): print(f"---- Running: batch_size={batch_size}, prompt_length={prompt_length}, new_tokens={new_token}") - + input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device) masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device) @@ -201,30 +295,37 @@ def benchmark(model, input_ids: torch.Tensor, masks: torch.Tensor, num_batches: is_decoder, new_token, tokenizer.pad_token_id, + memory_tracker=memory_tracker, ) - max_mem = max_mem * 1e-6 # in MB index = (batch_size, prompt_length, new_token) per_token_latency = mean_latency / new_token latencies[index] = per_token_latency - + throughput = batch_size / (per_token_latency * 1e-3) throughputs[index] = throughput all_max_mem[index] = max_mem # TODO: validate that maxmem is correct - print(f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB") + print( + f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB" + ) output_file.write( - "{},{},{},{},{},{},{},{}\n".format( + "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format( uses_gptq, + act_order, + bits, + group_size, + kernel, args.num_batches, batch_size, prompt_length, new_token, - f"{throughput:.4f}", + f"{load_time:.2f}", f"{per_token_latency:.4f}", + f"{throughput:.4f}", f"{max_mem:.4f}", ) ) diff --git a/tests/benchmark/memory_tracker.py b/tests/benchmark/memory_tracker.py new file mode 100644 index 0000000000..2d4caf1972 --- /dev/null +++ b/tests/benchmark/memory_tracker.py @@ -0,0 +1,58 @@ +from multiprocessing.connection import Connection +from multiprocessing import Pipe, Process +from contextlib import contextmanager +import os +import subprocess + +# Adapted from optimum-benchmark, I don't trust pytorch peak memory memory info when external libs are used. +class MemoryTracker: + def __init__(self): + self.peak_memory: int = 0 + self.device_index = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) + + @contextmanager + def track(self, interval: float = 0.1): + print(f"Tracking memory for device {self.device_index}") + yield from self._track_peak_memory(interval) + + def _track_peak_memory(self, interval: float): + child_connection, parent_connection = Pipe() + # instantiate process + mem_process: Process = PeakMemoryMeasureProcess(self.device_index, child_connection, interval) + mem_process.start() + # wait until we get memory + parent_connection.recv() + yield + # start parent connection + parent_connection.send(0) + # receive peak memory + self.peak_memory = parent_connection.recv() + + +class PeakMemoryMeasureProcess(Process): + def __init__(self, device_index: int, child_connection: Connection, interval: float): + super().__init__() + self.device_index = device_index + self.interval = interval + self.connection = child_connection + self.mem_usage = 0 + + def run(self): + self.connection.send(0) + stop = False + + command = f"nvidia-smi --query-gpu=memory.used --format=csv --id={self.device_index}" + + while True: + # py3nvml is broken since it outputs only the reserved memory, and nvidia-smi has only the MiB precision. + gpu_mem_mb = subprocess.check_output(command.split()).decode("ascii").split("\n")[1].split()[0] + gpu_mem_mb = int(gpu_mem_mb) * 1.048576 + self.mem_usage = max(self.mem_usage, gpu_mem_mb) + + if stop: + break + stop = self.connection.poll(self.interval) + + # send results to parent pipe + self.connection.send(self.mem_usage) + self.connection.close()