Skip to content

Commit

Permalink
- Adding handling of jitted functions to graph scanner.
Browse files Browse the repository at this point in the history
- Better dead code elimination inside higher order primitives.

PiperOrigin-RevId: 661672100
  • Loading branch information
james-martens authored and KfacJaxDev committed Aug 11, 2024
1 parent 5b9dcc1 commit a6d01e7
Showing 1 changed file with 77 additions and 12 deletions.
89 changes: 77 additions & 12 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from kfac_jax._src import utils
import numpy as np

HIGHER_ORDER_NAMES = ("cond", "while", "scan", "xla_call", "xla_pmap")
HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap")

# Types for annotation
Array = utils.Array
Expand Down Expand Up @@ -385,13 +385,18 @@ def make_jax_graph(
for v in closed_jaxpr.jaxpr.outvars] # pytype:disable=attribute-error

if clean_broadcasts:
# closed_jaxpr = clean_jaxpr(closed_jaxpr)
closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr)

in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars)

if isinstance(params_index, int):
params_vars = in_vars[params_index]
else:
params_vars = tuple(in_vars[i] for i in params_index)

params_vars, params_tree = jax.tree_util.tree_flatten(params_vars)

return JaxprGraph(
name=name,
closed_jaxpr=closed_jaxpr,
Expand Down Expand Up @@ -820,6 +825,8 @@ def read_env(
if isinstance(v, jax.core.Literal):
# Literals are values baked into the Jaxpr
result.append(v.val)
elif isinstance(v, jax.core.DropVar):
result.append(None)
else:
result.append(env[v])
return result
Expand Down Expand Up @@ -850,81 +857,133 @@ def to_jaxpr_or_closed_jaxpr(closed_jaxpr: ClosedJaxpr, original: J) -> J:

def apply_to_higher_order_primitives(eqn, func, *args, **kwargs):
"""Applies `func` only to higher order Jax primitives."""

if eqn.primitive.name not in HIGHER_ORDER_NAMES:
return eqn

elif eqn.primitive.name == "cond":
params = dict(**eqn.params)
params["branches"] = tuple(
func(branch, *args, **kwargs) for branch in params["branches"]
)
return eqn.replace(params=params)

elif eqn.primitive.name == "while":
params = dict(**eqn.params)
params["body_jaxpr"] = func(params["body_jaxpr"], *args, **kwargs)
return eqn.replace(params=params)
elif eqn.primitive.name == "scan":

elif eqn.primitive.name in ("scan", "pjit"):
params = dict(**eqn.params)
params["jaxpr"] = func(params["jaxpr"], *args, **kwargs)
return eqn.replace(params=params)

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
params = dict(**eqn.params)
params["call_jaxpr"] = func(params["call_jaxpr"], *args, **kwargs)
return eqn.replace(params=params)

else:
raise NotImplementedError()


def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J:
def clean_jaxpr(
jaxpr: J,
preserve_tags: bool = True,
outvar_is_dep: tuple[bool, ...] | None = None,
) -> J:
"""Runs dead code elimination on a Jaxpr, retaining loss and layer tags."""

closed_jaxpr = to_closed_jaxpr(jaxpr)
eqns = []
dependants = set(closed_jaxpr.jaxpr.outvars)

non_literal_outvars = set(var for var in closed_jaxpr.jaxpr.outvars
if not isinstance(var, jax.core.Literal))

if outvar_is_dep is not None:
dependants = set(var for var, is_dep in zip(
closed_jaxpr.jaxpr.outvars, outvar_is_dep, strict=True) if is_dep)
else:
dependants = non_literal_outvars

for eqn in reversed(closed_jaxpr.jaxpr.eqns):

eqn = apply_to_higher_order_primitives(
eqn, clean_jaxpr, preserve_tags=preserve_tags)
eqn,
functools.partial(
clean_jaxpr,
outvar_is_dep=tuple(var in dependants for var in eqn.outvars)),
preserve_tags=preserve_tags
)

check = False

for v in eqn.outvars:
if v in dependants:
dependants.remove(v)
check = True

if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)):
check = check or preserve_tags

if check:
eqns.append(eqn)
new_dependants = set(v for v in eqn.invars
if not isinstance(v, jax.core.Literal))
dependants = dependants.union(new_dependants)

# Dependants should only be invars
dependants = dependants - set(closed_jaxpr.jaxpr.invars +
closed_jaxpr.jaxpr.constvars)

if dependants:
if dependants - non_literal_outvars:
raise ValueError("Something went wrong with the dead code elimination.")

# Because trying to elimiate output variables from higher order primitives is
# too much work, we instead just replace them with zeros, which is moderately
# hacky.
# TODO(jamesmartens,botev): Do something better here.
final_outvars = []
for var in closed_jaxpr.jaxpr.outvars:

if not isinstance(var, jax.core.Literal) and var in dependants:
assert isinstance(var.aval, jax.core.ShapedArray)
val = np.zeros(var.aval.shape, dtype=var.aval.dtype)
dummy_literal = jax.core.Literal(val, var.aval)
final_outvars.append(dummy_literal)

else:
final_outvars.append(var)

closed_jaxpr = ClosedJaxpr(
jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns))),
jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns)),
outvars=final_outvars),
consts=closed_jaxpr.consts,
)
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)


def merge_broadcasts_jaxpr(jaxpr: J) -> J:
"""Merges consecutive broadcasts in the given Jaxpr."""
closed_jaxpr = clean_jaxpr(to_closed_jaxpr(jaxpr))

closed_jaxpr = to_closed_jaxpr(jaxpr)
closed_jaxpr = clean_jaxpr(closed_jaxpr)

broadcasts_outputs = {}
eqns = list()

for eqn in closed_jaxpr.jaxpr.eqns:

eqn = apply_to_higher_order_primitives(eqn, merge_broadcasts_jaxpr)

# We ignore broadcasting of constants
if (eqn.primitive.name == "broadcast_in_dim" and
not all(isinstance(v, jax.core.Literal) for v in eqn.invars)):

if eqn.invars[0] in broadcasts_outputs:
# Construct a merged equation from the previous and current one
prev_eqn = broadcasts_outputs[eqn.invars[0]]

broadcasts_outputs[eqn.outvars[0]] = prev_eqn.replace(
params={
"shape": eqn.params["shape"],
Expand All @@ -935,16 +994,21 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:
},
outvars=eqn.outvars,
)

else:
broadcasts_outputs[eqn.outvars[0]] = eqn

if eqn.outvars[0] in closed_jaxpr.jaxpr.outvars:
# We must preserve output equations
eqns.append(broadcasts_outputs[eqn.outvars[0]])

else:
for v in eqn.invars:
if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs:
eqns.append(broadcasts_outputs[v])

eqns.append(eqn)

closed_jaxpr = ClosedJaxpr(
jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns),
consts=closed_jaxpr.consts
Expand Down Expand Up @@ -1278,7 +1342,7 @@ def full_name(self) -> str:
prefix = prefix + f"{eqn.primitive.name}_{n}/"
if eqn.primitive.name == "cond":
raise NotImplementedError()
elif eqn.primitive.name == "scan":
elif eqn.primitive.name == "scan": # do we need to add "pjit" here too???
p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p)
for p in param_vars]
checks = [pi < eqn.params["num_consts"] for pi in p_indexes]
Expand Down Expand Up @@ -1313,7 +1377,7 @@ def top_level_parameters(self) -> tuple[Var, ...]:
assert eqn.primitive.name in HIGHER_ORDER_NAMES
if eqn.primitive.name == "cond":
raise NotImplementedError()
elif eqn.primitive.name == "scan":
elif eqn.primitive.name in ("scan", "pjit"):
invars = eqn.params["jaxpr"].jaxpr.invars
elif eqn.primitive.name == "while":
invars = eqn.params["body_jaxpr"].jaxpr.invars
Expand Down Expand Up @@ -1393,6 +1457,7 @@ def _auto_register_tags(
"cond": 0,
"while": 0,
"scan": 0,
"pjit": 0,
"xla_call": 0,
"xla_pmap": 0,
}
Expand Down Expand Up @@ -1425,7 +1490,7 @@ def _auto_register_tags(
sub_jaxprs = eqn.params["branches"]
elif eqn_name == "while":
sub_jaxprs = [eqn.params["body_jaxpr"]]
elif eqn_name == "scan":
elif eqn_name in ("scan", "pjit"):
sub_jaxprs = [eqn.params["jaxpr"]]
elif eqn_name in ("xla_call", "xla_pmap"):
sub_jaxprs = [eqn.params["call_jaxpr"]]
Expand Down Expand Up @@ -1471,7 +1536,7 @@ def _auto_register_tags(
eqn_params["branches"] = final_jaxprs
elif eqn_name == "while":
[eqn_params["body_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
elif eqn_name == "scan":
elif eqn_name in ("scan", "pjit"):
[eqn_params["jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
elif eqn_name in ("xla_call", "xla_pmap"):
[eqn_params["call_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
Expand Down

0 comments on commit a6d01e7

Please sign in to comment.