Skip to content

Commit

Permalink
Fix bug that lead use of the "split" and "fsplit" arguments to not pr…
Browse files Browse the repository at this point in the history
…oduce FixestMulti objects (#658)

* fix bug that lead split calls to not produce FixestMulti objects

* fix test bug with single fits evaluated in FixestMulti tests
  • Loading branch information
s3alfisc authored Oct 16, 2024
1 parent d18aba3 commit ea02434
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
11 changes: 4 additions & 7 deletions pyfixest/estimation/FixestMulti_.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __init__(
self._reps = reps if use_compression else None
self._seed = seed if use_compression else None

self._run_split = split or fsplit
self._run_split = split is not None or fsplit is not None
self._run_full = not (split and not fsplit)

self._splitvar: Optional[str] = None
if self._run_split:
Expand All @@ -82,8 +83,6 @@ def __init__(
else:
self._splitvar = None

self._run_full = not (split and not fsplit)

data = _polars_to_pandas(data)

if self._copy_data:
Expand Down Expand Up @@ -171,7 +170,7 @@ def _prepare_estimation(

FML = FixestFormulaParser(fml)
FML.set_fixest_multi_flag()
self._is_multiple_estimation = FML._is_multiple_estimation
self._is_multiple_estimation = FML._is_multiple_estimation or self._run_split
self.FixestFormulaDict = FML.FixestFormulaDict
self._method = estimation
self._is_iv = FML.is_iv
Expand Down Expand Up @@ -231,9 +230,7 @@ def _estimate_all_models(
_fixef_keys = list(FixestFormulaDict.keys())

all_splits = (["all"] if _run_full else []) + (
_data[_splitvar].dropna().unique().tolist()
if _run_split is not None
else []
_data[_splitvar].dropna().unique().tolist() if _run_split else []
)

for sample_split_value in all_splits:
Expand Down
8 changes: 3 additions & 5 deletions tests/test_vs_fixest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pyfixest as pf
from pyfixest.estimation.estimation import feols
from pyfixest.estimation.FixestMulti_ import FixestMulti
from pyfixest.utils.set_rpy2_path import update_r_paths
from pyfixest.utils.utils import get_data, ssc

Expand Down Expand Up @@ -547,11 +548,6 @@ def test_single_fit_iv(
@pytest.mark.parametrize(
"fml_multi",
[
("Y ~X1"),
("Y ~X1+X2"),
("Y~X1|f1"),
("Y~X1|f1+f2"),
("Y~X2|f2+f3"),
("Y~ sw(X1, X2)"),
("Y~ sw(X1, X2) |f1 "),
("Y~ csw(X1, X2)"),
Expand Down Expand Up @@ -607,6 +603,7 @@ def test_multi_fit(N, seed, beta_type, error_type, dropna, fml_multi):

try:
pyfixest = feols(fml=fml_multi, data=data)
assert isinstance(pyfixest, FixestMulti)
except ValueError as e:
if "is not of type 'O' or 'category'" in str(e):
data["f1"] = pd.Categorical(data.f1.astype(str))
Expand Down Expand Up @@ -680,6 +677,7 @@ def test_split_fit(N, seed, beta_type, error_type, dropna, fml_multi, split, fsp

try:
pyfixest = feols(fml=fml_multi, data=data, split=split, fsplit=fsplit)
assert isinstance(pyfixest, FixestMulti)
except ValueError as e:
if "is not of type 'O' or 'category'" in str(e):
data["f1"] = pd.Categorical(data.f1.astype(str))
Expand Down

0 comments on commit ea02434

Please sign in to comment.