Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Make sure all API methods accept sample_domain as None #53

Merged
merged 21 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
67841e9
Change default value for allow_auto_sample_domain + check if the mask…
YanisLalou Jan 12, 2024
d178f6c
Fix + Test _check_y_masking function
YanisLalou Jan 12, 2024
f1c1c5d
Add raise Exception to check_X_domain + Test
YanisLalou Jan 12, 2024
07b8846
Merge branch 'main' into Issue_17_branch
kachayev Jan 12, 2024
728a706
Flake8 warnings
YanisLalou Jan 12, 2024
c9a8a94
Merge branch 'main' into Issue_17_branch
YanisLalou Jan 15, 2024
1d88ccf
Docstring + Tests for exceptions handling
YanisLalou Jan 15, 2024
9e8042c
varible name change
YanisLalou Jan 16, 2024
f8ec8e9
Add Global variables
YanisLalou Jan 16, 2024
b124290
Split check_X_Y_domain and check_X_domain to more functions
YanisLalou Jan 18, 2024
68de2e6
fix typo error
YanisLalou Jan 18, 2024
f633160
Add test for extract_source_indices(), split_source_target_X_y(), spl…
YanisLalou Jan 18, 2024
c4f5627
Merge branch 'main' into Issue_17_branch
kachayev Jan 19, 2024
1d0b0ec
Rename test name
YanisLalou Jan 24, 2024
85ec64c
Merge branch 'main' into Issue_17_branch
kachayev Jan 24, 2024
06f2390
Changement on the source_target_split function --> Can accept now *ar…
YanisLalou Jan 25, 2024
8d22250
Fix plot_shifted_dataset.py
YanisLalou Jan 25, 2024
900d670
Merge branch 'main' into Issue_17_branch
kachayev Jan 25, 2024
4d3e192
fix plot_dataset_from_moons_distribution.py
YanisLalou Jan 25, 2024
8f3d286
remove unwanted changes to metrics.py
YanisLalou Jan 25, 2024
2f428a8
remove old comments + flake8
YanisLalou Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/datasets/plot_dataset_from_moons_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
random_state=RANDOM_SEED
)

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle('One source and one target', fontsize=14)
Expand Down Expand Up @@ -75,8 +77,8 @@
random_state=RANDOM_SEED
)

X_source, y_source, domain_source, X_target, y_target, domain_target = (
source_target_split(X, y, sample_domain, return_domain=True)
X_source, X_target, y_source, y_target, domain_source, domain_target = (
source_target_split(X, y, sample_domain, sample_domain=sample_domain)
)
fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle('Multi-source and Multi-target', fontsize=14)
Expand Down
4 changes: 3 additions & 1 deletion examples/datasets/plot_shifted_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def plot_shifted_dataset(shift, random_state=42):
random_state=random_state,
)

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle(shift.replace("_", " ").title(), fontsize=14)
Expand Down
4 changes: 3 additions & 1 deletion examples/datasets/plot_variable_frequency_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
random_state=RANDOM_SEED
)

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

# %% Visualize the signal

Expand Down
4 changes: 3 additions & 1 deletion examples/methods/plot_optimal_transport_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
)


X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)


n_tot_source = X_source.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_method_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
# preprocess dataset, split into training and test part
X, y, sample_domain = ds

Xs, ys, Xt, yt = source_target_split(X, y, sample_domain=sample_domain)
Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
Expand Down
4 changes: 3 additions & 1 deletion examples/validation/plot_cross_val_score_for_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
)

X, y, sample_domain = dataset.pack_train(as_sources=['s'], as_targets=['t'])
X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=0)

# %%
Expand Down
2 changes: 1 addition & 1 deletion skada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
TransferComponentAnalysis,
)
from ._pipeline import make_da_pipeline
from ._utils import source_target_split
from .utils import source_target_split


# make sure that the usage of the library is not possible
Expand Down
35 changes: 18 additions & 17 deletions skada/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from ot import da

from .base import BaseAdapter, clone
from ._utils import (
from .utils import (
check_X_domain,
check_X_y_domain,
source_target_split
)
from ._utils import (
_estimate_covariance,
_merge_source_target,
)
Expand Down Expand Up @@ -45,13 +48,9 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
X, y, X_target, y_target = check_X_y_domain(
X,
y,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
X, X_target, y, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
transport = self._create_transport_estimator()
self.ot_transport_ = clone(transport)
Expand Down Expand Up @@ -80,13 +79,13 @@ def adapt(self, X, y=None, sample_domain=None):
The weights of the samples.
"""
# xxx(okachaiev): implement auto-infer for sample_domain
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)
# in case of prediction we would get only target samples here,
# thus there's no need to perform any transformations
if X_source.shape[0] > 0:
Expand Down Expand Up @@ -604,13 +603,14 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

cov_source_ = _estimate_covariance(X_source, shrinkage=self.reg)
cov_target_ = _estimate_covariance(X_target, shrinkage=self.reg)
self.cov_source_inv_sqrt_ = _invsqrtm(cov_source_)
Expand Down Expand Up @@ -638,13 +638,14 @@ def adapt(self, X, y=None, sample_domain=None):
weights : None
No weights are returned here.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

X_source_adapt = np.dot(X_source, self.cov_source_inv_sqrt_)
X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_)
X_adapt = _merge_source_target(X_source_adapt, X_target, sample_domain)
Expand Down
59 changes: 34 additions & 25 deletions skada/_reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sklearn.utils.validation import check_is_fitted

from .base import AdaptationOutput, BaseAdapter, clone
from ._utils import _estimate_covariance, check_X_domain
from .utils import check_X_domain, source_target_split, extract_source_indices
from ._utils import _estimate_covariance
from ._pipeline import make_da_pipeline


Expand Down Expand Up @@ -62,11 +63,12 @@ def fit(self, X, y=None, sample_domain=None):
Returns self.
"""
# xxx(okachaiev): that's the reason we need a way to cache this call
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_joint=False,
sample_domain
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

self.weight_estimator_source_ = clone(self.weight_estimator)
self.weight_estimator_target_ = clone(self.weight_estimator)
self.weight_estimator_source_.fit(X_source)
Expand Down Expand Up @@ -94,11 +96,12 @@ def adapt(self, X, y=None, sample_domain=None):
weights : array-like, shape (n_samples,)
The weights of the samples.
"""
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): move this to API
if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
Expand Down Expand Up @@ -197,11 +200,12 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_joint=False
sample_domain
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

self.mean_source_ = X_source.mean(axis=0)
self.cov_source_ = _estimate_covariance(X_source, shrinkage=self.reg)
self.mean_target_ = X_target.mean(axis=0)
Expand Down Expand Up @@ -230,11 +234,12 @@ def adapt(self, X, y=None, sample_domain=None):
The weights of the samples.
"""
check_is_fitted(self)
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): move this to API
if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
Expand Down Expand Up @@ -337,11 +342,12 @@ def fit(self, X, y=None, sample_domain=None):
self : object
Returns self.
"""
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True
sample_domain
)
source_idx = extract_source_indices(sample_domain)

source_idx, = np.where(source_idx)
self.domain_classifier_ = clone(self.domain_classifier)
y_domain = np.ones(X.shape[0], dtype=np.int32)
Expand Down Expand Up @@ -371,11 +377,12 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
The weights of the samples.
"""
check_is_fitted(self)
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): move this to API
if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
Expand Down Expand Up @@ -505,13 +512,14 @@ def fit(self, X, y=None, sample_domain=None, **kwargs):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
allow_multi_target=True
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if isinstance(self.gamma, list):
self.best_gamma_ = self._likelihood_cross_validation(
self.gamma, X_source, X_target
Expand Down Expand Up @@ -599,11 +607,12 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
weights : array-like, shape (n_samples,)
The weights of the samples.
"""
source_idx = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
return_indices=True,
sample_domain
)
source_idx = extract_source_indices(sample_domain)

if source_idx.sum() > 0:
source_idx, = np.where(source_idx)
A = pairwise_kernels(
Expand Down
22 changes: 13 additions & 9 deletions skada/_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from sklearn.svm import SVC

from .base import BaseAdapter
from ._utils import check_X_domain, _merge_source_target
from .utils import check_X_domain, source_target_split
from ._utils import _merge_source_target
from ._pipeline import make_da_pipeline


Expand Down Expand Up @@ -79,13 +80,14 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
weights : None
No weights are returned here.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if X_source.shape[0]:
X_source = np.dot(self.pca_source_.transform(X_source), self.M_)
if X_target.shape[0]:
Expand All @@ -111,13 +113,14 @@ def fit(self, X, y=None, sample_domain=None, **kwargs):
self : object
Returns self.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if self.n_components is None:
n_components = min(min(X_source.shape), min(X_target.shape))
else:
Expand Down Expand Up @@ -245,13 +248,13 @@ def fit(self, X, y=None, sample_domain=None, **kwargs):
self : object
Returns self.
"""
self.X_source_, self.X_target_ = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
self.X_source_, self.X_target_ = source_target_split(X, sample_domain=sample_domain)

Kss = pairwise_kernels(self.X_source_, metric=self.kernel)
Ktt = pairwise_kernels(self.X_target_, metric=self.kernel)
Expand Down Expand Up @@ -306,13 +309,14 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
weights : None
No weights are returned here.
"""
X_source, X_target = check_X_domain(
X, sample_domain = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source, X_target = source_target_split(X, sample_domain=sample_domain)

if np.array_equal(X_source, self.X_source_) and np.array_equal(
X_target, self.X_target_
):
Expand Down
Loading
Loading