Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Fix inference latency issue when weights/neff are separated #584

Merged
merged 36 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5bc3516
something not working with data parallel
JingyaHuang Apr 30, 2024
469a58c
Merge branch 'main' into fix-non-inlined-perf
JingyaHuang May 7, 2024
1e12227
add workaround
JingyaHuang May 7, 2024
0c63ce8
fix
JingyaHuang May 7, 2024
7cc8288
fix style
JingyaHuang May 7, 2024
532b2a2
remove comments
JingyaHuang May 7, 2024
96144d8
fix doc build
JingyaHuang May 7, 2024
7416320
fix doc build
JingyaHuang May 7, 2024
353858b
fix doc build
JingyaHuang May 7, 2024
604ba9a
Merge branch 'main' into fix-non-inlined-perf
JingyaHuang May 8, 2024
04b2e14
bump dev version
JingyaHuang May 8, 2024
237e159
lazy loading
JingyaHuang May 8, 2024
83acdfe
move custom dp class under sd modeling
JingyaHuang May 8, 2024
1a54150
fix?
JingyaHuang May 8, 2024
45f2a4f
fix naming conflict on importing
JingyaHuang May 8, 2024
d15d22d
Merge branch 'main' into fix-non-inlined-perf
JingyaHuang May 20, 2024
af35486
add docstring
JingyaHuang May 20, 2024
b47967e
fix import
JingyaHuang May 20, 2024
052447e
fix tests
JingyaHuang May 20, 2024
2425138
fix test
JingyaHuang May 21, 2024
4f3377a
fix test
JingyaHuang May 21, 2024
1e924b0
fix for decoder as well
JingyaHuang May 21, 2024
a9345d9
try fix
JingyaHuang May 22, 2024
00d1d5d
try fix
JingyaHuang May 22, 2024
e211d41
try fix
JingyaHuang May 22, 2024
fda3303
try fix
JingyaHuang May 22, 2024
a17a3e8
try fix
JingyaHuang May 22, 2024
abf45ce
try fix
JingyaHuang May 22, 2024
c37c9d1
Merge branch 'main' into fix-non-inlined-perf
JingyaHuang May 24, 2024
00a48e8
fix style
JingyaHuang May 24, 2024
8107d83
fix typo
JingyaHuang May 24, 2024
87c3902
add back previous fix
JingyaHuang May 24, 2024
0ad907a
add back test with subprocess to pass ddp
JingyaHuang May 27, 2024
1a91d1a
Merge branch 'main' into fix-non-inlined-perf
JingyaHuang May 27, 2024
30c6736
leave test not on ddp until the cleanup on neuron device fixed in neu…
JingyaHuang May 27, 2024
068cf90
for sdxl as well
JingyaHuang May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
store_compilation_config,
)
from ...neuron.utils.cache_utils import get_model_name_or_path
from ...neuron.utils.hub_neuronx_cache import (
from ...neuron.utils.hub_cache_utils import (
dacorvo marked this conversation as resolved.
Show resolved Hide resolved
ModelCacheEntry,
build_cache_config,
cache_traced_neuron_artifacts,
Expand Down
88 changes: 74 additions & 14 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
from huggingface_hub import snapshot_download
from packaging.version import Version
from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig
from transformers.modeling_outputs import ModelOutput

Expand All @@ -53,13 +54,14 @@
replace_weights,
store_compilation_config,
)
from .utils.hub_neuronx_cache import (
from .utils.hub_cache_utils import (
ModelCacheEntry,
build_cache_config,
create_hub_compile_cache_proxy,
)
from .utils.require_utils import requires_torch_neuronx
from .utils.version_utils import get_neuronxcc_version
from .version import __sdk_version__


if is_neuronx_available():
Expand Down Expand Up @@ -114,15 +116,15 @@ def __init__(
unet: torch.jit._script.ScriptModule,
vae_decoder: Union[torch.jit._script.ScriptModule, "NeuronModelVaeDecoder"],
config: Dict[str, Any],
configs: Dict[str, "PretrainedConfig"],
neuron_configs: Dict[str, "NeuronDefaultConfig"],
tokenizer: CLIPTokenizer,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, LCMScheduler],
data_parallel_mode: str,
vae_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]] = None,
text_encoder_2: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]] = None,
tokenizer_2: Optional[CLIPTokenizer] = None,
feature_extractor: Optional[CLIPFeatureExtractor] = None,
configs: Optional[Dict[str, "PretrainedConfig"]] = None,
neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None,
):
Expand All @@ -137,6 +139,10 @@ def __init__(
config (`Dict[str, Any]`):
A config dictionary from which the model components will be instantiated. Make sure to only load
configuration files of compatible classes.
configs (Dict[str, "PretrainedConfig"], defaults to `None`):
A dictionary configurations for components of the pipeline.
neuron_configs (Dict[str, "NeuronDefaultConfig"], defaults to `None`):
A list of Neuron configurations.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
Expand All @@ -154,10 +160,6 @@ def __init__(
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`):
A model extracting features from generated images to be used as inputs for the `safety_checker`
configs (Optional[Dict[str, "PretrainedConfig"]], defaults to `None`):
A dictionary configurations for components of the pipeline.
neuron_configs (Optional["NeuronDefaultConfig"], defaults to `None`):
A list of Neuron configurations.
model_save_dir (`Optional[Union[str, Path, TemporaryDirectory]]`, defaults to `None`):
The directory under which the exported Neuron models were saved.
model_and_config_save_paths (`Optional[Dict[str, Tuple[str, Path]]]`, defaults to `None`):
Expand Down Expand Up @@ -274,6 +276,7 @@ def load_model(
vae_encoder_path: Optional[Union[str, Path]] = None,
text_encoder_2_path: Optional[Union[str, Path]] = None,
dynamic_batch_size: bool = False,
to_neuron: bool = False,
):
"""
Loads Stable Diffusion TorchScript modules compiled by neuron(x)-cc compiler. It will be first loaded onto CPU and then moved to
Expand All @@ -295,6 +298,8 @@ def load_model(
Path of the compiled second frozen text encoder. SDXL only.
dynamic_batch_size (`bool`, defaults to `False`):
Whether enable dynamic batch size for neuron compiled model. If `True`, the input batch size can be a multiple of the batch size during the compilation.
to_neuron (`bool`, defaults to `False`):
Whether to move manually the traced model to NeuronCore. It's only needed when `inline_weights_to_neff=False`, otherwise it is loaded automatically to a Neuron device.
"""
submodels = {
"text_encoder": text_encoder_path,
Expand All @@ -303,12 +308,25 @@ def load_model(
"vae_encoder": vae_encoder_path,
"text_encoder_2": text_encoder_2_path,
}
# DataParallel class to use (to remove after neuron sdk 2.20)
if to_neuron:
if Version(__sdk_version__) >= Version("2.20.0"):
raise NameError(
"`WeightSeparatedDataParallel` class should be deprecated when neuron sdk 2.20 is out. Please replace it with `torch_neuronx.DataParallel`."
)
dp_cls = WeightSeparatedDataParallel
else:
dp_cls = torch_neuronx.DataParallel

if data_parallel_mode == "all":
logger.info("Loading the whole pipeline into both Neuron Cores...")
for submodel_name, submodel_path in submodels.items():
if submodel_path is not None and submodel_path.is_file():
submodels[submodel_name] = torch_neuronx.DataParallel(
torch.jit.load(submodel_path),
submodel = NeuronTracedModel.load_model(
submodel_path, to_neuron=False
) # No need to load to neuron manually when dp
submodels[submodel_name] = dp_cls(
submodel,
[0, 1],
set_dynamic_batching=dynamic_batch_size,
)
Expand All @@ -319,19 +337,24 @@ def load_model(
submodels.pop("unet")
for submodel_name, submodel_path in submodels.items():
if submodel_path is not None and submodel_path.is_file():
submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path)
submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron)
else:
submodels[submodel_name] = None
submodels["unet"] = torch_neuronx.DataParallel(
torch.jit.load(unet_path),
unet = NeuronTracedModel.load_model(
unet_path, to_neuron=False
) # No need to load to neuron manually when dp
submodels["unet"] = dp_cls(
unet,
[0, 1],
set_dynamic_batching=dynamic_batch_size,
)
elif data_parallel_mode == "none":
logger.info("Loading the pipeline without any data parallelism...")
for submodel_name, submodel_path in submodels.items():
if submodel_path is not None and submodel_path.is_file():
submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path)
submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron)
else:
submodels[submodel_name] = None
else:
raise ValueError("You need to pass `data_parallel_mode` to define Neuron Core allocation.")

Expand Down Expand Up @@ -524,6 +547,10 @@ def _from_pretrained(
model_config = DiffusersPretrainedConfig.from_json_file(file_paths[1])
configs[name] = model_config
neuron_configs[name] = cls._neuron_config_init(model_config)
inline_weights_to_neff = all(
neuron_config._config.neuron.get("inline_weights_to_neff", False)
for _, neuron_config in neuron_configs.items()
)

if data_parallel_mode is None:
data_parallel_mode = cls.set_default_dp_mode(configs["unet"])
Expand All @@ -536,6 +563,7 @@ def _from_pretrained(
vae_encoder_path=model_and_config_save_paths["vae_encoder"][0] if vae_encoder is None else None,
text_encoder_2_path=model_and_config_save_paths["text_encoder_2"][0] if text_encoder_2 is None else None,
dynamic_batch_size=neuron_configs[DIFFUSION_MODEL_UNET_NAME].dynamic_batch_size,
to_neuron=not inline_weights_to_neff,
)

if model_save_dir is None:
Expand Down Expand Up @@ -745,7 +773,7 @@ def _export(
optlevel=optlevel,
model_type=getattr(neuron_config, "MODEL_TYPE", None),
task=getattr(neuron_config, "task", None),
output_hidden_states=output_hidden_states,
output_hidden_states=getattr(neuron_config, "output_hidden_states", False),
)
compilation_configs[name] = compilation_config

Expand Down Expand Up @@ -1046,3 +1074,35 @@ class NeuronStableDiffusionXLInpaintPipeline(
NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLInpaintPipelineMixin
):
__call__ = NeuronStableDiffusionXLInpaintPipelineMixin.__call__


if is_neuronx_available():
# TO REMOVE: This class will be included directly in the DDP API of Neuron SDK 2.20
class WeightSeparatedDataParallel(torch_neuronx.DataParallel):

def _load_modules(self, module):
try:
self.device_ids.sort()

loaded_modules = [module]
# If device_ids is non-consecutive, perform deepcopy's and load onto each core independently.
for i in range(len(self.device_ids) - 1):
loaded_modules.append(copy.deepcopy(module))
for i, nc_index in enumerate(self.device_ids):
torch_neuronx.experimental.placement.set_neuron_cores(loaded_modules[i], nc_index, 1)
torch_neuronx.move_trace_to_device(loaded_modules[i], nc_index)

except ValueError as err:
self.dynamic_batching_failed = True
logger.warning(f"Automatic dynamic batching failed due to {err}.")
logger.warning(
"Please disable dynamic batching by calling `disable_dynamic_batching()` "
"on your DataParallel module."
)
self.num_workers = 2 * len(loaded_modules)
return loaded_modules

else:

class WeightSeparatedDataParallel:
pass
27 changes: 21 additions & 6 deletions optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
replace_weights,
store_compilation_config,
)
from .utils.hub_neuronx_cache import ModelCacheEntry, build_cache_config, create_hub_compile_cache_proxy
from .utils.hub_cache_utils import ModelCacheEntry, build_cache_config, create_hub_compile_cache_proxy
from .utils.import_utils import is_neuronx_available
from .utils.misc import maybe_load_preprocessors
from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version
Expand All @@ -54,6 +54,8 @@
NEURON_COMPILER_VERSION = get_neuroncc_version()

if is_neuronx_available():
from torch_neuronx import move_trace_to_device

NEURON_COMPILER_TYPE = "neuronx-cc"
NEURON_COMPILER_VERSION = get_neuronxcc_version()

Expand Down Expand Up @@ -100,25 +102,34 @@ def __init__(
self.model = model
self.model_file_name = model_file_name or NEURON_FILE_NAME
self.config = config
self.neuron_config = self._neuron_config_init(self.config) if neuron_config is None else neuron_config
self.neuron_config = neuron_config
self.input_static_shapes = NeuronTracedModel.get_input_static_shapes(self.neuron_config)
self._attributes_init(model_save_dir, preprocessors, **kwargs)

@staticmethod
def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule:
def load_model(
path: Union[str, Path], to_neuron: bool = False, device_id: int = 0
) -> torch.jit._script.ScriptModule:
"""
Loads a TorchScript module compiled by neuron(x)-cc compiler. It will be first loaded onto CPU and then moved to
one or multiple [NeuronCore](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuroncores-arch.html).

Args:
path (`Union[str, Path]`):
Path of the compiled model.
to_neuron (`bool`, defaults to `False`):
Whether to move manually the traced model to NeuronCore. It's only needed when `inline_weights_to_neff=False`, otherwise it is loaded automatically to a Neuron device.
device_id (`int`, defaults to 0):
Index of NeuronCore to load the traced model to.
"""
if not isinstance(path, Path):
path = Path(path)

if path.is_file():
model = torch.jit.load(path)
# For non-inlined models, send the module manually to device. This is important for weights/neff non-inlined module since when loading the module, the neff is automatically moved to Neuron but not the weights. We need to move the weights to Neuron as well manually to avoid great host to device IO penalty.
if is_neuronx_available() and to_neuron:
move_trace_to_device(model, device_id)
return model

def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None):
Expand Down Expand Up @@ -186,9 +197,13 @@ def _from_pretrained(
model_compiler_version = config.neuron.get("compiler_version")
check_compiler_compatibility(model_compiler_type, model_compiler_version)

# reconstruct neuron config
neuron_config = cls._neuron_config_init(config) if neuron_config is None else neuron_config
inline_weights_to_neff = config.neuron.get("inline_weights_to_neff", False)

preprocessors = None
if model_path.is_dir():
model = NeuronTracedModel.load_model(model_path / file_name)
model = NeuronTracedModel.load_model(model_path / file_name, to_neuron=not inline_weights_to_neff)
new_model_save_dir = model_path
else:
model_cache_path = hf_hub_download(
Expand All @@ -202,7 +217,7 @@ def _from_pretrained(
local_files_only=local_files_only,
)

model = NeuronTracedModel.load_model(model_cache_path)
model = NeuronTracedModel.load_model(model_cache_path, to_neuron=not inline_weights_to_neff)
new_model_save_dir = Path(model_cache_path).parent

preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
Expand Down Expand Up @@ -608,7 +623,7 @@ def is_weights_neff_separated(self) -> bool:
"""
Whether the Neuron model has separated weights and neff graph (by setting `inline_weights_to_neff=False` during the compilation).
"""
return not self.config.neuron.get("inline_weights_to_neff", True)
return not self.config.neuron.get("inline_weights_to_neff")

def can_generate(self) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
get_num_neuron_cores_used,
has_write_access_to_repo,
)
from .utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.hub_cache_utils import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.misc import is_main_worker, is_precompilation
from .utils.require_utils import requires_neuronx_distributed, requires_torch_neuronx
from .utils.training_utils import (
Expand Down
Loading
Loading