Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] JDOT Regressor #76

Merged
merged 23 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from skada.datasets import DomainAwareDataset, make_shifted_blobs
from skada.datasets import DomainAwareDataset, make_shifted_blobs, make_shifted_datasets


# xxx(okachaiev): old API has to be gone when re-writing is done
Expand Down Expand Up @@ -30,6 +30,19 @@ def tmp_da_dataset():
)


@pytest.fixture(scope='session')
def da_reg_dataset():
X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=21,
shift="concept_drift",
noise=0.3,
label="regression",
random_state=42,
)
return X, y, sample_domain


@pytest.fixture(scope='session')
def da_dataset() -> DomainAwareDataset:
centers = np.array([[0, 0], [1, 1]])
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ API and modules
ClassRegularizerOTMapping
LinearOTMapping
CORAL
JDOTRegressor
make_da_pipeline


Expand Down
119 changes: 119 additions & 0 deletions examples/methods/plot_jdot_da.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Plot JDOT Regressor
===================


"""
# %% Imports
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import mean_squared_error
from sklearn.kernel_ridge import KernelRidge

from skada import JDOTRegressor
from skada.datasets import make_shifted_datasets
from skada import source_target_split


# %%
# Generate concept drift dataset
# ------------------------------

X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift="concept_drift",
noise=0.3,
label="regression",
random_state=42,
)

y = (y-y.mean())/y.std()

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


# %%
# Plot data
# ---------

plt.figure(1, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Target")
plt.title("Target data")
plt.axis(ax)

# %%
# Train on source data
# --------------------


clf = KernelRidge(kernel='rbf', alpha=0.5)
clf.fit(Xs, ys)

# Compute accuracy on source and target
ys_pred = clf.predict(Xs)
yt_pred = clf.predict(Xt)

mse_s = mean_squared_error(ys, ys_pred)
mse_t = mean_squared_error(yt, yt_pred)

print(f"MSE on source: {mse_s:.2f}")
print(f"MSE on target: {mse_t:.2f}")

XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100))
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)


plt.figure(2, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction")
plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5)
plt.title(f"KRR Prediction on source (MSE={mse_s:.2f})")
plt.axis(ax)

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction")
plt.imshow(Z, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5)
plt.title(f"KRR Prediction on target (MSE={mse_t:.2f})")
plt.axis(ax)


# %%
# Train with JDOT regressor
# -------------------------


jdot = JDOTRegressor(base_estimator=KernelRidge(kernel='rbf', alpha=0.5), alpha=0.01)

jdot.fit(X, y, sample_domain=sample_domain)

ys_pred = jdot.predict(Xs)
yt_pred = jdot.predict(Xt)

mse_s = mean_squared_error(ys, ys_pred)
mse_t = mean_squared_error(yt, yt_pred)

Zjdot = jdot.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)

print(f"JDOT MSE on source: {mse_s:.2f}")
print(f"JDOT MSE on target: {mse_t:.2f}")

plt.figure(3, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, label="Prediction")
plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5)
plt.title(f"JDOT Prediction on source (MSE={mse_s:.2f})")
plt.axis(ax)

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, label="Prediction")
plt.imshow(Zjdot, extent=(ax[0], ax[1], ax[2], ax[3]), origin='lower', alpha=0.5)
plt.title(f"JDOT Prediction on target (MSE={mse_t:.2f})")
plt.axis(ax)
4 changes: 4 additions & 0 deletions skada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TransferComponentAnalysisAdapter,
TransferComponentAnalysis,
)
from ._ot import solve_jdot_regression, JDOTRegressor
from ._pipeline import make_da_pipeline
from .utils import source_target_split

Expand Down Expand Up @@ -79,6 +80,9 @@
"TransferComponentAnalysisAdapter",
"TransferComponentAnalysis",

"solve_jdot_regression",
"JDOTRegressor",

"make_da_pipeline",

"source_target_split",
Expand Down
223 changes: 223 additions & 0 deletions skada/_ot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Author: Remi Flamary <[email protected]>
#
# License: BSD 3-Clause


import numpy as np
from sklearn.base import clone
from sklearn.utils.validation import check_is_fitted
from sklearn.linear_model import LinearRegression
from .base import DAEstimator
from .utils import source_target_split
import ot
import warnings


def solve_jdot_regression(base_estimator, Xs, ys, Xt, alpha=0.5, ws=None, wt=None,
n_iter_max=100, tol=1e-5, verbose=False, **kwargs):
"""Solve the joint distribution optimal transport regression problem

Parameters
----------
base_estimator : object
The base estimator to be used for the regression task. This estimator
should solve a least squares regression problem (regularized or not)
to correspond to JDOT theoretical regression problem but other
approaches can be used with the risk that the fixed point might not converge.
Xs : array-like of shape (n_samples, n_features)
Source domain samples.
ys : array-like of shape (n_samples,)
Source domain labels.
Xt : array-like of shape (m_samples, n_features)
Target domain samples.
alpha : float, default=0.5
The trade-off parameter between the feature and label loss in OT metric
ws : array-like of shape (n_samples,)
Source domain weights (will ne normalized to sum to 1).
wt : array-like of shape (m_samples,)
Target domain weights (will ne normalized to sum to 1).
n_iter_max: int
Max number of JDOT alternat optimization iterations.
tol: float>0
Tolerance for loss variations (OT and mse) stopping iterations.
verbose: bool
Print loss along iterations if True.as_integer_ratio
kwargs : dict
Additional parameters to be passed to the base estimator.


Returns
-------
estimator : object
The fitted estimator.
lst_loss_ot : list
The list of OT losses at each iteration.
lst_loss_tgt_labels : list
The list of target labels losses at each iteration.
sol : object
The solution of the OT problem.

References
----------
[1] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
Optimal Transportation for Domain Adaptation, Neural Information Processing
Systems (NIPS), 2017.

"""

estimator = clone(base_estimator)

# compute feature distance matrix
Mf = ot.dist(Xs, Xt)
Mf = Mf / Mf.mean()

nt = Xt.shape[0]
if ws is None:
a = np.ones((len(ys),)) / len(ys)
else :
a = ws / ws.sum()
if wt is None:
b = np.ones((nt,)) / nt
else:
b = wt / wt.sum()
kwargs['sample_weight'] = wt # add it as sample_weight for fit

lst_loss_ot = []
lst_loss_tgt_labels = []
y_pred = 0
Ml = ot.dist(ys.reshape(-1, 1), np.zeros((nt, 1)))

for i in range(n_iter_max):

if i > 0:
# update the cost matrix
M = (1 - alpha) * Mf + alpha * Ml
else:
M = (1 - alpha) * Mf

# sole OT problem
sol = ot.solve(M, a, b)

T = sol.plan
loss_ot = sol.value

if i == 0:
loss_ot += alpha * np.sum(Ml * T)

lst_loss_ot.append(loss_ot)

# compute the transported labels
yth = ys.T.dot(T) / b

# fit the estimator
estimator.fit(Xt, yth, **kwargs)
y_pred = estimator.predict(Xt)

Ml = ot.dist(ys.reshape(-1, 1), y_pred.reshape(-1, 1))

# compute the loss
loss_tgt_labels = np.mean((yth - y_pred)**2)
lst_loss_tgt_labels.append(loss_tgt_labels)

if verbose:
print(f'iter={i}, loss_ot={loss_ot}, loss_tgt_labels={loss_tgt_labels}')

# break on tol OT loss
if i > 0 and abs(lst_loss_ot[-1] - lst_loss_ot[-2]) < tol:
break

Check warning on line 127 in skada/_ot.py

View check run for this annotation

Codecov / codecov/patch

skada/_ot.py#L127

Added line #L127 was not covered by tests

# break on tol target loss
if i > 0 and abs(lst_loss_tgt_labels[-1] - lst_loss_tgt_labels[-2]) < tol:
break

# update the cost matrix
if i == n_iter_max - 1:
warnings.warn('Maximum number of iterations reached.')

return estimator, lst_loss_ot, lst_loss_tgt_labels, sol


class JDOTRegressor(DAEstimator):
"""Joint Distribution Optimal Transport Regressor

Parameters
----------
base_estimator : object
The base estimator to be used for the regression task. This estimator
should solve a least squares regression problem (regularized or not)
to correspond to JDOT theoretical regression problem but other
approaches can be used with the risk that the fixed point might not
converge. default value is LinearRegression() from scikit-learn.
alpha : float, default=0.5
The trade-off parameter between the feature and label loss in OT metric
n_iter_max: int
Max number of JDOT alternat optimization iterations.
tol: float>0
Tolerance for loss variations (OT and mse) stopping iterations.
verbose: bool
Print loss along iterations if True.as_integer_ratio

Attributes
----------
estimator_ : object
The fitted estimator.
lst_loss_ot_ : list
The list of OT losses at each iteration.
lst_loss_tgt_labels_ : list
The list of target labels losses at each iteration.
sol_ : object
The solution of the OT problem.

References
----------
[1] N. Courty, R. Flamary, A. Habrard, A. Rakotomamonjy, Joint Distribution
Optimal Transportation for Domain Adaptation, Neural Information
Processing Systems (NIPS), 2017.

"""

def __init__(self, base_estimator=None, alpha=0.5, n_iter_max=100,
tol=1e-5, verbose=False, **kwargs):
if base_estimator is None:
base_estimator = LinearRegression()
else:
if not hasattr(base_estimator, 'fit') or not hasattr(
base_estimator, 'predict'):
raise ValueError('base_estimator must be a regressor with'
' fit and predict methods')
self.base_estimator = base_estimator
self.kwargs = kwargs
self.alpha = alpha
self.n_iter_max = n_iter_max
self.tol = tol
self.verbose = verbose

def fit(self, X, y=None, sample_domain=None, *, sample_weight=None):
"""Fit adaptation parameters"""

Xs, Xt, ys, yt, ws, wt = source_target_split(
X, y, sample_weight, sample_domain=sample_domain)

res = solve_jdot_regression(self.base_estimator, Xs, ys, Xt, ws=ws, wt=wt,
alpha=self.alpha, n_iter_max=self.n_iter_max,
tol=self.tol, verbose=self.verbose, **self.kwargs)

self.estimator_, self.lst_loss_ot_, self.lst_loss_tgt_labels_, self.sol_ = res

def predict(self, X, sample_domain=None, *, sample_weight=None):
"""Predict using the model"""
check_is_fitted(self)
if sample_domain is not None and np.any(sample_domain >= 0):
warnings.warn(
'Source domain detected. Predictor is trained on target'
'and prediction might be biased.')
return self.estimator_.predict(X)

def score(self, X, y, sample_domain=None, *, sample_weight=None):
"""Return the coefficient of determination R^2 of the prediction"""
check_is_fitted(self)
if sample_domain is not None and np.any(sample_domain >= 0):
warnings.warn(
'Source domain detected. Predictor is trained on target'
'and score might be biased.')
return self.estimator_.score(X, y, sample_weight=sample_weight)
Loading
Loading