From bf827ea03edc58998eecbe37d06fe49456da1460 Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 01:57:47 +0800 Subject: [PATCH] add logs --- elk/utils/hf_utils.py | 51 +++++++++++++++++++++--------------------- elk/utils/multi_gpu.py | 1 - 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index a1fc8bf5..4c03c942 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -27,36 +27,35 @@ def determine_dtypes( is_cpu: bool, load_in_8bit: bool, ) -> torch.dtype | str: - with prevent_name_conflicts(): - model_cfg = AutoConfig.from_pretrained(model_str) + model_cfg = AutoConfig.from_pretrained(model_str) - # When the torch_dtype is None, this generally means the model is fp32, because - # the config was probably created before the `torch_dtype` field was added. - fp32_weights = model_cfg.torch_dtype in (None, torch.float32) + # When the torch_dtype is None, this generally means the model is fp32, because + # the config was probably created before the `torch_dtype` field was added. + fp32_weights = model_cfg.torch_dtype in (None, torch.float32) - # Required by `bitsandbytes` to load in 8-bit. - if load_in_8bit: - # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint - # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and - # we can't guarantee that there won't be overflow if we downcast to fp16. - if fp32_weights: - raise ValueError("Cannot load in 8-bit if weights are fp32") + # Required by `bitsandbytes` to load in 8-bit. + if load_in_8bit: + # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint + # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and + # we can't guarantee that there won't be overflow if we downcast to fp16. + if fp32_weights: + raise ValueError("Cannot load in 8-bit if weights are fp32") - torch_dtype = torch.float16 + torch_dtype = torch.float16 - # CPUs generally don't support anything other than fp32. - elif is_cpu: - torch_dtype = torch.float32 + # CPUs generally don't support anything other than fp32. + elif is_cpu: + torch_dtype = torch.float32 - # If the model is fp32 but bf16 is available, convert to bf16. - # Usually models with fp32 weights were actually trained in bf16, and - # converting them doesn't hurt performance. - elif fp32_weights and torch.cuda.is_bf16_supported(): - torch_dtype = torch.bfloat16 - print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") - else: - torch_dtype = "auto" - return torch_dtype + # If the model is fp32 but bf16 is available, convert to bf16. + # Usually models with fp32 weights were actually trained in bf16, and + # converting them doesn't hurt performance. + elif fp32_weights and torch.cuda.is_bf16_supported(): + torch_dtype = torch.bfloat16 + print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") + else: + torch_dtype = "auto" + return torch_dtype def instantiate_model( @@ -88,7 +87,7 @@ def instantiate_model( if arch_str.endswith(suffix): model_cls = getattr(transformers, arch_str) return model_cls.from_pretrained(model_str, **kwargs) - + print(f"Loading model with {kwargs}") return AutoModel.from_pretrained(model_str, **kwargs) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 4fe4c0de..477cc395 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -84,7 +84,6 @@ def create_device_map( model_str=model_str, load_in_8bit=False, is_cpu=False, - torch_dtype=torch.float16, ) # e.g. {"cuda:0": 16000, "cuda:1": 16000}