Skip to content

Commit

Permalink
fix(diffusion): deprecate WeightSeparatedDataParallel
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Sep 18, 2024
1 parent b5af576 commit b0846e9
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import torch
from huggingface_hub import snapshot_download
from packaging.version import Version
from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig
from transformers.modeling_outputs import ModelOutput

Expand Down Expand Up @@ -62,7 +61,6 @@
)
from .utils.require_utils import requires_torch_neuronx
from .utils.version_utils import get_neuronxcc_version
from .version import __sdk_version__


if is_neuronx_available():
Expand Down Expand Up @@ -351,15 +349,6 @@ def load_model(
"text_encoder_2": text_encoder_2_path,
"controlnet": controlnet_paths,
}
# DataParallel class to use (to remove after neuron sdk 2.20)
if to_neuron:
if Version(__sdk_version__) >= Version("2.20.0"):
raise NameError(
"`WeightSeparatedDataParallel` class should be deprecated when neuron sdk 2.20 is out. Please replace it with `torch_neuronx.DataParallel`."
)
dp_cls = WeightSeparatedDataParallel
else:
dp_cls = torch_neuronx.DataParallel

if data_parallel_mode == "all":
logger.info("Loading the whole pipeline into both Neuron Cores...")
Expand All @@ -372,7 +361,7 @@ def load_model(
submodel = NeuronTracedModel.load_model(
submodel_path, to_neuron=False
) # No need to load to neuron manually when dp
submodel = dp_cls(
submodel = torch_neuronx.DataParallel(
submodel,
[0, 1],
set_dynamic_batching=dynamic_batch_size,
Expand All @@ -395,7 +384,7 @@ def load_model(
unet = NeuronTracedModel.load_model(
unet_path, to_neuron=False
) # No need to load to neuron manually when dp
submodels["unet"] = dp_cls(
submodels["unet"] = torch_neuronx.DataParallel(
unet,
[0, 1],
set_dynamic_batching=dynamic_batch_size,
Expand All @@ -408,7 +397,9 @@ def load_model(
controlnet = NeuronTracedModel.load_model(
controlnet_path, to_neuron=False
) # No need to load to neuron manually when dp
controlnets.append(dp_cls(controlnet, [0, 1], set_dynamic_batching=dynamic_batch_size))
controlnets.append(
torch_neuronx.DataParallel(controlnet, [0, 1], set_dynamic_batching=dynamic_batch_size)
)
if controlnets:
submodels["controlnet"] = controlnets if len(controlnets) > 1 else controlnets[0]
else:
Expand Down

0 comments on commit b0846e9

Please sign in to comment.