Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vasilis/drlearner #137

Merged
merged 51 commits into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f31d86a
separating drlearner
vasilismsr Nov 6, 2019
7c4d76d
better docstring
vasilismsr Nov 6, 2019
6e2c972
notebook
vasilismsr Nov 6, 2019
b9d5238
drlearner documentation finalized.
vasilismsr Nov 6, 2019
ad72eab
removed doubly robust learner from metalearners
vasilismsr Nov 6, 2019
868b5c1
revmoed drlearner test from metalearners and started new test file fo…
vasilismsr Nov 6, 2019
a293f0f
revmoed drlearner test from metalearners and started new test file fo…
vasilismsr Nov 6, 2019
c63d6b3
changed tests in dml to adhere to the CATE API. Some calls to effect …
vasilismsr Nov 6, 2019
d2494b6
added exhaustive tests for drlearner. Fixed bugs from corner cases
vasilismsr Nov 7, 2019
e75fd48
bug in docstring regarding how split is called to generate crossfit f…
vasilismsr Nov 7, 2019
fbc3f1d
changed bootstarp tests to conform with keyword only effect and effec…
vasilismsr Nov 7, 2019
3397d95
bootstrap tests, fixing bugs related to positional arguments T0, T1
vasilismsr Nov 7, 2019
b13581e
linting
vasilismsr Nov 7, 2019
e4c355d
more tests for drlearner and small bugs
vasilismsr Nov 7, 2019
8611577
reverted back to coef_ and intercept_. Fixed docstring. Added exhaust…
vasilismsr Nov 7, 2019
0dcea06
added checks of fitted dims for W and Z during scoring in _OrthoLearn…
vasilismsr Nov 7, 2019
7c55fb0
testing notebook
vasilismsr Nov 7, 2019
b5c0a07
linting
vasilismsr Nov 7, 2019
59d9fb1
docstring fix regarding n_splits in multiple places. Fixed notebook f…
vasilismsr Nov 7, 2019
d3274ff
changed statsmodels inference inptu properties to reflect that we onl…
vasilismsr Nov 7, 2019
988c53d
small change in test drlearner
vasilismsr Nov 7, 2019
e48f7de
docstring for linear drlearner
vasilismsr Nov 7, 2019
95fa08f
improved example in docstring of lineardrlearner
vasilismsr Nov 7, 2019
ef0dff5
adding more slacks to some drlearner coverage tests
vasilismsr Nov 7, 2019
07a79e2
Merge branch 'master' into vasilis/drlearner
vasilismsr Nov 7, 2019
d2b7835
linting
vasilismsr Nov 7, 2019
0474015
linting
vasilismsr Nov 7, 2019
f6d1505
removed OrthoLearner testing notebook
vasilismsr Nov 7, 2019
cb60dc4
comments on fit score dimension mismatch
vasilismsr Nov 8, 2019
9768a76
removed leftover print statement from cate estimator
vasilismsr Nov 8, 2019
c315539
replaced :code: with
vasilismsr Nov 8, 2019
51e0595
added :meth:
vasilismsr Nov 8, 2019
f377e95
added :meth:
vasilismsr Nov 8, 2019
aa21807
removed + for string concat
vasilismsr Nov 8, 2019
a3b7b52
removed redundant test code
vasilismsr Nov 8, 2019
b00052c
added comment on overlapping tests between DML and DRLearner
vasilismsr Nov 8, 2019
cd597f4
removed redundant rand_sol code in test_drlearner
vasilismsr Nov 8, 2019
42fae99
removed + operator for string concat
vasilismsr Nov 8, 2019
8de555b
added TODO for allowing for 2d y of shape (n,1) and also added test t…
vasilismsr Nov 8, 2019
084561d
removed redundant adding and subtracting in statsmodelscateestimator
vasilismsr Nov 8, 2019
d606f2a
changed :attr: to :meth:
vasilismsr Nov 8, 2019
7803c2e
added TODO so that we merge functionality between statsmodelsinferenc…
vasilismsr Nov 8, 2019
610db58
replaced printing with subTests
vasilismsr Nov 8, 2019
68d5180
linting
vasilismsr Nov 8, 2019
fdd536a
fixed docstring in dml. Added utility function of inverse_onehot enco…
vasilismsr Nov 9, 2019
bfb1ff8
removed replacing None weights with np.ones in drlearner scoring sinc…
vasilismsr Nov 9, 2019
179c2cf
typo in error mst
vasilismsr Nov 9, 2019
9c4271e
made statsmodelslinearregression be child of BaseEstimator
vasilismsr Nov 9, 2019
8ea37c8
added comment on code design choice in model_final of drlearner, rela…
vasilismsr Nov 9, 2019
36ddc55
put docstrings in methods and removed them from attributes
vasilismsr Nov 9, 2019
c1549c8
Merge branch 'master' into vasilis/drlearner
vasilismsr Nov 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Public Module Reference
econml.deepiv
econml.dgp
econml.dml
econml.drlearner
econml.inference
econml.ortho_forest
econml.selective_regularization
Expand Down
25 changes: 20 additions & 5 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
n_splits: int, cross-validation generator or an iterable
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand Down Expand Up @@ -411,13 +411,26 @@ def _check_input_dims(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sa
for arr in [X, W, Z, sample_weight, sample_var]:
assert (arr is None) or (arr.shape[0] == Y.shape[0]), "Dimension mismatch"
self._d_x = X.shape[1:] if X is not None else None
self._d_w = W.shape[1:] if W is not None else None
self._d_z = Z.shape[1:] if Z is not None else None

def _check_fitted_dims(self, X):
if X is None:
assert self._d_x is None, "X was not None when fitting, so can't be none for effect"
assert self._d_x is None, "X was not None when fitting, so can't be none for score or effect"
else:
assert self._d_x == X.shape[1:], "Dimension mis-match of X with fitted X"

def _check_fitted_dims_w_z(self, W, Z):
if W is None:
assert self._d_w is None, "W was not None when fitting, so can't be none for score"
else:
assert self._d_w == W.shape[1:], "Dimension mis-match of W with fitted W"

if Z is None:
assert self._d_z is None, "Z was not None when fitting, so can't be none for score"
else:
assert self._d_z == Z.shape[1:], "Dimension mis-match of Z with fitted Z"

def _subinds_check_none(self, var, inds):
return var[inds] if var is not None else None

Expand Down Expand Up @@ -485,14 +498,14 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
folds = splitter.split(np.ones((T.shape[0], 1)), T)

if self._discrete_treatment:
T = self._label_encoder.fit_transform(T)
T = self._label_encoder.fit_transform(T.ravel())
# drop first column since all columns sum to one
T = self._one_hot_encoder.fit_transform(reshape(T, (-1, 1)))[:, 1:]
self._d_t = shape(T)[1:]
self.transformer = FunctionTransformer(
func=(lambda T:
self._one_hot_encoder.transform(
reshape(self._label_encoder.transform(T), (-1, 1)))[:, 1:]),
reshape(self._label_encoder.transform(T.ravel()), (-1, 1)))[:, 1:]),
validate=False)

nuisances, fitted_models, fitted_inds = _crossfit(self._model_nuisance, folds,
Expand Down Expand Up @@ -524,7 +537,7 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
return super().const_marginal_effect_interval(X, alpha=alpha)
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__

def effect_interval(self, X=None, T0=0, T1=1, *, alpha=0.1):
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
self._check_fitted_dims(X)
return super().effect_interval(X, T0=T0, T1=T1, alpha=alpha)
effect_interval.__doc__ = LinearCateEstimator.effect_interval.__doc__
Expand Down Expand Up @@ -560,6 +573,8 @@ def score(self, Y, T, X=None, W=None, Z=None):
"""
if not hasattr(self._model_final, 'score'):
raise AttributeError("Final model does not have a score method!")
self._check_fitted_dims(X)
self._check_fitted_dims_w_z(W, Z)
X, T = self._expand_treatments(X, T)
n_splits = len(self._models_nuisance)
for idx, mdl in enumerate(self._models_nuisance):
Expand Down
9 changes: 5 additions & 4 deletions econml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class _RLearner(_OrthoLearner):
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
n_splits: int, cross-validation generator or an iterable
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand All @@ -92,7 +92,8 @@ class _RLearner(_OrthoLearner):
:class:`~sklearn.model_selection.KFold` is used
(with a random shuffle in either case).

Unless an iterable is used, we call `split(X,T)` to generate the splits.
Unless an iterable is used, we call `split(concat[W, X], T)` to generate the splits. If all
W, X are None, then we call `split(ones((T.shape[0], 1)), T)`.

random_state: int, :class:`~numpy.random.mtrand.RandomState` instance or None
If int, random_state is the seed used by the random number generator;
Expand Down Expand Up @@ -203,8 +204,8 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
Y_pred = self._model_y.predict(X, W)
T_pred = self._model_t.predict(X, W)
if (X is None) and (W is None): # In this case predict above returns a single row
Y_pred = np.tile(Y_pred, Y.shape[0])
T_pred = np.tile(T_pred, T.shape[0])
Y_pred = np.tile(Y_pred.reshape(1, -1), (Y.shape[0], 1))
T_pred = np.tile(T_pred.reshape(1, -1), (T.shape[0], 1))
Y_res = Y - Y_pred.reshape(Y.shape)
T_res = T - T_pred.reshape(T.shape)
return Y_res, T_res
Expand Down
157 changes: 155 additions & 2 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .bootstrap import BootstrapEstimator
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape
from .inference import StatsModelsInference
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -173,10 +173,51 @@ def call(self, *args, **kwargs):

@_defer_to_inference
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\tau(X, T0, T1)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`effect(X, T0, T1)<effect>`, type of :meth:`effect(X, T0, T1))<effect>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\partial \\tau(T, X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`marginal_effect(T, X)<marginal_effect>`, \
type of :meth:`marginal_effect(T, X)<marginal_effect>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass


Expand Down Expand Up @@ -276,12 +317,32 @@ def marginal_effect(self, T, X=None):
return np.repeat(eff, shape(T)[0], axis=0) if X is None else eff

def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
X, T = self._expand_treatments(X, T)
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
return tuple(np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
for eff in effs)
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
by the model. Available only when `inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`const_marginal_effect(X)<const_marginal_effect>` ,\
type of :meth:`const_marginal_effect(X)<const_marginal_effect>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass


Expand Down Expand Up @@ -314,7 +375,7 @@ def _expand_treatments(self, X=None, *Ts):
return (X,) + tuple(outTs)

# override effect to set defaults, which works with the new definition of _expand_treatments
def effect(self, X=None, T0=0, T1=1):
def effect(self, X=None, *, T0=0, T1=1):
# NOTE: don't explicitly expand treatments here, because it's done in the super call
return super().effect(X, T0=T0, T1=T1)
effect.__doc__ = BaseCateEstimator.effect.__doc__
Expand Down Expand Up @@ -348,3 +409,95 @@ def coef__interval(self, *, alpha=0.1):
@BaseCateEstimator._defer_to_inference
def intercept__interval(self, *, alpha=0.1):
pass


class StatsModelsCateEstimatorDiscreteMixin(BaseCateEstimator):
# TODO Create parent StatsModelsCateEstimatorMixin class so that some functionalities can be shared

def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(statsmodels=StatsModelsInferenceDiscrete)
return options

@property
@abc.abstractmethod
def statsmodels(self):
pass

def coef_(self, T):
""" The coefficients in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.

Returns
-------
coef: (n_x,) or (n_y, n_x) array like
Where n_x is the number of features that enter the final model (either the
dimension of X or the dimension of featurizer.fit_transform(X) if the CATE
estimator has a featurizer.)
"""
_, T = self._expand_treatments(None, T)
ind = (T @ np.arange(T.shape[1])).astype(int)[0]
return self.statsmodels_fitted[ind].coef_

def intercept_(self, T):
""" The intercept in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.

Returns
-------
intercept: float or (n_y,) array like
"""
_, T = self._expand_treatments(None, T)
ind = (T @ np.arange(1, T.shape[1] + 1)).astype(int)[0] - 1
return self.statsmodels_fitted[ind].intercept_

@BaseCateEstimator._defer_to_inference
def coef__interval(self, T, *, alpha=0.1):
""" The confidence interval for the coefficients in the linear model of the
constant marginal treatment effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper: tuple(type of :meth:`coef_(T)<coef_>`, type of :meth:`coef_(T)<coef_`)
The lower and upper bounds of the confidence interval for each quantity.
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__interval(self, T, *, alpha=0.1):
""" The intercept in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper: tuple(type of :meth:`intercept_(T)<intercept_>`, type of :meth:`intercept_(T)<intercept_>`)
The lower and upper bounds of the confidence interval.
"""
pass
22 changes: 13 additions & 9 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import copy
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose,
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose, inverse_onehot,
broadcast_unit_treatments, reshape_treatmentwise_effects,
StatsModelsLinearRegression, LassoCVWrapper)
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
Expand Down Expand Up @@ -57,7 +57,7 @@ class DMLCateEstimator(_RLearner):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand All @@ -71,7 +71,8 @@ class DMLCateEstimator(_RLearner):
:class:`~sklearn.model_selection.KFold` is used
(with a random shuffle in either case).

Unless an iterable is used, we call `split(X,T)` to generate the splits.
Unless an iterable is used, we call `split(concat[W, X], T)` to generate the splits. If all
W, X are None, then we call `split(ones((T.shape[0], 1)), T)`.

random_state: int, :class:`~numpy.random.mtrand.RandomState` instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
Expand Down Expand Up @@ -120,7 +121,10 @@ def fit(self, X, W, Target, sample_weight=None):
# In this case, the Target is the one-hot-encoding of the treatment variable
# We need to go back to the label representation of the one-hot so as to call
# the classifier.
Target = np.matmul(Target, np.arange(1, Target.shape[1] + 1)).flatten()
if np.any(np.all(Target == 0, axis=0)) or (not np.any(np.all(Target == 0, axis=1))):
raise AttributeError("Provided crossfit folds contain training splits that " +
"don't contain all treatments")
Target = inverse_onehot(Target)

if sample_weight is not None:
self._model.fit(self._combine(X, W, Target.shape[0]), Target, sample_weight=sample_weight)
Expand Down Expand Up @@ -210,8 +214,8 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
The estimator for fitting the treatment to the features. Must implement
`fit` and `predict` methods.

featurizer: transformer, optional
(default is :class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
featurizer: transformer, optional (default is \
:class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
The transformer used to featurize the raw features when fitting the final model. Must implement
a `fit_transform` method.

Expand All @@ -222,7 +226,7 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand Down Expand Up @@ -328,7 +332,7 @@ class SparseLinearDMLCateEstimator(DMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand Down Expand Up @@ -391,7 +395,7 @@ class KernelDMLCateEstimator(LinearDMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand Down
Loading