Skip to content

Commit

Permalink
Add Parallel Cross Entropy (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix authored Sep 18, 2024
1 parent 2179d33 commit bf1befd
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 20 deletions.
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def run(self, *args, **kwargs):
def decompose_and_functionalize(
graph_module: GraphModule,
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(),
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy],
) -> Callable:
"""
API to decompose and functionalize a high-level graph module.
Expand Down
35 changes: 27 additions & 8 deletions optimum/fx/parallelization/op_registry/op_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.fx import Node

from ..core import Config
from ..utils import is_activation, is_embedding, is_linear
from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear


class Registry:
Expand Down Expand Up @@ -334,7 +334,16 @@ def propagate(self) -> List[int]:
ndim = arg.meta["val"].ndim
slice_dim = (slice_dim + ndim) % ndim
if slice_dim == axis:
# slice on the parallel axis is not allowed
# slice on the parallel axis is not allowed, except it's a nop
start, stop, step = 0, arg.meta["val"].shape[axis], 1
if len(self.node.args) > 2:
start = self.node.args[2]
elif len(self.node.args) > 3:
stop = self.node.args[3]
elif len(self.node.args) > 4:
step = self.node.args[4]
if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1:
return [axis]
return []
return [axis]

Expand Down Expand Up @@ -404,12 +413,12 @@ def propagate(self) -> List[int]:
if self.node.op in ["placeholder", "get_attr"]:
return [None]
elif self.node.op == "output":
for node in self.node.all_input_nodes:
# TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather
# parallelized output if intructed
if self.extract_axis(node) is not None:
return []
return [None]
# does not care about if output is being parallelized right now, because if the output is loss,
# then it must be not parallelized as long as it comes from sharded cross entropy.
# TODO: append all-gather comm ops before all parallelized output nodes if instructed.
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
return [axis]
elif is_linear(self.node):
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
Expand Down Expand Up @@ -438,6 +447,16 @@ def propagate(self) -> List[int]:
return [1, None] if self.config.enable_sequence_parallel else [None]
else:
return []
elif is_cross_entropy(self.node):
logits = self.node.all_input_nodes[0]
axis = self.extract_axis(logits)
if axis is None or (
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
):
# for cross entropy, the input logits parallel axis can only be the last axis or None
return [None]
else:
return []
elif is_activation(self.node):
return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()

Expand Down
1 change: 1 addition & 0 deletions optimum/fx/parallelization/parallel_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.
from .embedding import VocabParallelEmbedding
from .linear import ColumnParallelLinear, RowParallelLinear
from .loss import VocabParallelCrossEntropyLoss, sharded_cross_entropy_wrapper_fn
163 changes: 163 additions & 0 deletions optimum/fx/parallelization/parallel_layers/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn

from ..core import ParallelExecutionCtx


# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py
class _ShardedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(
ctx,
sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size)
target: torch.Tensor, # (batch_size, length)
group: dist.ProcessGroup,
):
# Maximum value along last dimension across all GPUs.
logits_max = torch.max(sharded_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)

# Get the shard's indices
sharded_hidden_size = sharded_logits.shape[-1]
rank = dist.get_rank(group)
start_index = rank * sharded_hidden_size
end_index = start_index + sharded_hidden_size

# Create a mask of valid ids (1 means it needs to be masked).
target_mask = (target < start_index) | (target >= end_index)
masked_target = target.clone() - start_index
masked_target[target_mask] = 0

# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, shard-size] and target to a 1-D tensor of size [*].
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
if predicted_logits_1d.is_contiguous():
predicted_logits_1d = predicted_logits_1d.clone()
else:
predicted_logits_1d = predicted_logits_1d.contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)

# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = sharded_logits
torch.exp(sharded_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)

# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits

# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

return loss

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors

# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, sharded_hidden_size)

# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()

# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))

return grad_input, None, None


def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup):
return _ShardedCrossEntropy.apply(sharded_logits, target, process_group)


def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup):
@wraps(sharded_cross_entropy)
def wrapper(
sharded_logits: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
):
if weight is not None or ignore_index != -100 or label_smoothing != 0.0:
raise ValueError(
"Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation."
)
loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group)

if size_average is not None or reduce is not None:
size_average = True if size_average is None else size_average
reduce = True if reduce is None else reduce

if size_average and reduce:
reduction = "mean"
elif reduce:
reduction = "sum"
else:
reduction = "none"

if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
return loss

return wrapper


class VocabParallelCrossEntropyLoss(nn.Module):
"""
Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet.
"""

def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None:
super(VocabParallelCrossEntropyLoss, self).__init__()
self.process_group = ctx.tp_group
self.reduction = reduction

def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor):
loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group)
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
45 changes: 44 additions & 1 deletion optimum/fx/parallelization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
from .decomp import decompose_and_functionalize
from .distributed import scatter
from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler
from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from .parallel_layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelCrossEntropyLoss,
VocabParallelEmbedding,
sharded_cross_entropy_wrapper_fn,
)
from .utils import (
is_cross_entropy,
is_embedding,
is_linear,
is_shape_consumer,
Expand Down Expand Up @@ -273,6 +280,11 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
info["sequence_parallel"] = False
self.place_marker_per_node(node, info)

elif is_cross_entropy(node):
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
if axis_before is not None:
self.place_marker_per_node(node, {"axis": "vocab"})

return graph_module


Expand Down Expand Up @@ -343,6 +355,35 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None:
layer_cache[key] = new_mod
setattr(parent_mod, field, new_mod)

@staticmethod
def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None:
axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis")
if axis is None:
return

assert axis in {"vocab"}, "Only support parallelization on vocab dim for now."
if node.op == "call_module":
graph_module = node.graph.owning_module
prefix_and_field = node.target.rsplit(".", maxsplit=1)
if len(prefix_and_field) == 2:
parent_mod = graph_module.get_submodule(prefix_and_field[0])
field = prefix_and_field[1]
else:
parent_mod = graph_module
field = node.target

mod: nn.CrossEntropyLoss = graph_module.get_submodule(node.target)
key, layer_cache = node.target, ctx.parallel_layer_cache
if key in layer_cache:
new_mod = layer_cache[key]
else:
assert ctx.compile_times == 0, "illegal path for recompilation"
new_mod = VocabParallelCrossEntropyLoss(ctx, reduction=mod.reduction)
layer_cache[key] = new_mod
setattr(parent_mod, field, new_mod)
else:
node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group)

@staticmethod
def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
def extract_shape_from_node(node: Node) -> List[Any]:
Expand Down Expand Up @@ -384,6 +425,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
self.handle_linear(node, ctx)
elif is_embedding(node):
self.handle_embedding(node, ctx)
elif is_cross_entropy(node):
self.handle_cross_entropy(node, ctx)
# correct the attention head num in parallel setting
elif is_shape_consumer(node):
self.handle_hard_coded_axis_param(node, ctx)
Expand Down
34 changes: 34 additions & 0 deletions optimum/fx/parallelization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def is_shape_generator(node: Node) -> bool:
return node.op == "call_method" and node.target == "size"


def is_cross_entropy(node: Node) -> bool:
if node.op == "call_function":
return node.target is F.cross_entropy
elif node.op == "call_module":
mod = node.graph.owning_module
return isinstance(mod.get_submodule(node.target), nn.CrossEntropyLoss)
return False


def is_cross_entropy_parallel_compatible(node: Node) -> bool:
"""
For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
"""
if node.op == "call_function":
weight = node.kwargs.get("weight", None)
ignore_index = node.kwargs.get("ignore_index", -100)
label_smoothing = node.kwargs.get("label_smoothing", 0.0)
if len(node.args) > 2 and weight is None:
weight = node.args[2]
if len(node.args) > 4 and ignore_index == -100:
ignore_index = node.args[4]
if len(node.args) > 7 and label_smoothing == 0.0:
label_smoothing = node.args[7]

return weight is None and ignore_index == -100 and label_smoothing == 0.0

elif node.op == "call_module":
mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target)
weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index
return weight is None and ignore_index == -100 and label_smoothing == 0.0

return False


def stable_topological_sort(graph: Graph):
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
args: List[torch.fx.node.Argument] = []
Expand Down
Loading

0 comments on commit bf1befd

Please sign in to comment.