Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- Adding handling of jitted functions to graph scanner. #254

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 145 additions & 11 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from kfac_jax._src import utils
import numpy as np

HIGHER_ORDER_NAMES = ("cond", "while", "scan", "xla_call", "xla_pmap")
HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap")
ITERATIVE_HIGHER_ORDER_NAMES = ("while", "scan")

# Types for annotation
Array = utils.Array
Expand Down Expand Up @@ -331,28 +332,38 @@ def make_jax_graph(
tag_ctor: TagCtor | None = None,
) -> JaxprGraph:
"""Creates a :class:`~JaxGraph` instance from the provided function and arguments."""

in_tree = jax.tree_util.tree_structure(func_args)
closed_jaxpr, out_shapes = jax.make_jaxpr(func, return_shape=True)(*func_args)

if compute_only_loss_tags:

make_var_func = jax.core.gensym()
eqns = []
sub_graph_vars = set()
loss_tags_output_vars = []

for eqn in reversed(closed_jaxpr.jaxpr.eqns):

if (isinstance(eqn.primitive, tags.LossTag) or
any(v in sub_graph_vars for v in eqn.outvars)):

if isinstance(eqn.primitive, tags.LossTag):

new_out_vars = []
for v in eqn.outvars:

if isinstance(v, jax.core.DropVar):
new_out_vars.append(make_var_func(v.aval))
else:
new_out_vars.append(v)

loss_tags_output_vars.extend(new_out_vars[::-1])
eqns.append(eqn.replace(outvars=new_out_vars))

else:
eqns.append(eqn)

sub_graph_vars.update(
v for v in eqn.invars if not isinstance(v, jax.core.Literal)
)
Expand Down Expand Up @@ -385,13 +396,19 @@ def make_jax_graph(
for v in closed_jaxpr.jaxpr.outvars] # pytype:disable=attribute-error

if clean_broadcasts:
closed_jaxpr = clean_jaxpr(closed_jaxpr)
closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr)
closed_jaxpr = clean_jaxpr(closed_jaxpr)

in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars)

if isinstance(params_index, int):
params_vars = in_vars[params_index]
else:
params_vars = tuple(in_vars[i] for i in params_index)

params_vars, params_tree = jax.tree_util.tree_flatten(params_vars)

return JaxprGraph(
name=name,
closed_jaxpr=closed_jaxpr,
Expand Down Expand Up @@ -529,7 +546,7 @@ def create_eqn(
env: dict[Var, Var],
make_var_func: MakeVarFunc,
) -> JaxprEqns:
"""Creates a new ``JaxprEqn`` for the this match."""
"""Creates a new ``JaxprEqn`` for this match."""
in_vars = [self.variables_map[k] for k in self.pattern.graph.jaxpr.invars]
in_vars = [env.get(v, v) for v in in_vars]
out_vars = [self.variables_map[k]
Expand Down Expand Up @@ -820,6 +837,8 @@ def read_env(
if isinstance(v, jax.core.Literal):
# Literals are values baked into the Jaxpr
result.append(v.val)
elif isinstance(v, jax.core.DropVar):
result.append(None)
else:
result.append(env[v])
return result
Expand Down Expand Up @@ -850,81 +869,162 @@ def to_jaxpr_or_closed_jaxpr(closed_jaxpr: ClosedJaxpr, original: J) -> J:

def apply_to_higher_order_primitives(eqn, func, *args, **kwargs):
"""Applies `func` only to higher order Jax primitives."""

if eqn.primitive.name not in HIGHER_ORDER_NAMES:
return eqn

elif eqn.primitive.name == "cond":
params = dict(**eqn.params)
params["branches"] = tuple(
func(branch, *args, **kwargs) for branch in params["branches"]
)
return eqn.replace(params=params)

elif eqn.primitive.name == "while":
params = dict(**eqn.params)
params["body_jaxpr"] = func(params["body_jaxpr"], *args, **kwargs)
return eqn.replace(params=params)
elif eqn.primitive.name == "scan":

elif eqn.primitive.name in ("scan", "pjit"):
params = dict(**eqn.params)
params["jaxpr"] = func(params["jaxpr"], *args, **kwargs)
return eqn.replace(params=params)

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
params = dict(**eqn.params)
params["call_jaxpr"] = func(params["call_jaxpr"], *args, **kwargs)
return eqn.replace(params=params)

else:
raise NotImplementedError()


def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J:
def clean_jaxpr(
jaxpr: J,
preserve_tags: bool = True,
outvar_is_dep: tuple[bool, ...] | None = None,
) -> J:
"""Runs dead code elimination on a Jaxpr, retaining loss and layer tags."""

closed_jaxpr = to_closed_jaxpr(jaxpr)
eqns = []
dependants = set(closed_jaxpr.jaxpr.outvars)

if outvar_is_dep is None:
outvar_is_dep = (True,) * len(closed_jaxpr.jaxpr.outvars)

dependants = set(
var for var, is_dep in zip(closed_jaxpr.jaxpr.outvars,
outvar_is_dep, strict=True)
if is_dep and not isinstance(var, jax.core.Literal)
)

for eqn in reversed(closed_jaxpr.jaxpr.eqns):

# It's much more complicated to trace dependencies through iterative higher
# order primitives, so we don't do it.
if eqn.primitive.name in ITERATIVE_HIGHER_ORDER_NAMES:
outvar_is_dep_for_eqn = None
else:
outvar_is_dep_for_eqn = tuple(var in dependants for var in eqn.outvars)

# Note that we currently only trace dependencies into higher order
# primitives, but not *through* them. If a single output of a higher order
# primtive is a dependency, then all of its inputs are treated as such too.
eqn = apply_to_higher_order_primitives(
eqn, clean_jaxpr, preserve_tags=preserve_tags)
eqn,
functools.partial(
clean_jaxpr,
outvar_is_dep=outvar_is_dep_for_eqn
),
preserve_tags=preserve_tags
)

check = False

for v in eqn.outvars:
if v in dependants:
dependants.remove(v)
check = True

if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)):
check = check or preserve_tags

if check:
eqns.append(eqn)
new_dependants = set(v for v in eqn.invars
if not isinstance(v, jax.core.Literal))
dependants = dependants.union(new_dependants)

# Dependants should only be invars
dependants = dependants - set(closed_jaxpr.jaxpr.invars +
closed_jaxpr.jaxpr.constvars)

if dependants:
raise ValueError("Something went wrong with the dead code elimination.")

# Because trying to eliminate output variables from higher order primitives is
# too much work, we instead just replace them with zeros, which is moderately
# hacky.
# TODO(jamesmartens,botev): Consider doing something better here.
final_outvars = []
for var, is_dep in zip(closed_jaxpr.jaxpr.outvars, outvar_is_dep,
strict=True):

if not isinstance(var, jax.core.Literal) and not is_dep:

assert isinstance(var.aval, jax.core.ShapedArray)

if not var.aval.shape: # scalar case

val = np.zeros(var.aval.shape, dtype=var.aval.dtype)
zero_literal = jax.core.Literal(val, var.aval)

final_outvars.append(zero_literal)

else:

def dummy_func():
return jnp.zeros(var.aval.shape, dtype=var.aval.dtype) # pylint: disable=cell-var-from-loop

dummy_jaxpr = jax.make_jaxpr(dummy_func)()

assert len(dummy_jaxpr.eqns) == 1
assert len(dummy_jaxpr.eqns[0].outvars) == 1

eqns.append(dummy_jaxpr.eqns[0])
final_outvars.append(dummy_jaxpr.eqns[0].outvars[0])

else:
final_outvars.append(var)

closed_jaxpr = ClosedJaxpr(
jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns))),
jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns)),
outvars=final_outvars),
consts=closed_jaxpr.consts,
)
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)


def merge_broadcasts_jaxpr(jaxpr: J) -> J:
"""Merges consecutive broadcasts in the given Jaxpr."""
closed_jaxpr = clean_jaxpr(to_closed_jaxpr(jaxpr))

closed_jaxpr = to_closed_jaxpr(jaxpr)

broadcasts_outputs = {}
eqns = list()

for eqn in closed_jaxpr.jaxpr.eqns:

eqn = apply_to_higher_order_primitives(eqn, merge_broadcasts_jaxpr)

# We ignore broadcasting of constants
if (eqn.primitive.name == "broadcast_in_dim" and
not all(isinstance(v, jax.core.Literal) for v in eqn.invars)):

if eqn.invars[0] in broadcasts_outputs:
# Construct a merged equation from the previous and current one
prev_eqn = broadcasts_outputs[eqn.invars[0]]

broadcasts_outputs[eqn.outvars[0]] = prev_eqn.replace(
params={
"shape": eqn.params["shape"],
Expand All @@ -935,16 +1035,21 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:
},
outvars=eqn.outvars,
)

else:
broadcasts_outputs[eqn.outvars[0]] = eqn

if eqn.outvars[0] in closed_jaxpr.jaxpr.outvars:
# We must preserve output equations
eqns.append(broadcasts_outputs[eqn.outvars[0]])

else:
for v in eqn.invars:
if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs:
eqns.append(broadcasts_outputs[v])

eqns.append(eqn)

closed_jaxpr = ClosedJaxpr(
jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns),
consts=closed_jaxpr.consts
Expand Down Expand Up @@ -1270,34 +1375,52 @@ def base_name(self) -> str:
@property
def full_name(self) -> str:
"""The full name of the tag location."""

prefix = ""
param_vars = self.bottom_level_parameters

for eqn, n in reversed(self.parent_equations):

assert eqn.primitive.name in HIGHER_ORDER_NAMES

# Prefix for this higher order primitive
prefix = prefix + f"{eqn.primitive.name}_{n}/"

if eqn.primitive.name == "cond":
raise NotImplementedError()

elif eqn.primitive.name == "scan":

p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p)
for p in param_vars]
checks = [pi < eqn.params["num_consts"] for pi in p_indexes]

if not (all(checks) or all(not ci for ci in checks)):
raise ValueError("Parameters inside scan of the same tag are not both"
" carry or const.")

if all(checks):
prefix = prefix + "const/"
else:
prefix = prefix + "carry/"

elif eqn.primitive.name == "pjit":
p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p)
for p in param_vars]

elif eqn.primitive.name == "while":
p_indexes = [eqn.params["body_jaxpr"].jaxpr.invars.index(p)
for p in param_vars]

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
p_indexes = [eqn.params["call_jaxpr"].invars.index(p)
for p in param_vars]

else:
raise NotImplementedError()

param_vars = [eqn.invars[pi] for pi in p_indexes]

return prefix + self.base_name

@property
Expand All @@ -1308,21 +1431,31 @@ def bottom_level_parameters(self) -> tuple[Var, ...]:
@property
def top_level_parameters(self) -> tuple[Var, ...]:
"""The top level parameter variables of the tag location."""

param_vars = self.bottom_level_parameters

for eqn, _ in reversed(self.parent_equations):

assert eqn.primitive.name in HIGHER_ORDER_NAMES

if eqn.primitive.name == "cond":
raise NotImplementedError()
elif eqn.primitive.name == "scan":

elif eqn.primitive.name in ("scan", "pjit"):
invars = eqn.params["jaxpr"].jaxpr.invars

elif eqn.primitive.name == "while":
invars = eqn.params["body_jaxpr"].jaxpr.invars

elif eqn.primitive.name in ("xla_call", "xla_pmap"):
invars = eqn.params["call_jaxpr"].invars

else:
raise NotImplementedError()

p_indexes = [invars.index(p) for p in param_vars]
param_vars = tuple(eqn.invars[pi] for pi in p_indexes)

return param_vars

def add_parent_eqn(self, eqn: JaxprEqn, counter: int):
Expand Down Expand Up @@ -1393,6 +1526,7 @@ def _auto_register_tags(
"cond": 0,
"while": 0,
"scan": 0,
"pjit": 0,
"xla_call": 0,
"xla_pmap": 0,
}
Expand Down Expand Up @@ -1425,7 +1559,7 @@ def _auto_register_tags(
sub_jaxprs = eqn.params["branches"]
elif eqn_name == "while":
sub_jaxprs = [eqn.params["body_jaxpr"]]
elif eqn_name == "scan":
elif eqn_name in ("scan", "pjit"):
sub_jaxprs = [eqn.params["jaxpr"]]
elif eqn_name in ("xla_call", "xla_pmap"):
sub_jaxprs = [eqn.params["call_jaxpr"]]
Expand Down Expand Up @@ -1471,7 +1605,7 @@ def _auto_register_tags(
eqn_params["branches"] = final_jaxprs
elif eqn_name == "while":
[eqn_params["body_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
elif eqn_name == "scan":
elif eqn_name in ("scan", "pjit"):
[eqn_params["jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
elif eqn_name in ("xla_call", "xla_pmap"):
[eqn_params["call_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking
Expand Down
Loading