Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't run prior and posterior predictive when calling fit #365

Merged
merged 9 commits into from
Sep 1, 2023
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def fit_summary(self, **kwargs):
def output_var(self):
pass

def generate_and_preprocess_model_data(
def _generate_and_preprocess_model_data(
self,
X: Union[pd.DataFrame, pd.Series],
y: Union[pd.Series, np.ndarray[Any, Any]],
Expand Down
28 changes: 15 additions & 13 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
**kwargs,
) -> None:
self.X: Optional[pd.DataFrame] = None
self.y: Optional[pd.Series] = None
self.y: Optional[Union[pd.Series, np.ndarray]] = None
self.date_column: str = date_column
self.channel_columns: Union[List[str], Tuple[str]] = channel_columns
self.n_channel: int = len(channel_columns)
Expand All @@ -69,8 +69,8 @@ def methods(self) -> List[Any]:
def validation_methods(
self,
) -> Tuple[
List[Callable[["BaseMMM", Union[pd.DataFrame, pd.Series]], None]],
List[Callable[["BaseMMM", Union[pd.DataFrame, pd.Series]], None]],
List[Callable[["BaseMMM", Union[pd.DataFrame, pd.Series, np.ndarray]], None]],
List[Callable[["BaseMMM", Union[pd.DataFrame, pd.Series, np.ndarray]], None]],
]:
"""
A property that provides validation methods for features ("X") and the target variable ("y").
Expand Down Expand Up @@ -98,7 +98,9 @@ def validation_methods(
],
)

def validate(self, target: str, data: Union[pd.DataFrame, pd.Series]) -> None:
def validate(
self, target: str, data: Union[pd.DataFrame, pd.Series, np.ndarray]
) -> None:
"""
Validates the input data based on the specified target type.

Expand All @@ -110,7 +112,7 @@ def validate(self, target: str, data: Union[pd.DataFrame, pd.Series]) -> None:
target : str
The type of target to be validated.
Expected values are "X" for features and "y" for the target variable.
data : Union[pd.DataFrame, pd.Series]
data : Union[pd.DataFrame, pd.Series, np.ndarray]
The input data to be validated.

Raises
Expand All @@ -134,14 +136,14 @@ def preprocessing_methods(
) -> Tuple[
List[
Callable[
["BaseMMM", Union[pd.DataFrame, pd.Series]],
Union[pd.DataFrame, pd.Series],
["BaseMMM", Union[pd.DataFrame, pd.Series, np.ndarray]],
Union[pd.DataFrame, pd.Series, np.ndarray],
]
],
List[
Callable[
["BaseMMM", Union[pd.DataFrame, pd.Series]],
Union[pd.DataFrame, pd.Series],
["BaseMMM", Union[pd.DataFrame, pd.Series, np.ndarray]],
Union[pd.DataFrame, pd.Series, np.ndarray],
]
],
]:
Expand Down Expand Up @@ -171,8 +173,8 @@ def preprocessing_methods(
)

def preprocess(
self, target: str, data: Union[pd.DataFrame, pd.Series]
) -> Union[pd.DataFrame, pd.Series]:
self, target: str, data: Union[pd.DataFrame, pd.Series, np.ndarray]
) -> Union[pd.DataFrame, pd.Series, np.ndarray]:
"""
Preprocess the provided data according to the specified target.

Expand All @@ -184,12 +186,12 @@ def preprocess(
target : str
Indicates whether the data represents features ("X") or the target variable ("y").

data : pd.DataFrame
data : Union[pd.DataFrame, pd.Series, np.ndarray]
The data to be preprocessed.

Returns
-------
Union[pd.DataFrame, pd.Series]
Union[pd.DataFrame, pd.Series, np.ndarray]
The preprocessed data.

Raises
Expand Down
43 changes: 34 additions & 9 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def default_sampler_config(self) -> Dict:

@property
def output_var(self):
"""Defines target variable for the model"""
return "y"

def generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: pd.Series
def _generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]
) -> None:
"""
Applies preprocessing to the data before fitting the model.
Expand All @@ -93,8 +94,8 @@ def generate_and_preprocess_model_data( # type: ignore

Parameters
----------
X : array, shape (n_obs, n_features)
y : array, shape (n_obs,)
X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features)
y : Union[pd.Series, np.ndarray], shape (n_obs,)
"""
date_data = X[self.date_column]
channel_data = X[self.channel_columns]
Expand Down Expand Up @@ -126,11 +127,11 @@ def generate_and_preprocess_model_data( # type: ignore
self.validate("X", X_data)
self.validate("y", y)
self.preprocessed_data: Dict[str, Union[pd.DataFrame, pd.Series]] = {
"X": self.preprocess("X", X_data),
"y": self.preprocess("y", y),
"X": self.preprocess("X", X_data), # type: ignore
"y": self.preprocess("y", y), # type: ignore
}
self.X: pd.DataFrame = X_data
self.y: pd.Series = y
self.y: Union[pd.Series, np.ndarray] = y

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
Expand All @@ -144,11 +145,35 @@ def _save_input_params(self, idata) -> None:
def build_model(
self,
X: pd.DataFrame,
y: pd.Series,
y: Union[pd.Series, np.ndarray],
**kwargs,
) -> None:
"""
Builds a probabilistic model using PyMC for marketing mix modeling.

The model incorporates channels, control variables, and Fourier components, applying
adstock and saturation transformations to the channel data. The final model is
constructed with multiple factors contributing to the response variable.

Parameters
----------
X : pd.DataFrame
The input data for the model, which should include columns for channels,
control variables (if applicable), and Fourier components (if applicable).

y : Union[pd.Series, np.ndarray]
The target/response variable for the modeling.

**kwargs : dict
Additional keyword arguments that might be required by underlying methods or utilities.

Attributes Set
---------------
model : pm.Model
The PyMC model object containing all the defined stochastic and deterministic variables.
"""
model_config = self.model_config
self.generate_and_preprocess_model_data(X, y)
self._generate_and_preprocess_model_data(X, y)
with pm.Model(coords=self.model_coords) as self.model:
channel_data_ = pm.MutableData(
name="channel_data",
Expand Down
Loading