Skip to content

Commit

Permalink
refactor(nxd): make load a class method
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 22, 2024
1 parent d3e7682 commit 2ba3432
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 46 deletions.
14 changes: 0 additions & 14 deletions examples/nxd/llama2/llama2_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from runner import InferenceRunner

from llama2.neuron_modeling_llama import (
Expand All @@ -9,19 +8,6 @@

class LlamaRunner(InferenceRunner):

def load_neuron_model(self, traced_model_path):
config = NeuronLlamaConfig.from_pretrained(traced_model_path)
model = NeuronLlamaForCausalLM.from_pretrained("", config)
self.config = config

model.load(traced_model_path)
if config.torch_dtype == torch.bfloat16:
model.context_encoding_model.bfloat16()
if model.token_generation_model is not None:
model.token_generation_model.bfloat16()

return model

def get_config_cls(self):
return NeuronLlamaConfig

Expand Down
1 change: 1 addition & 0 deletions examples/nxd/llama2/neuron_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,4 @@ class NeuronLlamaForCausalLM(NeuronBaseForCausalLM):
"""

_model_cls = NeuronLlamaModel
_config_cls = NeuronLlamaConfig
20 changes: 9 additions & 11 deletions examples/nxd/modules/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,14 @@ class NeuronBaseForCausalLM(GenerationMixin):
_STATE_DICT_MODEL_PREFIX = "model."

_model_cls = None
_config_cls = None

# Required by GenerationMixin, but present in PreTrainedModel
main_input_name = "input_ids"
_supports_cache_class = False
# _supports_static_cache = False

def __init__(self, model_path: str, config: PretrainedConfig):
def __init__(self, config: PretrainedConfig, model: torch.jit.ScriptModule):
super().__init__()

self.config = config
Expand All @@ -363,20 +364,15 @@ def __init__(self, model_path: str, config: PretrainedConfig):

self.sampler = None

self.model_path = model_path
self.context_encoding_model = ModelWrapper(
config=self.config,
model_cls=self._model_cls,
model=model,
tag=CONTEXT_ENCODING_MODEL_TAG,
max_input_tokens=self.config.max_context_length,
max_total_tokens=self.config.max_context_length,
)
self.token_generation_model = ModelWrapper(
config=self.config,
model_cls=self._model_cls,
model=model,
tag=TOKEN_GENERATION_MODEL_TAG,
max_input_tokens=1,
max_total_tokens=self.config.max_length,
)

def can_generate(self):
Expand Down Expand Up @@ -437,7 +433,10 @@ def export(cls, model_path: Union[str, Path], config: NeuronInferenceConfig, ser
def get_traced_model_path(base_path: Union[str, Path]):
return os.path.join(base_path, "model.pt")

def load(self, serialize_base_path):
@classmethod
def load(cls, serialize_base_path):

config = cls._config_cls.from_pretrained(serialize_base_path)

traced_model = torch.jit.load(NeuronBaseForCausalLM.get_traced_model_path(serialize_base_path))

Expand All @@ -454,8 +453,7 @@ def get_rank(shard):

traced_model.nxd_model.initialize(weights)

self.context_encoding_model.model = traced_model
self.token_generation_model.model = traced_model
return cls(config, traced_model)

@property
def device(self) -> torch.device:
Expand Down
17 changes: 2 additions & 15 deletions examples/nxd/modules/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,11 @@


class ModelWrapper(torch.nn.Module):
def __init__(self, config, model_cls, tag="", max_input_tokens: int = 128, max_total_tokens: int = 128) -> None:
def __init__(self, config, model, tag="") -> None:
super().__init__()
self.config = config

if not self.config.torch_dtype:
self.config.torch_dtype = torch.float32

if self.config.pad_token_id is None:
self.config.pad_token_id = 0

self.model_cls = model_cls
self.model = None
self.is_compiled = False
self.serialize_base_path = None
self.model = model
self.tag = tag
self.enable_bucketing = config.enable_bucketing
self.max_input_tokens = max_input_tokens
self.max_total_tokens = max_total_tokens

def _forward_with_pad(self, *args):
seq_ids = args[3]
Expand Down
15 changes: 9 additions & 6 deletions examples/nxd/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ def __init__(self, model_path: str = None, generation_config: GenerationConfig =
self.generation_config = generation_config

def load_neuron_model(self, traced_model_path):
# Implement per model
raise NotImplementedError
model = self.get_model_cls().load(traced_model_path)

if model.config.torch_dtype == torch.bfloat16:
model.context_encoding_model.bfloat16()
model.token_generation_model.bfloat16()

return model

def get_config_cls(self):
# Implement per model
Expand Down Expand Up @@ -109,7 +114,7 @@ def get_config_for_nxd(

config.trace_tokengen_model = kwargs.get("trace_tokengen_model", True)

config.pad_token_id = kwargs.get("pad_token_id", None)
config.pad_token_id = kwargs.get("pad_token_id", config.eos_token_id)

return config

Expand Down Expand Up @@ -192,6 +197,4 @@ def trace(
# We have the config in the trace_model_path
config.save_pretrained(traced_model_path)

model = self.get_model_cls().from_pretrained(self.model_path, config)

model.compile(serialize_base_path=traced_model_path)
self.get_model_cls().export(self.model_path, config, serialize_base_path=traced_model_path)

0 comments on commit 2ba3432

Please sign in to comment.