Skip to content

Commit

Permalink
rename parent class method, fix type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
cluhmann committed Oct 19, 2023
1 parent 850e099 commit a6e8251
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pymc as pm
import seaborn as sns
from xarray import DataArray
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.mmm.base import MMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
Expand Down Expand Up @@ -344,9 +345,9 @@ def _get_fourier_models_data(self, X) -> pd.DataFrame:
n_order=self.yearly_seasonality,
)

def channel_contributions_forward_pass(
def channel_contributions_forward_pass_untransformed(
self, channel_data: npt.NDArray[np.float_]
) -> npt.NDArray[np.float_]:
) -> TensorVariable:
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
Parameters
----------
Expand Down Expand Up @@ -556,15 +557,15 @@ def channel_contributions_forward_pass(
array-like
Transformed channel data.
"""
channel_contribution_forward_pass = super().channel_contributions_forward_pass(
channel_contribution_forward_pass = super().channel_contributions_forward_pass_untransformed(
channel_data=channel_data
)
target_transformed_vectorized = np.vectorize(
self.target_transformer.inverse_transform,
excluded=[1, 2],
signature="(m, n) -> (m, n)",
)
return target_transformed_vectorized(channel_contribution_forward_pass).eval()
return target_transformed_vectorized(channel_contribution_forward_pass.eval())

def get_channel_contributions_forward_pass_grid(
self, start: float, stop: float, num: int
Expand Down

0 comments on commit a6e8251

Please sign in to comment.