diff --git a/examples/datasets/plot_dataset_from_moons_distribution.py b/examples/datasets/plot_dataset_from_moons_distribution.py index ea1f64f4..f8f66388 100644 --- a/examples/datasets/plot_dataset_from_moons_distribution.py +++ b/examples/datasets/plot_dataset_from_moons_distribution.py @@ -28,7 +28,9 @@ random_state=RANDOM_SEED ) -X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain) +X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain +) fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4)) fig.suptitle('One source and one target', fontsize=14) @@ -75,8 +77,8 @@ random_state=RANDOM_SEED ) -X_source, y_source, domain_source, X_target, y_target, domain_target = ( - source_target_split(X, y, sample_domain, return_domain=True) +X_source, X_target, y_source, y_target, domain_source, domain_target = ( + source_target_split(X, y, sample_domain, sample_domain=sample_domain) ) fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4)) fig.suptitle('Multi-source and Multi-target', fontsize=14) diff --git a/examples/datasets/plot_shifted_dataset.py b/examples/datasets/plot_shifted_dataset.py index 6db96ab9..e35b9f25 100644 --- a/examples/datasets/plot_shifted_dataset.py +++ b/examples/datasets/plot_shifted_dataset.py @@ -41,7 +41,9 @@ def plot_shifted_dataset(shift, random_state=42): random_state=random_state, ) - X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain + ) fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4)) fig.suptitle(shift.replace("_", " ").title(), fontsize=14) diff --git a/examples/datasets/plot_variable_frequency_dataset.py b/examples/datasets/plot_variable_frequency_dataset.py index 1b03a17c..61aedefc 100644 --- a/examples/datasets/plot_variable_frequency_dataset.py +++ b/examples/datasets/plot_variable_frequency_dataset.py @@ -33,7 +33,9 @@ random_state=RANDOM_SEED ) -X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain) +X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain +) # %% Visualize the signal diff --git a/examples/methods/plot_optimal_transport_da.py b/examples/methods/plot_optimal_transport_da.py index 5e26e6c5..525b2e06 100644 --- a/examples/methods/plot_optimal_transport_da.py +++ b/examples/methods/plot_optimal_transport_da.py @@ -44,7 +44,9 @@ ) -X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain) +X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain +) n_tot_source = X_source.shape[0] diff --git a/examples/plot_method_comparison.py b/examples/plot_method_comparison.py index 6d00af17..8f346eea 100644 --- a/examples/plot_method_comparison.py +++ b/examples/plot_method_comparison.py @@ -117,7 +117,7 @@ # preprocess dataset, split into training and test part X, y, sample_domain = ds - Xs, ys, Xt, yt = source_target_split(X, y, sample_domain=sample_domain) + Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 diff --git a/examples/validation/plot_cross_val_score_for_da.py b/examples/validation/plot_cross_val_score_for_da.py index db79db63..ba108378 100644 --- a/examples/validation/plot_cross_val_score_for_da.py +++ b/examples/validation/plot_cross_val_score_for_da.py @@ -39,7 +39,9 @@ ) X, y, sample_domain = dataset.pack_train(as_sources=['s'], as_targets=['t']) -X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain) +X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain +) cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=0) # %% diff --git a/skada/__init__.py b/skada/__init__.py index 02ce0fc6..252556a9 100644 --- a/skada/__init__.py +++ b/skada/__init__.py @@ -39,7 +39,7 @@ TransferComponentAnalysis, ) from ._pipeline import make_da_pipeline -from ._utils import source_target_split +from .utils import source_target_split # make sure that the usage of the library is not possible diff --git a/skada/_mapping.py b/skada/_mapping.py index 046b78f3..5ea6a0ca 100644 --- a/skada/_mapping.py +++ b/skada/_mapping.py @@ -10,9 +10,12 @@ from ot import da from .base import BaseAdapter, clone -from ._utils import ( +from .utils import ( check_X_domain, check_X_y_domain, + source_target_split +) +from ._utils import ( _estimate_covariance, _merge_source_target, ) @@ -45,13 +48,9 @@ def fit(self, X, y=None, sample_domain=None): self : object Returns self. """ - X, y, X_target, y_target = check_X_y_domain( - X, - y, - sample_domain, - allow_multi_source=True, - allow_multi_target=True, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X, X_target, y, y_target = source_target_split( + X, y, sample_domain=sample_domain ) transport = self._create_transport_estimator() self.ot_transport_ = clone(transport) @@ -80,13 +79,13 @@ def adapt(self, X, y=None, sample_domain=None): The weights of the samples. """ # xxx(okachaiev): implement auto-infer for sample_domain - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, - allow_multi_target=True, - return_joint=False, + allow_multi_target=True ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) # in case of prediction we would get only target samples here, # thus there's no need to perform any transformations if X_source.shape[0] > 0: @@ -604,13 +603,14 @@ def fit(self, X, y=None, sample_domain=None): self : object Returns self. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, - allow_multi_target=True, - return_joint=False, + allow_multi_target=True ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + cov_source_ = _estimate_covariance(X_source, shrinkage=self.reg) cov_target_ = _estimate_covariance(X_target, shrinkage=self.reg) self.cov_source_inv_sqrt_ = _invsqrtm(cov_source_) @@ -638,13 +638,14 @@ def adapt(self, X, y=None, sample_domain=None): weights : None No weights are returned here. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, - allow_multi_target=True, - return_joint=False, + allow_multi_target=True ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + X_source_adapt = np.dot(X_source, self.cov_source_inv_sqrt_) X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_) X_adapt = _merge_source_target(X_source_adapt, X_target, sample_domain) diff --git a/skada/_reweight.py b/skada/_reweight.py index a74f74c2..0fe37fa0 100644 --- a/skada/_reweight.py +++ b/skada/_reweight.py @@ -16,7 +16,8 @@ from sklearn.utils.validation import check_is_fitted from .base import AdaptationOutput, BaseAdapter, clone -from ._utils import _estimate_covariance, check_X_domain +from .utils import check_X_domain, source_target_split, extract_source_indices +from ._utils import _estimate_covariance from ._pipeline import make_da_pipeline @@ -62,11 +63,12 @@ def fit(self, X, y=None, sample_domain=None): Returns self. """ # xxx(okachaiev): that's the reason we need a way to cache this call - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_joint=False, + sample_domain ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + self.weight_estimator_source_ = clone(self.weight_estimator) self.weight_estimator_target_ = clone(self.weight_estimator) self.weight_estimator_source_.fit(X_source) @@ -94,11 +96,12 @@ def adapt(self, X, y=None, sample_domain=None): weights : array-like, shape (n_samples,) The weights of the samples. """ - source_idx = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_indices=True, + sample_domain ) + source_idx = extract_source_indices(sample_domain) + # xxx(okachaiev): move this to API if source_idx.sum() > 0: source_idx, = np.where(source_idx) @@ -197,11 +200,12 @@ def fit(self, X, y=None, sample_domain=None): self : object Returns self. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_joint=False + sample_domain ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + self.mean_source_ = X_source.mean(axis=0) self.cov_source_ = _estimate_covariance(X_source, shrinkage=self.reg) self.mean_target_ = X_target.mean(axis=0) @@ -230,11 +234,12 @@ def adapt(self, X, y=None, sample_domain=None): The weights of the samples. """ check_is_fitted(self) - source_idx = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_indices=True, + sample_domain ) + source_idx = extract_source_indices(sample_domain) + # xxx(okachaiev): move this to API if source_idx.sum() > 0: source_idx, = np.where(source_idx) @@ -337,11 +342,12 @@ def fit(self, X, y=None, sample_domain=None): self : object Returns self. """ - source_idx = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_indices=True + sample_domain ) + source_idx = extract_source_indices(sample_domain) + source_idx, = np.where(source_idx) self.domain_classifier_ = clone(self.domain_classifier) y_domain = np.ones(X.shape[0], dtype=np.int32) @@ -371,11 +377,12 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs): The weights of the samples. """ check_is_fitted(self) - source_idx = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_indices=True, + sample_domain ) + source_idx = extract_source_indices(sample_domain) + # xxx(okachaiev): move this to API if source_idx.sum() > 0: source_idx, = np.where(source_idx) @@ -505,13 +512,14 @@ def fit(self, X, y=None, sample_domain=None, **kwargs): self : object Returns self. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, - allow_multi_target=True, - return_joint=False, + allow_multi_target=True ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + if isinstance(self.gamma, list): self.best_gamma_ = self._likelihood_cross_validation( self.gamma, X_source, X_target @@ -599,11 +607,12 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs): weights : array-like, shape (n_samples,) The weights of the samples. """ - source_idx = check_X_domain( + X, sample_domain = check_X_domain( X, - sample_domain, - return_indices=True, + sample_domain ) + source_idx = extract_source_indices(sample_domain) + if source_idx.sum() > 0: source_idx, = np.where(source_idx) A = pairwise_kernels( diff --git a/skada/_subspace.py b/skada/_subspace.py index 71a7155c..1e031e60 100644 --- a/skada/_subspace.py +++ b/skada/_subspace.py @@ -12,7 +12,8 @@ from sklearn.svm import SVC from .base import BaseAdapter -from ._utils import check_X_domain, _merge_source_target +from .utils import check_X_domain, source_target_split +from ._utils import _merge_source_target from ._pipeline import make_da_pipeline @@ -79,13 +80,14 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs): weights : None No weights are returned here. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, allow_multi_target=True, - return_joint=False, ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + if X_source.shape[0]: X_source = np.dot(self.pca_source_.transform(X_source), self.M_) if X_target.shape[0]: @@ -111,13 +113,14 @@ def fit(self, X, y=None, sample_domain=None, **kwargs): self : object Returns self. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, allow_multi_target=True, - return_joint=False, ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + if self.n_components is None: n_components = min(min(X_source.shape), min(X_target.shape)) else: @@ -245,13 +248,13 @@ def fit(self, X, y=None, sample_domain=None, **kwargs): self : object Returns self. """ - self.X_source_, self.X_target_ = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, allow_multi_target=True, - return_joint=False, ) + self.X_source_, self.X_target_ = source_target_split(X, sample_domain=sample_domain) Kss = pairwise_kernels(self.X_source_, metric=self.kernel) Ktt = pairwise_kernels(self.X_target_, metric=self.kernel) @@ -306,13 +309,14 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs): weights : None No weights are returned here. """ - X_source, X_target = check_X_domain( + X, sample_domain = check_X_domain( X, sample_domain, allow_multi_source=True, allow_multi_target=True, - return_joint=False, ) + X_source, X_target = source_target_split(X, sample_domain=sample_domain) + if np.array_equal(X_source, self.X_source_) and np.array_equal( X_target, self.X_target_ ): diff --git a/skada/_utils.py b/skada/_utils.py index 8af294c1..9f2863e1 100644 --- a/skada/_utils.py +++ b/skada/_utils.py @@ -6,7 +6,6 @@ import logging from numbers import Real -from typing import Optional, Set import numpy as np @@ -16,12 +15,22 @@ ledoit_wolf, shrunk_covariance, ) -from sklearn.utils import check_array, check_consistent_length - +from sklearn.utils.multiclass import type_of_target _logger = logging.getLogger('skada') _logger.setLevel(logging.DEBUG) +# Default label for datasets with source and target domains +_DEFAULT_TARGET_DOMAIN_LABEL = -2 +_DEFAULT_SOURCE_DOMAIN_LABEL = 1 + +# Default label for datasets without source domain +_DEFAULT_TARGET_DOMAIN_ONLY_LABEL = -1 + +# Default label for datasets with masked target labels +_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL = -1 +_DEFAULT_MASKED_TARGET_REGRESSION_LABEL = np.nan + def _estimate_covariance(X, shrinkage): if shrinkage is None: @@ -37,131 +46,6 @@ def _estimate_covariance(X, shrinkage): return s -def check_X_y_domain( - X, - y, - sample_domain, - allow_source: bool = True, - allow_multi_source: bool = True, - allow_target: bool = True, - allow_multi_target: bool = True, - return_indices: bool = False, - # xxx(okachaiev): most likely this needs to be removed as it doesn't fit new API - return_joint: bool = False, - allow_auto_sample_domain: bool = False, - allow_nd: bool = False, -): - """Input validation for DA estimator. - If we work in single-source and single target mode, return source and target - separately to avoid additional scan for 'sample_domain' array. - """ - - X = check_array(X, input_name='X', allow_nd=allow_nd) - y = check_array(y, force_all_finite=True, ensure_2d=False, input_name='y') - check_consistent_length(X, y) - if sample_domain is None and allow_auto_sample_domain: - sample_domain = np.ones_like(y) - # labels masked with -1 are recognized as targets, - # the rest is treated as a source - sample_domain[y == -1] = -2 - else: - sample_domain = check_array( - sample_domain, - dtype=np.int32, - ensure_2d=False, - input_name='sample_domain' - ) - check_consistent_length(X, sample_domain) - - source_idx = (sample_domain >= 0) - # xxx(okachaiev): this needs to be re-written to accommodate for a - # a new domain labeling convention without "intersections" - n_sources = np.unique(sample_domain[source_idx]).shape[0] - n_targets = np.unique(sample_domain[~source_idx]).shape[0] - if not allow_source and n_sources > 0: - raise ValueError(f"Number of sources provided is {n_sources} " - "and 'allow_source' is set to False") - if not allow_target and n_targets > 0: - raise ValueError(f"Number of targets provided is {n_targets} " - "and 'allow_target' is set to False") - if not allow_multi_source and n_sources > 1: - raise ValueError(f"Number of sources provided is {n_sources} " - "and 'allow_multi_source' is set to False") - if not allow_multi_target and n_sources > 1: - raise ValueError(f"Number of targets provided is {n_targets} " - "and 'allow_multi_target' is set to False") - - if return_indices: - # only source indices are given, target indices are ~source_idx - return source_idx - elif not return_joint: - # commonly used X, y, X_target, y_target format - return X[source_idx], y[source_idx], X[~source_idx], y[~source_idx] - else: - return X, y, sample_domain - - -# xxx(okachaiev): code duplication, just for testing -def check_X_domain( - X, - sample_domain, - *, - allow_domains: Optional[Set[int]] = None, - allow_source: bool = True, - allow_multi_source: bool = True, - allow_target: bool = True, - allow_multi_target: bool = True, - return_indices: bool = False, - # xxx(okachaiev): most likely this needs to be removed as it doesn't fit new API - return_joint: bool = True, - allow_auto_sample_domain: bool = False, -): - X = check_array(X, input_name='X') - if sample_domain is None and allow_auto_sample_domain: - # default target domain when sample_domain is not given - # xxx(okachaiev): I guess this should be -inf instead of a number - sample_domain = -2*np.ones(X.shape[0], dtype=np.int32) - else: - sample_domain = check_array( - sample_domain, - dtype=np.int32, - ensure_2d=False, - input_name='sample_domain' - ) - check_consistent_length(X, sample_domain) - if allow_domains is not None: - for domain in np.unique(sample_domain): - # xxx(okachaiev): re-definition of the wildcards - wildcard = np.inf if domain >= 0 else -np.inf - if domain not in allow_domains and wildcard not in allow_domains: - raise ValueError(f"Unknown domain label '{domain}' given") - - source_idx = (sample_domain >= 0) - n_sources = np.unique(sample_domain[source_idx]).shape[0] - n_targets = np.unique(sample_domain[~source_idx]).shape[0] - if not allow_source and n_sources > 0: - raise ValueError(f"Number of sources provided is {n_sources} " - "and 'allow_source' is set to False") - if not allow_target and n_targets > 0: - raise ValueError(f"Number of targets provided is {n_targets} " - "and 'allow_target' is set to False") - if not allow_multi_source and n_sources > 1: - raise ValueError(f"Number of sources provided is {n_sources} " - "and 'allow_multi_source' is set to False") - if not allow_multi_target and n_sources > 1: - raise ValueError(f"Number of targets provided is {n_targets} " - "and 'allow_multi_target' is set to False") - - if return_indices: - # only source indices are given, target indices are ~source_idx - return source_idx - elif not return_joint: - # commonly used X, y, X_target, y_target format - return X[source_idx], X[~source_idx] - else: - return X, sample_domain - - def _merge_source_target(X_source, X_target, sample_domain) -> np.ndarray: n_samples = X_source.shape[0] + X_target.shape[0] assert n_samples > 0 @@ -174,74 +58,40 @@ def _merge_source_target(X_source, X_target, sample_domain) -> np.ndarray: return output -def source_target_split(X, y, sample_domain=None, - sample_weight=None, return_domain=False): - r""" Split data into source and target domains +def _check_y_masking(y): + """Check that labels are properly masked + ie. labels are either -1 or >= 0 + Parameters ---------- - X : array-like of shape (n_samples, n_features) - Data to be split y : array-like of shape (n_samples,) Labels for the data - sample_domain : array-like of shape (n_samples,) - Domain labels for the data. Positive values are treated as source - domains, negative values are treated as target domains. If not given, - all samples are treated as source domains except those with y==-1. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights - return_domain : bool, default=False - Whether to return domain labels - - Returns - ------- - X_s : array-like of shape (n_samples_s, n_features) - Source data - y_s : array-like of shape (n_samples_s,) - Source labels - domain_s : array-like of shape (n_samples_s,) - Source domain labels (returned only if `return_domain` is True) - sample_weight_s : array-like of shape (n_samples_s,), default=None - Source sample weights (returned only if `sample_weight` is not None) - X_t : array-like of shape (n_samples_t, n_features) - Target data - y_t : array-like of shape (n_samples_t,) - Target labels - domain_t : array-like of shape (n_samples_t,) - Target domain labels (returned only if `return_domain` is True) - sample_weight_t : array-like of shape (n_samples_t,), - Target sample weights (returned only if `sample_weight` is not None) - """ - if sample_domain is None: - sample_domain = np.ones_like(y) - # labels masked with -1 are recognized as targets, - # the rest is treated as a source - sample_domain[y == -1] = -1 - - X_s = X[sample_domain >= 0] - y_s = y[sample_domain >= 0] - domain_s = sample_domain[sample_domain >= 0] - - X_t = X[sample_domain < 0] - y_t = y[sample_domain < 0] - domain_t = sample_domain[sample_domain < 0] - - if sample_weight is not None: - sample_weight_s = sample_weight[sample_domain >= 0] - sample_weight_t = sample_weight[sample_domain < 0] - - if return_domain: - return ( - X_s, y_s, domain_s, sample_weight_s, - X_t, y_t, domain_t, sample_weight_t - ) + # We need to check for this case first because + # type_of_target() doesnt handle nan values + if np.any(np.isnan(y)): + if y.ndim != 1: + raise ValueError("For a regression task, " + "more than 1D labels are not supported") else: - return X_s, y_s, sample_weight_s, X_t, y_t, sample_weight_t - else: - - if return_domain: - return X_s, y_s, domain_s, X_t, y_t, domain_t + return 'continuous' + + # Check if the target is a classification or regression target. + y_type = type_of_target(y) + + if y_type == 'continuous': + raise ValueError("For a regression task, " + "masked labels should be, " + f"{_DEFAULT_MASKED_TARGET_REGRESSION_LABEL}") + elif y_type == 'binary' or y_type == 'multiclass': + if (np.any(y < _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL) or + not np.any(y == _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL)): + raise ValueError("For a classification task, " + "masked labels should be, " + f"{_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL}") else: - return X_s, y_s, X_t, y_t + return 'classification' + else: + raise ValueError("Uncompatible label type: %r" % y_type) diff --git a/skada/base.py b/skada/base.py index 11229536..aac6fa7f 100644 --- a/skada/base.py +++ b/skada/base.py @@ -17,7 +17,7 @@ # xxx(okachaiev): this should be `skada.utils.check_X_y_domain` # rather than `skada._utils.check_X_y_domain` -from ._utils import check_X_domain +from .utils import check_X_domain def _estimator_has(attr): diff --git a/skada/datasets/tests/test_samples_generator.py b/skada/datasets/tests/test_samples_generator.py index e0a36956..3282f721 100644 --- a/skada/datasets/tests/test_samples_generator.py +++ b/skada/datasets/tests/test_samples_generator.py @@ -15,7 +15,7 @@ make_shifted_datasets, make_variable_frequency_dataset ) -from skada._utils import check_X_y_domain +from skada.utils import check_X_y_domain, source_target_split def test_make_dataset_from_moons_distribution(): @@ -27,11 +27,9 @@ def test_make_dataset_from_moons_distribution(): random_state=0, return_X_y=True, ) - X_source, y_source, X_target, y_target = check_X_y_domain( - X, - y=y, - sample_domain=sample_domain, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (2 * 50, 2), "X source shape mismatch" @@ -53,11 +51,9 @@ def test_make_dataset_from_multi_moons_distribution(): random_state=0, return_X_y=True, ) - X_source, y_source, X_target, y_target = check_X_y_domain( - X, - y=y, - sample_domain=sample_domain, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (3 * 2 * 50, 2), "X source shape mismatch" @@ -89,11 +85,9 @@ def test_make_shifted_blobs(): cluster_std=cluster_stds, random_state=None, ) - X_source, y_source, X_target, y_target = check_X_y_domain( - X, - y=y, - sample_domain=sample_domain, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (50, 2), "X source shape mismatch" @@ -117,11 +111,9 @@ def test_make_shifted_datasets(shift): noise=None, label="binary", ) - X_source, y_source, X_target, y_target = check_X_y_domain( - X, - y=y, - sample_domain=sample_domain, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (10 * 8, 2), "X source shape mismatch" @@ -145,11 +137,9 @@ def test_make_multi_source_shifted_datasets(shift): noise=None, label="multiclass", ) - X_source, y_source, X_target, y_target = check_X_y_domain( - X, - y=y, - sample_domain=sample_domain, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (10 * 8, 2), "X source shape mismatch" @@ -168,11 +158,9 @@ def test_make_subspace_datasets(): noise=None, label="binary", ) - X_source, y_source, X_target, y_target = check_X_y_domain( - X, - y=y, - sample_domain=sample_domain, - return_joint=False, + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (10 * 4, 2), "X source shape mismatch" @@ -194,12 +182,14 @@ def test_make_variable_frequency_dataset(): noise=None, random_state=None ) - X_source, y_source, X_target, y_target = check_X_y_domain( + X, y, sample_domain = check_X_y_domain( X, - y=y, - sample_domain=sample_domain, - return_joint=False, - allow_nd=True, + y, + sample_domain, + allow_nd=True + ) + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain ) assert X_source.shape == (3 * 10, 1, 3000), "X source shape mismatch" diff --git a/skada/metrics.py b/skada/metrics.py index 9ead547d..b185728b 100644 --- a/skada/metrics.py +++ b/skada/metrics.py @@ -17,7 +17,7 @@ from sklearn.utils.extmath import softmax from sklearn.utils.metadata_routing import _MetadataRequester, get_routing_for_object -from ._utils import check_X_y_domain +from .utils import check_X_y_domain, extract_source_indices, source_target_split # xxx(okachaiev): maybe it would be easier to reuse _BaseScorer? @@ -74,7 +74,10 @@ def _score( **params ): scorer = check_scoring(estimator, self.scoring) - source_idx = check_X_y_domain(X, y, sample_domain, return_indices=True) + + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + source_idx = extract_source_indices(sample_domain) + return self._sign * scorer( estimator, X[~source_idx], @@ -165,8 +168,10 @@ def _score(self, estimator, X, y, sample_domain=None, **params): f"The estimator {estimator!r} does not." ) - X_source, y_source, X_target, _ = check_X_y_domain( - X, y, sample_domain, return_joint=False) + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + X_source, X_target, y_source, _ = source_target_split( + X, y, sample_domain=sample_domain + ) self._fit(X_source, X_target) ws = self.weight_estimator_source_.score_samples(X_source) wt = self.weight_estimator_target_.score_samples(X_source) @@ -213,7 +218,8 @@ def _score(self, estimator, X, y, sample_domain=None, **params): f"The estimator {estimator!r} does not." ) - source_idx = check_X_y_domain(X, y, sample_domain, return_indices=True) + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + source_idx = extract_source_indices(sample_domain) proba = estimator.predict_proba( X[~source_idx], sample_domain=sample_domain[~source_idx], @@ -265,7 +271,8 @@ def _score(self, estimator, X, y, sample_domain=None, **params): f"The estimator {estimator!r} does not." ) - source_idx = check_X_y_domain(X, y, sample_domain, return_indices=True) + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + source_idx = extract_source_indices(sample_domain) proba = estimator.predict_proba( X[~source_idx], sample_domain=sample_domain[~source_idx], @@ -326,12 +333,12 @@ def __init__( self._sign = 1 if greater_is_better else -1 def _no_reduc_log_loss(self, y, y_pred): - return np.array( - [ - log_loss(y[i : i + 1], y_pred[i : i + 1], labels=np.unique(y)) - for i in range(len(y)) - ] - ) + return np.array( + [ + log_loss(y[i : i + 1], y_pred[i : i + 1], labels=np.unique(y)) + for i in range(len(y)) + ] + ) def _fit_adapt(self, features, features_target): domain_classifier = self.domain_classifier @@ -345,7 +352,8 @@ def _fit_adapt(self, features, features_target): return self def _score(self, estimator, X, y, sample_domain=None, **kwargs): - source_idx = check_X_y_domain(X, y, sample_domain, return_indices=True) + X, y, sample_domain = check_X_y_domain(X, y, sample_domain) + source_idx = extract_source_indices(sample_domain) rng = check_random_state(self.random_state) X_train, X_val, _, y_val, _, sample_domain_val = train_test_split( X[source_idx], y[source_idx], sample_domain[source_idx], @@ -354,12 +362,12 @@ def _score(self, estimator, X, y, sample_domain=None, **kwargs): features_train = estimator.get_features(X_train) features_val = estimator.get_features(X_val) features_target = estimator.get_features(X[~source_idx]) - + self._fit_adapt(features_train, features_target) N, N_target = len(features_train), len(features_target) predictions = self.domain_classifier_.predict_proba(features_val) weights = N / N_target * predictions[:, :1] / predictions[:, 1:] - + y_pred = estimator.predict_proba(X_val, sample_domain=sample_domain_val) error = self._loss_func(y_val, y_pred) assert weights.shape[0] == error.shape[0] diff --git a/skada/model_selection.py b/skada/model_selection.py index 0818919c..53dfd484 100644 --- a/skada/model_selection.py +++ b/skada/model_selection.py @@ -16,7 +16,7 @@ from sklearn.utils import check_random_state, indexable from sklearn.utils.metadata_routing import _MetadataRequester -from ._utils import check_X_domain +from .utils import check_X_domain, extract_source_indices class SplitSampleDomainRequesterMixin(_MetadataRequester): @@ -108,7 +108,16 @@ def __repr__(self): class SourceTargetShuffleSplit(BaseDomainAwareShuffleSplit): + """Source-Target-Shuffle-Split cross-validator. + Provides train/test indices to split data in train/test sets. + Each sample is used once as a test set (singleton) while the + remaining samples form the training set. + + Default split is implemented hierarchically. If first designates + a single domain as a target followed up by the single train/test + shuffle split. + """ def __init__( self, n_splits=10, *, test_size=None, train_size=None, random_state=None ): @@ -121,7 +130,8 @@ def __init__( self._default_test_size = 0.1 def _iter_indices(self, X, y=None, sample_domain=None): - indices = check_X_domain(X, sample_domain, return_indices=True) + X, sample_domain = check_X_domain(X, sample_domain) + indices = extract_source_indices(sample_domain) source_idx, = np.where(indices) target_idx, = np.where(~indices) n_source_samples = _num_samples(source_idx) @@ -162,14 +172,7 @@ class LeaveOneDomainOut(SplitSampleDomainRequesterMixin): """Leave-One-Domain-Out cross-validator. Provides train/test indices to split data in train/test sets. - Each sample is used once as a test set (singleton) while the - remaining samples form the training set. - - Default split is implemented hierarchically. If first designates - a single domain as a target followed up by the single train/test - shuffle split. """ - def __init__( self, max_n_splits=10, *, test_size=None, train_size=None, random_state=None ): @@ -246,7 +249,8 @@ def split(self, X, y=None, sample_domain=None): yield split_idx[train_idx], split_idx[test_idx] def _iter_indices(self, X, y=None, sample_domain=None): - indices = check_X_domain(X, sample_domain, return_indices=True) + X, sample_domain = check_X_domain(X, sample_domain) + indices = extract_source_indices(sample_domain) source_idx, = np.where(indices) target_idx, = np.where(~indices) n_source_samples = _num_samples(source_idx) diff --git a/skada/tests/test_utils.py b/skada/tests/test_utils.py new file mode 100644 index 00000000..d7f7479e --- /dev/null +++ b/skada/tests/test_utils.py @@ -0,0 +1,281 @@ +# Author: Yanis Lalou +# +# License: BSD 3-Clause + +import pytest + +import numpy as np +from skada.datasets import ( + make_dataset_from_moons_distribution +) + +from skada.utils import ( + check_X_y_domain, check_X_domain, + extract_source_indices +) +from skada.utils import source_target_split +from skada._utils import _check_y_masking + + +def test_check_y_masking_classification(): + y_properly_masked = np.array([-1, 1, 2, -1, 2, 1, 1]) + y_wrongfuly_masked_1 = np.array([-1, -2, 2, -1, 2, 1, 1]) + y_not_masked = np.array([1, 2, 2, 1, 2, 1, 1]) + + # Test that no ValueError is raised + _check_y_masking(y_properly_masked) + + with pytest.raises(ValueError): + _check_y_masking(y_wrongfuly_masked_1) + + with pytest.raises(ValueError): + _check_y_masking(y_not_masked) + + +def test_check_y_masking_regression(): + y_properly_masked = np.array([np.nan, 1, 2.5, -1, np.nan, 0, -1.5]) + y_not_masked = np.array([-1, -2, 2.5, -1, 2, 0, 1]) + + # Test that no ValueError is raised + _check_y_masking(y_properly_masked) + + with pytest.raises(ValueError): + _check_y_masking(y_not_masked) + + +def test_check_2d_y_masking(): + y_wrong_dim = np.array([[-1, 2], [1, 2], [1, 2]]) + + with pytest.raises(ValueError): + _check_y_masking(y_wrong_dim) + + +def test_check_X_y_domain_exceptions(): + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=50, + n_samples_target=20, + random_state=0, + return_X_y=True, + ) + + # Test that no ValueError is raised + check_X_y_domain(X, y, sample_domain=sample_domain) + + with pytest.raises(ValueError): + check_X_y_domain(X, y, sample_domain=None, allow_auto_sample_domain=False) + + +def test_check_X_domain_exceptions(): + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=50, + n_samples_target=20, + random_state=0, + return_X_y=True, + ) + + # Test that no ValueError is raised + check_X_domain(X, sample_domain=sample_domain) + + with pytest.raises(ValueError): + check_X_domain(X, sample_domain=None, allow_auto_sample_domain=False) + + +def test_source_target_split(): + n_samples_source = 50 + n_samples_target = 20 + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=n_samples_source, + n_samples_target=n_samples_target, + random_state=0, + return_X_y=True, + ) + + # Test that no ValueError is raised + _, _ = source_target_split(X, sample_domain=sample_domain) + + X_source, X_target, y_source, y_target = source_target_split( + X, y, sample_domain=sample_domain + ) + + assert X_source.shape == (2 * n_samples_source, 2), "X_source shape mismatch" + assert y_source.shape == (2 * n_samples_source, ), "y_source shape mismatch" + assert X_target.shape == (2 * n_samples_target, 2), "X_target shape mismatch" + assert y_target.shape == (2 * n_samples_target, ), "y_target shape mismatch" + + with pytest.raises(IndexError): + source_target_split(X, y[:-2], sample_domain=sample_domain) + + +def test_check_X_y_allow_exceptions(): + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=50, + n_samples_target=20, + random_state=0, + return_X_y=True, + ) + + # Generate a random_sample_domain of size len(y) + # with random integers between -5 and 5 (excluding 0) + random_sample_domain = np.random.choice( + np.concatenate((np.arange(-5, 0), np.arange(1, 6))), size=len(y) + ) + allow_source = False + allow_target = False + allow_multi_source = False + allow_multi_target = False + + positive_numbers = random_sample_domain[random_sample_domain > 0] + negative_numbers = random_sample_domain[random_sample_domain < 0] + # Count unique positive numbers + n_sources = len(np.unique(positive_numbers)) + n_targets = len(np.unique(negative_numbers)) + + with pytest.raises( + ValueError, + match=( + f"Number of sources provided is {n_sources} " + f"and 'allow_source' is set to {allow_source}" + ) + ): + check_X_y_domain( + X, y, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_source=allow_source + ) + + with pytest.raises( + ValueError, + match=( + f"Number of targets provided is {n_targets} " + f"and 'allow_target' is set to {allow_target}" + ) + ): + check_X_y_domain( + X, y, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_target=allow_target + ) + + with pytest.raises( + ValueError, + match=( + f"Number of sources provided is {n_sources} " + f"and 'allow_multi_source' is set to {allow_multi_source}" + ) + ): + check_X_y_domain( + X, y, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_multi_source=allow_multi_source + ) + + with pytest.raises( + ValueError, + match=( + f"Number of targets provided is {n_targets} " + f"and 'allow_multi_target' is set to {allow_multi_target}" + ) + ): + check_X_y_domain( + X, y, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_multi_target=allow_multi_target + ) + + +def test_check_X_allow_exceptions(): + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=50, + n_samples_target=20, + random_state=0, + return_X_y=True, + ) + + # Generate a random_sample_domain of size len(y) + # with random integers between -5 and 5 (excluding 0) + random_sample_domain = np.random.choice( + np.concatenate((np.arange(-5, 0), np.arange(1, 6))), size=len(y) + ) + allow_source = False + allow_target = False + allow_multi_source = False + allow_multi_target = False + + positive_numbers = random_sample_domain[random_sample_domain > 0] + negative_numbers = random_sample_domain[random_sample_domain < 0] + + # Count unique positive numbers + n_sources = len(np.unique(positive_numbers)) + n_targets = len(np.unique(negative_numbers)) + + with pytest.raises( + ValueError, + match=( + f"Number of sources provided is {n_sources} " + f"and 'allow_source' is set to {allow_source}" + ) + ): + check_X_domain( + X, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_source=allow_source + ) + + with pytest.raises( + ValueError, + match=( + f"Number of targets provided is {n_targets} " + f"and 'allow_target' is set to {allow_target}" + ) + ): + check_X_domain( + X, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_target=allow_target + ) + + with pytest.raises( + ValueError, + match=( + f"Number of sources provided is {n_sources} " + f"and 'allow_multi_source' is set to {allow_multi_source}" + ) + ): + check_X_domain( + X, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_multi_source=allow_multi_source + ) + + with pytest.raises( + ValueError, + match=( + f"Number of targets provided is {n_targets} " + f"and 'allow_multi_target' is set to {allow_multi_target}" + ) + ): + check_X_domain( + X, sample_domain=random_sample_domain, + allow_auto_sample_domain=False, allow_multi_target=allow_multi_target + ) + + +def test_extract_source_indices(): + n_samples_source = 50 + n_samples_target = 20 + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=n_samples_source, + n_samples_target=n_samples_target, + random_state=0, + return_X_y=True, + ) + source_idx = extract_source_indices(sample_domain) + + assert len(source_idx) == (len(sample_domain)), "source_idx shape mismatch" + assert np.sum(source_idx) == 2 * n_samples_source, "source_idx sum mismatch" + assert np.sum(~source_idx) == 2 * n_samples_target, "target_idx sum mismatch" diff --git a/skada/utils.py b/skada/utils.py new file mode 100644 index 00000000..a217baeb --- /dev/null +++ b/skada/utils.py @@ -0,0 +1,245 @@ +# Author: Yanis Lalou +# +# License: BSD 3-Clause + +from typing import Optional, Set + +import numpy as np +from itertools import chain + +from sklearn.utils import check_array, check_consistent_length + +from skada._utils import _check_y_masking +from skada._utils import ( + _DEFAULT_SOURCE_DOMAIN_LABEL, _DEFAULT_TARGET_DOMAIN_LABEL, + _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL, _DEFAULT_TARGET_DOMAIN_ONLY_LABEL +) + + +def check_X_y_domain( + X, + y, + sample_domain=None, + allow_source: bool = True, + allow_multi_source: bool = True, + allow_target: bool = True, + allow_multi_target: bool = True, + allow_auto_sample_domain: bool = True, + allow_nd: bool = False, +): + """ + Input validation for domain adaptation (DA) estimator. + If we work in single-source and single target mode, return source and target + separately to avoid additional scan for 'sample_domain' array. + + Parameters: + ---------- + X : array-like of shape (n_samples, n_features) + Input features + y : array-like of shape (n_samples,) + Target variable + sample_domain : array-like or None, optional (default=None) + Array specifying the domain labels for each sample. + allow_source : bool, optional (default=True) + Allow the presence of source domains. + allow_multi_source : bool, optional (default=True) + Allow multiple source domains. + allow_target : bool, optional (default=True) + Allow the presence of target domains. + allow_multi_target : bool, optional (default=True) + Allow multiple target domains. + allow_auto_sample_domain : bool, optional (default=True) + Allow automatic generation of sample_domain if not provided. + allow_nd : bool, optional (default=False) + Allow X and y to be N-dimensional arrays. + + Returns: + ---------- + X : array + Input features + y : array + Target variable + sample_domain : array + Array specifying the domain labels for each sample. + """ + + X = check_array(X, input_name='X', allow_nd=allow_nd) + y = check_array(y, force_all_finite=True, ensure_2d=False, input_name='y') + check_consistent_length(X, y) + + if sample_domain is None and not allow_auto_sample_domain: + raise ValueError("Either 'sample_domain' or 'allow_auto_sample_domain' " + "should be set") + elif sample_domain is None and allow_auto_sample_domain: + y_type = _check_y_masking(y) + sample_domain = _DEFAULT_SOURCE_DOMAIN_LABEL*np.ones_like(y) + # labels masked with -1 (for classification) are recognized as targets, + # labels masked with nan (for regression) are recognized as targets, + # the rest is treated as a source + if y_type == 'classification': + mask = (y == _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL) + else: + mask = (np.isnan(y)) + sample_domain[mask] = _DEFAULT_TARGET_DOMAIN_LABEL + + source_idx = extract_source_indices(sample_domain) + + # xxx(okachaiev): this needs to be re-written to accommodate for a + # a new domain labeling convention without "intersections" + n_sources = np.unique(sample_domain[source_idx]).shape[0] + n_targets = np.unique(sample_domain[~source_idx]).shape[0] + + if not allow_source and n_sources > 0: + raise ValueError(f"Number of sources provided is {n_sources} " + "and 'allow_source' is set to False") + if not allow_target and n_targets > 0: + raise ValueError(f"Number of targets provided is {n_targets} " + "and 'allow_target' is set to False") + if not allow_multi_source and n_sources > 1: + raise ValueError(f"Number of sources provided is {n_sources} " + "and 'allow_multi_source' is set to False") + if not allow_multi_target and n_sources > 1: + raise ValueError(f"Number of targets provided is {n_targets} " + "and 'allow_multi_target' is set to False") + + return X, y, sample_domain + + +# xxx(okachaiev): code duplication, just for testing +def check_X_domain( + X, + sample_domain, + *, + allow_domains: Optional[Set[int]] = None, + allow_source: bool = True, + allow_multi_source: bool = True, + allow_target: bool = True, + allow_multi_target: bool = True, + allow_auto_sample_domain: bool = True, +): + """ + Input validation for domain adaptation (DA) estimator. + If we work in single-source and single target mode, return source and target + separately to avoid additional scan for 'sample_domain' array. + + Parameters: + ---------- + X : array-like of shape (n_samples, n_features) + Input features. + sample_domain : array-like of shape (n_samples,) + Domain labels for each sample. + allow_domains : set of int, optional (default=None) + Set of allowed domain labels. If provided, only these domain labels are allowed. + allow_source : bool, optional (default=True) + Allow the presence of source domains. + allow_multi_source : bool, optional (default=True) + Allow multiple source domains. + allow_target : bool, optional (default=True) + Allow the presence of target domains. + allow_multi_target : bool, optional (default=True) + Allow multiple target domains. + allow_auto_sample_domain : bool, optional (default=True) + Allow automatic generation of sample_domain if not provided. + + Returns: + ---------- + X : array + Input features. + sample_domain : array + Combined domain labels for source and target domains. + """ + X = check_array(X, input_name='X') + + if sample_domain is None and not allow_auto_sample_domain: + raise ValueError("Either 'sample_domain' or 'allow_auto_sample_domain' " + "should be set") + elif sample_domain is None and allow_auto_sample_domain: + # default target domain when sample_domain is not given + # The idea is that with no labels we always assume + # target domain (_DEFAULT_TARGET_DOMAIN_ONLY_LABEL) + sample_domain = ( + _DEFAULT_TARGET_DOMAIN_ONLY_LABEL * np.ones(X.shape[0], dtype=np.int32) + ) + + source_idx = extract_source_indices(sample_domain) + check_consistent_length(X, sample_domain) + + if allow_domains is not None: + for domain in np.unique(sample_domain): + # xxx(okachaiev): re-definition of the wildcards + wildcard = np.inf if domain >= 0 else -np.inf + if domain not in allow_domains and wildcard not in allow_domains: + raise ValueError(f"Unknown domain label '{domain}' given") + + n_sources = np.unique(sample_domain[source_idx]).shape[0] + n_targets = np.unique(sample_domain[~source_idx]).shape[0] + + if not allow_source and n_sources > 0: + raise ValueError(f"Number of sources provided is {n_sources} " + "and 'allow_source' is set to False") + if not allow_target and n_targets > 0: + raise ValueError(f"Number of targets provided is {n_targets} " + "and 'allow_target' is set to False") + if not allow_multi_source and n_sources > 1: + raise ValueError(f"Number of sources provided is {n_sources} " + "and 'allow_multi_source' is set to False") + if not allow_multi_target and n_sources > 1: + raise ValueError(f"Number of targets provided is {n_targets} " + "and 'allow_multi_target' is set to False") + + return X, sample_domain + + +def extract_source_indices(sample_domain): + """Extract the indices of the source samples. + + Parameters: + ---------- + sample_domain : array-like of shape (n_samples,) + Array specifying the domain labels for each sample. + + Returns: + ---------- + source_idx : array + Boolean array indicating source indices. + """ + sample_domain = check_array( + sample_domain, + dtype=np.int32, + ensure_2d=False, + input_name='sample_domain' + ) + + source_idx = (sample_domain >= 0) + return source_idx + + +def source_target_split( + *arrays, + sample_domain +): + r""" Split data into source and target domains + + Parameters + ---------- + *arrays : sequence of array-like of identical shape (n_samples, n_features) + Input features + sample_domain : array-like of shape (n_samples,) + Array specifying the domain labels for each sample. + + Returns + ------- + splits : list, length=2 * len(arrays) + List containing source-target split of inputs. + """ + + if len(arrays) == 0: + raise ValueError("At least one array required as input") + + check_consistent_length(arrays) + + source_idx = extract_source_indices(sample_domain) + + return list(chain.from_iterable( + (a[source_idx], a[~source_idx]) for a in arrays + ))