From 0f7f86c89af86d6f0f687287ba9a2c51428c4db9 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 11 Aug 2023 13:45:13 +0200 Subject: [PATCH] implementing abstract methods --- pymc_marketing/clv/models/basic.py | 18 +++++++++++++++++- pymc_marketing/mmm/delayed_saturated_mmm.py | 2 +- pymc_marketing/model_builder.py | 14 +++++++++----- tests/mmm/test_base.py | 4 ++++ 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index b8c134d0..9173d6fc 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -2,9 +2,11 @@ import types import warnings from pathlib import Path -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import arviz as az +import numpy as np +import pandas as pd import pymc as pm from pymc import str_for_dist from pymc.backends import NDArray @@ -286,3 +288,17 @@ def fit_summary(self, **kwargs): return res["mean"].rename("value") else: return az.summary(self.fit_result, **kwargs) + + @property + def output_var(self): + pass + + def generate_and_preprocess_model_data( + self, + X: Union[pd.DataFrame, pd.Series], + y: Union[pd.Series, np.ndarray[Any, Any]], + ) -> None: + pass + + def _data_setter(self): + pass diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 7333d617..4c476a28 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -83,7 +83,7 @@ def default_sampler_config(self) -> Dict: def output_var(self): return "y" - def generate_and_preprocess_model_data( + def generate_and_preprocess_model_data( # type: ignore self, X: Union[pd.DataFrame, pd.Series], y: pd.Series ) -> None: """ diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 264349d8..1e3654b0 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -196,7 +196,9 @@ def default_sampler_config(self) -> Dict: @abstractmethod def generate_and_preprocess_model_data( - self, X: Union[pd.DataFrame, pd.Series], y: pd.Series + self, + X: Union[pd.DataFrame, pd.Series], + y: Union[pd.Series, np.ndarray[Any, Any]], ) -> None: """ Applies preprocessing to the data before fitting the model. @@ -505,9 +507,11 @@ def fit( predictor_names = [] if y is None: y = np.zeros(X.shape[0]) - y = pd.Series({self.output_var: y}) - self.generate_and_preprocess_model_data(X, y) - self.build_model(self.X, self.y) # type: ignore + y_df = pd.DataFrame({self.output_var: y}) + self.generate_and_preprocess_model_data(X, y_df.values.flatten()) + if self.X is None or self.y is None: + raise ValueError("X and y must be set before calling build_model!") + self.build_model(self.X, self.y) sampler_config = self.sampler_config.copy() sampler_config["progressbar"] = progressbar @@ -516,7 +520,7 @@ def fit( self.idata = self.sample_model(**sampler_config) X_df = pd.DataFrame(X, columns=X.columns) - combined_data = pd.concat([X_df, y], axis=1) + combined_data = pd.concat([X_df, y_df], axis=1) assert all(combined_data.columns), "All columns must have non-empty names" with warnings.catch_warnings(): warnings.filterwarnings( diff --git a/tests/mmm/test_base.py b/tests/mmm/test_base.py index e1f18370..55a643b2 100644 --- a/tests/mmm/test_base.py +++ b/tests/mmm/test_base.py @@ -74,6 +74,10 @@ def default_model_config(self): def default_sampler_config(self): pass + @property + def output_var(self): + pass + def _data_setter(self, X, y=None): pass