diff --git a/tzrec/main.py b/tzrec/main.py index 92b7e06..e267c76 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -27,7 +27,6 @@ from torch.utils.tensorboard import SummaryWriter from torchrec.distributed.model_parallel import ( DistributedModelParallel, - get_default_sharders, ) # NOQA @@ -80,7 +79,7 @@ from tzrec.protos.train_pb2 import TrainConfig from tzrec.utils import checkpoint_util, config_util from tzrec.utils.logging_util import ProgressLogger, logger -from tzrec.utils.plan_util import create_planner +from tzrec.utils.plan_util import create_planner, get_default_sharders from tzrec.version import __version__ as tzrec_version @@ -547,13 +546,18 @@ def train_and_evaluate( # pyre-ignore [16] batch_size=train_dataloader.dataset.sampled_batch_size, ) + plan = planner.collective_plan( model, get_default_sharders(), dist.GroupMember.WORLD ) if is_rank_zero: logger.info(str(plan)) - model = DistributedModelParallel(module=model, device=device, plan=plan) + model = DistributedModelParallel( + module=model, + device=device, + plan=plan, + ) dense_optim_cls, dense_optim_kwargs = optimizer_builder.create_dense_optimizer( pipeline_config.train_config.dense_optimizer diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index 354f6e2..f64684e 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -16,6 +16,7 @@ import psutil import torch from torch import distributed as dist +from torch import nn from torchrec.distributed.comm import get_local_size from torchrec.distributed.planner import EmbeddingShardingPlanner from torchrec.distributed.planner.proposers import UniformProposer @@ -28,6 +29,10 @@ ShardingOption, Topology, ) +from torchrec.distributed.sharding_plan import ModuleSharder +from torchrec.distributed.sharding_plan import ( + get_default_sharders as _get_default_sharders, +) def _bytes_to_float_bin(num_bytes: Union[float, int], bin_size: float) -> float: @@ -42,6 +47,7 @@ def create_planner(device: torch.device, batch_size: int) -> EmbeddingShardingPl topo_kwargs = {} if torch.cuda.is_available(): topo_kwargs["hbm_cap"] = torch.cuda.get_device_properties(device).total_memory + ddr_cap_per_rank = int(float(psutil.virtual_memory().total) / local_world_size) topo_kwargs["ddr_cap"] = ddr_cap_per_rank if "INTRA_NODE_BANDWIDTH" in os.environ: @@ -76,13 +82,26 @@ def create_planner(device: torch.device, batch_size: int) -> EmbeddingShardingPl return planner +def get_default_sharders() -> List[ModuleSharder[nn.Module]]: + """Get embedding module default sharder.""" + if torch.cuda.is_available(): + return _get_default_sharders() + else: + # ShardedEmbeddingCollection is not supported yet. + sharders = [] + for sharder in _get_default_sharders(): + if "EmbeddingCollection" not in str(sharder): + sharders.append(sharder) + return sharders + + class DynamicProgrammingProposer(Proposer): r"""Proposes sharding plans in dynamic programming fashion. - The problem of the Embedding Sharding Plan can be framed as follows: Given + The problem of the Embedding Sharding Plan can be framed as follows: Given :math:`M` tables and their corresponding :math:`N` Sharding Options, we need to select one sharding option for each table such that the total performance is - minimized, while keeping the overall HBM constraint :math:`K` in check. This can + minimized, while keeping the overall mem constraint :math:`K` in check. This can be abstracted into the following mathematical formulation: Given a matrix :math:`A` of dimensions :math:`(M, N)` and another matrix :math:`B` @@ -106,24 +125,29 @@ class DynamicProgrammingProposer(Proposer): .. math:: dp[i][f(k)] = \min_{j=0}^{N-1} \left( dp[i-1][f(k - A[i][j])] + B[i][j] \right) - Since :math:`K` is the sum allocated across all HBM, simply satisfying that the - total HBM in the plan equals :math:`K` does not guarantee that the allocation will + Since :math:`K` is the sum allocated across all mem, simply satisfying that the + total mem in the plan equals :math:`K` does not guarantee that the allocation will fit on all cards. Therefore, it is essential to maintain all the states of the last layer of :math:`dp`. This allows us to propose different plans under varying total - HBM constraints. + mem constraints. Args: - hbm_bins_per_device (int): hdm bins for dynamic programming precision. + mem_bins_per_device (int): hdm bins for dynamic programming precision. """ - def __init__(self, hbm_bins_per_device: int = 100) -> None: + def __init__(self, mem_bins_per_device: int = 100) -> None: self._inited: bool = False - self._hbm_bins_per_device: int = max(hbm_bins_per_device, 1) + self._mem_bins_per_device: int = max(mem_bins_per_device, 1) self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = ( OrderedDict() ) - self._proposal_indices: List[List[int]] = [] + # list of proposals with different total_mem, a proposal is a list of + # indices of sharding_options + self._proposal_list: List[List[int]] = [] self._current_proposal: int = -1 + self._plan_by_hbm = True + if not torch.cuda.is_available(): + self._plan_by_hbm = False def load( self, @@ -132,7 +156,13 @@ def load( ) -> None: """Load search space.""" self._reset() - for sharding_option in sorted(search_space, key=lambda x: x.total_storage.hbm): + # order the sharding_option by total_storage.hbm from low to high + for sharding_option in sorted( + search_space, + key=lambda x: x.total_storage.hbm + if self._plan_by_hbm + else x.total_storage.ddr, + ): fqn = sharding_option.fqn if fqn not in self._sharding_options_by_fqn: self._sharding_options_by_fqn[fqn] = [] @@ -140,7 +170,7 @@ def load( def _reset(self) -> None: self._sharding_options_by_fqn = OrderedDict() - self._proposal_indices = [] + self._proposal_list = [] self._current_proposal = -1 def propose(self) -> Optional[List[ShardingOption]]: @@ -151,7 +181,7 @@ def propose(self) -> Optional[List[ShardingOption]]: for sharding_options in self._sharding_options_by_fqn.values() ] elif self._current_proposal >= 0: - proposal_index = self._proposal_indices[self._current_proposal] + proposal_index = self._proposal_list[self._current_proposal] return [ self._sharding_options_by_fqn[fqn][index] for fqn, index in zip( @@ -171,60 +201,90 @@ def feedback( """Feedback last proposed plan.""" if not self._inited: self._inited = True - M = len(self._sharding_options_by_fqn) - N = max([len(x) for x in self._sharding_options_by_fqn.values()]) + table_count = len(self._sharding_options_by_fqn) + option_count = max([len(x) for x in self._sharding_options_by_fqn.values()]) assert storage_constraint is not None - hbm_total = sum([x.storage.hbm for x in storage_constraint.devices]) - K = self._hbm_bins_per_device * len(storage_constraint.devices) - bin_size = float(hbm_total) / K + # are we assuming the table will be evenly sharded on all devices? + mem_total = sum( + [ + x.storage.hbm if self._plan_by_hbm else x.storage.ddr + for x in storage_constraint.devices + ] + ) - dp = [[(float("inf"), float("inf"))] * K for _ in range(M)] - backtrack = [[(-1, -1)] * K for _ in range(M)] + bin_count = self._mem_bins_per_device * len(storage_constraint.devices) + bin_size = float(mem_total) / bin_count - hbm_by_fqn = [[float("inf") for _ in range(N)] for _ in range(M)] - perf_by_fqn = [[float("inf") for _ in range(N)] for _ in range(M)] - for m, sharding_options in enumerate( + dp = [ + [(float("inf"), float("inf"))] * bin_count for _ in range(table_count) + ] # [table_id][mem_bin][perf, mem] + + backtrack = [ + [(-1, -1)] * bin_count for _ in range(table_count) + ] # [table_id][mem_bin][opt_id, prev_mem_bin] + + mem_by_fqn = [ + [float("inf") for _ in range(option_count)] for _ in range(table_count) + ] # memory constraint lookup table: [table_id][sharding_option_id] + perf_by_fqn = [ + [float("inf") for _ in range(option_count)] for _ in range(table_count) + ] # performance metrics lookup table: [table_id][sharding_option_id] + + # populate mem and perf for each sharding option and table: + # A[table_id][sharding_option_id] + for table_id, sharding_options in enumerate( self._sharding_options_by_fqn.values() ): - for n, sharding_option in enumerate(sharding_options): - hbm_by_fqn[m][n] = _bytes_to_float_bin( - sharding_option.total_storage.hbm, bin_size + for opt_id, sharding_option in enumerate(sharding_options): + mem_by_fqn[table_id][opt_id] = _bytes_to_float_bin( + sharding_option.total_storage.hbm + if self._plan_by_hbm + else sharding_option.total_storage.ddr, + bin_size, ) - perf_by_fqn[m][n] = sharding_option.total_perf - - for j in range(N): - if hbm_by_fqn[0][j] < K: - hbm_i = int(hbm_by_fqn[0][j]) - if dp[0][hbm_i][0] > perf_by_fqn[0][j]: - dp[0][hbm_i] = (perf_by_fqn[0][j], hbm_by_fqn[0][j]) - backtrack[0][hbm_i] = (j, -1) - - for i in range(1, M): - for j in range(N): - for c in range(K): - prev_perf, perv_hbm = dp[i - 1][c] + perf_by_fqn[table_id][opt_id] = sharding_option.total_perf + + table_0 = 0 + for opt_j in range(option_count): + if mem_by_fqn[0][opt_j] < bin_count: + mem_i = int(mem_by_fqn[0][opt_j]) + # options are ordered in increasing order of mem, we only want to + # consider a sharding option that has higher mem and better perf + # (the smaller the better) + if dp[table_0][mem_i][0] > perf_by_fqn[table_0][opt_j]: + dp[table_0][mem_i] = ( + perf_by_fqn[table_0][opt_j], + mem_by_fqn[table_0][opt_j], + ) + backtrack[table_0][mem_i] = (opt_j, -1) + + # dp: table_count x option_count x bin_count + for table_i in range(1, table_count): + for opt_j in range(option_count): + for mem in range(bin_count): + prev_perf, perv_mem = dp[table_i - 1][mem] if prev_perf < float("inf"): - new_hbm = perv_hbm + hbm_by_fqn[i][j] - if new_hbm < K: - new_hbm_i = int(new_hbm) - new_perf = prev_perf + perf_by_fqn[i][j] - if dp[i][new_hbm_i][0] > new_perf: - dp[i][new_hbm_i] = (new_perf, new_hbm) - backtrack[i][new_hbm_i] = (j, c) - - self._proposal_indices = [] - for c in range(K - 1, -1, -1): - cur_col_idx, cur_hbm_idx = backtrack[M - 1][c] - if cur_col_idx >= 0: - column_indices = [-1] * M - column_indices[M - 1] = cur_col_idx - for i in range(M - 2, -1, -1): - column_indices[i], cur_hbm_idx = backtrack[i][cur_hbm_idx] - self._proposal_indices.append(column_indices) - if len(self._proposal_indices) > 0: + new_mem = perv_mem + mem_by_fqn[table_i][opt_j] + if new_mem < bin_count: + new_mem_i = int(new_mem) + new_perf = prev_perf + perf_by_fqn[table_i][opt_j] + if dp[table_i][new_mem_i][0] > new_perf: + dp[table_i][new_mem_i] = (new_perf, new_mem) + backtrack[table_i][new_mem_i] = (opt_j, mem) + self._proposal_list = [] + # fill in all the proposals, starting from highest mem to lowest mem + for c in range(bin_count - 1, -1, -1): + cur_opt_idx, cur_mem_idx = backtrack[table_count - 1][c] + if cur_opt_idx >= 0: + proposal_indices = [-1] * table_count + proposal_indices[table_count - 1] = cur_opt_idx + for i in range(table_count - 2, -1, -1): + proposal_indices[i], cur_mem_idx = backtrack[i][cur_mem_idx] + self._proposal_list.append(proposal_indices) + if len(self._proposal_list) > 0: self._current_proposal = 0 else: self._current_proposal += 1 - if self._current_proposal >= len(self._proposal_indices): + if self._current_proposal >= len(self._proposal_list): self._current_proposal = -1