Benchmarking and Optimizing Aesara Code #1174
Replies: 4 comments 16 replies
-
Out of curiosity does the Just do Regarding PyMC, keep in mind that the gradient probably dominates the sampling time, so it can be easily be worse than what you would expect from examining the logp performance alone. I read somewhere that the Scan grad is 4-6x slower than the value evaluation, but that obviously depends on the graph itself. |
Beta Was this translation helpful? Give feedback.
-
Printing the graphs is often good for relative comparisons (e.g. the scalar-case graph vs. the matrix-case graph). Sometimes one can see which optimizations were performed in one case but not the other, and, if those optimizations are meaningful, they could explain any large relative differences in performance. |
Beta Was this translation helpful? Give feedback.
-
Don't forget that a |
Beta Was this translation helpful? Give feedback.
-
Based on the discussion and the profiling so far, it seems the main bottleneck is in linear algebra operators. There is another implementation of the Kalman Filter that avoids computing determinants or inverting any matrices by introducing a second scan over the states of the system. Introducing this second scan reduces all the matrix-matrix operations to at most vector-vector, with no matrix inversions at all. Combined with the observed speed of the scalar implementation, especially when sampling in PyMC, I thought this might be a fruitful way to go. Here is the so-called "Univariate Kalman Filter", which has a scan-in-a-scan. It ends up being significantly slower than the linear algebra operations, which surprised me, especially since the inner scan is only over a single state (so it should only be a single operation). Univariate Kalman Filter with a scan inside a scan to avoid `at.linalg.solve` and `at.linalg.det`
def univariate_inner_filter_step(y, Z_row, sigma_H, a, P):
v = y - Z_row.dot(a)
PZT = P.dot(Z_row.T)
F = Z_row.dot(PZT) + sigma_H
K = PZT / F
a_filtered = a + K * v
P_filtered = P - at.outer(K, K) * F
ll_inner = at.log(F) + v ** 2 / F
return a_filtered, P_filtered, ll_inner
def matrix_predict2(a_filtered, P_filtered, T, R, Q):
a_predicted = T.dot(a_filtered)
P_predicted = matrix_dot(T, P_filtered, T.T) + matrix_dot(R, Q, R.T)
# Force P_predicted to be symmetric
P_predicted = 0.5 * (P_predicted + P_predicted.T)
return a_predicted, P_predicted
def univariate_kalman_step(y, a, P, T, Z, R, H, Q):
y = y[:, None]
result, updates = aesara.scan(univariate_inner_filter_step,
sequences=[y, Z, at.diag(H)],
outputs_info=[a, P, None],
name='scan_over_states',
profile=True)
a_filtered, P_filtered, ll_inner = result
a_filtered, P_filtered = a_filtered[-1], P_filtered[-1]
a_predicted, P_predicted = matrix_predict2(a_filtered, P_filtered, T=T, R=R, Q=Q)
ll = -0.5 * ((at.neq(ll_inner, 0).sum()) * N_CONST + ll_inner.sum())
return a_filtered, a_predicted, P_filtered, P_predicted, ll
filter_result, updates = aesara.scan(univariate_kalman_step,
sequences=[a_data],
outputs_info=[None, a_x0, None, a_P0, None],
non_sequences=[a_T, a_Z, a_R, a_H, a_Q],
name='univariate_filter',
profile=True)
Compiling the inner scan by itself and timing it shows that it's a very fast function, just a few microseconds (makes sense since it has only 1 step). When I profile the whole function together, though, the inner scan ends up dominating the execution time: Outer scan profile
Inner scan profile
I don't understand the discrepancy between the execution times in the two profiles. The outer scan reports that the inner scan used 6s to execute, but the profile of the inner scan says that it executed for only 2.4s. Sampling time for this set-up in PyMC is worst of all, so I guess @ricardoV94 was right that gradients of |
Beta Was this translation helpful? Give feedback.
-
Hi everyone,
I've been working on implementing a Kalman Filter in Aesara, with the objective of computing the log-likelihood of a class of time series models. The filter is pretty simple, it just recursively predicts new data using model dynamics, computes the error between the prediction and a single time-step, and then combines the prediction and observation into a optimal fused state.
In the most general form, the filter is a bunch of matrix equations, but a special case where everything is scalar can also be implemented. Here it this special case in Aesara:
Code block for the scalar kalman filter
On 100 data-points with fixed parameters, this scalar function runs extremely fast:
511 µs ± 4.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each).
Plugging the log-likelihood into PyMC, things also sample very fast (22 seconds). Here is the function profile, as a basis for comparison:Aesara profile output for the scalar kalman filter
Question 1: Checking the profile for
kalman_filter
, 98% of the execution time is spent in the scan. That makes perfect sense, since there isn't much else here. The scan is of type Py rather than C, is that correct? Is there anything else interesting in this profile output that should catch my attention? So far, everything seems fine.What I really want, though, is to implement the filter with matrices, so that it can generalize to more complex time series models. Here is the code for a Kalman filter that uses matrices:
Matrix Kalman filter code
Implementing the same scalar model using a set of 1x1 matrices gives the following time:
11.3 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
, or a 22x slowdown. Passing the likelihood to PyMC, sampling goes from 22 seconds (scalar model) to 4 minutes (matrix model).Once again, I look at the profile for
matrix_kalman_filter
, which is dominated by the scan Op, with a bit ofDot22
thrown in.Matrix Kalman Filter function profile
Question 2: How should I interpret these profile times in the presence of a single dominating apply node (the scan itself)? Can I really interpret the entirety of the speed differential to the calls to Dot22, even though this accounts for only 1.1% of the time spent in execution? (Obviously there are more differences between the two implementations to be addressed below; I am trying to ask a concrete question about this single output).
Now, these are not apples-to-apples comparisons, because aside from switching from multiplication and division to dot-products and matrix inversion, there are several additional operations introduced for numerical stability. The scalar filter has 32 apply nodes, while the matrix filter has 37, so obviously I expect the code to be somewhat slower. How much slower, and whether I am leaving any optimizations on the table in the matrix form of the filter, are open questions.
To get a sense of how much slowdown I should expect when moving from scalar math to linear algebra, I ran everything in pure python, and saw a 25x slowdown, so Aesara is actually doing a bit better here. To get a sense of what optimizations are possible, I put
njit
decorators on all the functions and re-ran the speed tests. This brought times down to 9.33 µs ± 500 ns and 733 µs ± 3.13 µs for the scalar and matrix versions, respectively. So the linear algebra slows down numba the most, by 78x, but in absolute time it's still the fastest implementation.Here a little table summarizing results:
Back to Aesara, @brandonwillard told me that all efforts start with
aesara.dprint
, so I checked the graph for the log-likelihood for the matrix function, but nothing jumps out at me as being "off"; there's no obvious duplicate computation or bugs happening.dprint output for matrix kalman filter log-likelihood term
So what have I learned? Linear algebra is slower than scalar multiplication. Thanks for coming to my TED talk. In all seriousness, what can be done to speed up the matrix formulation of the filter? From looking at the profile output and the
aesara.dprint
output, I don't see any obvious paths forward. The final goal is to be able to sample using this log-likelihood function, so speed is absolutely critical. With the simplest possible model already already taking 5 minutes, it's quite dishartening. For comparison, an ARMA(1,1) model takes over 15 minutes, and this is comprised only of 2x2 matrices. Scaling gets worse from there.All the code used to make this post is available in this gist if anyone is curious. I am really hoping that I am missing something obvious. At the very least, helpful pointers about how to read the profile summary and dprint outputs will be appreciated.
Beta Was this translation helpful? Give feedback.
All reactions