Skip to content

Commit

Permalink
make deepspeed model initialization fatser
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 27, 2024
1 parent 97e0b9e commit df79a8f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ run_docker_cuda:
--rm \
--pid host \
--shm-size 64G \
--gpus '"device=0,1"' \
--gpus all \
--entrypoint /bin/bash \
--volume $(PWD):/workspace \
--workdir /workspace \
Expand All @@ -81,8 +81,7 @@ run_docker_rocm:
--pid host \
--shm-size 64G \
--device /dev/kfd \
--device /dev/dri/renderD128 \
--device /dev/dri/renderD129 \
--device /dev/dri/ \
--entrypoint /bin/bash \
--volume $(PWD):/workspace \
--workdir /workspace \
Expand Down
33 changes: 22 additions & 11 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,33 @@ def load_model_from_pretrained(self) -> None:
LOGGER.info(f"\t+ Moving pipeline to device: {self.config.device}")
self.pretrained_model.to(self.config.device)
elif self.config.deepspeed_inference:
with torch.device("cpu"):
LOGGER.info("\t+ Loading DeepSpeed model directly on CPU to avoid OOM")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model, **self.config.hub_kwargs, **self.automodel_kwargs
)
if self.config.no_weights:
with torch.device("meta"):
LOGGER.info("\t+ Loading model on meta device for fast initialization")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
**self.config.hub_kwargs,
**self.automodel_kwargs,
)
LOGGER.info("\t+ Materializing model on CPU")
self.pretrained_model.to_empty(device="cpu")
LOGGER.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()
else:
LOGGER.info("\t+ Loading model on cpu to avoid OOM")
with torch.device("cpu"):
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
**self.config.hub_kwargs,
**self.automodel_kwargs,
)

torch.distributed.barrier() # better safe than hanging
LOGGER.info("\t+ Initializing DeepSpeed Inference")
LOGGER.info("\t+ Initializing DeepSpeed Inference Engine")
self.pretrained_model = init_inference(self.pretrained_model, config=self.config.deepspeed_inference_config)
torch.distributed.barrier() # better safe than hanging
elif self.is_quantized:
# we can't use device context manager since the model is quantized
# we can't use device context manager on quantized models
LOGGER.info("\t+ Loading Quantized model")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
Expand Down Expand Up @@ -218,10 +233,6 @@ def load_model_with_no_weights(self) -> None:
self.load_model_from_pretrained()
self.config.model = original_model

# dunno how necessary this is
LOGGER.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

def process_quantization_config(self) -> None:
if self.is_gptq_quantized:
LOGGER.info("\t+ Processing GPTQ config")
Expand Down

0 comments on commit df79a8f

Please sign in to comment.