-
Notifications
You must be signed in to change notification settings - Fork 2
Graph caching & tf.function
tf.function
caches a graph internally depending on the signature. For Python objects, this takes the hash and assumes they are immutable. Some objects, such as PDFs, are mutable though (e.g. set_*_range
) and can invalidated a graph as the changes of the objects are not detected in the hash of the signature.
To force a retrace, a new tf.function(py_func_here)
has to be made.
An extra decorator, the FunctionCacheRegistry
is created. It stores the signature and created functions in a FunctionHolder
object, which keeps track of any changes in the signature objects, e.g. if they invalidated caches. The FunctionHolder
uses the Python function as a hash (currently) and does equal comparison to check whether the signatures are equal.
If a change occurred that invalidates the graph, the FuncionHolder
has an is_valid
attribute, which renders to False, signaling that a retrace has to be forced.
currently (https://github.com/tensorflow/tensorflow/issues/35540) tracing recursively of a tf.function causes a deadlock. As a simple solution, every function that is being traced currently, is made sure to not be traced double. If it is tracing foo
and foo
is supposed to be traced again, the normal python function is executed instead.
This is a slight suboptimal solution, in principle a new tf.function
could be created to force an independent retracing. Handling the caching of the graph becomes more difficult though since different tf.function
wrapped functions hold different signatures. The effect of this (as recursive calls are rare anyway) is small.