A general caching solution that would work for the adjoint (I think) #2946
Replies: 2 comments 1 reply
-
A counter point to this is that, for memory-critical regimes, any sort of cache would exceed the available memory of the machine. I think that having a configurable cache size may fix this. I'll have a stab at implementing this caching solution for |
Beta Was this translation helpful? Give feedback.
-
I've worked on this problem in a few forms. It's a serious issue for small/idealized ocean calculations consisting of a long sequence of cheap solves. A few things that you might encounter:
|
Beta Was this translation helpful? Give feedback.
-
Firedrake adjoint, topical because of @dham's series of lectures on it, is plagued by cache related performance issues. This is because a lot of the assumptions we use to optimise the forward application do not hold when we start going backwards in time. In particular I think that we run into issues with the SSA duplication of variables. Data carrying objects are no longer persistent through the program and are therefore not suitable candidates for attaching caches.
I have had an idea for how we might be able to resolve these caching issues in a generic way.
Caching in Firedrake
There are effectively 3 types of object in Firedrake that we want to cache:
NonlinearVariationalSolver
,Assembler
,Interpolator
,Assigner
,Projector
.constant_jacobian
).Type 1 is the most straightforward. Given sufficient symbolic information we can establish an appropriate cache key and cache the results. Since the objects do not reference any "heavy" data structures a global cache is suitable as little memory is leaked.
The chief difficulty in the obtaining performance in the adjoint is with objects of the second type. These objects are expensive to create and hold references to "heavy" data structures like
Function
s and the mesh. This naturally causes a conflict between wanting to reuse them wherever possible and trying to avoid leaking memory.The method we mainly use to reuse such objects is to either tell the user to instantiate them directly (e.g. we do this with
NonlinearVariationalSolver
), or cache things on the form (e.g.assemble
). The former case is intrusive for the user and requires more operations to be taped for the adjoint. The latter is also less than ideal:assemble(f*v*dx)
instead ofform = f*v*dx; ...; assemble(form)
. This means that the form only exists for the duration of the function call. Any objects cached on it will never get reused.My suggestion
Since no object with the correct lifetime exists (for both the forward and adjoint problems) on which to cache solvers/assemblers/etc, I believe that they need to be stored in a bounded (i.e LRU) global cache (this could be done in a parallel safe way). The caching problem then becomes a hashing one. When do we know that it would be valid to reuse solvers, assemblers, etc?
I think that the key observation here is that we can safely reuse these objects if (a) the symbolic expression is the same as before, and (b) the same data/
Dat
s/buffers/pointers are getting used again. Crucially, the form/expression object itself is allowed to be different. In other words, given formsform0
andform1
, the solver/assembler/etc can be reused ifform0.signature() == form1.signature()
andextract_coefficients(form0) == extract_coefficients(form1)
but notform0 is form1
.For the adjoint, I think that given this caching approach we could ensure solver/assembler reuse for recompute/tlm/adjoint/hessian by making sure that we are reusing
Function
s as we run through the tape. We wouldn't even need to instantiate any solvers or assemblers. Calls tosolve
andassemble
would be sufficient.Hashing
I think we can get the right cache key for these UFL expressions very simply. Currently, computing a form signature renumbers all of the coefficients. This means that different, symbolically equivalent,
Coefficient
s (which would have different numbers) produce the same form signature. I propose relaxing this requirement so one can compute the form signature without applying a renumbering (i.e.form.signature(renumber=False)
. This would allow us to differentiate between forms with the same symbolic meaning but different attached terminals.Aside
Another idea I have had is that we could remove the need to pass
constant_jacobian
when creating solvers. Instead we can determine if we need to reassemble the Jacobian by inspecting thedat_version
of the various data-carrying terminals.Beta Was this translation helpful? Give feedback.
All reactions