Skip to content

Commit 445c30f

Browse files
committed
add plr tuning example
1 parent 424b1ec commit 445c30f

File tree

7 files changed

+247
-0
lines changed

7 files changed

+247
-0
lines changed

monte-cover/src/montecover/plm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from montecover.plm.pliv_late import PLIVLATECoverageSimulation
44
from montecover.plm.plr_ate import PLRATECoverageSimulation
5+
from montecover.plm.plr_ate_tune import PLRATETuningCoverageSimulation
56
from montecover.plm.plr_ate_sensitivity import PLRATESensitivityCoverageSimulation
67
from montecover.plm.plr_cate import PLRCATECoverageSimulation
78
from montecover.plm.plr_gate import PLRGATECoverageSimulation
@@ -12,4 +13,5 @@
1213
"PLRGATECoverageSimulation",
1314
"PLRCATECoverageSimulation",
1415
"PLRATESensitivityCoverageSimulation",
16+
"PLRATETuningCoverageSimulation",
1517
]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from typing import Any, Dict, Optional
2+
import optuna
3+
4+
import doubleml as dml
5+
from doubleml.plm.datasets import make_plr_CCDDHNR2018
6+
7+
from montecover.base import BaseSimulation
8+
from montecover.utils import create_learner_from_config
9+
10+
11+
class PLRATETuningCoverageSimulation(BaseSimulation):
12+
"""Simulation class for coverage properties of DoubleMLPLR for ATE estimation."""
13+
14+
def __init__(
15+
self,
16+
config_file: str,
17+
suppress_warnings: bool = True,
18+
log_level: str = "INFO",
19+
log_file: Optional[str] = None,
20+
):
21+
super().__init__(
22+
config_file=config_file,
23+
suppress_warnings=suppress_warnings,
24+
log_level=log_level,
25+
log_file=log_file,
26+
)
27+
28+
# Calculate oracle values
29+
self._calculate_oracle_values()
30+
31+
# tuning specific settings
32+
# parameter space for the outcome regression tuning
33+
def ml_l_params(trial):
34+
return {
35+
'n_estimators': trial.suggest_int('n_estimators', 100, 500, step=50),
36+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
37+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 100, step=5),
38+
'max_depth': trial.suggest_int('max_depth', 3, 10, step=1),
39+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-8, 10.0, log=True),
40+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-8, 10.0, log=True),
41+
}
42+
43+
# parameter space for the propensity score tuning
44+
def ml_m_params(trial):
45+
return {
46+
'n_estimators': trial.suggest_int('n_estimators', 100, 500, step=50),
47+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
48+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 100, step=5),
49+
'max_depth': trial.suggest_int('max_depth', 3, 10, step=1),
50+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-8, 10.0, log=True),
51+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-8, 10.0, log=True),
52+
}
53+
54+
self._param_space = {
55+
'ml_l': ml_l_params,
56+
'ml_m': ml_m_params
57+
}
58+
59+
self._optuna_settings = {
60+
'n_trials': 500,
61+
'show_progress_bar': False,
62+
'verbosity': optuna.logging.WARNING, # Suppress Optuna logs
63+
}
64+
65+
def _process_config_parameters(self):
66+
"""Process simulation-specific parameters from config"""
67+
# Process ML models in parameter grid
68+
assert "learners" in self.dml_parameters, "No learners specified in the config file"
69+
70+
required_learners = ["ml_g", "ml_m"]
71+
for learner in self.dml_parameters["learners"]:
72+
for ml in required_learners:
73+
assert ml in learner, f"No {ml} specified in the config file"
74+
75+
def _calculate_oracle_values(self):
76+
"""Calculate oracle values for the simulation."""
77+
self.logger.info("Calculating oracle values")
78+
79+
self.oracle_values = dict()
80+
self.oracle_values["theta"] = self.dgp_parameters["theta"]
81+
82+
def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
83+
"""Run a single repetition with the given parameters."""
84+
# Extract parameters
85+
learner_config = dml_params["learners"]
86+
learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"])
87+
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
88+
score = dml_params["score"]
89+
90+
# Model
91+
dml_model = dml.DoubleMLPLR(
92+
obj_dml_data=dml_data,
93+
ml_l=ml_g,
94+
ml_m=ml_m,
95+
ml_g=ml_g if score == "IV-type" else None,
96+
score=score,
97+
)
98+
dml_model.fit()
99+
100+
dml_model_tuned = dml.DoubleMLPLR(
101+
obj_dml_data=dml_data,
102+
ml_l=ml_g,
103+
ml_m=ml_m,
104+
ml_g=ml_g if score == "IV-type" else None,
105+
score=score,
106+
)
107+
dml_model_tuned.tune_ml_models(
108+
ml_param_space=self._param_space,
109+
optuna_settings=self._optuna_settings,
110+
)
111+
dml_model_tuned.fit()
112+
113+
result = {
114+
"coverage": [],
115+
}
116+
for model in [dml_model, dml_model_tuned]:
117+
for level in self.confidence_parameters["level"]:
118+
level_result = dict()
119+
level_result["coverage"] = self._compute_coverage(
120+
thetas=model.coef,
121+
oracle_thetas=self.oracle_values["theta"],
122+
confint=model.confint(level=level),
123+
joint_confint=None,
124+
)
125+
126+
# add parameters to the result
127+
for res in level_result.values():
128+
res.update(
129+
{
130+
"Learner g": learner_g_name,
131+
"Learner m": learner_m_name,
132+
"Score": score,
133+
"level": level,
134+
"Tuned": model is dml_model_tuned,
135+
}
136+
)
137+
for key, res in level_result.items():
138+
result[key].append(res)
139+
140+
return result
141+
142+
def summarize_results(self):
143+
"""Summarize the simulation results."""
144+
self.logger.info("Summarizing simulation results")
145+
146+
# Group by parameter combinations
147+
groupby_cols = ["Learner g", "Learner m", "Score", "level", "Tuned"]
148+
aggregation_dict = {
149+
"Coverage": "mean",
150+
"CI Length": "mean",
151+
"Bias": "mean",
152+
"repetition": "count",
153+
}
154+
155+
# Aggregate results (possibly multiple result dfs)
156+
result_summary = dict()
157+
for result_name, result_df in self.results.items():
158+
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
159+
self.logger.debug(f"Summarized {result_name} results")
160+
161+
return result_summary
162+
163+
def _generate_dml_data(self, dgp_params) -> dml.DoubleMLData:
164+
"""Generate data for the simulation."""
165+
data = make_plr_CCDDHNR2018(
166+
alpha=dgp_params["theta"],
167+
n_obs=dgp_params["n_obs"],
168+
dim_x=dgp_params["dim_x"],
169+
return_type="DataFrame",
170+
)
171+
dml_data = dml.DoubleMLData(data, "y", "d")
172+
return dml_data
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
simulation_parameters:
2+
repetitions: 100
3+
max_runtime: 19800
4+
random_seed: 42
5+
n_jobs: -2
6+
dgp_parameters:
7+
theta:
8+
- 0.5
9+
n_obs:
10+
- 500
11+
dim_x:
12+
- 20
13+
learner_definitions:
14+
lgbm: &id001
15+
name: LGBM Regr.
16+
dml_parameters:
17+
learners:
18+
- ml_g: *id001
19+
ml_m: *id001
20+
score:
21+
- partialling out
22+
confidence_parameters:
23+
level:
24+
- 0.95
25+
- 0.9
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Learner g,Learner m,Score,level,Tuned,Coverage,CI Length,Bias,repetition
2+
LGBM Regr.,LGBM Regr.,partialling out,0.9,False,0.84,0.1476322636407647,0.0419955778328819,100
3+
LGBM Regr.,LGBM Regr.,partialling out,0.9,True,0.9,0.14472954691929252,0.03509689726085052,100
4+
LGBM Regr.,LGBM Regr.,partialling out,0.95,False,0.91,0.17591469231721352,0.0419955778328819,100
5+
LGBM Regr.,LGBM Regr.,partialling out,0.95,True,0.95,0.17245589200927858,0.03509689726085052,100
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
2+
0.11.dev0,PLRATETuningCoverageSimulation,2025-11-17 15:23,25.530170826117196,3.12.9,scripts/plm/plr_ate_tune_config.yml

scripts/plm/plr_ate_tune.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from montecover.plm import PLRATETuningCoverageSimulation
2+
3+
# Create and run simulation with config file
4+
sim = PLRATETuningCoverageSimulation(
5+
config_file="scripts/plm/plr_ate_tune_config.yml",
6+
log_level="INFO",
7+
log_file="logs/plm/plr_ate_tune_sim.log",
8+
)
9+
sim.run_simulation()
10+
sim.save_results(output_path="results/plm/", file_prefix="plr_ate_tune")
11+
12+
# Save config file for reproducibility
13+
sim.save_config("results/plm/plr_ate_tune_config.yml")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Simulation parameters for PLR ATE Coverage
2+
3+
simulation_parameters:
4+
repetitions: 100
5+
max_runtime: 19800 # 5.5 hours in seconds
6+
random_seed: 42
7+
n_jobs: -2
8+
9+
dgp_parameters:
10+
theta: [0.5] # Treatment effect
11+
n_obs: [500] # Sample size
12+
dim_x: [20] # Number of covariates
13+
14+
# Define reusable learner configurations
15+
learner_definitions:
16+
lgbm: &lgbm
17+
name: "LGBM Regr."
18+
19+
dml_parameters:
20+
learners:
21+
- ml_g: *lgbm
22+
ml_m: *lgbm
23+
24+
25+
score: ["partialling out"]
26+
27+
confidence_parameters:
28+
level: [0.95, 0.90] # Confidence levels

0 commit comments

Comments
 (0)