Skip to content

Commit

Permalink
Assign coords in ParetoNBDModel
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 15, 2024
1 parent 72238c5 commit 10c8dc2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 4 additions & 0 deletions pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,13 @@ def _extract_predictive_variables(
must_be_unique=["customer_id"],
)

customer_id = data["customer_id"]
model_coords = self.model.coords # type: ignore
if self.purchase_covariate_cols:
purchase_xarray = xarray.DataArray(
data[self.purchase_covariate_cols],
dims=["customer_id", "purchase_covariate"],
coords=[customer_id, list(model_coords["purchase_covariate"])],
)
alpha_scale = self.fit_result["alpha_scale"]
purchase_coefficient = self.fit_result["purchase_coefficient"]
Expand All @@ -404,6 +407,7 @@ def _extract_predictive_variables(
dropout_xarray = xarray.DataArray(
data[self.dropout_covariate_cols],
dims=["customer_id", "dropout_covariate"],
coords=[customer_id, list(model_coords["dropout_covariate"])],
)
beta_scale = self.fit_result["beta_scale"]
dropout_coefficient = self.fit_result["dropout_coefficient"]
Expand Down
12 changes: 10 additions & 2 deletions tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,20 @@ def test_extract_predictive_covariates(self):
new_data = self.data.assign(
purchase_cov1=1.0,
dropout_cov=1.0,
customer_id=self.data["customer_id"] + 1,
)
different_vars = model._extract_predictive_variables(data=new_data)
different_alpha = different_vars["alpha"]
different_beta = different_vars["beta"]

different_alpha = different_vars["alpha"]
assert np.all(
different_alpha.customer_id.values == alpha_model.customer_id.values + 1
)
assert not np.allclose(alpha_model, different_alpha)

different_beta = different_vars["beta"]
assert np.all(
different_beta.customer_id.values == beta_model.customer_id.values + 1
)
assert not np.allclose(beta_model, different_beta)

def test_logp(self):
Expand Down

0 comments on commit 10c8dc2

Please sign in to comment.