Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 22, 2024
1 parent d3e7682 commit b7ffce7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 26 deletions.
19 changes: 8 additions & 11 deletions examples/nxd/modules/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class NeuronBaseForCausalLM(GenerationMixin):
_supports_cache_class = False
# _supports_static_cache = False

def __init__(self, model_path: str, config: PretrainedConfig):
def __init__(self, config: PretrainedConfig, model: torch.jit.ScriptModule):
super().__init__()

self.config = config
Expand All @@ -363,20 +363,15 @@ def __init__(self, model_path: str, config: PretrainedConfig):

self.sampler = None

self.model_path = model_path
self.context_encoding_model = ModelWrapper(
config=self.config,
model_cls=self._model_cls,
model=model,
tag=CONTEXT_ENCODING_MODEL_TAG,
max_input_tokens=self.config.max_context_length,
max_total_tokens=self.config.max_context_length,
)
self.token_generation_model = ModelWrapper(
config=self.config,
model_cls=self._model_cls,
model=model,
tag=TOKEN_GENERATION_MODEL_TAG,
max_input_tokens=1,
max_total_tokens=self.config.max_length,
)

def can_generate(self):
Expand Down Expand Up @@ -437,7 +432,10 @@ def export(cls, model_path: Union[str, Path], config: NeuronInferenceConfig, ser
def get_traced_model_path(base_path: Union[str, Path]):
return os.path.join(base_path, "model.pt")

def load(self, serialize_base_path):
@classmethod
def load(cls, serialize_base_path):

config = cls._config_cls.from_pretrained(serialize_base_path)

traced_model = torch.jit.load(NeuronBaseForCausalLM.get_traced_model_path(serialize_base_path))

Expand All @@ -454,8 +452,7 @@ def get_rank(shard):

traced_model.nxd_model.initialize(weights)

self.context_encoding_model.model = traced_model
self.token_generation_model.model = traced_model
return cls(config, traced_model)

@property
def device(self) -> torch.device:
Expand Down
17 changes: 2 additions & 15 deletions examples/nxd/modules/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,11 @@


class ModelWrapper(torch.nn.Module):
def __init__(self, config, model_cls, tag="", max_input_tokens: int = 128, max_total_tokens: int = 128) -> None:
def __init__(self, config, model, tag="") -> None:
super().__init__()
self.config = config

if not self.config.torch_dtype:
self.config.torch_dtype = torch.float32

if self.config.pad_token_id is None:
self.config.pad_token_id = 0

self.model_cls = model_cls
self.model = None
self.is_compiled = False
self.serialize_base_path = None
self.model = model
self.tag = tag
self.enable_bucketing = config.enable_bucketing
self.max_input_tokens = max_input_tokens
self.max_total_tokens = max_total_tokens

def _forward_with_pad(self, *args):
seq_ids = args[3]
Expand Down

0 comments on commit b7ffce7

Please sign in to comment.