Skip to content

Commit

Permalink
refactor(nxd): replace compile by export class method
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 22, 2024
1 parent e4d9abe commit d3e7682
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
25 changes: 10 additions & 15 deletions examples/nxd/modules/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from modules.autobucketing import slice_lhs, slice_rhs # noqa: E402
from modules.checkpoint import load_state_dict
from modules.config import NeuronInferenceConfig
from modules.gqa import ( # noqa: E402
determine_sharding_strategy, # noqa: E402
get_shardable_head_counts, # noqa: E402
Expand Down Expand Up @@ -382,22 +383,24 @@ def can_generate(self):
# Not needed after transformers 4.50
return True

def get_compiler_args(self):
@staticmethod
def get_compiler_args():
return "--enable-saturate-infinity --auto-cast=none --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2' -O1 "

@classmethod
def from_pretrained(cls, model_path: str, config: PretrainedConfig):
return cls(model_path, config)

def compile(self, serialize_base_path=None):
@classmethod
def export(cls, model_path: Union[str, Path], config: NeuronInferenceConfig, serialize_base_path=None):

base_compile_work_dir = os.environ.get("BASE_COMPILE_WORK_DIR", "/tmp/nxd_model/")

checkpoint_loader = CheckPointLoader(self.model_path, self._STATE_DICT_MODEL_PREFIX, self.config.torch_dtype)
checkpoint_loader = CheckPointLoader(model_path, cls._STATE_DICT_MODEL_PREFIX, config.torch_dtype)

builder = ModelBuilder(
router=None,
tp_degree=self.config.tp_degree,
tp_degree=config.tp_degree,
checkpoint_loader=checkpoint_loader.load_checkpoint,
compiler_workdir=base_compile_work_dir,
)
Expand All @@ -410,16 +413,8 @@ def compile(self, serialize_base_path=None):
# For LLM models, we typically use different sets of SPMDBucketModel for encoding and
# token generation, each with its own list of buckets.
exporters = [
ContextEncodingModelExporter(
self._model_cls,
self.config,
buckets=self.context_encoding_model.buckets,
),
TokenGenerationModelExporter(
self._model_cls,
self.config,
buckets=self.token_generation_model.buckets,
),
ContextEncodingModelExporter(cls._model_cls, config),
TokenGenerationModelExporter(cls._model_cls, config),
]
for exporter in exporters:
# We need a pickable object to provide the callbacks required by the Builder
Expand All @@ -428,7 +423,7 @@ def compile(self, serialize_base_path=None):
model_instance=exporter.get_model_instance(),
example_inputs=exporter.input_generator(),
bucket_config=exporter.bucket_config(),
compiler_args=self.get_compiler_args(),
compiler_args=cls.get_compiler_args(),
priority_model_idx=None,
)

Expand Down
36 changes: 20 additions & 16 deletions examples/nxd/modules/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ def __init__(self, config, model_cls, tag="", max_input_tokens: int = 128, max_t
self.max_input_tokens = max_input_tokens
self.max_total_tokens = max_total_tokens

@property
def buckets(self):
if self.enable_bucketing:
return generate_buckets(128, self.max_total_tokens)
return [self.max_total_tokens]

def _forward_with_pad(self, *args):
seq_ids = args[3]

Expand Down Expand Up @@ -195,8 +189,8 @@ class ModelExporter(ABC):
tag: str
model_cls: type
config: Any
max_input_tokens: int = 1
buckets: Tuple[int] = ()
buckets: Tuple[int]
max_input_tokens: int

def input_generator(self):
inputs = []
Expand All @@ -223,19 +217,24 @@ def bucket_config(self):

class ContextEncodingModelExporter(ModelExporter):

def __init__(self, model_cls, config, buckets: Tuple[int]):
def __init__(self, model_cls, config):
if config.enable_bucketing:
buckets = generate_buckets(128, config.max_context_length)
else:
buckets = [config.max_context_length]
print(buckets)
super().__init__(
tag=CONTEXT_ENCODING_MODEL_TAG,
model_cls=model_cls,
config=config,
max_input_tokens=config.max_context_length,
buckets=buckets,
max_input_tokens=config.max_context_length,
)

def bucket_config(self):
bucket_degree = len(self.buckets)
if bucket_degree == 1:
if not self.config.enable_bucketing:
return None
bucket_degree = len(self.buckets)
return BucketModelConfig(
bucket_kernel=get_context_encoder_bk,
bucket_kernel_constant_args=(
Expand All @@ -250,19 +249,24 @@ def bucket_config(self):

class TokenGenerationModelExporter(ModelExporter):

def __init__(self, model_cls, config, buckets: Tuple[int]):
def __init__(self, model_cls, config):
if config.enable_bucketing:
buckets = generate_buckets(128, config.max_length)
else:
buckets = [config.max_length]
print(buckets)
super().__init__(
tag=TOKEN_GENERATION_MODEL_TAG,
model_cls=model_cls,
config=config,
max_input_tokens=1,
buckets=buckets,
max_input_tokens=1,
)

def bucket_config(self):
bucket_degree = len(self.buckets)
if bucket_degree == 1:
if not self.config.enable_bucketing:
return None
bucket_degree = len(self.buckets)
return BucketModelConfig(
bucket_kernel=get_token_generation_bk,
bucket_kernel_constant_args=(torch.tensor(self.buckets), self.config.padding_side),
Expand Down

0 comments on commit d3e7682

Please sign in to comment.