Skip to content

Commit

Permalink
Fix regex
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 13, 2023
1 parent a05f8ce commit d80b3d7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _prepare_model_for_tp(
cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, return_orig_to_parallel=False, device=None)
# model.to(torch.float32)
model.to(torch.float32)
parallel_layers.move_model_to_device(model, self.device)
model.tie_weights()
self._model_cpu_parameters_to_xla[id(model)] = dict(zip(cpu_ids, model.parameters()))
Expand Down
9 changes: 5 additions & 4 deletions optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def args_and_kwargs_to_kwargs_only(

def _original_filename_to_safetensors_filename(filename: str) -> str:
"""Transforms the filename for any kind of checkpoint to a safetensors equivalent."""
name, extension = filename.rsplit(".", maxsplit=1)
pattern = rf"{name}(-[0-9]*-of-[0-9]*)?\.{extension}"
_, extension = filename.rsplit(".", maxsplit=1)
pattern = rf"\w+(-[0-9]*-of-[0-9]*)?\.{extension}"
match_ = re.match(pattern, filename)
if not match_:
raise ValueError(f"Could not convert {filename} to a safetensor filename.")
Expand Down Expand Up @@ -139,11 +139,12 @@ def convert_checkpoint_to_safetensors(

already_exists = safetensors_path.is_file()
is_distributed = torch.distributed.is_initialized()
is_main_process = is_distributed and torch.distributed.get_rank() > 0
is_main_process = is_distributed and torch.distributed.get_rank() == 0

# Only one worker will load the checkpoint (potentially huge) and perform the conversion.
if not already_exists and (not is_distributed or is_main_process):
logger.info(f"Converting {weight_file} to safetensors")
if log:
logger.info(f"Converting {weight_file} to safetensors")
checkpoint = torch.load(weight_file)
data_pointers = set()
for k, v in checkpoint.items():
Expand Down

0 comments on commit d80b3d7

Please sign in to comment.