From 65f16595bfcfa3f385bc68a9fac59df0609164ef Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Fri, 25 Oct 2024 11:27:39 +0200 Subject: [PATCH] [MRG] Handle scalar sample domain (#267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * handle scalar sample domain * add tests --------- Co-authored-by: RĂ©mi Flamary --- skada/tests/test_utils.py | 42 +++++++++++++++++++++++++++++++++++++++ skada/utils.py | 16 +++++++++++---- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/skada/tests/test_utils.py b/skada/tests/test_utils.py index cd065565..3cf2745b 100644 --- a/skada/tests/test_utils.py +++ b/skada/tests/test_utils.py @@ -103,6 +103,48 @@ def test_check_X_domain_exceptions(): check_X_domain(X, sample_domain=None, allow_auto_sample_domain=False) +def test_check_X_domain_scalar(): + X, _, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=50, + n_samples_target=20, + random_state=0, + return_X_y=True, + ) + + # Test scalar sample_domain + X_target = X[sample_domain == -2] + returned_X_target, sample_domain_target = check_X_domain(X_target, sample_domain=-2) + + assert sample_domain_target.shape[0] == X_target.shape[0] + assert np.array_equal(returned_X_target, X_target) + assert np.array_equal(sample_domain_target, -2 * np.ones(X_target.shape[0])) + + +def test_check_X_y_domain_scalar(): + X, y, sample_domain = make_dataset_from_moons_distribution( + pos_source=0.1, + pos_target=0.9, + n_samples_source=50, + n_samples_target=20, + random_state=0, + return_X_y=True, + ) + + # Test scalar sample_domain + X_target = X[sample_domain == -2] + y_target = y[sample_domain == -2] + returned_X_target, returned_y_target, sample_domain_target = check_X_y_domain( + X_target, y_target, sample_domain=-2 + ) + + assert sample_domain_target.shape[0] == X_target.shape[0] + assert np.array_equal(returned_X_target, X_target) + assert np.array_equal(returned_y_target, y_target) + assert np.array_equal(sample_domain_target, -2 * np.ones(X_target.shape[0])) + + def test_source_target_split(): n_samples_source = 50 n_samples_target = 20 diff --git a/skada/utils.py b/skada/utils.py index ff555202..79332ecc 100644 --- a/skada/utils.py +++ b/skada/utils.py @@ -46,8 +46,9 @@ def check_X_y_domain( Input features y : array-like of shape (n_samples,) Target variable - sample_domain : array-like or None, optional (default=None) - Array specifying the domain labels for each sample. + sample_domain : array-like, scalar, or None, optional (default=None) + Array specifying the domain labels for each sample. A scalar value + can be provided to assign all samples to the same domain. allow_source : bool, optional (default=True) Allow the presence of source domains. allow_multi_source : bool, optional (default=True) @@ -91,6 +92,9 @@ def check_X_y_domain( mask = (np.isnan(y)) sample_domain[mask] = _DEFAULT_TARGET_DOMAIN_LABEL + if np.isscalar(sample_domain): + sample_domain = sample_domain*np.ones_like(y) + source_idx = extract_source_indices(sample_domain) # xxx(okachaiev): this needs to be re-written to accommodate for a @@ -136,8 +140,9 @@ def check_X_domain( ---------- X : array-like of shape (n_samples, n_features) Input features. - sample_domain : array-like of shape (n_samples,) - Domain labels for each sample. + sample_domain : array-like, scalar, or None, optional (default=None) + Array specifying the domain labels for each sample. A scalar value + can be provided to assign all samples to the same domain. allow_domains : set of int, optional (default=None) Set of allowed domain labels. If provided, only these domain labels are allowed. allow_source : bool, optional (default=True) @@ -173,6 +178,9 @@ def check_X_domain( _DEFAULT_TARGET_DOMAIN_ONLY_LABEL * np.ones(X.shape[0], dtype=np.int32) ) + if np.isscalar(sample_domain): + sample_domain = sample_domain * np.ones(X.shape[0], dtype=np.int32) + source_idx = extract_source_indices(sample_domain) check_consistent_length(X, sample_domain)