diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 5e3512f..b24e410 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -34,7 +34,7 @@ from kfac_jax._src import utils import numpy as np -HIGHER_ORDER_NAMES = ("cond", "while", "scan", "xla_call", "xla_pmap") +HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap") # Types for annotation Array = utils.Array @@ -331,28 +331,38 @@ def make_jax_graph( tag_ctor: TagCtor | None = None, ) -> JaxprGraph: """Creates a :class:`~JaxGraph` instance from the provided function and arguments.""" + in_tree = jax.tree_util.tree_structure(func_args) closed_jaxpr, out_shapes = jax.make_jaxpr(func, return_shape=True)(*func_args) if compute_only_loss_tags: + make_var_func = jax.core.gensym() eqns = [] sub_graph_vars = set() loss_tags_output_vars = [] + for eqn in reversed(closed_jaxpr.jaxpr.eqns): + if (isinstance(eqn.primitive, tags.LossTag) or any(v in sub_graph_vars for v in eqn.outvars)): + if isinstance(eqn.primitive, tags.LossTag): + new_out_vars = [] for v in eqn.outvars: + if isinstance(v, jax.core.DropVar): new_out_vars.append(make_var_func(v.aval)) else: new_out_vars.append(v) + loss_tags_output_vars.extend(new_out_vars[::-1]) eqns.append(eqn.replace(outvars=new_out_vars)) + else: eqns.append(eqn) + sub_graph_vars.update( v for v in eqn.invars if not isinstance(v, jax.core.Literal) ) @@ -385,13 +395,19 @@ def make_jax_graph( for v in closed_jaxpr.jaxpr.outvars] # pytype:disable=attribute-error if clean_broadcasts: + closed_jaxpr = clean_jaxpr(closed_jaxpr) closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr) + closed_jaxpr = clean_jaxpr(closed_jaxpr) + in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars) + if isinstance(params_index, int): params_vars = in_vars[params_index] else: params_vars = tuple(in_vars[i] for i in params_index) + params_vars, params_tree = jax.tree_util.tree_flatten(params_vars) + return JaxprGraph( name=name, closed_jaxpr=closed_jaxpr, @@ -529,7 +545,7 @@ def create_eqn( env: dict[Var, Var], make_var_func: MakeVarFunc, ) -> JaxprEqns: - """Creates a new ``JaxprEqn`` for the this match.""" + """Creates a new ``JaxprEqn`` for this match.""" in_vars = [self.variables_map[k] for k in self.pattern.graph.jaxpr.invars] in_vars = [env.get(v, v) for v in in_vars] out_vars = [self.variables_map[k] @@ -820,6 +836,8 @@ def read_env( if isinstance(v, jax.core.Literal): # Literals are values baked into the Jaxpr result.append(v.val) + elif isinstance(v, jax.core.DropVar): + result.append(None) else: result.append(env[v]) return result @@ -850,51 +868,81 @@ def to_jaxpr_or_closed_jaxpr(closed_jaxpr: ClosedJaxpr, original: J) -> J: def apply_to_higher_order_primitives(eqn, func, *args, **kwargs): """Applies `func` only to higher order Jax primitives.""" + if eqn.primitive.name not in HIGHER_ORDER_NAMES: return eqn + elif eqn.primitive.name == "cond": params = dict(**eqn.params) params["branches"] = tuple( func(branch, *args, **kwargs) for branch in params["branches"] ) return eqn.replace(params=params) + elif eqn.primitive.name == "while": params = dict(**eqn.params) params["body_jaxpr"] = func(params["body_jaxpr"], *args, **kwargs) return eqn.replace(params=params) - elif eqn.primitive.name == "scan": + + elif eqn.primitive.name in ("scan", "pjit"): params = dict(**eqn.params) params["jaxpr"] = func(params["jaxpr"], *args, **kwargs) return eqn.replace(params=params) + elif eqn.primitive.name in ("xla_call", "xla_pmap"): params = dict(**eqn.params) params["call_jaxpr"] = func(params["call_jaxpr"], *args, **kwargs) return eqn.replace(params=params) + else: raise NotImplementedError() -def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J: +def clean_jaxpr( + jaxpr: J, + preserve_tags: bool = True, + outvar_is_dep: tuple[bool, ...] | None = None, +) -> J: """Runs dead code elimination on a Jaxpr, retaining loss and layer tags.""" + closed_jaxpr = to_closed_jaxpr(jaxpr) eqns = [] - dependants = set(closed_jaxpr.jaxpr.outvars) + + if outvar_is_dep is None: + outvar_is_dep = (True,) * len(closed_jaxpr.jaxpr.outvars) + + dependants = set( + var for var, is_dep in zip(closed_jaxpr.jaxpr.outvars, + outvar_is_dep, strict=True) + if is_dep and not isinstance(var, jax.core.Literal) + ) + for eqn in reversed(closed_jaxpr.jaxpr.eqns): + eqn = apply_to_higher_order_primitives( - eqn, clean_jaxpr, preserve_tags=preserve_tags) + eqn, + functools.partial( + clean_jaxpr, + outvar_is_dep=tuple(var in dependants for var in eqn.outvars)), + preserve_tags=preserve_tags + ) check = False + for v in eqn.outvars: if v in dependants: dependants.remove(v) check = True + if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)): check = check or preserve_tags + if check: eqns.append(eqn) new_dependants = set(v for v in eqn.invars if not isinstance(v, jax.core.Literal)) dependants = dependants.union(new_dependants) + # Dependants should only be invars dependants = dependants - set(closed_jaxpr.jaxpr.invars + closed_jaxpr.jaxpr.constvars) @@ -902,8 +950,44 @@ def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J: if dependants: raise ValueError("Something went wrong with the dead code elimination.") + # Because trying to eliminate output variables from higher order primitives is + # too much work, we instead just replace them with zeros, which is moderately + # hacky. + # TODO(jamesmartens,botev): Consider doing something better here. + final_outvars = [] + for var, is_dep in zip(closed_jaxpr.jaxpr.outvars, outvar_is_dep, + strict=True): + + if not isinstance(var, jax.core.Literal) and not is_dep: + + assert isinstance(var.aval, jax.core.ShapedArray) + + if not var.aval.shape: # scalar case + + val = np.zeros(var.aval.shape, dtype=var.aval.dtype) + zero_literal = jax.core.Literal(val, var.aval) + + final_outvars.append(zero_literal) + + else: + + def dummy_func(): + return jnp.zeros(var.aval.shape, dtype=var.aval.dtype) # pylint: disable=cell-var-from-loop + + dummy_jaxpr = jax.make_jaxpr(dummy_func)() + + assert len(dummy_jaxpr.eqns) == 1 + assert len(dummy_jaxpr.eqns[0].outvars) == 1 + + eqns.append(dummy_jaxpr.eqns[0]) + final_outvars.append(dummy_jaxpr.eqns[0].outvars[0]) + + else: + final_outvars.append(var) + closed_jaxpr = ClosedJaxpr( - jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns))), + jaxpr=closed_jaxpr.jaxpr.replace(eqns=list(reversed(eqns)), + outvars=final_outvars), consts=closed_jaxpr.consts, ) return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr) @@ -911,20 +995,24 @@ def clean_jaxpr(jaxpr: J, preserve_tags: bool = True) -> J: def merge_broadcasts_jaxpr(jaxpr: J) -> J: """Merges consecutive broadcasts in the given Jaxpr.""" - closed_jaxpr = clean_jaxpr(to_closed_jaxpr(jaxpr)) + + closed_jaxpr = to_closed_jaxpr(jaxpr) broadcasts_outputs = {} eqns = list() for eqn in closed_jaxpr.jaxpr.eqns: + eqn = apply_to_higher_order_primitives(eqn, merge_broadcasts_jaxpr) # We ignore broadcasting of constants if (eqn.primitive.name == "broadcast_in_dim" and not all(isinstance(v, jax.core.Literal) for v in eqn.invars)): + if eqn.invars[0] in broadcasts_outputs: # Construct a merged equation from the previous and current one prev_eqn = broadcasts_outputs[eqn.invars[0]] + broadcasts_outputs[eqn.outvars[0]] = prev_eqn.replace( params={ "shape": eqn.params["shape"], @@ -935,16 +1023,21 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J: }, outvars=eqn.outvars, ) + else: broadcasts_outputs[eqn.outvars[0]] = eqn + if eqn.outvars[0] in closed_jaxpr.jaxpr.outvars: # We must preserve output equations eqns.append(broadcasts_outputs[eqn.outvars[0]]) + else: for v in eqn.invars: if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs: eqns.append(broadcasts_outputs[v]) + eqns.append(eqn) + closed_jaxpr = ClosedJaxpr( jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns), consts=closed_jaxpr.consts @@ -1270,34 +1363,52 @@ def base_name(self) -> str: @property def full_name(self) -> str: """The full name of the tag location.""" + prefix = "" param_vars = self.bottom_level_parameters + for eqn, n in reversed(self.parent_equations): + assert eqn.primitive.name in HIGHER_ORDER_NAMES + # Prefix for this higher order primitive prefix = prefix + f"{eqn.primitive.name}_{n}/" + if eqn.primitive.name == "cond": raise NotImplementedError() + elif eqn.primitive.name == "scan": + p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p) for p in param_vars] checks = [pi < eqn.params["num_consts"] for pi in p_indexes] + if not (all(checks) or all(not ci for ci in checks)): raise ValueError("Parameters inside scan of the same tag are not both" " carry or const.") + if all(checks): prefix = prefix + "const/" else: prefix = prefix + "carry/" + + elif eqn.primitive.name == "pjit": + p_indexes = [eqn.params["jaxpr"].jaxpr.invars.index(p) + for p in param_vars] + elif eqn.primitive.name == "while": p_indexes = [eqn.params["body_jaxpr"].jaxpr.invars.index(p) for p in param_vars] + elif eqn.primitive.name in ("xla_call", "xla_pmap"): p_indexes = [eqn.params["call_jaxpr"].invars.index(p) for p in param_vars] + else: raise NotImplementedError() + param_vars = [eqn.invars[pi] for pi in p_indexes] + return prefix + self.base_name @property @@ -1308,21 +1419,31 @@ def bottom_level_parameters(self) -> tuple[Var, ...]: @property def top_level_parameters(self) -> tuple[Var, ...]: """The top level parameter variables of the tag location.""" + param_vars = self.bottom_level_parameters + for eqn, _ in reversed(self.parent_equations): + assert eqn.primitive.name in HIGHER_ORDER_NAMES + if eqn.primitive.name == "cond": raise NotImplementedError() - elif eqn.primitive.name == "scan": + + elif eqn.primitive.name in ("scan", "pjit"): invars = eqn.params["jaxpr"].jaxpr.invars + elif eqn.primitive.name == "while": invars = eqn.params["body_jaxpr"].jaxpr.invars + elif eqn.primitive.name in ("xla_call", "xla_pmap"): invars = eqn.params["call_jaxpr"].invars + else: raise NotImplementedError() + p_indexes = [invars.index(p) for p in param_vars] param_vars = tuple(eqn.invars[pi] for pi in p_indexes) + return param_vars def add_parent_eqn(self, eqn: JaxprEqn, counter: int): @@ -1393,6 +1514,7 @@ def _auto_register_tags( "cond": 0, "while": 0, "scan": 0, + "pjit": 0, "xla_call": 0, "xla_pmap": 0, } @@ -1425,7 +1547,7 @@ def _auto_register_tags( sub_jaxprs = eqn.params["branches"] elif eqn_name == "while": sub_jaxprs = [eqn.params["body_jaxpr"]] - elif eqn_name == "scan": + elif eqn_name in ("scan", "pjit"): sub_jaxprs = [eqn.params["jaxpr"]] elif eqn_name in ("xla_call", "xla_pmap"): sub_jaxprs = [eqn.params["call_jaxpr"]] @@ -1471,7 +1593,7 @@ def _auto_register_tags( eqn_params["branches"] = final_jaxprs elif eqn_name == "while": [eqn_params["body_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking - elif eqn_name == "scan": + elif eqn_name in ("scan", "pjit"): [eqn_params["jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking elif eqn_name in ("xla_call", "xla_pmap"): [eqn_params["call_jaxpr"]] = final_jaxprs # pylint:disable=unbalanced-tuple-unpacking