Skip to content

Commit

Permalink
Add the epsilon kwarg into the sci-kit learn version of the Regressio…
Browse files Browse the repository at this point in the history
…nDiscontinuity class
  • Loading branch information
drbenvincent committed Jul 22, 2023
1 parent 9bda3d9 commit e007b94
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,21 @@ def plot(self):

class RegressionDiscontinuity(ExperimentalDesign):
"""
Analyse data from regression discontinuity experiments.
.. note::
There is no pre/post intervention data distinction for the regression
discontinuity design, we fit all the data available.
A class to analyse regression discontinuity experiments.
:param data:
A pandas dataframe
:param formula:
A statistical model formula
:param treatment_threshold:
A scalar threshold value at which the treatment is applied
:param model:
A sci-kit learn model object
:param running_variable_name:
The name of the predictor variable that the treatment threshold is based upon
:param epsilon:
A small scalar value which determines how far above and below the treatment
threshold to evaluate the causal impact.
"""

def __init__(
Expand All @@ -362,13 +370,15 @@ def __init__(
treatment_threshold,
model=None,
running_variable_name="x",
epsilon: float = 0.001,
**kwargs,
):
super().__init__(model=model, **kwargs)
self.data = data
self.formula = formula
self.running_variable_name = running_variable_name
self.treatment_threshold = treatment_threshold
self.epsilon = epsilon
y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
Expand Down Expand Up @@ -404,7 +414,10 @@ def __init__(
self.x_discon = pd.DataFrame(
{
self.running_variable_name: np.array(
[self.treatment_threshold - 0.001, self.treatment_threshold + 0.001]
[
self.treatment_threshold - self.epsilon,
self.treatment_threshold + self.epsilon,
]
),
"treated": np.array([0, 1]),
}
Expand Down

0 comments on commit e007b94

Please sign in to comment.