Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes for TP #260

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
xm = None

if is_neuronx_distributed_available():
from neuronx_distributed import parallel_layers
from neuronx_distributed.utils.model_utils import move_model_to_device


Expand Down Expand Up @@ -143,15 +142,26 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, z
if num_steps != 1:
self.gradient_accumulation_steps = num_steps

def _prepare_data_loader_for_tp(self, data_loader: DataLoader) -> DataLoader:
def _prepare_data_loader_for_distributed(
self, data_loader: DataLoader, num_replicas: int, rank: int
) -> DataLoader:
# TODO: make it more robust, similar to the prepare_data_loader function in `accelerate`.
if isinstance(data_loader.sampler, DistributedSampler):
return data_loader
sampler = DistributedSampler(
data_loader.dataset,
num_replicas=parallel_layers.parallel_state.get_data_parallel_size(),
rank=parallel_layers.parallel_state.get_data_parallel_rank(),
)

orig_sampler = data_loader.sampler
if isinstance(orig_sampler, torch.utils.data.SequentialSampler):
shuffle = False
else:
shuffle = True
if not isinstance(orig_sampler, torch.utils.data.RandomSampler):
logger.warning(
f"The sampler {orig_sampler} is going to be replaced by a torch.utils.data.DistributedSampler. This "
"new sampler will shuffle the dataset, it might not be the expected behaviour."
)

sampler = DistributedSampler(data_loader.dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)

data_loader_for_tp = DataLoader(
data_loader.dataset,
batch_size=data_loader.batch_size,
Expand All @@ -166,8 +176,15 @@ def _prepare_data_loader_for_tp(self, data_loader: DataLoader) -> DataLoader:

def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optional[bool] = None):
if self.state.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
data_loader = self._prepare_data_loader_for_tp(data_loader)
from neuronx_distributed import parallel_layers

num_replicas = parallel_layers.parallel_state.get_data_parallel_size()
rank = parallel_layers.parallel_state.get_data_parallel_rank()
else:
num_replicas = xm.xrt_world_size()
rank = xm.get_ordinal()
if self.state.num_processes > 1:
data_loader = self._prepare_data_loader_for_distributed(data_loader, num_replicas=num_replicas, rank=rank)
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader
# TODO: fix that.
Expand Down Expand Up @@ -204,7 +221,7 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
model_parallel_is_initialized,
)

if not is_neuronx_distributed_available() or not model_parallel_is_initialized():
if not model_parallel_is_initialized():
sharding_groups = None
grad_norm_groups = None
else:
Expand Down Expand Up @@ -329,7 +346,7 @@ def _prepare_model_for_tp(
cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, device=None)
if os.environ.get("XLA_USE_BF16", "0") == "1":
if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
model.to(torch.float32)
Expand Down
27 changes: 25 additions & 2 deletions optimum/neuron/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,28 @@
# limitations under the License.
"""Custom operations related to accelerate for Neuron."""


import torch
from accelerate.utils.operations import recursively_apply

from ...utils import is_neuronx_distributed_available
from ...utils.require_utils import requires_torch_xla


@requires_torch_xla
def _xla_gather(tensor, out_of_graph: bool = False):
import torch_xla.core.xla_model as xm

groups = None
if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
model_parallel_is_initialized,
)

if model_parallel_is_initialized():
groups = get_data_parallel_group(as_list=True)

def _xla_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]
Expand All @@ -32,9 +44,20 @@ def _xla_gather_one(tensor):
tensor = tensor.contiguous()

if out_of_graph:
gathered = xm.mesh_reduce("nested_xla_gather", tensor, torch.cat)
gathered_tensors = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x)
if groups is not None:
new_gathered_tensors = []
# Since groups is containing list of group of replicas, we consider that visiting the first group of
# replicas is enough since the value should be the same accross other axes.
replicas_to_consider = set(groups[0])
for idx, tensor in enumerate(gathered_tensors):
if idx not in replicas_to_consider:
continue
new_gathered_tensors.append(tensor)
gathered_tensors = new_gathered_tensors
gathered = torch.cat(gathered_tensors)
else:
gathered = xm.all_gather(tensor)
gathered = xm.all_gather(tensor, groups=groups, pin_layout=False)
return gathered

res = recursively_apply(_xla_gather_one, tensor, error_on_other_type=True)
Expand Down
15 changes: 10 additions & 5 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ def save_model_checkpoint_as_sharded(
optimizer: Optional["torch.optim.Optimizer"] = None,
):
cls._check_model_was_parallelized(model)

from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

data_parallel_rank = get_data_parallel_rank()
tensor_parallel_rank = get_tensor_model_parallel_rank()

if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

Expand All @@ -474,12 +483,8 @@ def save_model_checkpoint_as_sharded(
state_dict["optimizer_state_dict"] = optimizer.state_dict()

output_path = output_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

if get_data_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0:
if data_parallel_rank == 0 and tensor_parallel_rank == 0:
if output_path.is_dir():
shutil.rmtree(output_path, ignore_errors=True)
output_path.mkdir()
Expand Down
3 changes: 2 additions & 1 deletion optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,8 @@ def greedy_search(
else:
next_token_logits = outputs.logits[:, -1, :]

xm.mark_step()

# pre-process distribution
# Move to cpu to handle arbitrary logits_processor
next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
Expand Down Expand Up @@ -1302,7 +1304,6 @@ def greedy_search(
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step

batch_size, _ = input_ids.shape
update_indices = torch.stack(
[torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
Expand Down
81 changes: 79 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@
TRAINER_STATE_NAME,
TRAINING_ARGS_NAME,
)
from transformers.trainer_pt_utils import reissue_pt_warnings
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput
from transformers.trainer_pt_utils import (
reissue_pt_warnings,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
)
from transformers.utils import is_sagemaker_mp_enabled

from ..utils import check_if_transformers_greater, logging
Expand All @@ -55,6 +60,7 @@
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
get_model_param_count,
is_precompilation,
is_topology_supported,
patch_generation_mixin_to_neuron_generation_mixin,
patched_finfo,
prepare_environment_for_neuron,
Expand Down Expand Up @@ -130,6 +136,12 @@ def __init__(self, *args, **kwargs):
if not isinstance(self, Trainer):
raise TypeError(f"{self.__class__.__name__} can only be mixed with Trainer subclasses.")

if not is_topology_supported():
num_devices = xm.xrt_world_size()
raise ValueError(
f"Topology not supported. Supported number of devices: 1, 2, 8 or a multiple of 32. Got: {num_devices}."
)

training_args = kwargs.get("args", None)
if training_args is None and len(args) >= 2:
training_args = args[1]
Expand Down Expand Up @@ -255,6 +267,9 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return super()._get_train_sampler()

def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[torch.utils.data.Sampler]:
return torch.utils.data.SequentialSampler(eval_dataset)

@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args)
Expand Down Expand Up @@ -295,6 +310,68 @@ def _inner_training_loop(
ignore_keys_for_eval=ignore_keys_for_eval,
)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}

xm.mark_step()

if self.args.tp_plugin.tensor_parallel_size > 1:
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
get_data_parallel_size,
)

dp_size = get_data_parallel_size()
tr_loss_div = tr_loss / dp_size
tr_loss_scalar = xm.all_reduce(
xm.REDUCE_SUM,
tr_loss_div,
groups=get_data_parallel_group(as_list=True),
)
tr_loss_scalar = tr_loss_scalar.detach().item()
else:
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)

metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _save_checkpoint_with_accelerator(self, model, trial, metrics=None):
if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP and not self.is_fsdp_enabled:
# TODO: handle this case better?
Expand Down
15 changes: 13 additions & 2 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@
from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import NeuronGenerationMixin
from . import is_torch_xla_available
from .require_utils import requires_torch_xla


if TYPE_CHECKING:
from transformers import PreTrainedModel

if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl

TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP = "4.30.0.dev0"
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE = "4.30.0.dev0"
Expand Down Expand Up @@ -145,6 +144,15 @@ def is_model_officially_supported(model: "PreTrainedModel") -> bool:
return class_name in _SUPPORTED_MODEL_NAMES


@requires_torch_xla
def is_topology_supported() -> bool:
import torch_xla.core.xla_model as xm

num_devices = xm.xrt_world_size()
allowed_number_of_devices = [1, 2, 8]
return num_devices in allowed_number_of_devices or num_devices % 32 == 0


class FirstAndLastDataset(Dataset):
def __init__(
self, dataloader: DataLoader, num_repeat: int = 10, gradient_accumulation_steps: int = 1, world_size: int = 1
Expand Down Expand Up @@ -270,11 +278,14 @@ def patch_transformers_for_neuron_sdk():
transformers.utils.logging.set_verbosity = set_verbosity


@requires_torch_xla
def skip_first_batches(dataloader, num_batches=0):
"""
Wrapper around `accelerate.data_loader.skip_first_batches` to handle `pl.ParallelLoader` when using
`torch_xla.distributed`, for XLA FSDP for instance.
"""
import torch_xla.distributed.parallel_loader as pl

if isinstance(dataloader, (pl.ParallelLoader, pl.PerDeviceLoader)):
dataloader._loader = skip_first_batches(dataloader._loader, num_batches=num_batches)
else:
Expand Down
Loading