Skip to content

Commit

Permalink
some base logic and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jun 13, 2024
1 parent a148e83 commit 3f78436
Show file tree
Hide file tree
Showing 2 changed files with 520 additions and 0 deletions.
307 changes: 307 additions & 0 deletions pymc_marketing/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model configuration utilities.
Example model configuration for model with alpha and beta parameters
```python
model_config = {
"alpha": {
"dist": "Normal",
"kwargs": {
"mu": 0,
"sigma": 1,
},
},
"beta": {
"dist": "Normal",
"kwargs": {
"mu": 0,
"sigma": 1,
},
"dims": "geo",
},
}
```
"""

from collections.abc import Callable
from typing import Any

import numpy as np
import pandas as pd
import pymc as pm
from pymc.distributions.shape_utils import Dims
from pytensor import tensor as pt


class UnsupportedShapeError(Exception):
"""Error for when the shape of the hierarchical variable is not supported."""


def handle_1d(var: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorLike:
return var


def handle_2d(var: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorLike:
if dims == desired_dims:
return var

if dims[::-1] == desired_dims:
return var.T

if dims[0] == desired_dims[-2]:
return var[:, None]

return var


HANDLE_MAPPING = {1: handle_1d, 2: handle_2d}

DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]


def create_dim_handler(desired_dims: Dims) -> DimHandler:
desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)
ndims = len(desired_dims)
if ndims > 2:
raise UnsupportedShapeError(
"At most two dims can be specified. Raise an issue if support for more dims is needed."
)

def handle_shape(
var: pt.TensorLike,
dims: Dims,
) -> pt.TensorLike:
"""Handle the shape for a hierarchical parameter."""
dims = desired_dims if dims is None else dims
dims = dims if isinstance(dims, tuple) else (dims,)

if not set(dims).issubset(set(desired_dims)):
raise UnsupportedShapeError("The dims of the variable are not supported.")

handle = HANDLE_MAPPING[ndims]
return handle(var, dims, desired_dims)

return handle_shape


def get_distribution(name: str) -> type[pm.Distribution]:
"""Retrieve a PyMC distribution class by name.
Parameters
----------
name : str
Name of a PyMC distribution.
Returns
-------
pm.Distribution
A PyMC distribution class that can be used to instantiate a random
variable.
Raises
------
ValueError
If the specified distribution name does not correspond to any
distribution in PyMC.
"""
if not hasattr(pm, name):
raise ValueError(f"Distribution {name} does not exist in PyMC.")

return getattr(pm, name)


def handle_nested_distribution(
name: str,
param: str,
parameter_config: dict[str, Any],
dim_handler: DimHandler,
):
param_name = f"{name}_{param}"
kwargs = {
key: value
for key, value in parameter_config.items()
if key not in ["dist", "kwargs"]
}
var = create_distribution(
param_name,
parameter_config["dist"],
parameter_config["kwargs"],
**kwargs,
)
return dim_handler(var, parameter_config.get("dims"))


class ModelConfigError(Exception):
def __init__(self, param: str) -> None:
self.param = param
self.message = (

Check warning on line 151 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L150-L151

Added lines #L150 - L151 were not covered by tests
f"Invalid parameter configuration for '{param}'."
" It must be either a dictionary with 'dist' and 'kwargs' keys or a numeric value."
)

super().__init__(self.message)

Check warning on line 156 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L156

Added line #L156 was not covered by tests


def handle_parameter_configurations(
name: str,
param: str,
parameter_config: dict[str, Any],
dim_handler: DimHandler,
) -> Any:
is_nested_distribution = (
isinstance(parameter_config, dict)
and "dist" in parameter_config
and "kwargs" in parameter_config
)
if is_nested_distribution:
return handle_nested_distribution(
name,
param,
parameter_config,
dim_handler=dim_handler,
)

if isinstance(parameter_config, int | float | np.ndarray | pt.TensorVariable):
return parameter_config

raise ModelConfigError(param)

Check warning on line 181 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L181

Added line #L181 was not covered by tests


def handle_parameter_distributions(
name: str,
param_distributions: dict[str, dict[str, Any]],
dim_handler: DimHandler,
) -> dict[str, Any]:
return {
param: handle_parameter_configurations(
name,
param,
parameter_config,
dim_handler=dim_handler,
)
for param, parameter_config in param_distributions.items()
}


def create_distribution(
name: str,
distribution_name: str,
distribution_kwargs: dict[str, Any],
**kwargs,
) -> pt.TensorVariable:
dim_handler = create_dim_handler(kwargs.get("dims"))
parameter_distributions = handle_parameter_distributions(
name, distribution_kwargs, dim_handler=dim_handler
)
distribution = get_distribution(name=distribution_name)
return distribution(name, **parameter_distributions, **kwargs)


def create_distribution_from_config(name: str, config) -> pt.TensorVariable:
parameter_config = config[name]
return create_distribution(
name,
parameter_config["dist"],
parameter_config["kwargs"],
dims=parameter_config.get("dims"),
)


LIKELIHOOD_DISTRIBUTIONS: set[str] = {
"Normal",
"Beta",
"StudentT",
"Laplace",
"Logistic",
"LogNormal",
"Wald",
"TruncatedNormal",
"Gamma",
"AsymmetricLaplace",
"VonMises",
}


class UnsupportedDistributionError(Exception):
pass


def create_likelihood_distribution(
name: str,
param_config: dict,
mu: pt.TensorVariable,
observed: np.ndarray | pd.Series,
dims: str,
) -> pt.TensorVariable:
"""
Create and return a likelihood distribution for the model.
This method prepares the distribution and its parameters as specified in the
configuration dictionary, validates them, and constructs the likelihood
distribution using PyMC.
Parameters
----------
dist : Dict
A configuration dictionary that must contain a 'dist' key with the name of
the distribution and a 'kwargs' key with parameters for the distribution.
observed : Union[np.ndarray, pd.Series]
The observed data to which the likelihood distribution will be fitted.
dims : str
The dimensions of the data.
Returns
-------
TensorVariable
The likelihood distribution constructed with PyMC.
Raises
------
ValueError
If 'kwargs' key is missing in `dist`, or the parameter configuration does
not contain 'dist' and 'kwargs' keys, or if 'mu' is present in the nested
'kwargs'
"""
if param_config["dist"] not in LIKELIHOOD_DISTRIBUTIONS:
raise UnsupportedDistributionError(

Check warning on line 280 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L279-L280

Added lines #L279 - L280 were not covered by tests
f"""
The distribution used for the likelihood is not allowed.
Please, use one of the following distributions: {list(LIKELIHOOD_DISTRIBUTIONS)}.
"""
)

if "mu" in param_config["kwargs"]:
raise ValueError(

Check warning on line 288 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L287-L288

Added lines #L287 - L288 were not covered by tests
"The 'mu' key is not allowed directly within 'kwargs' of the main distribution as it is reserved."
)

param_config["kwargs"]["mu"] = mu

Check warning on line 292 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L292

Added line #L292 was not covered by tests

kwargs = {

Check warning on line 294 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L294

Added line #L294 was not covered by tests
key: value
for key, value in param_config.items()
if key not in ["dist", "kwargs"]
}
kwargs["dims"] = dims
kwargs["observed"] = observed

Check warning on line 300 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L299-L300

Added lines #L299 - L300 were not covered by tests

return create_distribution(

Check warning on line 302 in pymc_marketing/model_config.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_config.py#L302

Added line #L302 was not covered by tests
name=name,
distribution_name=param_config["dist"],
distribution_kwargs=param_config["kwargs"],
**kwargs,
)
Loading

0 comments on commit 3f78436

Please sign in to comment.