diff --git a/pyfixest/estimation/FixestMulti_.py b/pyfixest/estimation/FixestMulti_.py index ec68bf0f..8d62e91e 100644 --- a/pyfixest/estimation/FixestMulti_.py +++ b/pyfixest/estimation/FixestMulti_.py @@ -71,7 +71,8 @@ def __init__( self._reps = reps if use_compression else None self._seed = seed if use_compression else None - self._run_split = split or fsplit + self._run_split = split is not None or fsplit is not None + self._run_full = not (split and not fsplit) self._splitvar: Optional[str] = None if self._run_split: @@ -82,8 +83,6 @@ def __init__( else: self._splitvar = None - self._run_full = not (split and not fsplit) - data = _polars_to_pandas(data) if self._copy_data: @@ -171,7 +170,7 @@ def _prepare_estimation( FML = FixestFormulaParser(fml) FML.set_fixest_multi_flag() - self._is_multiple_estimation = FML._is_multiple_estimation + self._is_multiple_estimation = FML._is_multiple_estimation or self._run_split self.FixestFormulaDict = FML.FixestFormulaDict self._method = estimation self._is_iv = FML.is_iv @@ -231,9 +230,7 @@ def _estimate_all_models( _fixef_keys = list(FixestFormulaDict.keys()) all_splits = (["all"] if _run_full else []) + ( - _data[_splitvar].dropna().unique().tolist() - if _run_split is not None - else [] + _data[_splitvar].dropna().unique().tolist() if _run_split else [] ) for sample_split_value in all_splits: diff --git a/tests/test_vs_fixest.py b/tests/test_vs_fixest.py index 72b755fb..4587cd08 100644 --- a/tests/test_vs_fixest.py +++ b/tests/test_vs_fixest.py @@ -11,6 +11,7 @@ import pyfixest as pf from pyfixest.estimation.estimation import feols +from pyfixest.estimation.FixestMulti_ import FixestMulti from pyfixest.utils.set_rpy2_path import update_r_paths from pyfixest.utils.utils import get_data, ssc @@ -547,11 +548,6 @@ def test_single_fit_iv( @pytest.mark.parametrize( "fml_multi", [ - ("Y ~X1"), - ("Y ~X1+X2"), - ("Y~X1|f1"), - ("Y~X1|f1+f2"), - ("Y~X2|f2+f3"), ("Y~ sw(X1, X2)"), ("Y~ sw(X1, X2) |f1 "), ("Y~ csw(X1, X2)"), @@ -607,6 +603,7 @@ def test_multi_fit(N, seed, beta_type, error_type, dropna, fml_multi): try: pyfixest = feols(fml=fml_multi, data=data) + assert isinstance(pyfixest, FixestMulti) except ValueError as e: if "is not of type 'O' or 'category'" in str(e): data["f1"] = pd.Categorical(data.f1.astype(str)) @@ -680,6 +677,7 @@ def test_split_fit(N, seed, beta_type, error_type, dropna, fml_multi, split, fsp try: pyfixest = feols(fml=fml_multi, data=data, split=split, fsplit=fsplit) + assert isinstance(pyfixest, FixestMulti) except ValueError as e: if "is not of type 'O' or 'category'" in str(e): data["f1"] = pd.Categorical(data.f1.astype(str))