From b7ffce7ff30745320a681f59fe7a522bb4492c2e Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 22 Oct 2024 16:01:24 +0000 Subject: [PATCH] wip --- examples/nxd/modules/model_base.py | 19 ++++++++----------- examples/nxd/modules/model_wrapper.py | 17 ++--------------- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/examples/nxd/modules/model_base.py b/examples/nxd/modules/model_base.py index ada42b32e..47af0565c 100644 --- a/examples/nxd/modules/model_base.py +++ b/examples/nxd/modules/model_base.py @@ -352,7 +352,7 @@ class NeuronBaseForCausalLM(GenerationMixin): _supports_cache_class = False # _supports_static_cache = False - def __init__(self, model_path: str, config: PretrainedConfig): + def __init__(self, config: PretrainedConfig, model: torch.jit.ScriptModule): super().__init__() self.config = config @@ -363,20 +363,15 @@ def __init__(self, model_path: str, config: PretrainedConfig): self.sampler = None - self.model_path = model_path self.context_encoding_model = ModelWrapper( config=self.config, - model_cls=self._model_cls, + model=model, tag=CONTEXT_ENCODING_MODEL_TAG, - max_input_tokens=self.config.max_context_length, - max_total_tokens=self.config.max_context_length, ) self.token_generation_model = ModelWrapper( config=self.config, - model_cls=self._model_cls, + model=model, tag=TOKEN_GENERATION_MODEL_TAG, - max_input_tokens=1, - max_total_tokens=self.config.max_length, ) def can_generate(self): @@ -437,7 +432,10 @@ def export(cls, model_path: Union[str, Path], config: NeuronInferenceConfig, ser def get_traced_model_path(base_path: Union[str, Path]): return os.path.join(base_path, "model.pt") - def load(self, serialize_base_path): + @classmethod + def load(cls, serialize_base_path): + + config = cls._config_cls.from_pretrained(serialize_base_path) traced_model = torch.jit.load(NeuronBaseForCausalLM.get_traced_model_path(serialize_base_path)) @@ -454,8 +452,7 @@ def get_rank(shard): traced_model.nxd_model.initialize(weights) - self.context_encoding_model.model = traced_model - self.token_generation_model.model = traced_model + return cls(config, traced_model) @property def device(self) -> torch.device: diff --git a/examples/nxd/modules/model_wrapper.py b/examples/nxd/modules/model_wrapper.py index ae1024d7d..1f0f6fcdd 100644 --- a/examples/nxd/modules/model_wrapper.py +++ b/examples/nxd/modules/model_wrapper.py @@ -15,24 +15,11 @@ class ModelWrapper(torch.nn.Module): - def __init__(self, config, model_cls, tag="", max_input_tokens: int = 128, max_total_tokens: int = 128) -> None: + def __init__(self, config, model, tag="") -> None: super().__init__() self.config = config - - if not self.config.torch_dtype: - self.config.torch_dtype = torch.float32 - - if self.config.pad_token_id is None: - self.config.pad_token_id = 0 - - self.model_cls = model_cls - self.model = None - self.is_compiled = False - self.serialize_base_path = None + self.model = model self.tag = tag - self.enable_bucketing = config.enable_bucketing - self.max_input_tokens = max_input_tokens - self.max_total_tokens = max_total_tokens def _forward_with_pad(self, *args): seq_ids = args[3]