diff --git a/pymc_marketing/clv/models/pareto_nbd.py b/pymc_marketing/clv/models/pareto_nbd.py index 01c31616..4c4474ff 100644 --- a/pymc_marketing/clv/models/pareto_nbd.py +++ b/pymc_marketing/clv/models/pareto_nbd.py @@ -633,7 +633,7 @@ def _population_distributions( # This is the shape if using fit_method="map" if self.fit_result.dims == {"chain": 1, "draw": 1}: - shape_kwargs = {"shape": 10000} + shape_kwargs = {"shape": 4000} else: shape_kwargs = {} diff --git a/tests/clv/models/test_pareto_nbd.py b/tests/clv/models/test_pareto_nbd.py index 7944e802..63b9fbef 100644 --- a/tests/clv/models/test_pareto_nbd.py +++ b/tests/clv/models/test_pareto_nbd.py @@ -361,7 +361,7 @@ def test_dropout_purchase_distributions(self) -> None: assert isinstance(customer_dropout, xarray.DataArray) assert isinstance(customer_purchase_rate, xarray.DataArray) - N = 10000 + N = 4000 lam = pm.Gamma.dist(self.r_true, self.alpha_true, size=N) mu = pm.Gamma.dist(self.s_true, self.beta_true, size=N) rtol = 0.05