diff --git a/conftest.py b/conftest.py index a0c8bd49..7d76df13 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from skada.datasets import DomainAwareDataset, make_shifted_blobs +from skada.datasets import DomainAwareDataset, make_shifted_blobs, make_shifted_datasets # xxx(okachaiev): old API has to be gone when re-writing is done @@ -30,6 +30,19 @@ def tmp_da_dataset(): ) +@pytest.fixture(scope='session') +def da_reg_dataset(): + X, y, sample_domain = make_shifted_datasets( + n_samples_source=20, + n_samples_target=21, + shift="concept_drift", + noise=0.3, + label="regression", + random_state=42, + ) + return X, y, sample_domain + + @pytest.fixture(scope='session') def da_dataset() -> DomainAwareDataset: centers = np.array([[0, 0], [1, 1]]) diff --git a/docs/source/all.rst b/docs/source/all.rst index e8cbc88f..8be04fdb 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -44,6 +44,7 @@ API and modules ClassRegularizerOTMapping LinearOTMapping CORAL + JDOTRegressor make_da_pipeline diff --git a/examples/methods/plot_jdot_da.py b/examples/methods/plot_jdot_da.py new file mode 100644 index 00000000..82532760 --- /dev/null +++ b/examples/methods/plot_jdot_da.py @@ -0,0 +1,119 @@ +""" +Plot JDOT Regressor +=================== + + +""" +# %% Imports +import numpy as np +import matplotlib.pyplot as plt + +from sklearn.metrics import mean_squared_error +from sklearn.kernel_ridge import KernelRidge + +from skada import JDOTRegressor +from skada.datasets import make_shifted_datasets +from skada import source_target_split + + +# %% +# Generate concept drift dataset +# ------------------------------ + +X, y, sample_domain = make_shifted_datasets( + n_samples_source=20, + n_samples_target=20, + shift="concept_drift", + noise=0.3, + label="regression", + random_state=42, + ) + +y = (y-y.mean())/y.std() + +Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) + + +# %% +# Plot data +# --------- + +plt.figure(1, (10, 5)) +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Source") +plt.title("Source data") +ax = plt.axis() + +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Target") +plt.title("Target data") +plt.axis(ax) + +# %% +# Train on source data +# -------------------- + + +clf = KernelRidge(kernel='rbf', alpha=0.5) +clf.fit(Xs, ys) + +# Compute accuracy on source and target +ys_pred = clf.predict(Xs) +yt_pred = clf.predict(Xt) + +mse_s = mean_squared_error(ys, ys_pred) +mse_t = mean_squared_error(yt, yt_pred) + +print(f"MSE on source: {mse_s:.2f}") +print(f"MSE on target: {mse_t:.2f}") + +XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100)) +Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) + + +plt.figure(2, (10, 5)) +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction") +plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5) +plt.title(f"KRR Prediction on source (MSE={mse_s:.2f})") +plt.axis(ax) + +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction") +plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5) +plt.title(f"KRR Prediction on target (MSE={mse_t:.2f})") +plt.axis(ax) + + +# %% +# Train with JDOT regressor +# ------------------------- + + +jdot = JDOTRegressor(base_estimator=KernelRidge(kernel='rbf', alpha=0.5), alpha=0.01) + +jdot.fit(X, y, sample_domain=sample_domain) + +ys_pred = jdot.predict(Xs) +yt_pred = jdot.predict(Xt) + +mse_s = mean_squared_error(ys, ys_pred) +mse_t = mean_squared_error(yt, yt_pred) + +Zjdot = jdot.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) + +print(f"JDOT MSE on source: {mse_s:.2f}") +print(f"JDOT MSE on target: {mse_t:.2f}") + +plt.figure(3, (10, 5)) +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction") +plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5) +plt.title(f"JDOT Prediction on source (MSE={mse_s:.2f})") +plt.axis(ax) + +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction") +plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5) +plt.title(f"JDOT Prediction on target (MSE={mse_t:.2f})") +plt.axis(ax) diff --git a/skada/__init__.py b/skada/__init__.py index 252556a9..7129dbca 100644 --- a/skada/__init__.py +++ b/skada/__init__.py @@ -38,6 +38,7 @@ TransferComponentAnalysisAdapter, TransferComponentAnalysis, ) +from ._ot import solve_jdot_regression, JDOTRegressor from ._pipeline import make_da_pipeline from .utils import source_target_split @@ -79,6 +80,9 @@ "TransferComponentAnalysisAdapter", "TransferComponentAnalysis", + "solve_jdot_regression", + "JDOTRegressor", + "make_da_pipeline", "source_target_split", diff --git a/skada/_ot.py b/skada/_ot.py new file mode 100644 index 00000000..881cd321 --- /dev/null +++ b/skada/_ot.py @@ -0,0 +1,223 @@ +# Author: Remi Flamary +# +# License: BSD 3-Clause + + +import numpy as np +from sklearn.base import clone +from sklearn.utils.validation import check_is_fitted +from sklearn.linear_model import LinearRegression +from .base import DAEstimator +from .utils import source_target_split +import ot +import warnings + + +def solve_jdot_regression(base_estimator, Xs, ys, Xt, alpha=0.5, ws=None, wt=None, + n_iter_max=100, tol=1e-5, verbose=False, **kwargs): + """Solve the joint distribution optimal transport regression problem + + Parameters + ---------- + base_estimator : object + The base estimator to be used for the regression task. This estimator + should solve a least squares regression problem (regularized or not) + to correspond to JDOT theoretical regression problem but other + approaches can be used with the risk that the fixed point might not converge. + Xs : array-like of shape (n_samples, n_features) + Source domain samples. + ys : array-like of shape (n_samples,) + Source domain labels. + Xt : array-like of shape (m_samples, n_features) + Target domain samples. + alpha : float, default=0.5 + The trade-off parameter between the feature and label loss in OT metric + ws : array-like of shape (n_samples,) + Source domain weights (will ne normalized to sum to 1). + wt : array-like of shape (m_samples,) + Target domain weights (will ne normalized to sum to 1). + n_iter_max: int + Max number of JDOT alternat optimization iterations. + tol: float>0 + Tolerance for loss variations (OT and mse) stopping iterations. + verbose: bool + Print loss along iterations if True.as_integer_ratio + kwargs : dict + Additional parameters to be passed to the base estimator. + + + Returns + ------- + estimator : object + The fitted estimator. + lst_loss_ot : list + The list of OT losses at each iteration. + lst_loss_tgt_labels : list + The list of target labels losses at each iteration. + sol : object + The solution of the OT problem. + + References + ---------- + [1] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution + Optimal Transportation for Domain Adaptation, Neural Information Processing + Systems (NIPS), 2017. + + """ + + estimator = clone(base_estimator) + + # compute feature distance matrix + Mf = ot.dist(Xs, Xt) + Mf = Mf / Mf.mean() + + nt = Xt.shape[0] + if ws is None: + a = np.ones((len(ys),)) / len(ys) + else : + a = ws / ws.sum() + if wt is None: + b = np.ones((nt,)) / nt + else: + b = wt / wt.sum() + kwargs['sample_weight'] = wt # add it as sample_weight for fit + + lst_loss_ot = [] + lst_loss_tgt_labels = [] + y_pred = 0 + Ml = ot.dist(ys.reshape(-1, 1), np.zeros((nt, 1))) + + for i in range(n_iter_max): + + if i > 0: + # update the cost matrix + M = (1 - alpha) * Mf + alpha * Ml + else: + M = (1 - alpha) * Mf + + # sole OT problem + sol = ot.solve(M, a, b) + + T = sol.plan + loss_ot = sol.value + + if i == 0: + loss_ot += alpha * np.sum(Ml * T) + + lst_loss_ot.append(loss_ot) + + # compute the transported labels + yth = ys.T.dot(T) / b + + # fit the estimator + estimator.fit(Xt, yth, **kwargs) + y_pred = estimator.predict(Xt) + + Ml = ot.dist(ys.reshape(-1, 1), y_pred.reshape(-1, 1)) + + # compute the loss + loss_tgt_labels = np.mean((yth - y_pred)**2) + lst_loss_tgt_labels.append(loss_tgt_labels) + + if verbose: + print(f'iter={i}, loss_ot={loss_ot}, loss_tgt_labels={loss_tgt_labels}') + + # break on tol OT loss + if i > 0 and abs(lst_loss_ot[-1] - lst_loss_ot[-2]) < tol: + break + + # break on tol target loss + if i > 0 and abs(lst_loss_tgt_labels[-1] - lst_loss_tgt_labels[-2]) < tol: + break + + # update the cost matrix + if i == n_iter_max - 1: + warnings.warn('Maximum number of iterations reached.') + + return estimator, lst_loss_ot, lst_loss_tgt_labels, sol + + +class JDOTRegressor(DAEstimator): + """Joint Distribution Optimal Transport Regressor + + Parameters + ---------- + base_estimator : object + The base estimator to be used for the regression task. This estimator + should solve a least squares regression problem (regularized or not) + to correspond to JDOT theoretical regression problem but other + approaches can be used with the risk that the fixed point might not + converge. default value is LinearRegression() from scikit-learn. + alpha : float, default=0.5 + The trade-off parameter between the feature and label loss in OT metric + n_iter_max: int + Max number of JDOT alternat optimization iterations. + tol: float>0 + Tolerance for loss variations (OT and mse) stopping iterations. + verbose: bool + Print loss along iterations if True.as_integer_ratio + + Attributes + ---------- + estimator_ : object + The fitted estimator. + lst_loss_ot_ : list + The list of OT losses at each iteration. + lst_loss_tgt_labels_ : list + The list of target labels losses at each iteration. + sol_ : object + The solution of the OT problem. + + References + ---------- + [1] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution + Optimal Transportation for Domain Adaptation, Neural Information + Processing Systems (NIPS), 2017. + + """ + + def __init__(self, base_estimator=None, alpha=0.5, n_iter_max=100, + tol=1e-5, verbose=False, **kwargs): + if base_estimator is None: + base_estimator = LinearRegression() + else: + if not hasattr(base_estimator, 'fit') or not hasattr( + base_estimator, 'predict'): + raise ValueError('base_estimator must be a regressor with' + ' fit and predict methods') + self.base_estimator = base_estimator + self.kwargs = kwargs + self.alpha = alpha + self.n_iter_max = n_iter_max + self.tol = tol + self.verbose = verbose + + def fit(self, X, y=None, sample_domain=None, *, sample_weight=None): + """Fit adaptation parameters""" + + Xs, Xt, ys, yt, ws, wt = source_target_split( + X, y, sample_weight, sample_domain=sample_domain) + + res = solve_jdot_regression(self.base_estimator, Xs, ys, Xt, ws=ws, wt=wt, + alpha=self.alpha, n_iter_max=self.n_iter_max, + tol=self.tol, verbose=self.verbose, **self.kwargs) + + self.estimator_, self.lst_loss_ot_, self.lst_loss_tgt_labels_, self.sol_ = res + + def predict(self, X, sample_domain=None, *, sample_weight=None): + """Predict using the model""" + check_is_fitted(self) + if sample_domain is not None and np.any(sample_domain >= 0): + warnings.warn( + 'Source domain detected. Predictor is trained on target' + 'and prediction might be biased.') + return self.estimator_.predict(X) + + def score(self, X, y, sample_domain=None, *, sample_weight=None): + """Return the coefficient of determination R^2 of the prediction""" + check_is_fitted(self) + if sample_domain is not None and np.any(sample_domain >= 0): + warnings.warn( + 'Source domain detected. Predictor is trained on target' + 'and score might be biased.') + return self.estimator_.score(X, y, sample_weight=sample_weight) diff --git a/skada/base.py b/skada/base.py index 869200bf..658c188c 100644 --- a/skada/base.py +++ b/skada/base.py @@ -110,6 +110,32 @@ def transform( ) +class DAEstimator(BaseEstimator): + """ Generic DA estimator class + + """ + + __metadata_request__fit = {'sample_domain': True} + __metadata_request__partial_fit = {'sample_domain': True} + __metadata_request__predict = {'sample_domain': True, 'allow_source': True} + __metadata_request__predict_proba = {'sample_domain': True, 'allow_source': True} + __metadata_request__predict_log_proba = { + 'sample_domain': True, 'allow_source': True} + __metadata_request__score = {'sample_domain': True, 'allow_source': True} + __metadata_request__decision_function = { + 'sample_domain': True, 'allow_source': True} + + @abstractmethod + def fit(self, X, y=None, sample_domain=None, *, sample_weight=None): + """Fit adaptation parameters""" + pass + + @abstractmethod + def predict(self, X, sample_domain=None, *, sample_weight=None): + """Predict using the model""" + pass + + class BaseSelector(BaseEstimator): def __init__(self, base_estimator: BaseEstimator, **kwargs): diff --git a/skada/model_selection.py b/skada/model_selection.py index 53dfd484..d808c1b3 100644 --- a/skada/model_selection.py +++ b/skada/model_selection.py @@ -118,6 +118,7 @@ class SourceTargetShuffleSplit(BaseDomainAwareShuffleSplit): a single domain as a target followed up by the single train/test shuffle split. """ + def __init__( self, n_splits=10, *, test_size=None, train_size=None, random_state=None ): @@ -173,6 +174,7 @@ class LeaveOneDomainOut(SplitSampleDomainRequesterMixin): Provides train/test indices to split data in train/test sets. """ + def __init__( self, max_n_splits=10, *, test_size=None, train_size=None, random_state=None ): diff --git a/skada/tests/test_base.py b/skada/tests/test_base.py new file mode 100644 index 00000000..22338cf7 --- /dev/null +++ b/skada/tests/test_base.py @@ -0,0 +1,32 @@ +# Author: Yanis Lalou +# +# License: BSD 3-Clause + +import numpy as np + +from skada.base import BaseAdapter, DAEstimator + + +def test_BaseAdapter(): + + X = np.random.rand(10, 2) + + cls = BaseAdapter() + + cls.fit(X=X, y=None, sample_domain=None) + # set one attribute to shohat something fitted + cls.something_ = 1 + cls.transform(X=X, y=None, sample_domain=None) + cls.fit_transform(X=X, y=None, sample_domain=None) + + +def test_DAEstimator(): + + X = np.random.rand(10, 2) + + cls = DAEstimator() + + cls.fit(X=X, y=None, sample_domain=None) + # set one attribute to shohat something fitted + cls.something_ = 1 + cls.predict(X=X, sample_domain=None) diff --git a/skada/tests/test_ot.py b/skada/tests/test_ot.py new file mode 100644 index 00000000..e1f5e04c --- /dev/null +++ b/skada/tests/test_ot.py @@ -0,0 +1,63 @@ +# Author: Remi Flamary +# +# License: BSD 3-Clause + +import numpy as np +from skada import JDOTRegressor +from sklearn.linear_model import Ridge +from sklearn.preprocessing import StandardScaler +from skada.utils import source_target_split +from skada import make_da_pipeline + + +def test_JDOTRegressor(da_reg_dataset): + + X, y, sample_domain = da_reg_dataset + w = np.random.rand(X.shape[0]) + + Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) + + # standard case + jdot = JDOTRegressor(base_estimator=Ridge(), alpha=.1, verbose=True) + + jdot.fit(X, y, sample_domain=sample_domain) + + ypred = jdot.predict(Xt) + + assert ypred.shape[0] == Xt.shape[0] + + # JDOT with weights + jdot = JDOTRegressor(base_estimator=Ridge(), verbose=True, n_iter_max=1) + jdot.fit(X, y, sample_weight=w, sample_domain=sample_domain) + + score = jdot.score(X, y, sample_domain=sample_domain) + + assert score >= 0 + + # JDOT with default base estimator + jdot = JDOTRegressor() + jdot.fit(X, y, sample_domain=sample_domain) + + with np.testing.assert_raises(ValueError): + jdot = JDOTRegressor(StandardScaler()) + + +def test_JDOTRegressor_pipeline(da_reg_dataset): + + X, y, sample_domain = da_reg_dataset + + Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) + + jdot = make_da_pipeline( + StandardScaler(), JDOTRegressor( + Ridge(), alpha=.1, verbose=True)) + + jdot.fit(X, y, sample_domain=sample_domain) + + ypred = jdot.predict(Xt) + + assert ypred.shape[0] == Xt.shape[0] + + ypred2 = jdot.predict(X, sample_domain=sample_domain) + + assert ypred2.shape[0] == X.shape[0] diff --git a/skada/tests/test_utils.py b/skada/tests/test_utils.py index 7f652c5d..ac5adc2c 100644 --- a/skada/tests/test_utils.py +++ b/skada/tests/test_utils.py @@ -114,6 +114,14 @@ def test_source_target_split(): with pytest.raises(IndexError): source_target_split(X, y[:-2], sample_domain=sample_domain) + X_source, X_target, weights_source, weights_target = source_target_split( + X, None, sample_domain=sample_domain) + + assert X_source.shape == (2 * n_samples_source, 2), "X_source shape mismatch" + assert X_target.shape == (2 * n_samples_target, 2), "X_target shape mismatch" + assert weights_source is None, "weights_source should be None" + assert weights_target is None, "weights_target should be None" + def test_check_X_y_allow_exceptions(): X, y, sample_domain = make_dataset_from_moons_distribution( diff --git a/skada/utils.py b/skada/utils.py index a217baeb..a4b6c237 100644 --- a/skada/utils.py +++ b/skada/utils.py @@ -223,7 +223,10 @@ def source_target_split( Parameters ---------- *arrays : sequence of array-like of identical shape (n_samples, n_features) - Input features + Input features and target variable(s), and or sample_weights to be + split. All arrays should have the same length except if None is given + then a couple of None variables are returned to allow for optional + sample_weight. sample_domain : array-like of shape (n_samples,) Array specifying the domain labels for each sample. @@ -241,5 +244,6 @@ def source_target_split( source_idx = extract_source_indices(sample_domain) return list(chain.from_iterable( - (a[source_idx], a[~source_idx]) for a in arrays + (a[source_idx], a[~source_idx]) if a is not None else (None, None) + for a in arrays ))