Skip to content

Commit d861a3a

Browse files
authored
Merge pull request #3 from DoubleML/m-iivm-subgroups
implemented the subgroups for the IIVM model for the serverless class
2 parents a826c31 + 57d676d commit d861a3a

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

doubleml_serverless/double_ml_iivm_aws_lambda.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,32 +74,49 @@ def _ml_nuisance_aws_lambda(self, cv_params):
7474
self._dml_data.z_cols[0], self._dml_data.x_cols,
7575
method='predict_proba')
7676

77-
_attach_learner(payload_ml_r0,
78-
'ml_r0', self.learner['ml_r'],
79-
self._dml_data.d_cols[0], self._dml_data.x_cols,
80-
method='predict_proba')
81-
82-
_attach_learner(payload_ml_r1,
83-
'ml_r1', self.learner['ml_r'],
84-
self._dml_data.d_cols[0], self._dml_data.x_cols,
85-
method='predict_proba')
86-
87-
all_payloads = [payload_ml_g0, payload_ml_g1, payload_ml_m, payload_ml_r0, payload_ml_r1]
88-
all_smpls = [smpls_z0, smpls_z1, self.smpls, smpls_z0, smpls_z1]
77+
all_payloads = [payload_ml_g0, payload_ml_g1, payload_ml_m]
78+
all_smpls = [smpls_z0, smpls_z1, self.smpls]
79+
send_train_ids = [True, True, False]
80+
params_names = ['ml_g0', 'ml_g1', 'ml_m']
81+
82+
if self.subgroups['always_takers']:
83+
_attach_learner(payload_ml_r0,
84+
'ml_r0', self.learner['ml_r'],
85+
self._dml_data.d_cols[0], self._dml_data.x_cols,
86+
method='predict_proba')
87+
all_payloads.append(payload_ml_r0)
88+
all_smpls.append(smpls_z0)
89+
send_train_ids.append(True)
90+
params_names.append('ml_r0')
91+
92+
if self.subgroups['never_takers']:
93+
_attach_learner(payload_ml_r1,
94+
'ml_r1', self.learner['ml_r'],
95+
self._dml_data.d_cols[0], self._dml_data.x_cols,
96+
method='predict_proba')
97+
all_payloads.append(payload_ml_r1)
98+
all_smpls.append(smpls_z1)
99+
send_train_ids.append(True)
100+
params_names.append('ml_r1')
89101

90102
payloads = _attach_smpls(all_payloads,
91103
all_smpls,
92104
self.n_folds,
93105
self.n_rep,
94106
self._dml_data.n_obs,
95107
cv_params['n_lambdas_cv'],
96-
[True, True, False, True, True],
108+
send_train_ids,
97109
cv_params['seed'])
98110

99-
preds = self.invoke_lambdas(payloads, self.smpls, self.params_names,
111+
preds = self.invoke_lambdas(payloads, self.smpls, params_names,
100112
self._dml_data.n_obs, self.n_rep,
101113
cv_params['n_lambdas_cv'])
102114

115+
if not self.subgroups['always_takers']:
116+
preds['ml_r0'] = np.zeros_like(preds['ml_g0'])
117+
if not self.subgroups['never_takers']:
118+
preds['ml_r1'] = np.ones_like(preds['ml_g1'])
119+
103120
for i_rep in range(self.n_rep):
104121
# compute score elements
105122

doubleml_serverless/tests/test_iivm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,16 @@ def trimming_threshold(request):
5151
return request.param
5252

5353

54+
@pytest.fixture(scope='module',
55+
params=[{'always_takers': True, 'never_takers': True},
56+
{'always_takers': False, 'never_takers': True},
57+
{'always_takers': True, 'never_takers': False}])
58+
def subgroups(request):
59+
return request.param
60+
61+
5462
@pytest.fixture(scope="module")
55-
def dml_iivm_fixture(generate_data_iivm, idx, learner, score, dml_procedure, trimming_threshold):
63+
def dml_iivm_fixture(generate_data_iivm, idx, learner, score, dml_procedure, trimming_threshold, subgroups):
5664
boot_methods = ['normal']
5765
n_folds = 4
5866
n_rep_boot = 502
@@ -77,6 +85,7 @@ def dml_iivm_fixture(generate_data_iivm, idx, learner, score, dml_procedure, tri
7785
ml_g, ml_m, ml_r,
7886
n_folds,
7987
score=score,
88+
subgroups=subgroups,
8089
dml_procedure=dml_procedure)
8190

8291
dml_iivm_lambda.fit_aws_lambda()
@@ -87,6 +96,7 @@ def dml_iivm_fixture(generate_data_iivm, idx, learner, score, dml_procedure, tri
8796
ml_g, ml_m, ml_r,
8897
n_folds,
8998
score=score,
99+
subgroups=subgroups,
90100
dml_procedure=dml_procedure)
91101

92102
dml_iivm.fit()

0 commit comments

Comments
 (0)