Skip to content

Commit

Permalink
Add BetaGeoBetaBinomModel (#1031)
Browse files Browse the repository at this point in the history
* beta_geo_beta_binom.py copy from prev pr

* test_beta_geo_beta_binom.py copy from prev PR

* beta_geo_beta_binom imports

* basic.py validate homogeneous T

* copy _logp fix from prev PR

* notebook and WIP _distribution_new_customers

* test_distribution_new_customers

* TODOs and test coverage

* docstrings

* docstrings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ColtAllen and pre-commit-ci[bot] authored Sep 13, 2024
1 parent 561e5c3 commit 660686b
Show file tree
Hide file tree
Showing 8 changed files with 7,310 additions and 446 deletions.
6,498 changes: 6,072 additions & 426 deletions docs/source/notebooks/clv/dev/beta_geo_beta_binom.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pymc_marketing/clv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""CLV models and utilities."""

from pymc_marketing.clv.models import (
BetaGeoBetaBinomModel,
BetaGeoModel,
GammaGammaModel,
GammaGammaModelIndividual,
Expand All @@ -34,6 +35,7 @@

__all__ = (
"BetaGeoModel",
"BetaGeoBetaBinomModel",
"ParetoNBDModel",
"GammaGammaModel",
"GammaGammaModelIndividual",
Expand Down
8 changes: 1 addition & 7 deletions pymc_marketing/clv/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,23 +601,17 @@ def logp(value, alpha, beta, gamma, delta, T):
"""Log-likelihood of the distribution."""
t_x = pt.atleast_1d(value[..., 0])
x = pt.atleast_1d(value[..., 1])
scalar_case = t_x.type.broadcastable == (True,)

for param in (t_x, x, alpha, beta, gamma, delta, T):
if param.type.ndim > 1:
raise NotImplementedError(
f"BetaGeoBetaBinom logp only implemented for vector parameters, got ndim={param.type.ndim}"
)
if scalar_case:
if param.type.broadcastable == (False,):
raise NotImplementedError(
f"Parameter {param} cannot be larger than scalar value"
)

# Broadcast all the parameters so they are sequences.
# Potentially inefficient, but otherwise ugly logic needed to unpack arguments in the scan function,
# since sequences always precede non-sequences.
_, alpha, beta, gamma, delta, T = pt.broadcast_arrays(
t_x, alpha, beta, gamma, delta, T = pt.broadcast_arrays(
t_x, alpha, beta, gamma, delta, T
)

Expand Down
2 changes: 2 additions & 0 deletions pymc_marketing/clv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from pymc_marketing.clv.models.basic import CLVModel
from pymc_marketing.clv.models.beta_geo import BetaGeoModel
from pymc_marketing.clv.models.beta_geo_beta_binom import BetaGeoBetaBinomModel
from pymc_marketing.clv.models.gamma_gamma import (
GammaGammaModel,
GammaGammaModelIndividual,
Expand All @@ -25,6 +26,7 @@

__all__ = (
"CLVModel",
"BetaGeoBetaBinomModel",
"GammaGammaModel",
"GammaGammaModelIndividual",
"BetaGeoModel",
Expand Down
6 changes: 6 additions & 0 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _validate_cols(
data: pd.DataFrame,
required_cols: Sequence[str],
must_be_unique: Sequence[str] = (),
must_be_homogenous: Sequence[str] = (),
):
existing_columns = set(data.columns)
n = data.shape[0]
Expand All @@ -71,6 +72,11 @@ def _validate_cols(
if required_col in must_be_unique:
if data[required_col].nunique() != n:
raise ValueError(f"Column {required_col} has duplicate entries")
if required_col in must_be_homogenous:
if data[required_col].nunique() != 1:
raise ValueError(
f"Column {required_col} has non-homogeneous entries"
)

def __repr__(self) -> str:
"""Representation of the model."""
Expand Down
Loading

0 comments on commit 660686b

Please sign in to comment.