Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Aug 20, 2024
1 parent b956c96 commit 1f5d335
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 17 deletions.
5 changes: 3 additions & 2 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class BudgetOptimizer(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

def objective(self, budgets: list[float]) -> float:
"""Calculate the total response during a period of time given the budgets,
considering the saturation and adstock transformations.
"""Calculate the total response during a period of time given the budgets.
It considers the saturation and adstock transformations.
Parameters
----------
Expand Down
45 changes: 33 additions & 12 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def check_array(X, **kwargs):


class ModelBuilder(ABC):
"""ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
"""Base class for building models with PyMC Marketing.
It provides an easy-to-use API (similar to scikit-learn) for models
and help with deployment.
"""

Expand All @@ -61,7 +63,7 @@ def __init__(
model_config: dict | None = None,
sampler_config: dict | None = None,
):
"""Initializes model configuration and sampler configuration for the model
"""Initialize model configuration and sampler configuration for the model.
Parameters
----------
Expand Down Expand Up @@ -107,7 +109,7 @@ def _data_setter(
X: np.ndarray | pd.DataFrame,
y: np.ndarray | pd.Series | None = None,
) -> None:
"""Sets new data in the model.
"""Set new data in the model.
Parameters
----------
Expand Down Expand Up @@ -147,7 +149,9 @@ def output_var(self) -> str:
@property
@abstractmethod
def default_model_config(self) -> dict:
"""Returns a class default config dict for model builder if no model_config is provided on class initialization
"""Return a class default configuration dictionary.
For model builder if no model_config is provided on class initialization
Useful for understanding structure of required model_config to allow its customization by users
Examples
Expand Down Expand Up @@ -176,7 +180,9 @@ def default_model_config(self) -> dict:
@property
@abstractmethod
def default_sampler_config(self) -> dict:
"""Returns a class default sampler dict for model builder if no sampler_config is provided on class initialization
"""Return a class default sampler configuration dictionary.
For model builder if no sampler_config is provided on class initialization
Useful for understanding structure of required sampler_config to allow its customization by users
Examples
Expand All @@ -201,7 +207,8 @@ def default_sampler_config(self) -> dict:
def _generate_and_preprocess_model_data(
self, X: pd.DataFrame | pd.Series, y: np.ndarray
) -> None:
"""Applies preprocessing to the data before fitting the model.
"""Apply preprocessing to the data before fitting the model.
if validate is True, it will check if the data is valid for the model.
sets self.model_coords based on provided dataset
Expand Down Expand Up @@ -236,8 +243,9 @@ def build_model(
y: pd.Series | np.ndarray,
**kwargs,
) -> None:
"""Creates an instance of pm.Model based on provided data and model_config, and
attaches it to self.
"""Create an instance of `pm.Model` based on provided data and model_config.
It attaches the model to self.model.
Parameters
----------
Expand Down Expand Up @@ -265,6 +273,14 @@ def build_model(
"""

def create_idata_attrs(self) -> dict[str, str]:
"""Create attributes for the inference data.
Returns
-------
dict[str, str]
A dictionary of attributes for the inference data.
"""

def default(x):
if isinstance(x, Prior):
return x.to_json()
Expand Down Expand Up @@ -390,7 +406,9 @@ def save(self, fname: str) -> None:

@classmethod
def _model_config_formatting(cls, model_config: dict) -> dict:
"""Because of json serialization, model_config values that were originally tuples
"""Format the model configuration.
Because of json serialization, model_config values that were originally tuples
or numpy are being encoded as lists. This function converts them back to tuples
and numpy arrays to ensure correct id encoding.
"""
Expand Down Expand Up @@ -421,7 +439,7 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:

@classmethod
def load(cls, fname: str):
"""Creates a ModelBuilder instance from a file,
"""Create a ModelBuilder instance from a file.
Loads inference data for the model.
Expand Down Expand Up @@ -486,6 +504,7 @@ def fit(
**kwargs: Any,
) -> az.InferenceData:
"""Fit a model using the data passed as a parameter.
Sets attrs to inference data of the model.
Parameters
Expand Down Expand Up @@ -569,8 +588,9 @@ def predict(
extend_idata: bool = True,
**kwargs,
) -> np.ndarray:
"""Uses model to predict on unseen data and return point prediction of all the samples. The point prediction
for each input row is the expected output value, computed as the mean of MCMC samples.
"""Use a model to predict on unseen data and return point prediction of all the samples.
The point prediction for each input row is the expected output value, computed as the mean of MCMC samples.
Parameters
----------
Expand Down Expand Up @@ -726,6 +746,7 @@ def set_params(self, **params):
@abstractmethod
def _serializable_model_config(self) -> dict[str, int | float | dict]:
"""Converts non-serializable values from model_config to their serializable reversable equivalent.
Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
so in order to save the model they need to be formatted.
Expand Down
18 changes: 17 additions & 1 deletion pymc_marketing/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Version of the package."""

# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,7 +31,8 @@
here = os.path.dirname(os.path.realpath(__file__))


def read_version():
def read_version() -> str:
"""Read the version from the version file."""
version_file = os.path.join(here, "version.txt")
with open(version_file, encoding="utf-8") as buff:
return buff.read().splitlines()[0]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ ignore = [
"D",
"S101", # Use of assert
]
"scripts/*" = ["D"]

[tool.ruff.lint.pycodestyle]
max-line-length = 120
Expand Down
15 changes: 15 additions & 0 deletions streamlit/mmm-explainer/Visualise_Priors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit page for visualising priors."""

# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
15 changes: 15 additions & 0 deletions streamlit/mmm-explainer/pages/Adstock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit page for adstock transformations."""

# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
15 changes: 15 additions & 0 deletions streamlit/mmm-explainer/pages/Saturation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit page for saturation curves."""

# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
21 changes: 19 additions & 2 deletions streamlit/mmm-explainer/prior_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for plotting prior distributions."""

# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -51,8 +66,10 @@ def get_distribution(distribution_name=pz.distributions, **params):
def plot_prior_distribution(
draws, nbins=100, opacity=0.1, title="Prior Distribution - Visualised"
):
"""Plots samples of a prior distribution as a histogram with a KDE (Kernel Density Estimate) overlay
and a violin plot along the top too with quartile values.
"""Plot samples of a prior distribution as a histogram.
It uses a KDE (Kernel Density Estimate) overlay and a violin plot along the top too
with quartile values.
Parameters
----------
Expand Down

0 comments on commit 1f5d335

Please sign in to comment.