From e9334208c5afccd68a706aba36537bb5d10e7897 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Sat, 19 Jul 2025 09:11:38 +0100 Subject: [PATCH] Handle integer treated variable in RegressionDiscontinuity --- .../experiments/regression_discontinuity.py | 6 +++ causalpy/tests/test_input_validation.py | 49 +++++++++++++++++++ docs/source/_static/interrogate_badge.svg | 6 +-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/causalpy/experiments/regression_discontinuity.py b/causalpy/experiments/regression_discontinuity.py index a1962b9a..ec24ba0b 100644 --- a/causalpy/experiments/regression_discontinuity.py +++ b/causalpy/experiments/regression_discontinuity.py @@ -210,6 +210,12 @@ def input_validation(self): """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 ) + # Convert integer treated variable to boolean if needed + if self.data["treated"].dtype in ["int64", "int32"]: + # Make a copy to avoid SettingWithCopyWarning + self.data = self.data.copy() + self.data["treated"] = self.data["treated"].astype(bool) + def _is_treated(self, x): """Returns ``True`` if `x` is greater than or equal to the treatment threshold. diff --git a/causalpy/tests/test_input_validation.py b/causalpy/tests/test_input_validation.py index 30eaa344..43fd9208 100644 --- a/causalpy/tests/test_input_validation.py +++ b/causalpy/tests/test_input_validation.py @@ -383,3 +383,52 @@ def test_rkink_epsilon_check(): kink_point=kink, epsilon=-1, ) + + +# RegressionDiscontinuity + + +def setup_regression_discontinuity_data(threshold=0.5): + """Create data for a regression discontinuity test.""" + np.random.seed(42) + x = np.random.uniform(0, 1, 100) + treated = np.where(x > threshold, 1, 0) + y = 2 * x + treated + np.random.normal(0, 1, 100) + return pd.DataFrame({"x": x, "treated": treated, "y": y}) + + +def test_regression_discontinuity_int_treatment(): + """Test that RegressionDiscontinuity works with integer treatment variables.""" + threshold = 0.5 + df = setup_regression_discontinuity_data(threshold) + assert df["treated"].dtype == np.int64 # Ensure treatment is int + + # This should work now with our fix + result = cp.RegressionDiscontinuity( + df, + formula="y ~ 1 + x + treated + x:treated", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + treatment_threshold=threshold, + ) + + # Check that the treatment variable was converted to bool + assert result.data["treated"].dtype == bool + + +def test_regression_discontinuity_bool_treatment(): + """Test that RegressionDiscontinuity works with boolean treatment variables.""" + threshold = 0.5 + df = setup_regression_discontinuity_data(threshold) + df["treated"] = df["treated"].astype(bool) # Convert to bool + assert df["treated"].dtype == bool # Ensure treatment is bool + + # This should work as before + result = cp.RegressionDiscontinuity( + df, + formula="y ~ 1 + x + treated + x:treated", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + treatment_threshold=threshold, + ) + + # Check that the treatment variable is still bool + assert result.data["treated"].dtype == bool diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index d2d886ad..4704ef6c 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 95.4% + interrogate: 95.5% @@ -12,8 +12,8 @@ interrogate interrogate - 95.4% - 95.4% + 95.5% + 95.5%