Skip to content

Commit

Permalink
Final cleanup FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 8, 2024
1 parent 353a38d commit 39c8372
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 28 deletions.
8 changes: 1 addition & 7 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import contextlib
import inspect
import os
import re
import shutil
Expand Down Expand Up @@ -54,7 +53,6 @@
AutocastBackend,
ModelParallelismPlugin,
NeuronDistributedType,
NeuronFullyShardedDataParallelPlugin,
get_tied_parameters_dict,
patch_accelerate_is_tpu_available,
tie_parameters,
Expand Down Expand Up @@ -420,11 +418,7 @@ def prepare_model(
model.config.output_attentions = False
model.config.output_hidden_states = False

if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.prepare_model_for_xla_fsdp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
return self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
Expand Down
6 changes: 3 additions & 3 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
parse_choice_from_env,
parse_flag_from_env,
)
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin, SageMakerDistributedType
from accelerate.utils.dataclasses import SageMakerDistributedType

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
Expand All @@ -41,7 +41,7 @@
set_common_neuron_cc_flags,
set_neuron_cc_flags_for_torch_amp,
)
from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin
from .utils import NeuronDistributedType
from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin


Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__(self, cpu: bool = False, **kwargs):
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)

def wait_for_everyone(self):
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
xm.rendezvous("accelerate.utils.wait_for_everyone")
else:
super().wait_for_everyone()
Expand Down
7 changes: 1 addition & 6 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,18 @@
"""Custom dataclasses for Neuron."""

import enum
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

import torch
from accelerate.utils.constants import MODEL_NAME, OPTIMIZER_NAME
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin

from ...distributed import ParallelizersManager
from ...utils import is_torch_xla_available


if is_torch_xla_available():
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp.state_dict_utils import consolidate_sharded_model_checkpoints
pass

if TYPE_CHECKING:
from transformers import PreTrainedModel
Expand Down
5 changes: 1 addition & 4 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Defines Trainer subclasses to perform training on AWS Neuron instances."""

import copy
import glob
import math
import os
import random
Expand All @@ -27,7 +26,6 @@

import numpy as np
import torch
import transformers
from accelerate import __version__ as accelerate_version
from accelerate.utils import AutocastKwargs
from packaging import version
Expand Down Expand Up @@ -65,7 +63,7 @@
from transformers.training_args import ParallelMode
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_sagemaker_mp_enabled

from ..utils import check_if_transformers_greater, logging
from ..utils import logging
from .accelerate import NeuronAccelerator, NeuronDistributedType
from .distributed import Parallelizer, ParallelizersManager
from .distributed.utils import make_optimizer_constructor_lazy
Expand All @@ -87,7 +85,6 @@
from .utils.patching import patch_everywhere
from .utils.require_utils import requires_neuronx_distributed, requires_torch_neuronx
from .utils.training_utils import (
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
get_model_param_count,
is_main_worker_for_metrics,
is_main_worker_for_metrics_method,
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __post_init__(self):
patch_accelerate_is_tpu_available()

if self.fsdp != "":
raise RuntimeError("FSDP is not supported yet.")
raise RuntimeError("FSDP is not supported.")

if self.fp16:
raise ValueError("The fp16 data type is not supported in Neuron, please use bf16 instead.")
Expand Down
9 changes: 2 additions & 7 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import torch
import transformers
from accelerate import skip_first_batches as accelerate_skip_first_batches
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, Dataset, IterableDataset
from transformers import GenerationMixin
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
Expand All @@ -47,7 +45,7 @@

from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin
from . import is_neuronx_distributed_available, is_torch_xla_available
from . import is_neuronx_distributed_available
from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla


Expand All @@ -59,10 +57,6 @@
from transformers import PreTrainedModel


TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP = "4.30.0.dev0"
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE = "4.30.0.dev0"


def _generate_supported_model_class_names(
model_type: str,
supported_tasks: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -129,6 +123,7 @@ def _generate_supported_model_class_names(
def is_precompilation() -> bool:
return os.environ.get("NEURON_PARALLEL_COMPILE") == "1"


def is_model_officially_supported(model: "PreTrainedModel") -> bool:
class_name = model.__class__.__name__
return class_name in _SUPPORTED_MODEL_NAMES
Expand Down

0 comments on commit 39c8372

Please sign in to comment.