diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 5e3512f..c58aabc 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -34,7 +34,7 @@ from kfac_jax._src import utils import numpy as np -HIGHER_ORDER_NAMES = ("cond", "while", "scan", "xla_call", "xla_pmap") +HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap") # Types for annotation Array = utils.Array @@ -385,13 +385,18 @@ def make_jax_graph( for v in closed_jaxpr.jaxpr.outvars] # pytype:disable=attribute-error if clean_broadcasts: + # closed_jaxpr = clean_jaxpr(closed_jaxpr) closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr) + in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars) + if isinstance(params_index, int): params_vars = in_vars[params_index] else: params_vars = tuple(in_vars[i] for i in params_index) + params_vars, params_tree = jax.tree_util.tree_flatten(params_vars) + return JaxprGraph( name=name, closed_jaxpr=closed_jaxpr, @@ -820,6 +825,8 @@ def read_env( if isinstance(v, jax.core.Literal): # Literals are values baked into the Jaxpr result.append(v.val) + elif isinstance(v, jax.core.DropVar): + result.append(None) else: result.append(env[v]) return result @@ -850,60 +857,107 @@ def to_jaxpr_or_closed_jaxpr(closed_jaxpr: ClosedJaxpr, original: J) -> J: def apply_to_higher_order_primitives(eqn, func, *args, **kwargs): """Applies `func` only to higher order Jax primitives.""" + if eqn.primitive.name not in HIGHER_ORDER_NAMES: return eqn + elif eqn.primitive.name == "cond": params = dict(**eqn.params) params["branches"] = tuple( func(branch, *args, **kwargs) for branch in params["branches"] ) return eqn.replace(params=params) + elif eqn.primitive.name == "while": params = dict(**eqn.params) params["body_jaxpr"] = func(params["body_jaxpr"], *args, **kwargs) return eqn.replace(params=params) - elif eqn.primitive.name == "scan": + + elif eqn.primitive.name in ("scan", "pjit"): params = dict(**eqn.params) params["jaxpr"] = func(params["jaxpr"], *args, **kwargs) return eqn.replace(params=params) + elif eqn.primitive.name in ("xla_call", "xla_pmap"): params = dict(**eqn.params) params["call_jaxpr"] = func(params["call_jaxpr"], *args, **kwargs) return eqn.replace(params=params) + else: raise NotImplementedError() -def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J: +def clean_jaxpr( + jaxpr: J, + preserve_tags: bool = True, + outvar_is_dep: tuple[bool, ...] | None = None, +) -> J: """Runs dead code elimination on a Jaxpr, retaining loss and layer tags.""" + closed_jaxpr = to_closed_jaxpr(jaxpr) eqns = [] - dependants = set(closed_jaxpr.jaxpr.outvars) + + non_literal_outvars = set(var for var in closed_jaxpr.jaxpr.outvars + if not isinstance(var, jax.core.Literal)) + + if outvar_is_dep is not None: + dependants = set(var for var, is_dep in zip( + closed_jaxpr.jaxpr.outvars, outvar_is_dep, strict=True) if is_dep) + else: + dependants = non_literal_outvars + for eqn in reversed(closed_jaxpr.jaxpr.eqns): + eqn = apply_to_higher_order_primitives( - eqn, clean_jaxpr, preserve_tags=preserve_tags) + eqn, + functools.partial( + clean_jaxpr, + outvar_is_dep=tuple(var in dependants for var in eqn.outvars)), + preserve_tags=preserve_tags + ) check = False + for v in eqn.outvars: if v in dependants: dependants.remove(v) check = True + if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)): check = check or preserve_tags + if check: eqns.append(eqn) new_dependants = set(v for v in eqn.invars if not isinstance(v, jax.core.Literal)) dependants = dependants.union(new_dependants) + # Dependants should only be invars dependants = dependants - set(closed_jaxpr.jaxpr.invars + closed_jaxpr.jaxpr.constvars) - if dependants: + if dependants - non_literal_outvars: raise ValueError("Something went wrong with the dead code elimination.") + # Because trying to elimiate output variables from higher order primitives is + # too much work, we instead just replace them with zeros, which is moderately + # hacky. + # TODO(jamesmartens,botev): Do something better here. + final_outvars = [] + for var in closed_jaxpr.jaxpr.outvars: + + if not isinstance(var, jax.core.Literal) and var in dependants: + assert isinstance(var.aval, jax.core.ShapedArray) + val = np.zeros(var.aval.shape, dtype=var.aval.dtype) + dummy_literal = jax.core.Literal(val, var.aval) + final_outvars.append(dummy_literal) + + else: + final_outvars.append(var) + closed_jaxpr = ClosedJaxpr( - jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns))), + jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns)), + outvars=final_outvars), consts=closed_jaxpr.consts, ) return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr) @@ -911,20 +965,25 @@ def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J: def merge_broadcasts_jaxpr(jaxpr: J) -> J: """Merges consecutive broadcasts in the given Jaxpr.""" - closed_jaxpr = clean_jaxpr(to_closed_jaxpr(jaxpr)) + + closed_jaxpr = to_closed_jaxpr(jaxpr) + closed_jaxpr = clean_jaxpr(closed_jaxpr) broadcasts_outputs = {} eqns = list() for eqn in closed_jaxpr.jaxpr.eqns: + eqn = apply_to_higher_order_primitives(eqn, merge_broadcasts_jaxpr) # We ignore broadcasting of constants if (eqn.primitive.name == "broadcast_in_dim" and not all(isinstance(v, jax.core.Literal) for v in eqn.invars)): + if eqn.invars[0] in broadcasts_outputs: # Construct a merged equation from the previous and current one prev_eqn = broadcasts_outputs[eqn.invars[0]] + broadcasts_outputs[eqn.outvars[0]] = prev_eqn.replace( params={ "shape": eqn.params["shape"], @@ -935,16 +994,21 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J: }, outvars=eqn.outvars, ) + else: broadcasts_outputs[eqn.outvars[0]] = eqn + if eqn.outvars[0] in closed_jaxpr.jaxpr.outvars: # We must preserve output equations eqns.append(broadcasts_outputs[eqn.outvars[0]]) + else: for v in eqn.invars: if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs: eqns.append(broadcasts_outputs[v]) + eqns.append(eqn) + closed_jaxpr = ClosedJaxpr( jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns), consts=closed_jaxpr.consts @@ -1278,7 +1342,7 @@ def full_name(self) -> str: prefix = prefix + f"{eqn.primitive.name}_{n}/" if eqn.primitive.name == "cond": raise NotImplementedError() - elif eqn.primitive.name == "scan": + elif eqn.primitive.name == "scan": # do we need to add "pjit" here too??? p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p) for p in param_vars] checks = [pi < eqn.params["num_consts"] for pi in p_indexes] @@ -1313,7 +1377,7 @@ def top_level_parameters(self) -> tuple[Var, ...]: assert eqn.primitive.name in HIGHER_ORDER_NAMES if eqn.primitive.name == "cond": raise NotImplementedError() - elif eqn.primitive.name == "scan": + elif eqn.primitive.name in ("scan", "pjit"): invars = eqn.params["jaxpr"].jaxpr.invars elif eqn.primitive.name == "while": invars = eqn.params["body_jaxpr"].jaxpr.invars @@ -1393,6 +1457,7 @@ def _auto_register_tags( "cond": 0, "while": 0, "scan": 0, + "pjit": 0, "xla_call": 0, "xla_pmap": 0, } @@ -1425,7 +1490,7 @@ def _auto_register_tags( sub_jaxprs = eqn.params["branches"] elif eqn_name == "while": sub_jaxprs = [eqn.params["body_jaxpr"]] - elif eqn_name == "scan": + elif eqn_name in ("scan", "pjit"): sub_jaxprs = [eqn.params["jaxpr"]] elif eqn_name in ("xla_call", "xla_pmap"): sub_jaxprs = [eqn.params["call_jaxpr"]] @@ -1471,7 +1536,7 @@ def _auto_register_tags( eqn_params["branches"] = final_jaxprs elif eqn_name == "while": [eqn_params["body_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking - elif eqn_name == "scan": + elif eqn_name in ("scan", "pjit"): [eqn_params["jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking elif eqn_name in ("xla_call", "xla_pmap"): [eqn_params["call_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking