From b0846e9a5103789ea24bd34c16526f85543e8505 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 18 Sep 2024 07:46:51 +0000 Subject: [PATCH] fix(diffusion): deprecate WeightSeparatedDataParallel --- optimum/neuron/modeling_diffusion.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 30a310bf8..91f137195 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -27,7 +27,6 @@ import torch from huggingface_hub import snapshot_download -from packaging.version import Version from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig from transformers.modeling_outputs import ModelOutput @@ -62,7 +61,6 @@ ) from .utils.require_utils import requires_torch_neuronx from .utils.version_utils import get_neuronxcc_version -from .version import __sdk_version__ if is_neuronx_available(): @@ -351,15 +349,6 @@ def load_model( "text_encoder_2": text_encoder_2_path, "controlnet": controlnet_paths, } - # DataParallel class to use (to remove after neuron sdk 2.20) - if to_neuron: - if Version(__sdk_version__) >= Version("2.20.0"): - raise NameError( - "`WeightSeparatedDataParallel` class should be deprecated when neuron sdk 2.20 is out. Please replace it with `torch_neuronx.DataParallel`." - ) - dp_cls = WeightSeparatedDataParallel - else: - dp_cls = torch_neuronx.DataParallel if data_parallel_mode == "all": logger.info("Loading the whole pipeline into both Neuron Cores...") @@ -372,7 +361,7 @@ def load_model( submodel = NeuronTracedModel.load_model( submodel_path, to_neuron=False ) # No need to load to neuron manually when dp - submodel = dp_cls( + submodel = torch_neuronx.DataParallel( submodel, [0, 1], set_dynamic_batching=dynamic_batch_size, @@ -395,7 +384,7 @@ def load_model( unet = NeuronTracedModel.load_model( unet_path, to_neuron=False ) # No need to load to neuron manually when dp - submodels["unet"] = dp_cls( + submodels["unet"] = torch_neuronx.DataParallel( unet, [0, 1], set_dynamic_batching=dynamic_batch_size, @@ -408,7 +397,9 @@ def load_model( controlnet = NeuronTracedModel.load_model( controlnet_path, to_neuron=False ) # No need to load to neuron manually when dp - controlnets.append(dp_cls(controlnet, [0, 1], set_dynamic_batching=dynamic_batch_size)) + controlnets.append( + torch_neuronx.DataParallel(controlnet, [0, 1], set_dynamic_batching=dynamic_batch_size) + ) if controlnets: submodels["controlnet"] = controlnets if len(controlnets) > 1 else controlnets[0] else: