Skip to content

Commit

Permalink
add args to command
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Sep 19, 2024
1 parent f8b5259 commit 9cf51f2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def parse_args_neuronx(parser: "ArgumentParser"):
choices=["bf16", "fp16", "tf32"],
help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.',
)
optional_group.add_argument(
"--tensor_parallel_size",
type=int,
default=1,
help="Tensor parallelism degree, the number of devices on which to shard the model.",
)
optional_group.add_argument(
"--dynamic-batch-size",
action="store_true",
Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def infer_stable_diffusion_shapes_from_diffusers(
def get_submodels_and_neuron_configs(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
tensor_parallel_size: int,
task: str,
output: Path,
library_name: Optional[str] = None,
Expand Down Expand Up @@ -464,6 +465,7 @@ def load_models_and_neuron_configs(
model_name_or_path: str,
output: Path,
model: Optional[Union["PreTrainedModel", "ModelMixin"]],
tensor_parallel_size: int,
task: str,
dynamic_batch_size: bool,
cache_dir: Optional[str],
Expand Down Expand Up @@ -507,6 +509,7 @@ def load_models_and_neuron_configs(
models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
model=model,
input_shapes=input_shapes,
tensor_parallel_size=tensor_parallel_size,
task=task,
library_name=library_name,
output=output,
Expand All @@ -530,6 +533,7 @@ def main_export(
model_name_or_path: str,
output: Union[str, Path],
compiler_kwargs: Dict[str, Any],
tensor_parallel_size: int,
model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None,
task: str = "auto",
dynamic_batch_size: bool = False,
Expand Down Expand Up @@ -567,6 +571,7 @@ def main_export(
model_name_or_path=model_name_or_path,
output=output,
model=model,
tensor_parallel_size=tensor_parallel_size,
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
Expand Down Expand Up @@ -710,6 +715,7 @@ def main():
model_name_or_path=args.model,
output=args.output,
compiler_kwargs=compiler_kwargs,
tensor_parallel_size=args.tensor_parallel_size,
task=task,
dynamic_batch_size=args.dynamic_batch_size,
atol=args.atol,
Expand Down

0 comments on commit 9cf51f2

Please sign in to comment.