You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importnumpyasnpimporttensorflow_probabilityastfpimporttensorflowastftime_series_with_nans= [-1.0, 1.0, np.nan, 2.4, np.nan, 5]
observed_time_series=tfp.sts.MaskedTimeSeries(
time_series=time_series_with_nans, is_missing=tf.math.is_nan(time_series_with_nans)
)
# Build model using observed time series to set heuristic priors.linear_trend_model=tfp.sts.LocalLinearTrend(observed_time_series=observed_time_series)
model=tfp.sts.Sum([linear_trend_model], observed_time_series=observed_time_series)
# Fit model to dataparameter_samples, _=tf.function(
func=lambdaots: tfp.sts.fit_with_hmc(model, ots), jit_compile=True, autograph=False
)(observed_time_series)
Using JIT as suggested here on this comment: #1704 (comment) gives me the following error:
I'm trying to run the following code:
Using JIT as suggested here on this comment: #1704 (comment) gives me the following error:
It looks like
tfp.sts.fit_with_hmc
involves creating variables as part of its execution, which raises the question:If yes, since GPU doesn't work for STS, and JAX as well (#1646 (comment)) are there any other alternatives to speed up? fit_with_hmc?
I'm using:
using v2.16 also produces the same error:
The text was updated successfully, but these errors were encountered: