From 2ba3432086cf4c1e87aa13866b1b79fe37c77eb8 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 22 Oct 2024 16:01:24 +0000 Subject: [PATCH] refactor(nxd): make load a class method --- examples/nxd/llama2/llama2_runner.py | 14 -------------- examples/nxd/llama2/neuron_modeling_llama.py | 1 + examples/nxd/modules/model_base.py | 20 +++++++++----------- examples/nxd/modules/model_wrapper.py | 17 ++--------------- examples/nxd/runner.py | 15 +++++++++------ 5 files changed, 21 insertions(+), 46 deletions(-) diff --git a/examples/nxd/llama2/llama2_runner.py b/examples/nxd/llama2/llama2_runner.py index 3a30ba76b..538b9afae 100644 --- a/examples/nxd/llama2/llama2_runner.py +++ b/examples/nxd/llama2/llama2_runner.py @@ -1,4 +1,3 @@ -import torch from runner import InferenceRunner from llama2.neuron_modeling_llama import ( @@ -9,19 +8,6 @@ class LlamaRunner(InferenceRunner): - def load_neuron_model(self, traced_model_path): - config = NeuronLlamaConfig.from_pretrained(traced_model_path) - model = NeuronLlamaForCausalLM.from_pretrained("", config) - self.config = config - - model.load(traced_model_path) - if config.torch_dtype == torch.bfloat16: - model.context_encoding_model.bfloat16() - if model.token_generation_model is not None: - model.token_generation_model.bfloat16() - - return model - def get_config_cls(self): return NeuronLlamaConfig diff --git a/examples/nxd/llama2/neuron_modeling_llama.py b/examples/nxd/llama2/neuron_modeling_llama.py index 66630f2b7..777df7de9 100644 --- a/examples/nxd/llama2/neuron_modeling_llama.py +++ b/examples/nxd/llama2/neuron_modeling_llama.py @@ -342,3 +342,4 @@ class NeuronLlamaForCausalLM(NeuronBaseForCausalLM): """ _model_cls = NeuronLlamaModel + _config_cls = NeuronLlamaConfig diff --git a/examples/nxd/modules/model_base.py b/examples/nxd/modules/model_base.py index ada42b32e..db505a018 100644 --- a/examples/nxd/modules/model_base.py +++ b/examples/nxd/modules/model_base.py @@ -346,13 +346,14 @@ class NeuronBaseForCausalLM(GenerationMixin): _STATE_DICT_MODEL_PREFIX = "model." _model_cls = None + _config_cls = None # Required by GenerationMixin, but present in PreTrainedModel main_input_name = "input_ids" _supports_cache_class = False # _supports_static_cache = False - def __init__(self, model_path: str, config: PretrainedConfig): + def __init__(self, config: PretrainedConfig, model: torch.jit.ScriptModule): super().__init__() self.config = config @@ -363,20 +364,15 @@ def __init__(self, model_path: str, config: PretrainedConfig): self.sampler = None - self.model_path = model_path self.context_encoding_model = ModelWrapper( config=self.config, - model_cls=self._model_cls, + model=model, tag=CONTEXT_ENCODING_MODEL_TAG, - max_input_tokens=self.config.max_context_length, - max_total_tokens=self.config.max_context_length, ) self.token_generation_model = ModelWrapper( config=self.config, - model_cls=self._model_cls, + model=model, tag=TOKEN_GENERATION_MODEL_TAG, - max_input_tokens=1, - max_total_tokens=self.config.max_length, ) def can_generate(self): @@ -437,7 +433,10 @@ def export(cls, model_path: Union[str, Path], config: NeuronInferenceConfig, ser def get_traced_model_path(base_path: Union[str, Path]): return os.path.join(base_path, "model.pt") - def load(self, serialize_base_path): + @classmethod + def load(cls, serialize_base_path): + + config = cls._config_cls.from_pretrained(serialize_base_path) traced_model = torch.jit.load(NeuronBaseForCausalLM.get_traced_model_path(serialize_base_path)) @@ -454,8 +453,7 @@ def get_rank(shard): traced_model.nxd_model.initialize(weights) - self.context_encoding_model.model = traced_model - self.token_generation_model.model = traced_model + return cls(config, traced_model) @property def device(self) -> torch.device: diff --git a/examples/nxd/modules/model_wrapper.py b/examples/nxd/modules/model_wrapper.py index ae1024d7d..1f0f6fcdd 100644 --- a/examples/nxd/modules/model_wrapper.py +++ b/examples/nxd/modules/model_wrapper.py @@ -15,24 +15,11 @@ class ModelWrapper(torch.nn.Module): - def __init__(self, config, model_cls, tag="", max_input_tokens: int = 128, max_total_tokens: int = 128) -> None: + def __init__(self, config, model, tag="") -> None: super().__init__() self.config = config - - if not self.config.torch_dtype: - self.config.torch_dtype = torch.float32 - - if self.config.pad_token_id is None: - self.config.pad_token_id = 0 - - self.model_cls = model_cls - self.model = None - self.is_compiled = False - self.serialize_base_path = None + self.model = model self.tag = tag - self.enable_bucketing = config.enable_bucketing - self.max_input_tokens = max_input_tokens - self.max_total_tokens = max_total_tokens def _forward_with_pad(self, *args): seq_ids = args[3] diff --git a/examples/nxd/runner.py b/examples/nxd/runner.py index c0a330b54..3d56f5188 100644 --- a/examples/nxd/runner.py +++ b/examples/nxd/runner.py @@ -42,8 +42,13 @@ def __init__(self, model_path: str = None, generation_config: GenerationConfig = self.generation_config = generation_config def load_neuron_model(self, traced_model_path): - # Implement per model - raise NotImplementedError + model = self.get_model_cls().load(traced_model_path) + + if model.config.torch_dtype == torch.bfloat16: + model.context_encoding_model.bfloat16() + model.token_generation_model.bfloat16() + + return model def get_config_cls(self): # Implement per model @@ -109,7 +114,7 @@ def get_config_for_nxd( config.trace_tokengen_model = kwargs.get("trace_tokengen_model", True) - config.pad_token_id = kwargs.get("pad_token_id", None) + config.pad_token_id = kwargs.get("pad_token_id", config.eos_token_id) return config @@ -192,6 +197,4 @@ def trace( # We have the config in the trace_model_path config.save_pretrained(traced_model_path) - model = self.get_model_cls().from_pretrained(self.model_path, config) - - model.compile(serialize_base_path=traced_model_path) + self.get_model_cls().export(self.model_path, config, serialize_base_path=traced_model_path)