Skip to content

Commit

Permalink
Remove some code
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 5, 2024
1 parent 326d79b commit 566c7c5
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 123 deletions.
44 changes: 6 additions & 38 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy as np
import torch
import transformers
from accelerate import __version__ as accelerate_version
from accelerate.utils import AutocastKwargs
from packaging import version
Expand Down Expand Up @@ -157,12 +158,9 @@ def __init__(self, *args, **kwargs):
if is_precompilation():
self.prepare_args_for_precompilation(training_args)

if check_if_transformers_greater(TRANSFORMERS_MIN_VERSION_USE_ACCELERATE):
import transformers

transformers.trainer.Accelerator = NeuronAccelerator

super().__init__(*args, **kwargs)
# TODO: is it needed?
with Patcher([("transformers.trainer.Accelerator", NeuronAccelerator)]):
super().__init__(*args, **kwargs)

# We need to change which process can be seen as "world process zero" to make sure the proper metrics
# (eg.g loss) are logged and sent to the callbacks (for instance WandbCallback).
Expand All @@ -171,13 +169,6 @@ def __init__(self, *args, **kwargs):
is_world_process_zero=is_main_worker_for_metrics(),
)

# That's the case for Transformers < 4.30.0
if not hasattr(self, "is_fsdp_enabled"):
self.is_fsdp_enabled = False

if self.is_fsdp_enabled and self.args.do_eval:
raise ValueError("Evaluation is not supported with XLA FSDP yet.")

if self.args.local_rank <= 0:
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -240,9 +231,6 @@ def prepare_args_for_precompilation(self, args: "TrainingArguments"):
logger.info("Disabling prediction during precompilation as this is not well supported yet.")
args.do_predict = False

def validate_args(self, args: "TrainingArguments"):
pass

def create_accelerator_and_postprocess(self):
# create accelerator object
self.accelerator = NeuronAccelerator(
Expand Down Expand Up @@ -307,7 +295,7 @@ def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args)
lazy_load = args.mp_plugin.should_parallelize or args.zero_1
if check_if_transformers_greater("4.30.0") and lazy_load:
if lazy_load:
optimizer_cls = make_optimizer_constructor_lazy(optimizer_cls)
return optimizer_cls, optimizer_kwargs

Expand Down Expand Up @@ -615,34 +603,14 @@ def _save_checkpoint(self, model, trial, metrics=None):

def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
# It has been handled during model parallelization.
# TODO: how to handle pp?
if self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
return
super()._load_from_checkpoint(resume_from_checkpoint, model=model)

def _load_optimizer_and_scheduler_for_xla_fsdp(self, checkpoint):
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.accelerator.state.fsdp_plugin.load_optimizer(self.accelerator, self.optimizer, self.model, checkpoint)

with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
self.lr_scheduler.load_state_dict(lr_scheduler_state)

# TODO: load grad scaling?

def _load_optimizer_and_scheduler(self, checkpoint):
if checkpoint is None:
return
if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP:
return self._load_optimizer_and_scheduler_for_xla_fsdp(checkpoint)
elif self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
self.lr_scheduler.load_state_dict(lr_scheduler_state)
Expand Down
3 changes: 1 addition & 2 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def __post_init__(self):
patch_accelerate_is_tpu_available()

if self.fsdp != "":
# Disabling FSDP until next release because it is still very experimental and not validated.
raise RuntimeError("FSDP is not supported yet.")

if self.fp16:
Expand Down Expand Up @@ -178,9 +177,9 @@ def __post_init__(self):
with Patcher([("transformers.training_args.get_xla_device_type", lambda _: "GPU")]):
super().__post_init__()

# TODO: try to use the patcher for NeuronPartialState instead of rewriting the method.
@cached_property
def _setup_devices(self) -> "torch.device":

requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
NeuronAcceleratorState._reset_state()
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/utils/argument_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def store_compilation_config(
inline_weights_to_neff: bool,
optlevel: str,
model_type: Optional[str] = None,
task: str = None,
task: Optional[str] = None,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
output_attentions: bool = False,
Expand Down
84 changes: 2 additions & 82 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,8 @@ def _generate_supported_model_class_names(
def is_precompilation() -> bool:
return os.environ.get("NEURON_PARALLEL_COMPILE") == "1"


def is_model_officially_supported(model: "PreTrainedModel") -> bool:
# In theory the type annotation is not correct since we can have also a XlaFullyShardedDataParallel
# but let's ignore it here.
if not is_torch_xla_available():
raise RuntimeError(
"is_model_officially_supported requires torch_xla to run, please install it by running: "
"pip install torch_xla"
)
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel

if isinstance(model, XlaFullyShardedDataParallel):
class_name = model.module.__class__.__name__
else:
class_name = model.__class__.__name__
class_name = model.__class__.__name__
return class_name in _SUPPORTED_MODEL_NAMES


Expand All @@ -156,74 +143,6 @@ def is_topology_supported() -> bool:
return num_devices in allowed_number_of_devices or num_devices % 32 == 0


class FirstAndLastDataset(Dataset):
def __init__(
self, dataloader: DataLoader, num_repeat: int = 10, gradient_accumulation_steps: int = 1, world_size: int = 1
):
self.dataloader = dataloader
self.num_repeat = num_repeat * gradient_accumulation_steps * world_size
self.samples = self.create_samples()

def _create_samples_for_map_style_dataset(self):
samples = []
num_samples = len(self.dataloader.dataset)
batch_size = self.dataloader.batch_size
if batch_size is None and self.dataloader.batch_sampler is not None:
batch_size = self.dataloader.batch_sampler.batch_size

# TODO: validate that.
if batch_size is None:
samples = [self.dataloader.dataset[0]] * self.num_repeat + [self.dataloader.dataset[-1]] * self.num_repeat
return samples

num_batches = num_samples // batch_size
remaining = num_samples % batch_size

iterator = iter(self.dataloader)
first_batch = next(iterator)
samples = [first_batch] * self.num_repeat

if num_batches >= 1 and remaining != 0:

def map_fn(example):
if isinstance(example, torch.Tensor):
return example[:remaining]
else:
return example

last_batch = tree_map(map_fn, first_batch)
samples += [last_batch] * self.num_repeat

return samples

def _create_samples_for_iterable_dataset(self):
# Will not work if the iterable dataset yields dynamic batch sizes.
iterator = iter(self.dataloader)
first_batch = next(iterator)
samples = [first_batch] * self.num_repeat
last_batch = None
while True:
try:
last_batch = next(iterator)
except StopIteration:
if last_batch is not None:
samples += [last_batch] * self.num_repeat
break
return samples

def create_samples(self):
if isinstance(self.dataloader.dataset, IterableDataset):
return self._create_samples_for_iterable_dataset()
else:
return self._create_samples_for_map_style_dataset()

def __getitem__(self, idx: int):
return self.samples[idx]

def __len__(self):
return len(self.samples)


def patch_generation_mixin_to_neuron_generation_mixin(model: "PreTrainedModel"):
"""
Changes the vanilla `GenerationMixin` class from Transformers to `NeuronGenerationMixin` in the model's
Expand All @@ -250,6 +169,7 @@ def patch_generation_mixin_to_neuron_generation_mixin(model: "PreTrainedModel"):
cls.__bases__ = tuple(new_bases)


# TODO: to refactor with `patch_generation_mixin_to_neuron_generation_mixin"
def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrainedModel"):
"""
Changes the vanilla `GenerationMixin` class from Transformers to `GeneralNeuronGenerationMixin` in the model's
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/utils/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_torch_version: Optional[str] = None


# TODO: how does it compare to the similar function in `cache_utils.py`?
def get_neuronxcc_version() -> str:
global _neuronxcc_version
if _neuronxcc_version is not None:
Expand Down

0 comments on commit 566c7c5

Please sign in to comment.