From 7b13ab6eee6aa7b32d35473bc9b72694f49bd8e2 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 20 Nov 2023 15:09:15 +0100 Subject: [PATCH] Fix --- docs/source/_toctree.yml | 4 ++++ docs/source/guides/distributed_training.mdx | 2 -- docs/source/package_reference/distributed.mdx | 2 +- .../distributed/parallelizers_manager.py | 18 ++++++++++++++++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 431b11138..aec2b334a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -26,6 +26,8 @@ title: Fine-tune Transformers with AWS Trainium - local: guides/export_model title: Export a model to Inferentia + - local: guides/distributed_training + title: Distributed Training - local: guides/models title: Neuron models for inference - local: guides/pipelines @@ -34,6 +36,8 @@ - sections: - local: package_reference/trainer title: Neuron Trainer + - local: package_reference/distributed + title: Neuron Distributed - local: package_reference/export title: Neuron Exporter - local: package_reference/modeling diff --git a/docs/source/guides/distributed_training.mdx b/docs/source/guides/distributed_training.mdx index 14e2bab1d..9d643780c 100644 --- a/docs/source/guides/distributed_training.mdx +++ b/docs/source/guides/distributed_training.mdx @@ -9,8 +9,6 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> - - # Distributed Training with `optimum-neuron` [AWS Trainium instances](https://aws.amazon.com/machine-learning/trainium/) are great to train models. They can contain up to 16 Neuron devices, each device containing 2 Neuron cores and has 32GB of memory (16GB per core). For example a `trn1.32xlarge` instance has 32 x 16 = 512GB of memory, that is a lot of memory to train models! diff --git a/docs/source/package_reference/distributed.mdx b/docs/source/package_reference/distributed.mdx index 85ef819f0..4e870bcf2 100644 --- a/docs/source/package_reference/distributed.mdx +++ b/docs/source/package_reference/distributed.mdx @@ -19,7 +19,7 @@ The `optimum.neuron.distributed` module provides a set of tools to perform distr The [`~optimum.neuron.distributed.Parallelizer`] class is the abstract base class being derived for every model supporting model parallelism. It provides methods to parallelize the model and save and load sharded checkpoints. -[[autodoc]] distributed.base.Parallelizer +[[autodoc]] distributed.Parallelizer - _parallelize - parallelize - optimizer_for_tp diff --git a/optimum/neuron/distributed/parallelizers_manager.py b/optimum/neuron/distributed/parallelizers_manager.py index bf3524ce0..ac2dff7bb 100644 --- a/optimum/neuron/distributed/parallelizers_manager.py +++ b/optimum/neuron/distributed/parallelizers_manager.py @@ -62,6 +62,9 @@ class ParallelizersManager: @classmethod def get_supported_model_types(cls) -> List[str]: + """ + Provides the list of supported model types for parallelization. + """ return list(cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys()) @classmethod @@ -74,11 +77,26 @@ def _get_model_type(cls, model_type_or_model: Union[str, PreTrainedModel]) -> st @classmethod def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> bool: + """ + Returns `True` if the model can be parallelized, `False` otherwise. + + Args: + model_type_or_model (`Union[str, PreTrainedModel]`): + Either the model type or an instance of the model. + """ model_type = cls._get_model_type(model_type_or_model) return model_type in cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS @classmethod def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Type[Parallelizer]: + """ + Returns the parallelizer class associated to the model. + + Args: + model_type_or_model (`Union[str, PreTrainedModel]`): + Either the model type or an instance of the model. + + """ model_type = cls._get_model_type(model_type_or_model) if not cls.is_model_supported(model_type_or_model): supported_models = ", ".join(cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys())