Skip to content

Commit

Permalink
add logs
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 3, 2023
1 parent d3a8f29 commit bf827ea
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
51 changes: 25 additions & 26 deletions elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,35 @@ def determine_dtypes(
is_cpu: bool,
load_in_8bit: bool,
) -> torch.dtype | str:
with prevent_name_conflicts():
model_cfg = AutoConfig.from_pretrained(model_str)
model_cfg = AutoConfig.from_pretrained(model_str)

# When the torch_dtype is None, this generally means the model is fp32, because
# the config was probably created before the `torch_dtype` field was added.
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)
# When the torch_dtype is None, this generally means the model is fp32, because
# the config was probably created before the `torch_dtype` field was added.
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)

# Required by `bitsandbytes` to load in 8-bit.
if load_in_8bit:
# Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint
# is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and
# we can't guarantee that there won't be overflow if we downcast to fp16.
if fp32_weights:
raise ValueError("Cannot load in 8-bit if weights are fp32")
# Required by `bitsandbytes` to load in 8-bit.
if load_in_8bit:
# Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint
# is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and
# we can't guarantee that there won't be overflow if we downcast to fp16.
if fp32_weights:
raise ValueError("Cannot load in 8-bit if weights are fp32")

torch_dtype = torch.float16
torch_dtype = torch.float16

# CPUs generally don't support anything other than fp32.
elif is_cpu:
torch_dtype = torch.float32
# CPUs generally don't support anything other than fp32.
elif is_cpu:
torch_dtype = torch.float32

# If the model is fp32 but bf16 is available, convert to bf16.
# Usually models with fp32 weights were actually trained in bf16, and
# converting them doesn't hurt performance.
elif fp32_weights and torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
print("Weights seem to be fp32, but bf16 is available. Loading in bf16.")
else:
torch_dtype = "auto"
return torch_dtype
# If the model is fp32 but bf16 is available, convert to bf16.
# Usually models with fp32 weights were actually trained in bf16, and
# converting them doesn't hurt performance.
elif fp32_weights and torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
print("Weights seem to be fp32, but bf16 is available. Loading in bf16.")
else:
torch_dtype = "auto"
return torch_dtype


def instantiate_model(
Expand Down Expand Up @@ -88,7 +87,7 @@ def instantiate_model(
if arch_str.endswith(suffix):
model_cls = getattr(transformers, arch_str)
return model_cls.from_pretrained(model_str, **kwargs)

print(f"Loading model with {kwargs}")
return AutoModel.from_pretrained(model_str, **kwargs)


Expand Down
1 change: 0 additions & 1 deletion elk/utils/multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def create_device_map(
model_str=model_str,
load_in_8bit=False,
is_cpu=False,
torch_dtype=torch.float16,
)

# e.g. {"cuda:0": 16000, "cuda:1": 16000}
Expand Down

0 comments on commit bf827ea

Please sign in to comment.