Skip to content

Commit

Permalink
- Adding handling of jitted functions to graph scanner.
Browse files Browse the repository at this point in the history
- Better dead code elimination inside higher order primitives (by tracking dependencies from outside to inside such primitives).

PiperOrigin-RevId: 666524373
  • Loading branch information
james-martens authored and KfacJaxDev committed Aug 22, 2024
1 parent f5c4e03 commit b59b188
Showing 1 changed file with 145 additions and 11 deletions.
156 changes: 145 additions & 11 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from kfac_jax._src import utils
import numpy as np

HIGHER_ORDER_NAMES = ("cond", "while", "scan", "xla_call", "xla_pmap")
HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap")
ITERATIVE_HIGHER_ORDER_NAMES = ("while", "scan")

# Types for annotation
Array = utils.Array
Expand Down Expand Up @@ -331,28 +332,38 @@ def make_jax_graph(
tag_ctor: TagCtor | None = None,
) -> JaxprGraph:
"""Creates a :class:`~JaxGraph` instance from the provided function and arguments."""

in_tree = jax.tree_util.tree_structure(func_args)
closed_jaxpr, out_shapes = jax.make_jaxpr(func, return_shape=True)(*func_args)

if compute_only_loss_tags:

make_var_func = jax.core.gensym()
eqns = []
sub_graph_vars = set()
loss_tags_output_vars = []

for eqn in reversed(closed_jaxpr.jaxpr.eqns):

if (isinstance(eqn.primitive, tags.LossTag) or
any(v in sub_graph_vars for v in eqn.outvars)):

if isinstance(eqn.primitive, tags.LossTag):

new_out_vars = []
for v in eqn.outvars:

if isinstance(v, jax.core.DropVar):
new_out_vars.append(make_var_func(v.aval))
else:
new_out_vars.append(v)

loss_tags_output_vars.extend(new_out_vars[::-1])
eqns.append(eqn.replace(outvars=new_out_vars))

else:
eqns.append(eqn)

sub_graph_vars.update(
v for v in eqn.invars if not isinstance(v, jax.core.Literal)
)
Expand Down Expand Up @@ -385,13 +396,19 @@ def make_jax_graph(
for v in closed_jaxpr.jaxpr.outvars] # pytype:disable=attribute-error

if clean_broadcasts:
closed_jaxpr = clean_jaxpr(closed_jaxpr)
closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr)
closed_jaxpr = clean_jaxpr(closed_jaxpr)

in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars)

if isinstance(params_index, int):
params_vars = in_vars[params_index]
else:
params_vars = tuple(in_vars[i] for i in params_index)

params_vars, params_tree = jax.tree_util.tree_flatten(params_vars)

return JaxprGraph(
name=name,
closed_jaxpr=closed_jaxpr,
Expand Down Expand Up @@ -529,7 +546,7 @@ def create_eqn(
env: dict[Var, Var],
make_var_func: MakeVarFunc,
) -> JaxprEqns:
"""Creates a new ``JaxprEqn`` for the this match."""
"""Creates a new ``JaxprEqn`` for this match."""
in_vars = [self.variables_map[k] for k in self.pattern.graph.jaxpr.invars]
in_vars = [env.get(v, v) for v in in_vars]
out_vars = [self.variables_map[k]
Expand Down Expand Up @@ -820,6 +837,8 @@ def read_env(
if isinstance(v, jax.core.Literal):
# Literals are values baked into the Jaxpr
result.append(v.val)
elif isinstance(v, jax.core.DropVar):
result.append(None)
else:
result.append(env[v])
return result
Expand Down Expand Up @@ -850,81 +869,162 @@ def to_jaxpr_or_closed_jaxpr(closed_jaxpr: ClosedJaxpr, original: J) -> J:

def apply_to_higher_order_primitives(eqn, func, *args, **kwargs):
"""Applies `func` only to higher order Jax primitives."""

if eqn.primitive.name not in HIGHER_ORDER_NAMES:
return eqn

elif eqn.primitive.name == "cond":
params = dict(**eqn.params)
params["branches"] = tuple(
func(branch, *args, **kwargs) for branch in params["branches"]
)
return eqn.replace(params=params)

elif eqn.primitive.name == "while":
params = dict(**eqn.params)
params["body_jaxpr"] = func(params["body_jaxpr"], *args, **kwargs)
return eqn.replace(params=params)
elif eqn.primitive.name == "scan":

elif eqn.primitive.name in ("scan", "pjit"):
params = dict(**eqn.params)
params["jaxpr"] = func(params["jaxpr"], *args, **kwargs)
return eqn.replace(params=params)

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
params = dict(**eqn.params)
params["call_jaxpr"] = func(params["call_jaxpr"], *args, **kwargs)
return eqn.replace(params=params)

else:
raise NotImplementedError()


def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J:
def clean_jaxpr(
jaxpr: J,
preserve_tags: bool = True,
outvar_is_dep: tuple[bool, ...] | None = None,
) -> J:
"""Runs dead code elimination on a Jaxpr, retaining loss and layer tags."""

closed_jaxpr = to_closed_jaxpr(jaxpr)
eqns = []
dependants = set(closed_jaxpr.jaxpr.outvars)

if outvar_is_dep is None:
outvar_is_dep = (True,) * len(closed_jaxpr.jaxpr.outvars)

dependants = set(
var for var, is_dep in zip(closed_jaxpr.jaxpr.outvars,
outvar_is_dep, strict=True)
if is_dep and not isinstance(var, jax.core.Literal)
)

for eqn in reversed(closed_jaxpr.jaxpr.eqns):

# It's much more complicated to trace dependencies through iterative higher
# order primitives, so we don't do it.
if eqn.primitive.name in ITERATIVE_HIGHER_ORDER_NAMES:
outvar_is_dep_for_eqn = None
else:
outvar_is_dep_for_eqn = tuple(var in dependants for var in eqn.outvars)

# Note that we currently only trace dependencies into higher order
# primitives, but not *through* them. If a single output of a higher order
# primtive is a dependency, then all of its inputs are treated as such too.
eqn = apply_to_higher_order_primitives(
eqn, clean_jaxpr, preserve_tags=preserve_tags)
eqn,
functools.partial(
clean_jaxpr,
outvar_is_dep=outvar_is_dep_for_eqn
),
preserve_tags=preserve_tags
)

check = False

for v in eqn.outvars:
if v in dependants:
dependants.remove(v)
check = True

if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)):
check = check or preserve_tags

if check:
eqns.append(eqn)
new_dependants = set(v for v in eqn.invars
if not isinstance(v, jax.core.Literal))
dependants = dependants.union(new_dependants)

# Dependants should only be invars
dependants = dependants - set(closed_jaxpr.jaxpr.invars +
closed_jaxpr.jaxpr.constvars)

if dependants:
raise ValueError("Something went wrong with the dead code elimination.")

# Because trying to eliminate output variables from higher order primitives is
# too much work, we instead just replace them with zeros, which is moderately
# hacky.
# TODO(jamesmartens,botev): Consider doing something better here.
final_outvars = []
for var, is_dep in zip(closed_jaxpr.jaxpr.outvars, outvar_is_dep,
strict=True):

if not isinstance(var, jax.core.Literal) and not is_dep:

assert isinstance(var.aval, jax.core.ShapedArray)

if not var.aval.shape: # scalar case

val = np.zeros(var.aval.shape, dtype=var.aval.dtype)
zero_literal = jax.core.Literal(val, var.aval)

final_outvars.append(zero_literal)

else:

def dummy_func():
return jnp.zeros(var.aval.shape, dtype=var.aval.dtype) # pylint: disable=cell-var-from-loop

dummy_jaxpr = jax.make_jaxpr(dummy_func)()

assert len(dummy_jaxpr.eqns) == 1
assert len(dummy_jaxpr.eqns[0].outvars) == 1

eqns.append(dummy_jaxpr.eqns[0])
final_outvars.append(dummy_jaxpr.eqns[0].outvars[0])

else:
final_outvars.append(var)

closed_jaxpr = ClosedJaxpr(
jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns))),
jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns)),
outvars=final_outvars),
consts=closed_jaxpr.consts,
)
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)


def merge_broadcasts_jaxpr(jaxpr: J) -> J:
"""Merges consecutive broadcasts in the given Jaxpr."""
closed_jaxpr = clean_jaxpr(to_closed_jaxpr(jaxpr))

closed_jaxpr = to_closed_jaxpr(jaxpr)

broadcasts_outputs = {}
eqns = list()

for eqn in closed_jaxpr.jaxpr.eqns:

eqn = apply_to_higher_order_primitives(eqn, merge_broadcasts_jaxpr)

# We ignore broadcasting of constants
if (eqn.primitive.name == "broadcast_in_dim" and
not all(isinstance(v, jax.core.Literal) for v in eqn.invars)):

if eqn.invars[0] in broadcasts_outputs:
# Construct a merged equation from the previous and current one
prev_eqn = broadcasts_outputs[eqn.invars[0]]

broadcasts_outputs[eqn.outvars[0]] = prev_eqn.replace(
params={
"shape": eqn.params["shape"],
Expand All @@ -935,16 +1035,21 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:
},
outvars=eqn.outvars,
)

else:
broadcasts_outputs[eqn.outvars[0]] = eqn

if eqn.outvars[0] in closed_jaxpr.jaxpr.outvars:
# We must preserve output equations
eqns.append(broadcasts_outputs[eqn.outvars[0]])

else:
for v in eqn.invars:
if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs:
eqns.append(broadcasts_outputs[v])

eqns.append(eqn)

closed_jaxpr = ClosedJaxpr(
jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns),
consts=closed_jaxpr.consts
Expand Down Expand Up @@ -1270,34 +1375,52 @@ def base_name(self) -> str:
@property
def full_name(self) -> str:
"""The full name of the tag location."""

prefix = ""
param_vars = self.bottom_level_parameters

for eqn, n in reversed(self.parent_equations):

assert eqn.primitive.name in HIGHER_ORDER_NAMES

# Prefix for this higher order primitive
prefix = prefix + f"{eqn.primitive.name}_{n}/"

if eqn.primitive.name == "cond":
raise NotImplementedError()

elif eqn.primitive.name == "scan":

p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p)
for p in param_vars]
checks = [pi < eqn.params["num_consts"] for pi in p_indexes]

if not (all(checks) or all(not ci for ci in checks)):
raise ValueError("Parameters inside scan of the same tag are not both"
" carry or const.")

if all(checks):
prefix = prefix + "const/"
else:
prefix = prefix + "carry/"

elif eqn.primitive.name == "pjit":
p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p)
for p in param_vars]

elif eqn.primitive.name == "while":
p_indexes = [eqn.params["body_jaxpr"].jaxpr.invars.index(p)
for p in param_vars]

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
p_indexes = [eqn.params["call_jaxpr"].invars.index(p)
for p in param_vars]

else:
raise NotImplementedError()

param_vars = [eqn.invars[pi] for pi in p_indexes]

return prefix + self.base_name

@property
Expand All @@ -1308,21 +1431,31 @@ def bottom_level_parameters(self) -> tuple[Var, ...]:
@property
def top_level_parameters(self) -> tuple[Var, ...]:
"""The top level parameter variables of the tag location."""

param_vars = self.bottom_level_parameters

for eqn, _ in reversed(self.parent_equations):

assert eqn.primitive.name in HIGHER_ORDER_NAMES

if eqn.primitive.name == "cond":
raise NotImplementedError()
elif eqn.primitive.name == "scan":

elif eqn.primitive.name in ("scan", "pjit"):
invars = eqn.params["jaxpr"].jaxpr.invars

elif eqn.primitive.name == "while":
invars = eqn.params["body_jaxpr"].jaxpr.invars

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
invars = eqn.params["call_jaxpr"].invars

else:
raise NotImplementedError()

p_indexes = [invars.index(p) for p in param_vars]
param_vars = tuple(eqn.invars[pi] for pi in p_indexes)

return param_vars

def add_parent_eqn(self, eqn: JaxprEqn, counter: int):
Expand Down Expand Up @@ -1393,6 +1526,7 @@ def _auto_register_tags(
"cond": 0,
"while": 0,
"scan": 0,
"pjit": 0,
"xla_call": 0,
"xla_pmap": 0,
}
Expand Down Expand Up @@ -1425,7 +1559,7 @@ def _auto_register_tags(
sub_jaxprs = eqn.params["branches"]
elif eqn_name == "while":
sub_jaxprs = [eqn.params["body_jaxpr"]]
elif eqn_name == "scan":
elif eqn_name in ("scan", "pjit"):
sub_jaxprs = [eqn.params["jaxpr"]]
elif eqn_name in ("xla_call", "xla_pmap"):
sub_jaxprs = [eqn.params["call_jaxpr"]]
Expand Down Expand Up @@ -1471,7 +1605,7 @@ def _auto_register_tags(
eqn_params["branches"] = final_jaxprs
elif eqn_name == "while":
[eqn_params["body_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
elif eqn_name == "scan":
elif eqn_name in ("scan", "pjit"):
[eqn_params["jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
elif eqn_name in ("xla_call", "xla_pmap"):
[eqn_params["call_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
Expand Down

0 comments on commit b59b188

Please sign in to comment.