Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 482487281
Change-Id: I1a278e72d5de8d364454a8918374e3e238d32546
  • Loading branch information
michevan authored and copybara-github committed Oct 20, 2022
1 parent a34b827 commit b6ff3b3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
5 changes: 3 additions & 2 deletions lightweight_mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,8 +1287,9 @@ def plot_prior_and_posterior(

if feature in seasonal_features:
for i_season in range(media_mix_model._degrees_seasonality):
for j_season in range(media_mix_model._degrees_seasonality):
subplot_title = f"{feature}, seasonal mode {i_season}:{j_season}"
for j_season in range(2):
sin_or_cos = "sin" if j_season == 0 else "cos"
subplot_title = f"{feature}, seasonal mode {i_season}:{sin_or_cos}"
posterior_samples = np.array(media_mix_model.trace[feature][:,
i_season,
j_season])
Expand Down
13 changes: 7 additions & 6 deletions lightweight_mmm/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,32 +494,32 @@ def test_prior_posterior_plot_clipping_bounds_for_kdeplots(
testcase_name="carryover_national_model",
model_name="carryover",
is_geo_model=False,
expected_number_of_subplots=29),
expected_number_of_subplots=31),
dict(
testcase_name="carryover_geo_model",
model_name="carryover",
is_geo_model=True,
expected_number_of_subplots=45),
expected_number_of_subplots=47),
dict(
testcase_name="adstock_national_model",
model_name="adstock",
is_geo_model=False,
expected_number_of_subplots=24),
expected_number_of_subplots=26),
dict(
testcase_name="adstock_geo_model",
model_name="adstock",
is_geo_model=True,
expected_number_of_subplots=40),
expected_number_of_subplots=42),
dict(
testcase_name="hill_adstock_national_model",
model_name="hill_adstock",
is_geo_model=False,
expected_number_of_subplots=29),
expected_number_of_subplots=31),
dict(
testcase_name="hill_adstock_geo_model",
model_name="hill_adstock",
is_geo_model=True,
expected_number_of_subplots=45),
expected_number_of_subplots=47),
])
def test_prior_posterior_plot_makes_correct_number_of_subplots(
self, model_name, is_geo_model, expected_number_of_subplots):
Expand All @@ -538,6 +538,7 @@ def test_prior_posterior_plot_makes_correct_number_of_subplots(
target=target,
media_prior=jnp.ones(5) * 50,
extra_features=extra_features,
degrees_seasonality=3,
number_warmup=2,
number_samples=2,
number_chains=1)
Expand Down

0 comments on commit b6ff3b3

Please sign in to comment.