Skip to content

Commit

Permalink
add train/eval/export on cpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Nov 6, 2024
1 parent 1f9b401 commit 8409415
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 60 deletions.
10 changes: 7 additions & 3 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torch.utils.tensorboard import SummaryWriter
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
get_default_sharders,
)

# NOQA
Expand Down Expand Up @@ -80,7 +79,7 @@
from tzrec.protos.train_pb2 import TrainConfig
from tzrec.utils import checkpoint_util, config_util
from tzrec.utils.logging_util import ProgressLogger, logger
from tzrec.utils.plan_util import create_planner
from tzrec.utils.plan_util import create_planner, get_default_sharders
from tzrec.version import __version__ as tzrec_version


Expand Down Expand Up @@ -547,13 +546,18 @@ def train_and_evaluate(
# pyre-ignore [16]
batch_size=train_dataloader.dataset.sampled_batch_size,
)

plan = planner.collective_plan(
model, get_default_sharders(), dist.GroupMember.WORLD
)
if is_rank_zero:
logger.info(str(plan))

model = DistributedModelParallel(module=model, device=device, plan=plan)
model = DistributedModelParallel(
module=model,
device=device,
plan=plan,
)

dense_optim_cls, dense_optim_kwargs = optimizer_builder.create_dense_optimizer(
pipeline_config.train_config.dense_optimizer
Expand Down
174 changes: 117 additions & 57 deletions tzrec/utils/plan_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import psutil
import torch
from torch import distributed as dist
from torch import nn
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.distributed.planner.proposers import UniformProposer
Expand All @@ -28,6 +29,10 @@
ShardingOption,
Topology,
)
from torchrec.distributed.sharding_plan import ModuleSharder
from torchrec.distributed.sharding_plan import (
get_default_sharders as _get_default_sharders,
)


def _bytes_to_float_bin(num_bytes: Union[float, int], bin_size: float) -> float:
Expand All @@ -42,6 +47,7 @@ def create_planner(device: torch.device, batch_size: int) -> EmbeddingShardingPl
topo_kwargs = {}
if torch.cuda.is_available():
topo_kwargs["hbm_cap"] = torch.cuda.get_device_properties(device).total_memory

ddr_cap_per_rank = int(float(psutil.virtual_memory().total) / local_world_size)
topo_kwargs["ddr_cap"] = ddr_cap_per_rank
if "INTRA_NODE_BANDWIDTH" in os.environ:
Expand Down Expand Up @@ -76,13 +82,26 @@ def create_planner(device: torch.device, batch_size: int) -> EmbeddingShardingPl
return planner


def get_default_sharders() -> List[ModuleSharder[nn.Module]]:
"""Get embedding module default sharder."""
if torch.cuda.is_available():
return _get_default_sharders()
else:
# ShardedEmbeddingCollection is not supported yet.
sharders = []
for sharder in _get_default_sharders():
if "EmbeddingCollection" not in str(sharder):
sharders.append(sharder)
return sharders


class DynamicProgrammingProposer(Proposer):
r"""Proposes sharding plans in dynamic programming fashion.
The problem of the Embedding Sharding Plan can be framed as follows: Given
The problem of the Embedding Sharding Plan can be framed as follows: Given
:math:`M` tables and their corresponding :math:`N` Sharding Options, we need to
select one sharding option for each table such that the total performance is
minimized, while keeping the overall HBM constraint :math:`K` in check. This can
minimized, while keeping the overall mem constraint :math:`K` in check. This can
be abstracted into the following mathematical formulation:
Given a matrix :math:`A` of dimensions :math:`(M, N)` and another matrix :math:`B`
Expand All @@ -106,24 +125,29 @@ class DynamicProgrammingProposer(Proposer):
.. math::
dp[i][f(k)] = \min_{j=0}^{N-1} \left( dp[i-1][f(k - A[i][j])] + B[i][j] \right)
Since :math:`K` is the sum allocated across all HBM, simply satisfying that the
total HBM in the plan equals :math:`K` does not guarantee that the allocation will
Since :math:`K` is the sum allocated across all mem, simply satisfying that the
total mem in the plan equals :math:`K` does not guarantee that the allocation will
fit on all cards. Therefore, it is essential to maintain all the states of the last
layer of :math:`dp`. This allows us to propose different plans under varying total
HBM constraints.
mem constraints.
Args:
hbm_bins_per_device (int): hdm bins for dynamic programming precision.
mem_bins_per_device (int): hdm bins for dynamic programming precision.
"""

def __init__(self, hbm_bins_per_device: int = 100) -> None:
def __init__(self, mem_bins_per_device: int = 100) -> None:
self._inited: bool = False
self._hbm_bins_per_device: int = max(hbm_bins_per_device, 1)
self._mem_bins_per_device: int = max(mem_bins_per_device, 1)
self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = (
OrderedDict()
)
self._proposal_indices: List[List[int]] = []
# list of proposals with different total_mem, a proposal is a list of
# indices of sharding_options
self._proposal_list: List[List[int]] = []
self._current_proposal: int = -1
self._plan_by_hbm = True
if not torch.cuda.is_available():
self._plan_by_hbm = False

def load(
self,
Expand All @@ -132,15 +156,21 @@ def load(
) -> None:
"""Load search space."""
self._reset()
for sharding_option in sorted(search_space, key=lambda x: x.total_storage.hbm):
# order the sharding_option by total_storage.hbm from low to high
for sharding_option in sorted(
search_space,
key=lambda x: x.total_storage.hbm
if self._plan_by_hbm
else x.total_storage.ddr,
):
fqn = sharding_option.fqn
if fqn not in self._sharding_options_by_fqn:
self._sharding_options_by_fqn[fqn] = []
self._sharding_options_by_fqn[fqn].append(sharding_option)

def _reset(self) -> None:
self._sharding_options_by_fqn = OrderedDict()
self._proposal_indices = []
self._proposal_list = []
self._current_proposal = -1

def propose(self) -> Optional[List[ShardingOption]]:
Expand All @@ -151,7 +181,7 @@ def propose(self) -> Optional[List[ShardingOption]]:
for sharding_options in self._sharding_options_by_fqn.values()
]
elif self._current_proposal >= 0:
proposal_index = self._proposal_indices[self._current_proposal]
proposal_index = self._proposal_list[self._current_proposal]
return [
self._sharding_options_by_fqn[fqn][index]
for fqn, index in zip(
Expand All @@ -171,60 +201,90 @@ def feedback(
"""Feedback last proposed plan."""
if not self._inited:
self._inited = True
M = len(self._sharding_options_by_fqn)
N = max([len(x) for x in self._sharding_options_by_fqn.values()])
table_count = len(self._sharding_options_by_fqn)
option_count = max([len(x) for x in self._sharding_options_by_fqn.values()])

assert storage_constraint is not None
hbm_total = sum([x.storage.hbm for x in storage_constraint.devices])
K = self._hbm_bins_per_device * len(storage_constraint.devices)
bin_size = float(hbm_total) / K
# are we assuming the table will be evenly sharded on all devices?
mem_total = sum(
[
x.storage.hbm if self._plan_by_hbm else x.storage.ddr
for x in storage_constraint.devices
]
)

dp = [[(float("inf"), float("inf"))] * K for _ in range(M)]
backtrack = [[(-1, -1)] * K for _ in range(M)]
bin_count = self._mem_bins_per_device * len(storage_constraint.devices)
bin_size = float(mem_total) / bin_count

hbm_by_fqn = [[float("inf") for _ in range(N)] for _ in range(M)]
perf_by_fqn = [[float("inf") for _ in range(N)] for _ in range(M)]
for m, sharding_options in enumerate(
dp = [
[(float("inf"), float("inf"))] * bin_count for _ in range(table_count)
] # [table_id][mem_bin][perf, mem]

backtrack = [
[(-1, -1)] * bin_count for _ in range(table_count)
] # [table_id][mem_bin][opt_id, prev_mem_bin]

mem_by_fqn = [
[float("inf") for _ in range(option_count)] for _ in range(table_count)
] # memory constraint lookup table: [table_id][sharding_option_id]
perf_by_fqn = [
[float("inf") for _ in range(option_count)] for _ in range(table_count)
] # performance metrics lookup table: [table_id][sharding_option_id]

# populate mem and perf for each sharding option and table:
# A[table_id][sharding_option_id]
for table_id, sharding_options in enumerate(
self._sharding_options_by_fqn.values()
):
for n, sharding_option in enumerate(sharding_options):
hbm_by_fqn[m][n] = _bytes_to_float_bin(
sharding_option.total_storage.hbm, bin_size
for opt_id, sharding_option in enumerate(sharding_options):
mem_by_fqn[table_id][opt_id] = _bytes_to_float_bin(
sharding_option.total_storage.hbm
if self._plan_by_hbm
else sharding_option.total_storage.ddr,
bin_size,
)
perf_by_fqn[m][n] = sharding_option.total_perf

for j in range(N):
if hbm_by_fqn[0][j] < K:
hbm_i = int(hbm_by_fqn[0][j])
if dp[0][hbm_i][0] > perf_by_fqn[0][j]:
dp[0][hbm_i] = (perf_by_fqn[0][j], hbm_by_fqn[0][j])
backtrack[0][hbm_i] = (j, -1)

for i in range(1, M):
for j in range(N):
for c in range(K):
prev_perf, perv_hbm = dp[i - 1][c]
perf_by_fqn[table_id][opt_id] = sharding_option.total_perf

table_0 = 0
for opt_j in range(option_count):
if mem_by_fqn[0][opt_j] < bin_count:
mem_i = int(mem_by_fqn[0][opt_j])
# options are ordered in increasing order of mem, we only want to
# consider a sharding option that has higher mem and better perf
# (the smaller the better)
if dp[table_0][mem_i][0] > perf_by_fqn[table_0][opt_j]:
dp[table_0][mem_i] = (
perf_by_fqn[table_0][opt_j],
mem_by_fqn[table_0][opt_j],
)
backtrack[table_0][mem_i] = (opt_j, -1)

# dp: table_count x option_count x bin_count
for table_i in range(1, table_count):
for opt_j in range(option_count):
for mem in range(bin_count):
prev_perf, perv_mem = dp[table_i - 1][mem]
if prev_perf < float("inf"):
new_hbm = perv_hbm + hbm_by_fqn[i][j]
if new_hbm < K:
new_hbm_i = int(new_hbm)
new_perf = prev_perf + perf_by_fqn[i][j]
if dp[i][new_hbm_i][0] > new_perf:
dp[i][new_hbm_i] = (new_perf, new_hbm)
backtrack[i][new_hbm_i] = (j, c)

self._proposal_indices = []
for c in range(K - 1, -1, -1):
cur_col_idx, cur_hbm_idx = backtrack[M - 1][c]
if cur_col_idx >= 0:
column_indices = [-1] * M
column_indices[M - 1] = cur_col_idx
for i in range(M - 2, -1, -1):
column_indices[i], cur_hbm_idx = backtrack[i][cur_hbm_idx]
self._proposal_indices.append(column_indices)
if len(self._proposal_indices) > 0:
new_mem = perv_mem + mem_by_fqn[table_i][opt_j]
if new_mem < bin_count:
new_mem_i = int(new_mem)
new_perf = prev_perf + perf_by_fqn[table_i][opt_j]
if dp[table_i][new_mem_i][0] > new_perf:
dp[table_i][new_mem_i] = (new_perf, new_mem)
backtrack[table_i][new_mem_i] = (opt_j, mem)
self._proposal_list = []
# fill in all the proposals, starting from highest mem to lowest mem
for c in range(bin_count - 1, -1, -1):
cur_opt_idx, cur_mem_idx = backtrack[table_count - 1][c]
if cur_opt_idx >= 0:
proposal_indices = [-1] * table_count
proposal_indices[table_count - 1] = cur_opt_idx
for i in range(table_count - 2, -1, -1):
proposal_indices[i], cur_mem_idx = backtrack[i][cur_mem_idx]
self._proposal_list.append(proposal_indices)
if len(self._proposal_list) > 0:
self._current_proposal = 0
else:
self._current_proposal += 1
if self._current_proposal >= len(self._proposal_indices):
if self._current_proposal >= len(self._proposal_list):
self._current_proposal = -1

0 comments on commit 8409415

Please sign in to comment.