Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pre Neuron Inf Cache system]Support neff/weights decoupling #402

Merged
merged 24 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=Path,
help="Path indicating the directory where to store intermediary files generated by Neuronx compiler.",
)
optional_group.add_argument(
"--disable-weights-neff-inline",
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
action="store_true",
help="Whether to disable the weights / neff graph inline. You can only replace weights of neuron-compiled models when the weights-neff inlining has been disabled during the compilation.",
)
optional_group.add_argument(
"--disable-validation",
action="store_true",
Expand Down
3 changes: 3 additions & 0 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def main_export(
atol: Optional[float] = None,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[Union[str, Path]] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
trust_remote_code: bool = False,
subfolder: str = "",
Expand Down Expand Up @@ -415,6 +416,7 @@ def main_export(
models_and_neuron_configs=models_and_neuron_configs,
output_dir=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
output_file_names=output_model_names,
compiler_kwargs=compiler_kwargs,
Expand Down Expand Up @@ -523,6 +525,7 @@ def main():
atol=args.atol,
cache_dir=args.cache_dir,
compiler_workdir=args.compiler_workdir,
inline_weights_to_neff=not args.disable_weights_neff_inline,
optlevel=optlevel,
trust_remote_code=args.trust_remote_code,
subfolder=args.subfolder,
Expand Down
15 changes: 15 additions & 0 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def export_models(
],
output_dir: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
output_file_names: Optional[Dict[str, str]] = None,
compiler_kwargs: Optional[Dict[str, Any]] = {},
Expand All @@ -288,6 +289,8 @@ def export_models(
Output directory to store the exported Neuron models.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory to store intermediary outputs of the neuron compiler.
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
Expand Down Expand Up @@ -334,6 +337,7 @@ def export_models(
config=sub_neuron_config,
output=output_path,
compiler_workdir=compiler_workdir_path,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
**compiler_kwargs,
)
Expand Down Expand Up @@ -362,6 +366,7 @@ def export_models(
dynamic_batch_size=sub_neuron_config.dynamic_batch_size,
compiler_type=NEURON_COMPILER_TYPE,
compiler_version=NEURON_COMPILER_VERSION,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
model_type=getattr(sub_neuron_config, "MODEL_TYPE", None),
task=getattr(sub_neuron_config, "task", None),
Expand Down Expand Up @@ -392,6 +397,7 @@ def export(
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
Expand All @@ -406,6 +412,7 @@ def export(
config=config,
output=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
auto_cast=auto_cast,
auto_cast_type=auto_cast_type,
Expand All @@ -421,6 +428,7 @@ def export_neuronx(
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
Expand All @@ -437,6 +445,8 @@ def export_neuronx(
Directory to store the exported Neuron model.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory used by neuronx-cc, where you can find intermediary outputs (neff, weight, hlo...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
Expand Down Expand Up @@ -504,10 +514,15 @@ def export_neuronx(
dummy_inputs_tuple,
compiler_args=compiler_args,
input_output_aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)

if config.dynamic_batch_size is True:
if not inline_weights_to_neff:
raise ValueError(
"Dynamic batching is not yet compatible with the weights/neff non-inlined model. Please set `dynamic_batch_size=False` or `inline_weights_to_neff=True`."
)
neuron_model = neuronx.dynamic_batch(neuron_model)

# diffusers specific
Expand Down
30 changes: 18 additions & 12 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
@register_in_tasks_manager("bert", *COMMON_TEXT_TASKS)
class BertNeuronConfig(TextEncoderNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert")
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
Expand All @@ -83,6 +83,8 @@ class AlbertNeuronConfig(BertNeuronConfig):

@register_in_tasks_manager("convbert", *COMMON_TEXT_TASKS)
class ConvBertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-1 # TODO: why accuracy more off than other arch

@property
def outputs(self) -> List[str]:
if self.task == "feature-extraction":
Expand All @@ -91,12 +93,16 @@ def outputs(self) -> List[str]:


@register_in_tasks_manager("electra", *COMMON_TEXT_TASKS)
class ElectraNeuronConfig(ConvBertNeuronConfig):
pass
class ElectraNeuronConfig(BertNeuronConfig):
@property
def outputs(self) -> List[str]:
if self.task == "feature-extraction":
return ["last_hidden_state"]
return self._TASK_TO_COMMON_OUTPUTS[self.task]


@register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS)
class FlaubertNeuronConfig(ConvBertNeuronConfig):
class FlaubertNeuronConfig(ElectraNeuronConfig):
pass


Expand All @@ -106,18 +112,18 @@ class MobileBertNeuronConfig(BertNeuronConfig):


@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ConvBertNeuronConfig):
class RoFormerNeuronConfig(ElectraNeuronConfig):
pass


@register_in_tasks_manager("xlm", *COMMON_TEXT_TASKS)
class XLMNeuronConfig(ConvBertNeuronConfig):
class XLMNeuronConfig(ElectraNeuronConfig):
pass


@register_in_tasks_manager("distilbert", *COMMON_TEXT_TASKS)
class DistilBertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
Expand All @@ -132,7 +138,7 @@ def outputs(self) -> List[str]:

@register_in_tasks_manager("camembert", *COMMON_TEXT_TASKS)
class CamembertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
Expand All @@ -156,8 +162,8 @@ class XLMRobertaNeuronConfig(CamembertNeuronConfig):

# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta", *COMMON_TEXT_TASKS)
class DebertaNeuronConfig(BertNeuronConfig):
@register_in_tasks_manager("deberta", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"]))
class DebertaNeuronConfig(ElectraNeuronConfig):
@property
def inputs(self) -> List[str]:
common_inputs = super().inputs
Expand All @@ -169,8 +175,8 @@ def inputs(self) -> List[str]:

# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta-v2", *COMMON_TEXT_TASKS)
class DebertaV2NeuronConfig(DebertaNeuronConfig):
@register_in_tasks_manager("deberta-v2", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"]))
class DebertaV2NeuronConfig(ElectraNeuronConfig):
pass


Expand Down
26 changes: 24 additions & 2 deletions optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from ..exporters.tasks import TasksManager
from ..modeling_base import OptimizedModel
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .utils import NEURON_FILE_NAME, is_neuron_available, store_compilation_config
from .utils import (
NEURON_FILE_NAME,
check_if_weights_replacable,
is_neuron_available,
replace_weights,
store_compilation_config,
)
from .utils.import_utils import is_neuronx_available
from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version

Expand Down Expand Up @@ -103,7 +109,13 @@ def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule:
path = Path(path)

if path.is_file():
return torch.jit.load(path)
model = torch.jit.load(path)
return model

def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None):
check_if_weights_replacable(self.config, weights)
if weights is not None:
replace_weights(self.model, weights)

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down Expand Up @@ -216,6 +228,7 @@ def _export(
force_download: bool = False,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[Union[str, Path]] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
subfolder: str = "",
local_files_only: bool = False,
Expand Down Expand Up @@ -303,6 +316,7 @@ def _export(
config=neuron_config,
output=save_dir_path / NEURON_FILE_NAME,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
**compiler_kwargs,
)
Expand All @@ -316,6 +330,7 @@ def _export(
dynamic_batch_size=dynamic_batch_size,
compiler_type=compiler_type,
compiler_version=compiler_version,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
task=task,
)
Expand Down Expand Up @@ -570,3 +585,10 @@ def remove_padding(
]

return outputs

@property
def is_weights_neff_separated(self) -> bool:
"""
Whether the Neuron model has separated weights and neff graph (by setting `inline_weights_to_neff=False` during the compilation).
"""
return not self.config.neuron.get("inline_weights_to_neff", True)
4 changes: 4 additions & 0 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def _export(
force_download: bool = True,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[str] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
subfolder: str = "",
local_files_only: bool = False,
Expand Down Expand Up @@ -580,6 +581,8 @@ def _export(
standard cache should not be used.
compiler_workdir (`Optional[str]`, defaults to `None`):
Path to a directory in which the neuron compiler will store all intermediary files during the compilation(neff, weight, hlo graph...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
Expand Down Expand Up @@ -640,6 +643,7 @@ def _export(
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _export(
force_download: bool = True,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[str] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
subfolder: str = "",
local_files_only: bool = False,
Expand Down Expand Up @@ -302,6 +303,7 @@ def _export(
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_transformers_neuronx_available,
)
from .input_generators import DummyBeamValuesGenerator
from .misc import check_if_weights_replacable, replace_weights
from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl
from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function
from .training_utils import (
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/utils/argument_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def store_compilation_config(
dynamic_batch_size: bool,
compiler_type: str,
compiler_version: str,
inline_weights_to_neff: bool,
optlevel: str,
model_type: Optional[str] = None,
task: str = None,
Expand All @@ -161,6 +162,7 @@ def store_compilation_config(
# Add neuron version to the config, so it can be checked at load time
config_args["compiler_type"] = compiler_type
config_args["compiler_version"] = compiler_version
config_args["inline_weights_to_neff"] = inline_weights_to_neff

# Add input shapes during compilation to the config
for axis, shape in input_shapes.items():
Expand Down
42 changes: 41 additions & 1 deletion optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import torch
from transformers.modeling_utils import _add_variant
Expand All @@ -42,6 +42,9 @@
from .require_utils import requires_safetensors


if TYPE_CHECKING:
from transformers import PretrainedConfig

logger = logging.get_logger()


Expand Down Expand Up @@ -508,3 +511,40 @@ def download_checkpoints_in_cache(
resolved_archive_file = filenames_to_safetensors_filenames[Path(resolved_archive_file).name]

return resolved_archive_file, sharded_metadata


def replace_weights(
model: torch.jit._script.RecursiveScriptModule,
weights: Union[Dict[str, torch.Tensor], torch.nn.Module],
prefix: str = "model",
):
"""
Replaces the weights in a Neuron Model with weights from another model, the original neuron model should have separated weights(by setting `inline_weights_to_neff=Talse` during the tracing).
"""
if isinstance(weights, torch.nn.Module):
weights = weights.state_dict()

# extract module paths from the weights c module
code = model.weights._c.code
start_str = "__parameters__ = ["
end_str = "]\n"
module_paths = code.split(start_str)[1].split(end_str)[0].strip()[:-1:].replace('"', "").split(", ")
module_paths = [module_path for module_path in module_paths if module_path != ""]

for module_path in module_paths:
if len(re.findall("\w\d+", module_path)) > 0:
continue
else:
model.weights._c.setattr(module_path, weights[module_path.replace(prefix + "->", "").replace("->", ".")])


def check_if_weights_replacable(
config: "PretrainedConfig", weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]]
):
is_weights_neff_separated = (
not config.neuron.get("inline_weights_to_neff", True) if hasattr(config, "neuron") else False
)
if weights is not None and not is_weights_neff_separated:
raise RuntimeError(
"Unable to replace weights of the neuron model since its weights and neff are not separated, please set `inline_weights_to_neff=Talse` when converting the model to Neuron format."
)
Loading
Loading