Skip to content

Commit

Permalink
Merge pull request #48 from simai-ml/add-sample_weight
Browse files Browse the repository at this point in the history
Add sample weight
  • Loading branch information
gmartinonQM authored May 27, 2021
2 parents 43966bf + 36c8c7d commit c7b6430
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 7 deletions.
5 changes: 5 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
History
=======

0.2.1 (2020-XX-XX)
------------------

* Add sample_weight argument in fit method

0.2.0 (2021-05-21)
------------------

Expand Down
106 changes: 100 additions & 6 deletions mapie/estimators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
from typing import Optional, Union, Iterable, Tuple, List
from inspect import signature

import numpy as np
from joblib import Parallel, delayed
from sklearn.utils import check_X_y, check_array
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_is_fitted, _check_sample_weight
from sklearn.base import clone
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.linear_model import LinearRegression
Expand Down Expand Up @@ -276,14 +277,91 @@ def _check_alpha(self, alpha: Union[float, Iterable[float]]) -> np.ndarray:
raise ValueError("Invalid alpha. Allowed values are between 0 and 1.")
return alpha_np

def _check_null_weight(
self,
sample_weight: ArrayLike,
X: ArrayLike,
y: ArrayLike
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
"""
Check sample weights and remove samples with null sample weights.
Parameters
----------
sample_weight : ArrayLike
Sample weights.
X : ArrayLike
Training samples.
y : ArrayLike
Training labels.
Returns
-------
sample_weight : ArrayLike
Non-null sample weights.
X : ArrayLike
Training samples with non-null weights.
y : ArrayLike
Training labels with non-null weights.
"""
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
non_null_weight = sample_weight != 0
X, y = X[non_null_weight, :], y[non_null_weight]
sample_weight = sample_weight[non_null_weight]
return sample_weight, X, y

def _fit_estimator(
self,
estimator: RegressorMixin,
X: ArrayLike,
y: ArrayLike,
supports_sw: bool,
sample_weight: ArrayLike
) -> RegressorMixin:
"""
Fit an estimator on training data by distinguishing two cases:
- the estimator supports sample weights and sample weights are provided.
- the estimator does not support samples weights or samples weights are not provided
Parameters
----------
estimator : RegressorMixin
Estimator to train.
X : ArrayLike of shape (n_samples, n_features)
Input data.
y : ArrayLike of shape (n_samples,)
Input labels.
supports_sw : bool
Whether or not estimator supports sample weights.
sample_weight : ArrayLike of shape (n_samples,)
Sample weights. If None, then samples are equally weighted. By default None.
Returns
-------
RegressorMixin
Fitted estimator.
"""
if sample_weight is not None and supports_sw:
estimator.fit(X, y, sample_weight=sample_weight)
else:
estimator.fit(X, y)
return estimator

def _fit_and_predict_oof_model(
self,
estimator: RegressorMixin,
X: ArrayLike,
y: ArrayLike,
train_index: ArrayLike,
val_index: ArrayLike,
k: int
k: int,
supports_sw: bool,
sample_weight: Optional[ArrayLike] = None
) -> Tuple[RegressorMixin, ArrayLike, ArrayLike, ArrayLike]:
"""
Fit a single out-of-fold model on a given training set and
Expand All @@ -309,6 +387,12 @@ def _fit_and_predict_oof_model(
k : int
Split identification number.
supports_sw : bool
Whether or not estimator supports sample weights.
sample_weight : ArrayLike of shape (n_samples,)
Sample weights. If None, then samples are equally weighted. By default None.
Returns
-------
Tuple[RegressorMixin, ArrayLike, ArrayLike, ArrayLike]
Expand All @@ -319,12 +403,13 @@ def _fit_and_predict_oof_model(
- [3]: Validation data indices, of shapes (n_samples_val,)
"""
X_train, y_train, X_val = X[train_index], y[train_index], X[val_index]
estimator.fit(X_train, y_train)
sample_weight_train = sample_weight[train_index] if sample_weight is not None else None
estimator = self._fit_estimator(estimator, X_train, y_train, supports_sw, sample_weight_train)
y_pred = estimator.predict(X_val)
val_id = np.full_like(y_pred, k)
return estimator, y_pred, val_id, val_index

def fit(self, X: ArrayLike, y: ArrayLike) -> MapieRegressor:
def fit(self, X: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike] = None) -> MapieRegressor:
"""
Fit estimator and compute residuals used for prediction intervals.
Fit the base estimator under the ``single_estimator_`` attribute.
Expand All @@ -339,6 +424,12 @@ def fit(self, X: ArrayLike, y: ArrayLike) -> MapieRegressor:
y : ArrayLike of shape (n_samples,)
Training labels.
sample_weight : ArrayLike of shape (n_samples,), default=None
Sample weights for fitting the out-of-fold models. If None, then samples are equally weighted.
If some weights are null, their corresponding observations are removed before the fitting process and
hence have no residuals.
If weights are non-uniform, residuals are still uniformly weighted.
Returns
-------
MapieRegressor
Expand All @@ -348,17 +439,20 @@ def fit(self, X: ArrayLike, y: ArrayLike) -> MapieRegressor:
cv = self._check_cv(self.cv)
estimator = self._check_estimator(self.estimator)
X, y = check_X_y(X, y, force_all_finite=False, dtype=["float64", "object"])
fit_parameters = signature(estimator.fit).parameters
supports_sw = "sample_weight" in fit_parameters
sample_weight, X, y = self._check_null_weight(sample_weight, X, y)
y_pred = np.empty_like(y, dtype=float)
self.estimators_: List[RegressorMixin] = []
self.n_features_in_ = X.shape[1]
self.k_ = np.empty_like(y, dtype=int)
self.single_estimator_ = clone(estimator).fit(X, y)
self.single_estimator_ = self._fit_estimator(clone(estimator), X, y, supports_sw, sample_weight)
if self.method == "naive":
y_pred = self.single_estimator_.predict(X)
else:
cv_outputs = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
delayed(self._fit_and_predict_oof_model)(
clone(estimator), X, y, train_index, val_index, k
clone(estimator), X, y, train_index, val_index, k, supports_sw, sample_weight
) for k, (train_index, val_index) in enumerate(cv.split(X))
)
self.estimators_, predictions, val_ids, val_indices = map(list, zip(*cv_outputs))
Expand Down
27 changes: 26 additions & 1 deletion mapie/tests/test_estimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Union, Optional
from typing_extensions import TypedDict
from inspect import signature

import pytest
import numpy as np
Expand Down Expand Up @@ -54,7 +55,7 @@
SKLEARN_EXCLUDED_CHECKS = {
"check_regressors_train",
"check_pipeline_consistency",
"check_fit_score_takes_y"
"check_fit_score_takes_y",
}


Expand All @@ -75,6 +76,13 @@ def test_default_parameters() -> None:
assert mapie.n_jobs is None


def test_default_sample_weight() -> None:
"""Test default sample weights"""
mapie = MapieRegressor()
mapie.fit(X_toy, y_toy)
assert signature(mapie.fit).parameters["sample_weight"].default is None


def test_fit() -> None:
"""Test that fit raises no errors."""
mapie = MapieRegressor()
Expand Down Expand Up @@ -354,3 +362,20 @@ def test_results_single_and_multi_jobs(strategy: str) -> None:
y_preds_single = mapie_single.predict(X_toy)
y_preds_multi = mapie_multi.predict(X_toy)
np.testing.assert_almost_equal(y_preds_single, y_preds_multi)


@pytest.mark.parametrize("strategy", [*STRATEGIES])
def test_results_with_constant_sample_weights(strategy: str) -> None:
"""Test PIs when sample weights are None or constant with different values."""
n_samples = len(X_reg)
mapie0 = MapieRegressor(alpha=0.05, **STRATEGIES[strategy])
mapie0.fit(X_reg, y_reg, sample_weight=None)
mapie1 = MapieRegressor(alpha=0.05, **STRATEGIES[strategy])
mapie1.fit(X_reg, y_reg, sample_weight=np.ones(shape=n_samples))
mapie2 = MapieRegressor(alpha=0.05, **STRATEGIES[strategy])
mapie2.fit(X_reg, y_reg, sample_weight=np.ones(shape=n_samples)*5)
y_preds0 = mapie0.predict(X_reg)
y_preds1 = mapie1.predict(X_reg)
y_preds2 = mapie2.predict(X_reg)
np.testing.assert_almost_equal(y_preds0, y_preds1)
np.testing.assert_almost_equal(y_preds1, y_preds2)

0 comments on commit c7b6430

Please sign in to comment.