-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add datasetter for date and fourier #405
Changes from 4 commits
31606ab
9383e76
9c42739
a3e7600
11eeaa1
c69d93b
02a0df7
65dccfa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from pymc_marketing.mmm.base import MMM | ||
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget | ||
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation | ||
from pymc_marketing.mmm.utils import generate_fourier_modes | ||
from pymc_marketing.mmm.utils import generate_yearly_fourier_modes | ||
from pymc_marketing.mmm.validating import ValidateControlColumns | ||
|
||
__all__ = ["DelayedSaturatedMMM"] | ||
|
@@ -85,7 +85,7 @@ def output_var(self): | |
return "y" | ||
|
||
def _generate_and_preprocess_model_data( # type: ignore | ||
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray] | ||
self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray] | ||
) -> None: | ||
""" | ||
Applies preprocessing to the data before fitting the model. | ||
|
@@ -94,7 +94,7 @@ def _generate_and_preprocess_model_data( # type: ignore | |
|
||
Parameters | ||
---------- | ||
X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features) | ||
X : pd.DataFrame, shape (n_obs, n_features) | ||
y : Union[pd.Series, np.ndarray], shape (n_obs,) | ||
""" | ||
date_data = X[self.date_column] | ||
|
@@ -326,7 +326,7 @@ def default_model_config(self) -> Dict: | |
} | ||
return model_config | ||
|
||
def _get_fourier_models_data(self, X) -> pd.DataFrame: | ||
def _get_fourier_models_data(self, X: pd.DataFrame) -> pd.DataFrame: | ||
"""Generates fourier modes to model seasonality. | ||
|
||
References | ||
|
@@ -338,10 +338,9 @@ def _get_fourier_models_data(self, X) -> pd.DataFrame: | |
date_data: pd.Series = pd.to_datetime( | ||
arg=X[self.date_column], format="%Y-%m-%d" | ||
) | ||
periods: npt.NDArray[np.float_] = date_data.dt.dayofyear.to_numpy() / 365.25 | ||
return generate_fourier_modes( | ||
periods=periods, | ||
n_order=self.yearly_seasonality, | ||
|
||
return generate_yearly_fourier_modes( | ||
dayofyear=date_data.dt.dayofyear.to_numpy(), n_order=self.yearly_seasonality | ||
) | ||
|
||
def channel_contributions_forward_pass( | ||
|
@@ -486,18 +485,42 @@ def _data_setter( | |
""" | ||
new_channel_data: Optional[np.ndarray] = None | ||
|
||
if isinstance(X, pd.DataFrame): | ||
def from_frame_or_array( | ||
X: Union[pd.DataFrame, np.ndarray], columns, handle_frame_func=None | ||
) -> np.ndarray: | ||
if not isinstance(X, (pd.DataFrame, np.ndarray)): | ||
raise TypeError("X must be either a pandas DataFrame or a numpy array") | ||
|
||
if isinstance(X, np.ndarray): | ||
return X | ||
|
||
if handle_frame_func is None: | ||
|
||
def handle_frame_func(X): | ||
raise RuntimeError(f"New data must contain {columns}!") | ||
|
||
try: | ||
new_channel_data = X[self.channel_columns].to_numpy() | ||
except KeyError as e: | ||
raise RuntimeError("New data must contain channel_data!", e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DataFrames already raise |
||
elif isinstance(X, np.ndarray): | ||
new_channel_data = X | ||
else: | ||
raise TypeError("X must be either a pandas DataFrame or a numpy array") | ||
return X[columns].to_numpy() | ||
except KeyError: | ||
return handle_frame_func(X) | ||
|
||
new_channel_data = from_frame_or_array(X, columns=self.channel_columns) | ||
data: Dict[str, Union[np.ndarray, Any]] = {"channel_data": new_channel_data} | ||
|
||
if self.control_columns is not None: | ||
new_control_data = from_frame_or_array(X, columns=self.control_columns) | ||
data["control_data"] = new_control_data | ||
|
||
if self.yearly_seasonality is not None: | ||
|
||
def handle_frame_func(X): | ||
return self._get_fourier_models_data(X).to_numpy() | ||
|
||
new_fourier_data = from_frame_or_array( | ||
X, columns=self.fourier_columns, handle_frame_func=handle_frame_func | ||
) | ||
data["fourier_data"] = new_fourier_data | ||
|
||
if y is not None: | ||
if isinstance(y, pd.Series): | ||
data[ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,7 +67,7 @@ def batched_convolution(x, w, axis: int = 0): | |
|
||
|
||
def geometric_adstock( | ||
x, alpha: float = 0.0, l_max: int = 12, normalize: bool = False, axis: int = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getting a mypy issue |
||
x, alpha=0.0, l_max: int = 12, normalize: bool = False, axis: int = 0 | ||
): | ||
"""Geometric adstock transformation. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should support ndarray when we want to select different sets of columns. Unless we want to support the case where only costs are added. i.e
isinstance(X, np.ndarray) and self.control_columns is None and self.yearly_seasonality is None
I think that this can simplify the code heavily