Skip to content

Commit

Permalink
Various MMM small documentation fixes (#854)
Browse files Browse the repository at this point in the history
* make nb ruff astral-sh/ruff-vscode#546

* fix docs

* fix quickstart

* add types
  • Loading branch information
juanitorduz authored Jul 22, 2024
1 parent a98815f commit e6f844f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ repos:
- --exclude=docs/
- --exclude=scripts/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.3
rev: v0.5.4
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
types_or: [python, pyi, jupyter]
args: ["--fix", "--output-format=full"]
exclude: ^docs/source/notebooks/clv/dev/
- id: ruff-format
types_or: [ python, pyi, jupyter ]
types_or: [python, pyi, jupyter]
exclude: ^docs/source/notebooks/clv/dev/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
Expand Down
16 changes: 11 additions & 5 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,30 +111,36 @@ Start VS Code and ensure that the "Jupyter" extension is installed. Press Ctrl +

```python
import pandas as pd
from pymc_marketing.mmm import DelayedSaturatedMMM

from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation,
MMM,
)


data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
data = pd.read_csv(data_url, parse_dates=['date_week'])

mmm = DelayedSaturatedMMM(
mmm = MMM(
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
date_column="date_week",
channel_columns=["x1", "x2"],
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
)
```

Initiate fitting and get a visualization of some of the outputs with:

```python
X = data.drop('y',axis=1)
y = data['y']
X = data.drop("y",axis=1)
y = data["y"]
mmm.fit(X,y)
mmm.plot_components_contributions();
```
Expand Down
24 changes: 13 additions & 11 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,20 +1683,22 @@ def sample_posterior_predictive(
Sample from the model's posterior predictive distribution.
Parameters
---------
----------
X_pred : array, shape (n_pred, n_features)
The input data used for prediction.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
extend_idata : bool, optional
Boolean determining whether the predictions should be added to inference data object. Defaults to True.
combined: bool, optional
Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True.
include_last_observations: bool, optional
Boolean determining whether to include the last observations of the training data in order to carry over
costs with the adstock transformation. Assumes that X_pred are the next predictions following the
training data.Defaults to False.
original_scale: bool, optional
Boolean determining whether to return the predictions in the original scale of the target variable.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
include_last_observations: Boolean determining whether to include the last observations of the training
data in order to carry over costs with the adstock transformation.
Assumes that X_pred are the next predictions following the training data.
Defaults to False.
original_scale: Boolean determining whether to return the predictions in the original scale
of the target variable. Defaults to True.
**sample_posterior_predictive_kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
**sample_posterior_predictive_kwargs
Additional arguments to pass to pymc.sample_posterior_predictive
Returns
-------
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ ignore = [
[tool.ruff.lint.pycodestyle]
max-line-length = 120

[tool.ruff]
extend-include = ["*.ipynb"]

[tool.pytest.ini_options]
addopts = [
"-v",
Expand Down

0 comments on commit e6f844f

Please sign in to comment.