Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 179 additions & 118 deletions econml/_ortho_learner.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _gen_ortho_learner_model_final(self):
return _ModelFinal(self._gen_rlearner_model_final())

def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None):
cache_values=False, inference=None, only_final=False):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.

Expand Down
2 changes: 1 addition & 1 deletion econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _gen_rlearner_model_final(self):

# override only so that we can update the docstring to indicate support for `LinearModelFinalInference`
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
cache_values=False, inference='auto', only_final=False):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).

Expand Down
126 changes: 87 additions & 39 deletions econml/inference/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class BootstrapEstimator:
n_jobs: int, default: None
The maximum number of concurrently running jobs, as in joblib.Parallel.

only_final : bool, default True
Whether to bootstrap only the final model, for estimators that do cross-fitting.
Ignored for estimators where this does not apply.

verbose: int, default: 0
Verbosity level

Expand All @@ -56,12 +60,16 @@ class BootstrapEstimator:
def __init__(self, wrapped,
n_bootstrap_samples=100,
n_jobs=None,
only_final=True,
verbose=0,
compute_means=True,
bootstrap_type='pivot'):
if not hasattr(wrapped, "_gen_ortho_learner_model_final"):
only_final = False
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
self._n_bootstrap_samples = n_bootstrap_samples
self._n_jobs = n_jobs
self._only_final = only_final
self._verbose = verbose
self._compute_means = compute_means
self._bootstrap_type = bootstrap_type
Expand All @@ -86,44 +94,76 @@ def fit(self, *args, **named_args):
The full signature of this method is the same as that of the wrapped object's `fit` method.
"""
from .._cate_estimator import BaseCateEstimator # need to nest this here to avoid circular import
from ..panel.dml import DynamicDML

index_chunks = None
if isinstance(self._instances[0], BaseCateEstimator):
index_chunks = self._instances[0]._strata(*args, **named_args)
if index_chunks is not None:
index_chunks = self.__stratified_indices(index_chunks)
if index_chunks is None:
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
index_chunks = [np.arange(n_samples)] # one chunk with all indices

indices = []
for chunk in index_chunks:
n_samples = len(chunk)
indices.append(chunk[np.random.choice(n_samples,
size=(self._n_bootstrap_samples, n_samples),
replace=True)])

indices = np.hstack(indices)

if self._only_final:
self._wrapped._gen_cloned_ortho_learner_model_finals(self._n_bootstrap_samples)

def fit(x, *args, **kwargs):
x.fit(*args, **kwargs)
return x # Explicitly return x in case fit fails to return its target

def convertArg(arg, inds):
def convertArg_(arg, inds):
arr = np.asarray(arg)
if arr.ndim > 0:
return arr[inds]
else: # arg was a scalar, so we shouldn't have converted it
return arg
if arg is None:
return None
arr = np.asarray(arg)
if arr.ndim > 0:
return arr[inds]
else: # arg was a scalar, so we shouldn't have converted it
return arg

self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=self._verbose)(
delayed(fit)(obj,
*[convertArg(arg, inds) for arg in args],
**{arg: convertArg(named_args[arg], inds) for arg in named_args})
for obj, inds in zip(self._instances, indices)
)
if isinstance(arg, tuple):
converted_arg = []
for arg_param in arg:
converted_arg.append(convertArg_(arg_param, inds))
return tuple(converted_arg)
return convertArg_(arg, inds)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably worth reworking the API here a bit to pull this and some of the neighboring code out to a a new top-level function in this module (say, fit_with_subsets(clones, strata, *args, **kwargs)) that calls fit on each clone in parallel with the correctly modified arguments. This could then be used both here as well as directly within OrthoLearner for the final_only case.

"""
For DynamicDML only
Take n_bootstrap sets of samples of length n_panels among arange(n_panels) and then each sample corresponds with the chunk

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a comment rather than a docstring, and put it with the corresponding logic.

index_chunks = None
indices = []

if isinstance(self._wrapped, BaseCateEstimator):
index_chunks = self._instances[0]._strata(*args, **named_args)
if (index_chunks is not None):
index_chunks = self.__stratified_indices(index_chunks)
if index_chunks is None:
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
index_chunks = [np.arange(n_samples)] # one chunk with all indices
if isinstance(self._wrapped, DynamicDML):
n_index_chunks = len(index_chunks)
bootstrapped_chunk_indices = np.random.choice(n_index_chunks,
size=(self._n_bootstrap_samples, n_index_chunks),
replace=True)
for i in range(self._n_bootstrap_samples):
samples = bootstrapped_chunk_indices[i]
sample_chunk_indices = [index_chunks[j] for j in samples]
indices_sample = np.hstack(sample_chunk_indices)
indices.append(indices_sample)
indices = np.array(indices)
else:
for chunk in index_chunks:
n_samples = len(chunk)
sample = chunk[np.random.choice(n_samples,
size=(self._n_bootstrap_samples, n_samples),
replace=True)]
indices.append(sample)
indices = np.hstack(indices)
if not self._only_final:
self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=self._verbose)(
delayed(fit)(obj,
*[convertArg(arg, inds) for arg in args],
**{arg: convertArg(named_args[arg], inds) for arg in named_args})
for obj, inds in zip(self._instances, indices)
)
else:
self._wrapped._set_bootstrap_params(indices, self._n_bootstrap_samples, self._verbose)
self._wrapped.fit(*args, **named_args)
self._instances = [clone(self._wrapped, safe=False)]
return self

def __getattr__(self, name):
Expand All @@ -139,8 +179,16 @@ def __getattr__(self, name):

def proxy(make_call, name, summary):
def summarize_with(f):
results = np.array(Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=self._verbose)(
(f, (obj, name), {}) for obj in self._instances)), f(self._wrapped, name)
instance_results = []
obj = clone(self._wrapped, safe=False)
for i in range(self._n_bootstrap_samples):
if self._only_final:
obj._set_current_cloned_ortho_learner_model_final(i)
else:
obj = self._instances[i]
instance_results.append(f(obj, name))
instance_results = np.array(instance_results)
results = instance_results, f(self._wrapped, name)
Comment on lines +171 to +178
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to continue to use parallelism here if possible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our discussion, it seems like keeping this in _bootstrap.py and using parallel processes or threads will not be possible. In the future, we would like to make this call in _ortho_learner.py, just as we do with fit.

return summary(*results)
if make_call:
def call(*args, **kwargs):
Expand All @@ -151,11 +199,11 @@ def call(*args, **kwargs):

def get_mean():
# for attributes that exist on the wrapped object, just compute the mean of the wrapped calls
return proxy(callable(getattr(self._instances[0], name)), name, lambda arr, _: np.mean(arr, axis=0))
return proxy(callable(getattr(self._wrapped, name)), name, lambda arr, _: np.mean(arr, axis=0))

def get_std():
prefix = name[: - len('_std')]
return proxy(callable(getattr(self._instances[0], prefix)), prefix,
return proxy(callable(getattr(self._wrapped, prefix)), prefix,
lambda arr, _: np.std(arr, axis=0))

def get_interval():
Expand All @@ -182,7 +230,7 @@ def normal_bootstrap(arr, est):
'pivot': pivot_bootstrap}[self._bootstrap_type]
return proxy(can_call, prefix, fn)

can_call = callable(getattr(self._instances[0], prefix))
can_call = callable(getattr(self._wrapped, prefix))
if can_call:
# collect extra arguments and pass them through, if the wrapped attribute was callable
def call(*args, lower=5, upper=95, **kwargs):
Expand All @@ -208,10 +256,10 @@ def fname_transformer(x):
inf_type = 'effect'
elif prefix == 'coef_':
inf_type = 'coefficient'
if (hasattr(self._instances[0], 'cate_feature_names') and
callable(self._instances[0].cate_feature_names)):
if (hasattr(self._wrapped, 'cate_feature_names') and
callable(self._wrapped.cate_feature_names)):
def fname_transformer(x):
return self._instances[0].cate_feature_names(x)
return self._wrapped.cate_feature_names(x)
elif prefix == 'intercept_':
inf_type = 'intercept'
else:
Expand All @@ -223,7 +271,7 @@ def fname_transformer(x):
d_t = None
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1

can_call = callable(getattr(self._instances[0], prefix))
can_call = callable(getattr(self._wrapped, prefix))

kind = self._bootstrap_type
if kind == 'percentile' or kind == 'pivot':
Expand Down
9 changes: 7 additions & 2 deletions econml/inference/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class BootstrapInference(Inference):
verbose: int, default: 0
Verbosity level

only_final : bool, default True
Whether to bootstrap only the final model, for estimators that do cross-fitting.
Ignored for estimators where this does not apply.

bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot'
Bootstrap method used to compute results.
'percentile' will result in using the empiracal CDF of the replicated computations of the statistics.
Expand All @@ -79,14 +83,15 @@ class BootstrapInference(Inference):
'normal' will instead compute a pivot interval assuming the replicates are normally distributed.
"""

def __init__(self, n_bootstrap_samples=100, n_jobs=-1, bootstrap_type='pivot', verbose=0):
def __init__(self, n_bootstrap_samples=100, n_jobs=-1, only_final=True, bootstrap_type='pivot', verbose=0):
self._n_bootstrap_samples = n_bootstrap_samples
self._n_jobs = n_jobs
self._only_final = only_final
self._bootstrap_type = bootstrap_type
self._verbose = verbose

def fit(self, estimator, *args, **kwargs):
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False,
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, self._only_final, compute_means=False,
bootstrap_type=self._bootstrap_type, verbose=self._verbose)
filtered_kwargs = filter_none_kwargs(**kwargs)
est.fit(*args, **filtered_kwargs)
Expand Down
13 changes: 12 additions & 1 deletion econml/panel/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@


def _get_groups_period_filter(groups, n_periods):
"""
Computes a dictionary of group indices at each timestep

If n_periods = 3, we would expect group_period_filter to contain the indices
of the full dataset that correspond with a sample collected at each period
number. This later becomes used to index into the full dataset.
"""
group_counts = {}
group_period_filter = {i: [] for i in range(n_periods)}
for i, g in enumerate(groups):
if g not in group_counts:
group_counts[g] = 0
group_period_filter[group_counts[g]].append(i)
group_period_filter[group_counts[g] % n_periods].append(i)
group_counts[g] += 1
return group_period_filter

Expand Down Expand Up @@ -157,6 +164,10 @@ def __init__(self, model_final, n_periods):

def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None, groups=None):
# NOTE: sample weight, sample var are not passed in
_, group_counts = np.unique(groups, return_counts=True)
unique_group_counts = np.unique(group_counts)
assert np.all(unique_group_counts % self.n_periods == 0), \
"Each group should appear in whole multiples in bootstrapping"
period_filters = _get_groups_period_filter(groups, self.n_periods)
Y_res, T_res = nuisances
self._d_y = Y.shape[1:]
Expand Down
Loading