diff --git a/pymc_marketing/mmm/validating.py b/pymc_marketing/mmm/validating.py index 16983b6e..f4688b08 100644 --- a/pymc_marketing/mmm/validating.py +++ b/pymc_marketing/mmm/validating.py @@ -12,14 +12,6 @@ ] -def _has_valid_indices(df: pd.DataFrame) -> bool: - return ( - isinstance(df.index, pd.RangeIndex) - and df.index.start == 0 - and df.index.stop == len(df) - ) - - def validation_method_y(method: Callable) -> Callable: if not hasattr(method, "_tags"): method._tags = {} # type: ignore @@ -49,7 +41,11 @@ def validate_date_col(self, data: pd.DataFrame) -> None: if self.date_column not in data.columns: raise ValueError(f"date_col {self.date_column} not in data") - if not _has_valid_indices(data): + if ( + not isinstance(data.index, pd.RangeIndex) + or data.index.start != 0 + or data.index.stop != len(data) + ): raise ValueError( "X or y has incorrect indices. Try to reset with `data.reset_index(inplace=True)`" )