@@ -74,32 +74,49 @@ def _ml_nuisance_aws_lambda(self, cv_params):
7474 self ._dml_data .z_cols [0 ], self ._dml_data .x_cols ,
7575 method = 'predict_proba' )
7676
77- _attach_learner (payload_ml_r0 ,
78- 'ml_r0' , self .learner ['ml_r' ],
79- self ._dml_data .d_cols [0 ], self ._dml_data .x_cols ,
80- method = 'predict_proba' )
81-
82- _attach_learner (payload_ml_r1 ,
83- 'ml_r1' , self .learner ['ml_r' ],
84- self ._dml_data .d_cols [0 ], self ._dml_data .x_cols ,
85- method = 'predict_proba' )
86-
87- all_payloads = [payload_ml_g0 , payload_ml_g1 , payload_ml_m , payload_ml_r0 , payload_ml_r1 ]
88- all_smpls = [smpls_z0 , smpls_z1 , self .smpls , smpls_z0 , smpls_z1 ]
77+ all_payloads = [payload_ml_g0 , payload_ml_g1 , payload_ml_m ]
78+ all_smpls = [smpls_z0 , smpls_z1 , self .smpls ]
79+ send_train_ids = [True , True , False ]
80+ params_names = ['ml_g0' , 'ml_g1' , 'ml_m' ]
81+
82+ if self .subgroups ['always_takers' ]:
83+ _attach_learner (payload_ml_r0 ,
84+ 'ml_r0' , self .learner ['ml_r' ],
85+ self ._dml_data .d_cols [0 ], self ._dml_data .x_cols ,
86+ method = 'predict_proba' )
87+ all_payloads .append (payload_ml_r0 )
88+ all_smpls .append (smpls_z0 )
89+ send_train_ids .append (True )
90+ params_names .append ('ml_r0' )
91+
92+ if self .subgroups ['never_takers' ]:
93+ _attach_learner (payload_ml_r1 ,
94+ 'ml_r1' , self .learner ['ml_r' ],
95+ self ._dml_data .d_cols [0 ], self ._dml_data .x_cols ,
96+ method = 'predict_proba' )
97+ all_payloads .append (payload_ml_r1 )
98+ all_smpls .append (smpls_z1 )
99+ send_train_ids .append (True )
100+ params_names .append ('ml_r1' )
89101
90102 payloads = _attach_smpls (all_payloads ,
91103 all_smpls ,
92104 self .n_folds ,
93105 self .n_rep ,
94106 self ._dml_data .n_obs ,
95107 cv_params ['n_lambdas_cv' ],
96- [ True , True , False , True , True ] ,
108+ send_train_ids ,
97109 cv_params ['seed' ])
98110
99- preds = self .invoke_lambdas (payloads , self .smpls , self . params_names ,
111+ preds = self .invoke_lambdas (payloads , self .smpls , params_names ,
100112 self ._dml_data .n_obs , self .n_rep ,
101113 cv_params ['n_lambdas_cv' ])
102114
115+ if not self .subgroups ['always_takers' ]:
116+ preds ['ml_r0' ] = np .zeros_like (preds ['ml_g0' ])
117+ if not self .subgroups ['never_takers' ]:
118+ preds ['ml_r1' ] = np .ones_like (preds ['ml_g1' ])
119+
103120 for i_rep in range (self .n_rep ):
104121 # compute score elements
105122
0 commit comments