Skip to content

Commit

Permalink
fix: Add type ignore args to all sample-args
Browse files Browse the repository at this point in the history
  • Loading branch information
louismagowan committed Oct 18, 2024
1 parent 7e4e987 commit 2b9768a
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def log_inference_data(


def log_evaluation_metrics(
mmm: MMM,
y_true: np.ndarray,
y_pred: np.ndarray,
metrics_to_calculate: list[str] | None = None,
Expand All @@ -471,8 +470,6 @@ def log_evaluation_metrics(
Parameters
----------
mmm : MMM
The fitted MMM object.
y_true : np.ndarray
The true values of the target variable.
y_pred : np.ndarray
Expand Down Expand Up @@ -590,7 +587,7 @@ def predict(
include_last_observations=self.include_last_observations,
original_scale=self.original_scale,
var_names=self.var_names,
**self.sample_kwargs,
**self.sample_kwargs, # type: ignore[arg-type]
)
elif predict_method == "sample_posterior_predictive":
return self.model.sample_posterior_predictive(
Expand All @@ -600,7 +597,7 @@ def predict(
include_last_observations=self.include_last_observations,
original_scale=self.original_scale,
var_names=self.var_names,
**self.sample_kwargs,
**self.sample_kwargs, # type: ignore[arg-type]
)
elif predict_method == "sample_prior_predictive":
return self.model.sample_prior_predictive(
Expand Down Expand Up @@ -947,7 +944,6 @@ def new_fit(self, *args, **kwargs):

posterior_preds = self.sample_posterior_predictive(self.X)
log_evaluation_metrics(
self,
y_true=self.y,
y_pred=posterior_preds[
self.output_var[0]
Expand Down

0 comments on commit 2b9768a

Please sign in to comment.