Skip to content

Graph caching & tf.function

Jonas Eschle edited this page Jan 4, 2020 · 1 revision

tf.function

tf.function caches a graph internally depending on the signature. For Python objects, this takes the hash and assumes they are immutable. Some objects, such as PDFs, are mutable though (e.g. set_*_range) and can invalidated a graph as the changes of the objects are not detected in the hash of the signature.

To force a retrace, a new tf.function(py_func_here) has to be made.

zfit decorator z.function

An extra decorator, the FunctionCacheRegistry is created. It stores the signature and created functions in a FunctionHolder object, which keeps track of any changes in the signature objects, e.g. if they invalidated caches. The FunctionHolder uses the Python function as a hash (currently) and does equal comparison to check whether the signatures are equal.

If a change occurred that invalidates the graph, the FuncionHolder has an is_valid attribute, which renders to False, signaling that a retrace has to be forced.

Deadlock in tracing

currently (https://github.com/tensorflow/tensorflow/issues/35540) tracing recursively of a tf.function causes a deadlock. As a simple solution, every function that is being traced currently, is made sure to not be traced double. If it is tracing foo and foo is supposed to be traced again, the normal python function is executed instead.

This is a slight suboptimal solution, in principle a new tf.function could be created to force an independent retracing. Handling the caching of the graph becomes more difficult though since different tf.function wrapped functions hold different signatures. The effect of this (as recursive calls are rare anyway) is small.