Skip to content

Commit 424b1ec

Browse files
committed
add first tuning simulation
1 parent 6e0f15c commit 424b1ec

File tree

7 files changed

+242
-0
lines changed

7 files changed

+242
-0
lines changed

monte-cover/src/montecover/irm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from montecover.irm.cvar import CVARCoverageSimulation
66
from montecover.irm.iivm_late import IIVMLATECoverageSimulation
77
from montecover.irm.irm_ate import IRMATECoverageSimulation
8+
from montecover.irm.irm_ate_tune import IRMATETuningCoverageSimulation
89
from montecover.irm.irm_ate_sensitivity import IRMATESensitivityCoverageSimulation
910
from montecover.irm.irm_atte import IRMATTECoverageSimulation
1011
from montecover.irm.irm_atte_sensitivity import IRMATTESensitivityCoverageSimulation
@@ -18,6 +19,7 @@
1819
"APOSCoverageSimulation",
1920
"CVARCoverageSimulation",
2021
"IRMATECoverageSimulation",
22+
"IRMATETuningCoverageSimulation",
2123
"IIVMLATECoverageSimulation",
2224
"IRMATESensitivityCoverageSimulation",
2325
"IRMATTECoverageSimulation",
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Any, Dict, Optional
2+
import optuna
3+
4+
import doubleml as dml
5+
from doubleml.irm.datasets import make_irm_data
6+
7+
from montecover.base import BaseSimulation
8+
from montecover.utils import create_learner_from_config
9+
10+
11+
class IRMATETuningCoverageSimulation(BaseSimulation):
12+
"""Simulation class for coverage properties of DoubleMLIRM for ATE estimation with hyperparameter tuning."""
13+
14+
def __init__(
15+
self,
16+
config_file: str,
17+
suppress_warnings: bool = True,
18+
log_level: str = "INFO",
19+
log_file: Optional[str] = None,
20+
):
21+
super().__init__(
22+
config_file=config_file,
23+
suppress_warnings=suppress_warnings,
24+
log_level=log_level,
25+
log_file=log_file,
26+
)
27+
28+
# Calculate oracle values
29+
self._calculate_oracle_values()
30+
31+
# tuning specific settings
32+
# parameter space for the outcome regression tuning
33+
def ml_g_params(trial):
34+
return {
35+
'n_estimators': trial.suggest_int('n_estimators', 100, 500, step=50),
36+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
37+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 100, step=5),
38+
'max_depth': trial.suggest_int('max_depth', 3, 10, step=1),
39+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-8, 10.0, log=True),
40+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-8, 10.0, log=True),
41+
}
42+
43+
# parameter space for the propensity score tuning
44+
def ml_m_params(trial):
45+
return {
46+
'n_estimators': trial.suggest_int('n_estimators', 100, 500, step=50),
47+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
48+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 100, step=5),
49+
'max_depth': trial.suggest_int('max_depth', 3, 10, step=1),
50+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-8, 10.0, log=True),
51+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-8, 10.0, log=True),
52+
}
53+
54+
self._param_space = {
55+
'ml_g': ml_g_params,
56+
'ml_m': ml_m_params
57+
}
58+
59+
self._optuna_settings = {
60+
'n_trials': 500,
61+
'show_progress_bar': False,
62+
'verbosity': optuna.logging.WARNING, # Suppress Optuna logs
63+
}
64+
65+
def _process_config_parameters(self):
66+
"""Process simulation-specific parameters from config"""
67+
# Process ML models in parameter grid
68+
assert "learners" in self.dml_parameters, "No learners specified in the config file"
69+
70+
required_learners = ["ml_g", "ml_m"]
71+
for learner in self.dml_parameters["learners"]:
72+
for ml in required_learners:
73+
assert ml in learner, f"No {ml} specified in the config file"
74+
75+
def _calculate_oracle_values(self):
76+
"""Calculate oracle values for the simulation."""
77+
self.logger.info("Calculating oracle values")
78+
79+
self.oracle_values = dict()
80+
self.oracle_values["theta"] = self.dgp_parameters["theta"]
81+
82+
def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]) -> Dict[str, Any]:
83+
"""Run a single repetition with the given parameters."""
84+
# Extract parameters
85+
learner_config = dml_params["learners"]
86+
learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"])
87+
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
88+
89+
# Model
90+
dml_model = dml.DoubleMLIRM(
91+
obj_dml_data=dml_data,
92+
ml_g=ml_g,
93+
ml_m=ml_m,
94+
)
95+
dml_model.fit()
96+
97+
dml_model_tuned = dml.DoubleMLIRM(
98+
obj_dml_data=dml_data,
99+
ml_g=ml_g,
100+
ml_m=ml_m,
101+
)
102+
dml_model_tuned.tune_ml_models(
103+
ml_param_space=self._param_space,
104+
optuna_settings=self._optuna_settings,
105+
)
106+
dml_model_tuned.fit()
107+
108+
result = {
109+
"coverage": [],
110+
}
111+
for model in [dml_model, dml_model_tuned]:
112+
for level in self.confidence_parameters["level"]:
113+
level_result = dict()
114+
level_result["coverage"] = self._compute_coverage(
115+
thetas=model.coef,
116+
oracle_thetas=self.oracle_values["theta"],
117+
confint=model.confint(level=level),
118+
joint_confint=None,
119+
)
120+
121+
# add parameters to the result
122+
for res_metric in level_result.values():
123+
res_metric.update(
124+
{
125+
"Learner g": learner_g_name,
126+
"Learner m": learner_m_name,
127+
"level": level,
128+
"Tuned": model is dml_model_tuned,
129+
}
130+
)
131+
for key, res in level_result.items():
132+
result[key].append(res)
133+
134+
return result
135+
136+
def summarize_results(self):
137+
"""Summarize the simulation results."""
138+
self.logger.info("Summarizing simulation results")
139+
140+
# Group by parameter combinations
141+
groupby_cols = ["Learner g", "Learner m", "level", "Tuned"]
142+
aggregation_dict = {
143+
"Coverage": "mean",
144+
"CI Length": "mean",
145+
"Bias": "mean",
146+
"repetition": "count",
147+
}
148+
149+
# Aggregate results (possibly multiple result dfs)
150+
result_summary = dict()
151+
for result_name, result_df in self.results.items():
152+
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
153+
self.logger.debug(f"Summarized {result_name} results")
154+
155+
return result_summary
156+
157+
def _generate_dml_data(self, dgp_params: Dict[str, Any]) -> dml.DoubleMLData:
158+
"""Generate data for the simulation."""
159+
data = make_irm_data(
160+
theta=dgp_params["theta"],
161+
n_obs=dgp_params["n_obs"],
162+
dim_x=dgp_params["dim_x"],
163+
return_type="DataFrame",
164+
)
165+
dml_data = dml.DoubleMLData(data, "y", "d")
166+
return dml_data
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
simulation_parameters:
2+
repetitions: 100
3+
max_runtime: 19800
4+
random_seed: 42
5+
n_jobs: -2
6+
dgp_parameters:
7+
theta:
8+
- 0.5
9+
n_obs:
10+
- 1000
11+
dim_x:
12+
- 5
13+
learner_definitions:
14+
lgbmr: &id001
15+
name: LGBM Regr.
16+
lgbmc: &id002
17+
name: LGBM Clas.
18+
dml_parameters:
19+
learners:
20+
- ml_g: *id001
21+
ml_m: *id002
22+
confidence_parameters:
23+
level:
24+
- 0.95
25+
- 0.9
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,repetition
2+
LGBM Regr.,LGBM Clas.,0.9,False,0.98,1.917963098754478,0.3759301980407579,100
3+
LGBM Regr.,LGBM Clas.,0.9,True,0.9,0.3286032416773015,0.08161915599017401,100
4+
LGBM Regr.,LGBM Clas.,0.95,False,0.99,2.285393992292617,0.3759301980407579,100
5+
LGBM Regr.,LGBM Clas.,0.95,True,0.94,0.3915549130558737,0.08161915599017401,100
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
2+
0.11.dev0,IRMATETuningCoverageSimulation,2025-11-14 18:20,22.557860120137533,3.12.9,scripts/irm/irm_ate_tune_config.yml

scripts/irm/irm_ate_tune.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from montecover.irm import IRMATETuningCoverageSimulation
2+
3+
# Create and run simulation with config file
4+
sim = IRMATETuningCoverageSimulation(
5+
config_file="scripts/irm/irm_ate_tune_config.yml",
6+
log_level="INFO",
7+
log_file="logs/irm/irm_ate_tune_sim.log",
8+
)
9+
sim.run_simulation()
10+
sim.save_results(output_path="results/irm/", file_prefix="irm_ate_tune")
11+
12+
# Save config file for reproducibility
13+
sim.save_config("results/irm/irm_ate_tune_config.yml")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Simulation parameters for IRM ATE Coverage with Tuning
2+
3+
simulation_parameters:
4+
repetitions: 100
5+
max_runtime: 19800 # 5.5 hours in seconds
6+
random_seed: 42
7+
n_jobs: -2
8+
9+
dgp_parameters:
10+
theta: [0.5] # Treatment effect
11+
n_obs: [1000] # Sample size
12+
dim_x: [5] # Number of covariates
13+
14+
# Define reusable learner configurations
15+
learner_definitions:
16+
lgbmr: &lgbmr
17+
name: "LGBM Regr."
18+
19+
lgbmc: &lgbmc
20+
name: "LGBM Clas."
21+
22+
dml_parameters:
23+
learners:
24+
- ml_g: *lgbmr
25+
ml_m: *lgbmc
26+
27+
28+
confidence_parameters:
29+
level: [0.95, 0.90] # Confidence levels

0 commit comments

Comments
 (0)