Skip to content

Commit

Permalink
implementing abstract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelraczycki committed Aug 11, 2023
1 parent 9f11e99 commit 0f7f86c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
18 changes: 17 additions & 1 deletion pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import types
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
from pymc import str_for_dist
from pymc.backends import NDArray
Expand Down Expand Up @@ -286,3 +288,17 @@ def fit_summary(self, **kwargs):
return res["mean"].rename("value")
else:
return az.summary(self.fit_result, **kwargs)

@property
def output_var(self):
pass

Check warning on line 294 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L294

Added line #L294 was not covered by tests

def generate_and_preprocess_model_data(
self,
X: Union[pd.DataFrame, pd.Series],
y: Union[pd.Series, np.ndarray[Any, Any]],
) -> None:
pass

Check warning on line 301 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L301

Added line #L301 was not covered by tests

def _data_setter(self):
pass

Check warning on line 304 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L304

Added line #L304 was not covered by tests
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def default_sampler_config(self) -> Dict:
def output_var(self):
return "y"

def generate_and_preprocess_model_data(
def generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: pd.Series
) -> None:
"""
Expand Down
14 changes: 9 additions & 5 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def default_sampler_config(self) -> Dict:

@abstractmethod
def generate_and_preprocess_model_data(
self, X: Union[pd.DataFrame, pd.Series], y: pd.Series
self,
X: Union[pd.DataFrame, pd.Series],
y: Union[pd.Series, np.ndarray[Any, Any]],
) -> None:
"""
Applies preprocessing to the data before fitting the model.
Expand Down Expand Up @@ -505,9 +507,11 @@ def fit(
predictor_names = []
if y is None:
y = np.zeros(X.shape[0])

Check warning on line 509 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L509

Added line #L509 was not covered by tests
y = pd.Series({self.output_var: y})
self.generate_and_preprocess_model_data(X, y)
self.build_model(self.X, self.y) # type: ignore
y_df = pd.DataFrame({self.output_var: y})
self.generate_and_preprocess_model_data(X, y_df.values.flatten())
if self.X is None or self.y is None:
raise ValueError("X and y must be set before calling build_model!")

Check warning on line 513 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L513

Added line #L513 was not covered by tests
self.build_model(self.X, self.y)

sampler_config = self.sampler_config.copy()
sampler_config["progressbar"] = progressbar
Expand All @@ -516,7 +520,7 @@ def fit(
self.idata = self.sample_model(**sampler_config)

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y], axis=1)
combined_data = pd.concat([X_df, y_df], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
with warnings.catch_warnings():
warnings.filterwarnings(
Expand Down
4 changes: 4 additions & 0 deletions tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def default_model_config(self):
def default_sampler_config(self):
pass

@property
def output_var(self):
pass

def _data_setter(self, X, y=None):
pass

Expand Down

0 comments on commit 0f7f86c

Please sign in to comment.