Skip to content

Commit

Permalink
[MRG] Make sure all API methods accept sample_domain as None (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
YanisLalou authored Jan 26, 2024
1 parent 895f862 commit d4eb894
Show file tree
Hide file tree
Showing 17 changed files with 714 additions and 312 deletions.
8 changes: 5 additions & 3 deletions examples/datasets/plot_dataset_from_moons_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
random_state=RANDOM_SEED
)

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle('One source and one target', fontsize=14)
Expand Down Expand Up @@ -75,8 +77,8 @@
random_state=RANDOM_SEED
)

X_source, y_source, domain_source, X_target, y_target, domain_target = (
source_target_split(X, y, sample_domain, return_domain=True)
X_source, X_target, y_source, y_target, domain_source, domain_target = (
source_target_split(X, y, sample_domain, sample_domain=sample_domain)
)
fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle('Multi-source and Multi-target', fontsize=14)
Expand Down
4 changes: 3 additions & 1 deletion examples/datasets/plot_shifted_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def plot_shifted_dataset(shift, random_state=42):
random_state=random_state,
)

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle(shift.replace("_", " ").title(), fontsize=14)
Expand Down
4 changes: 3 additions & 1 deletion examples/datasets/plot_variable_frequency_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
random_state=RANDOM_SEED
)

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

# %% Visualize the signal

Expand Down
4 changes: 3 additions & 1 deletion examples/methods/plot_optimal_transport_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
)


X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)


n_tot_source = X_source.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_method_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
# preprocess dataset, split into training and test part
X, y, sample_domain = ds

Xs, ys, Xt, yt = source_target_split(X, y, sample_domain=sample_domain)
Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
Expand Down
4 changes: 3 additions & 1 deletion examples/validation/plot_cross_val_score_for_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
)

X, y, sample_domain = dataset.pack_train(as_sources=['s'], as_targets=['t'])
X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=0)

# %%
Expand Down
2 changes: 1 addition & 1 deletion skada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
TransferComponentAnalysis,
)
from ._pipeline import make_da_pipeline
from ._utils import source_target_split
from .utils import source_target_split


# make sure that the usage of the library is not possible
Expand Down
35 changes: 18 additions & 17 deletions skada/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from ot import da

from .base import BaseAdapter, clone
from ._utils import (
from .utils import (
check_X_domain,
check_X_y_domain,
source_target_split
)
from ._utils import (
_estimate_covariance,
_merge_source_target,
)
Expand Down Expand Up @@ -45,13 +48,9 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
X, y, X_target, y_target = check_X_y_domain(
X,
y,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
X, X_target, y, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
transport = self._create_transport_estimator()
self.ot_transport_ = clone(transport)
Expand Down Expand Up @@ -80,13 +79,13 @@ def adapt(self, X, y=None, sample_domain=None):
The weights of the samples.
"""
# xxx(okachaiev): implement auto-infer for sample_domain
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)
# in case of prediction we would get only target samples here,
# thus there's no need to perform any transformations
if X_source.shape[0] > 0:
Expand Down Expand Up @@ -604,13 +603,14 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

cov_source_ = _estimate_covariance(X_source, shrinkage=self.reg)
cov_target_ = _estimate_covariance(X_target, shrinkage=self.reg)
self.cov_source_inv_sqrt_ = _invsqrtm(cov_source_)
Expand Down Expand Up @@ -638,13 +638,14 @@ def adapt(self, X, y=None, sample_domain=None):
weights : None
No weights are returned here.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

X_source_adapt = np.dot(X_source, self.cov_source_inv_sqrt_)
X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_)
X_adapt = _merge_source_target(X_source_adapt, X_target, sample_domain)
Expand Down
59 changes: 34 additions & 25 deletions skada/_reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sklearn.utils.validation import check_is_fitted

from .base import AdaptationOutput, BaseAdapter, clone
from ._utils import _estimate_covariance, check_X_domain
from .utils import check_X_domain, source_target_split, extract_source_indices
from ._utils import _estimate_covariance
from ._pipeline import make_da_pipeline


Expand Down Expand Up @@ -62,11 +63,12 @@ def fit(self, X, y=None, sample_domain=None):
Returns self.
"""
# xxx(okachaiev): that's the reason we need a way to cache this call
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_joint=False,
sample_domain
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

self.weight_estimator_source_ = clone(self.weight_estimator)
self.weight_estimator_target_ = clone(self.weight_estimator)
self.weight_estimator_source_.fit(X_source)
Expand Down Expand Up @@ -94,11 +96,12 @@ def adapt(self, X, y=None, sample_domain=None):
weights : array-like, shape (n_samples,)
The weights of the samples.
"""
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): move this to API
if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
Expand Down Expand Up @@ -197,11 +200,12 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_joint=False
sample_domain
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

self.mean_source_ = X_source.mean(axis=0)
self.cov_source_ = _estimate_covariance(X_source, shrinkage=self.reg)
self.mean_target_ = X_target.mean(axis=0)
Expand Down Expand Up @@ -230,11 +234,12 @@ def adapt(self, X, y=None, sample_domain=None):
The weights of the samples.
"""
check_is_fitted(self)
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): move this to API
if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
Expand Down Expand Up @@ -337,11 +342,12 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True
sample_domain
)
source_idx = extract_source_indices(sample_domain)

source_idx, = np.where(source_idx)
self.domain_classifier_ = clone(self.domain_classifier)
y_domain = np.ones(X.shape[0], dtype=np.int32)
Expand Down Expand Up @@ -371,11 +377,12 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
The weights of the samples.
"""
check_is_fitted(self)
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): move this to API
if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
Expand Down Expand Up @@ -505,13 +512,14 @@ def fit(self, X, y=None, sample_domain=None, **kwargs):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if isinstance(self.gamma, list):
self.best_gamma_ = self._likelihood_cross_validation(
self.gamma, X_source, X_target
Expand Down Expand Up @@ -599,11 +607,12 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
weights : array-like, shape (n_samples,)
The weights of the samples.
"""
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
A = pairwise_kernels(
Expand Down
22 changes: 13 additions & 9 deletions skada/_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from sklearn.svm import SVC

from .base import BaseAdapter
from ._utils import check_X_domain, _merge_source_target
from .utils import check_X_domain, source_target_split
from ._utils import _merge_source_target
from ._pipeline import make_da_pipeline


Expand Down Expand Up @@ -79,13 +80,14 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
weights : None
No weights are returned here.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if X_source.shape[0]:
X_source = np.dot(self.pca_source_.transform(X_source), self.M_)
if X_target.shape[0]:
Expand All @@ -111,13 +113,14 @@ def fit(self, X, y=None, sample_domain=None, **kwargs):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if self.n_components is None:
n_components = min(min(X_source.shape), min(X_target.shape))
else:
Expand Down Expand Up @@ -245,13 +248,13 @@ def fit(self, X, y=None, sample_domain=None, **kwargs):
self : object
Returns self.
"""
self.X_source_, self.X_target_ = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
self.X_source_, self.X_target_ = source_target_split(X, sample_domain=sample_domain)

Check failure on line 257 in skada/_subspace.py

View workflow job for this annotation

GitHub Actions / Flake8-Codestyle-Check

line too long (92 > 88 characters)

Kss = pairwise_kernels(self.X_source_, metric=self.kernel)
Ktt = pairwise_kernels(self.X_target_, metric=self.kernel)
Expand Down Expand Up @@ -306,13 +309,14 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
weights : None
No weights are returned here.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if np.array_equal(X_source, self.X_source_) and np.array_equal(
X_target, self.X_target_
):
Expand Down
Loading

0 comments on commit d4eb894

Please sign in to comment.