From f643f72a440b13d35eab789edc19b2ef01b65274 Mon Sep 17 00:00:00 2001 From: Sam Bailey Date: Thu, 27 Apr 2023 10:19:59 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 527614892 Change-Id: I8ab428bfb6b670b09c2aa5d06600f2fdcaedf964 --- lightweight_mmm/__init__.py | 2 +- lightweight_mmm/models.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lightweight_mmm/__init__.py b/lightweight_mmm/__init__.py index 0421dcc..98ff0e2 100644 --- a/lightweight_mmm/__init__.py +++ b/lightweight_mmm/__init__.py @@ -17,4 +17,4 @@ Detailed documentation and examples can be found in the [Github repository](https://github.com/google/lightweight_mmm). """ -__version__ = "0.1.7.1" +__version__ = "0.1.8" diff --git a/lightweight_mmm/models.py b/lightweight_mmm/models.py index 443c099..5135470 100644 --- a/lightweight_mmm/models.py +++ b/lightweight_mmm/models.py @@ -353,8 +353,12 @@ def media_mix_model( name="geo_media_plate", size=n_geos, dim=-1): + # Corrects the mean to be the same as in the channel only case. + normalisation_factor = jnp.sqrt(2.0 / jnp.pi) coef_media = numpyro.sample( - name="coef_media", fn=dist.HalfNormal(scale=coef_media)) + name="coef_media", + fn=dist.HalfNormal(scale=coef_media * normalisation_factor) + ) with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2): with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",