Skip to content

Commit

Permalink
add sac part to fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanzhang816 committed Jul 25, 2024
1 parent aa0bfe5 commit f0c88ce
Showing 1 changed file with 108 additions and 22 deletions.
130 changes: 108 additions & 22 deletions fsdp_sac_ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
Follow instructions here: https://ergo-code.github.io/HiGHS/dev/interfaces/cpp/
Some example commands to run:
python fsdp_ilp.py --in_file=GPT_modules_info.json --memory_budget=3
python fsdp_ilp.py --in_file=GPT_modules_info.json --memory_budget=4 --verbose
python fsdp_ilp.py --in_file=GPT_modules_info.json --memory_budget=4 --verbose \
python fsdp_sac_ilp.py --in_file=GPT_modules_info.json --memory_budget=3
python fsdp_sac_ilp.py --in_file=GPT_modules_info.json --memory_budget=4 --verbose
python fsdp_sac_ilp.py --in_file=GPT_modules_info.json --memory_budget=4 --verbose \
--fsdp_units GPT.transformer.h.0 GPT.transformer.h.1 GPT.transformer.h.2 \
GPT.transformer.h.3 GPT.transformer.h.4 GPT.transformer.h.5
"""

import argparse
import json
import logging
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -79,7 +80,7 @@ def fsdp_milp(
MEM_MULTIPLIER = 2**30

# Create a MILP problem
prob = LpProblem("FSDP", LpMinimize)
prob = LpProblem("FSDP_SAC", LpMinimize)

# Create decision variables
x = LpVariable.matrix("x", list(range(num_nodes)), 0, 1, LpInteger)
Expand All @@ -102,6 +103,12 @@ def fsdp_milp(
t4 = LpVariable.matrix("t4", list(range(num_nodes)), 0)
bw_e = LpVariable.matrix("bw_e", list(range(num_nodes)), 0)

y = LpVariable.matrix("y", list(range(num_nodes)), 0, 1, LpInteger)
r = LpVariable.matrix("r", list(range(num_nodes)), 0, 1)
d = LpVariable.matrix("d", list(range(num_nodes)), 0)
rcp = LpVariable.matrix("rcp", list(range(num_nodes)), 0)
rct = LpVariable.matrix("rct", list(range(num_nodes)), 0)

# Add constraints
P_1 = graph.nodes[0]["param_per_module"] / MEM_MULTIPLIER
G_1 = graph.nodes[0]["grad_per_module"] / MEM_MULTIPLIER
Expand All @@ -122,6 +129,23 @@ def fsdp_milp(
if graph.ad_matrix[i][j] == 1:
prob += x[i] + x[j] <= 1

# [Constraint] No nested AC units
for i in range(num_nodes):
for j in range(i + 1, num_nodes):
if graph.ad_matrix[i][j] == 1:
prob += y[i] + y[j] <= 1

# [Constraint] Do not AC leaf modules
for i in range(num_nodes):
if graph.nodes[i]["is_leaf"]:
prob += y[i] == 0

# [Constraint] Composiblity between FSDP and AC
for i in range(num_nodes):
for j in range(i + 1, num_nodes):
if graph.ad_matrix[i][j] == 1:
prob += y[i] + x[j] <= 1

# [Constraint] Express parameter taken care of by each module for FSDP
for i in range(1, num_nodes):
P_i = graph.nodes[i]["param_per_module"] / MEM_MULTIPLIER
Expand All @@ -145,11 +169,34 @@ def fsdp_milp(
m[i] == (P_1 + TG_i) / world_size + lpDot(p, coeff) + lpDot(g, coeff) + a[i]
)

# [Constraint] Express amount of discarded activation memory
for i in range(num_nodes):
ACM_i = graph.nodes[i]["ac_memory"] / MEM_MULTIPLIER
IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER
prob += d[i] == ACM_i * r[i] - (ACM_i - IA_i) * y[i]

# [Constraint] Express total activation memory in the backward pass
for i in range(num_nodes):
AG_i = graph.nodes[i]["act_grad_per_module"] / MEM_MULTIPLIER
TA_i = graph.nodes[i]["act_total"] / MEM_MULTIPLIER
prob += a[i] == TA_i + AG_i
ACM_i = graph.nodes[i]["ac_memory"] / MEM_MULTIPLIER
IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER
# related to discarded amount of memory
pos = graph.nodes[i]["pos_fw_post_order"]
coeff = np.zeros(num_nodes)
for k in range(pos):
j = graph.name2node[graph.fw_post_order[k]]["index"]
coeff[j] = 1
prob += a[i] + lpDot(coeff, d) == TA_i + AG_i

# [Constraint] Ensure correctness of r_i
for i in range(num_nodes):
prob += y[i] >= r[i]
if graph.nodes[i]["is_leaf"]:
continue
ACM_i = graph.nodes[i]["ac_memory"] / MEM_MULTIPLIER
IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER
prob += r[i] >= (ACM_i - IA_i) / ACM_i * y[i]

# [Constraint] Express peak memory
for i in range(num_nodes):
Expand All @@ -162,6 +209,20 @@ def fsdp_milp(
# [Constraint] Respect memory budget
prob += max_m + 2 * max_p <= memory_budget

# [Constraint] Express percentage of recomputation time
for i in range(num_nodes):
for s in range(graph.nodes[i]["n_segments"]):
slope = graph.nodes[i]["slopes"][s]
intercept = graph.nodes[i]["intercepts"][s]
prob += rcp[i] - slope * r[i] >= intercept

# [Constraint] Express recomputation time rec_i = y_i * (rep_i * FCP_i)
for i in range(num_nodes):
ACT_i = graph.nodes[i]["ac_runtime"]
prob += rct[i] <= BIG_M * y[i]
prob += rct[i] <= ACT_i * rcp[i]
prob += rct[i] >= ACT_i * rcp[i] - BIG_M * (1 - y[i])

# [Constraint] Express the all gather communication time of each FSDP unit
comm_model = comm_params["all_gather"]
for i in range(num_nodes):
Expand Down Expand Up @@ -233,14 +294,22 @@ def fsdp_milp(
# [Constraint] Express the exposed computation time in the backward pass
for i in range(1, num_nodes):
BCP_i = graph.nodes[i]["bw_runtime_per_module"]
prob += t4[i] >= bw_ag[i] + bw_rs[i] - BCP_i
prob += t4[i] >= bw_ag[i] + bw_rs[i] - BCP_i - rct[i]
prob += bw_e[i] <= BIG_M * x[i]
prob += bw_e[i] <= t4[i]
prob += bw_e[i] >= t4[i] - BIG_M * (1 - x[i])
prob += bw_e[0] == 0

# Set Objeictive
prob += lpSum(fw_e[1:]) + lpSum(bw_e[1:]) + ag[0] + rs[0] + fw_ag[0] + bw_rs[0]
prob += (
lpSum(fw_e[1:])
+ lpSum(bw_e[1:])
+ ag[0]
+ rs[0]
+ fw_ag[0]
+ bw_rs[0]
+ lpSum(rct)
)

# Solve
start_time = time.time()
Expand All @@ -252,28 +321,36 @@ def fsdp_milp(
return

# Print solution
ac_decisions = {}
for i in range(num_nodes):
if round(y[i].varValue) == 1:
ac_decisions[graph.nodes[i]["fqn"]] = round(r[i].varValue, 4)
logger.info(f"AC decisions are {json.dumps(ac_decisions, indent=2)}")
fsdp_decisions = set()
for i in range(num_nodes):
if round(value(x[i]) if x[i] else 0) == 1:
fsdp_decisions.add(graph.nodes[i]["fqn"])
peak_mem = (max_m.varValue + 2 * max_p.varValue) * MEM_MULTIPLIER
obj = round(value(prob.objective), 4)

logger.info(
f"On {world_size} GPUs\n"
+ f" FSDP units are {fsdp_decisions}\n"
+ f" peak memory is {display_bytes(peak_mem, 'GiB')}\n"
+ f" total exposed computation time is {obj} ms"
+ f" total exposed computation time + recomputation time is {obj} ms"
)

if verbose:
logger.info("\n\n --------- DETAILS ---------")
for i in range(num_nodes):
x_i = value(x[i]) if x[i] else 0
y_i = value(y[i]) if y[i] else 0
p_i = p[i].varValue * MEM_MULTIPLIER
g_i = g[i].varValue * MEM_MULTIPLIER
a_i = a[i].varValue * MEM_MULTIPLIER
m_i = m[i].varValue * MEM_MULTIPLIER
y_i = y[i].varValue
r_i = r[i].varValue
d_i = d[i].varValue * MEM_MULTIPLIER
ag_i = ag[i].varValue if ag[i] else 0
fw_ag_i = fw_ag[i].varValue if fw_ag[i] else 0
bw_ag_i = bw_ag[i].varValue if bw_ag[i] else 0
Expand All @@ -283,22 +360,31 @@ def fsdp_milp(
BCP_i = graph.nodes[i]["bw_runtime_per_module"]
fw_e_i = fw_e[i].varValue if fw_e[i] else 0
bw_e_i = bw_e[i].varValue if bw_e[i] else 0
rcp_i = rcp[i].varValue if rcp[i].varValue else 0
rct_i = rct[i].varValue
logger.info(
("FSDP" if round(x_i) == 1 else " ")
+ (" AC" if round(y_i) == 1 else " ")
+ f" {graph.nodes[i]['fqn']:<40}: "
+ f"p_i = {display_bytes(p_i, 'GiB'):<10} "
+ f"g_i = {display_bytes(g_i, 'GiB'):<10} "
+ f"a_i = {display_bytes(a_i, 'GiB'):<10} "
+ f"m_i = {display_bytes(m_i, 'GiB'):<10} "
+ f"ag_i = {round(ag_i, 2):5.2f} ms "
+ f"fw_ag_i = {round(fw_ag_i, 2):5.2f} ms "
+ f"bw_ag_i = {round(bw_ag_i, 2):5.2f} ms "
+ f"rs_i = {round(rs_i, 2):5.2f} ms "
+ f"bw_rs_i = {round(bw_rs_i, 2):5.2f} ms "
+ f"FCP_i = {FCP_i:8.2f} ms "
+ f"BCP_i = {BCP_i:8.2f} ms "
+ f"fw_e_i = {round(fw_e_i, 2):5.2f} ms "
+ f"bw_e_i = {round(bw_e_i, 2):5.2f} ms "
+ f" p_i = {display_bytes(p_i, 'GiB'):>10} "
+ f" g_i = {display_bytes(g_i, 'GiB'):>10} "
+ f" a_i = {display_bytes(a_i, 'GiB'):>10} "
+ f" d_i = {display_bytes(d_i, 'GiB'):>10} "
+ f" r_i = {round(r_i, 2):6.2f} "
+ f" m_i = {display_bytes(m_i, 'GiB'):>10} \n"
+ " " * 50
+ f" ag_i = {round(ag_i, 2):6.2f} ms "
+ f" fw_ag_i = {round(fw_ag_i, 2):6.2f} ms "
+ f" bw_ag_i = {round(bw_ag_i, 2):6.2f} ms "
+ f" rs_i = {round(rs_i, 2):6.2f} ms "
+ f" bw_rs_i = {round(bw_rs_i, 2):6.2f} ms \n"
+ " " * 50
+ f" FCP_i = {FCP_i:6.2f} ms "
+ f" BCP_i = {BCP_i:6.2f} ms "
+ f" rcp_i = {round(rcp_i, 2):6.2f} ms "
+ f" rct_i = {round(rct_i, 2):6.2f} ms "
+ f" fw_e_i = {round(fw_e_i, 2):6.2f} ms "
+ f" bw_e_i = {round(bw_e_i, 2):6.2f} ms "
)


Expand Down

0 comments on commit f0c88ce

Please sign in to comment.