Skip to content

Commit

Permalink
[MRG] Handle scalar sample domain (#267)
Browse files Browse the repository at this point in the history
* handle scalar sample domain

* add tests

---------

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
antoinecollas and rflamary authored Oct 25, 2024
1 parent 54f91d6 commit 65f1659
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
42 changes: 42 additions & 0 deletions skada/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,48 @@ def test_check_X_domain_exceptions():
check_X_domain(X, sample_domain=None, allow_auto_sample_domain=False)


def test_check_X_domain_scalar():
X, _, sample_domain = make_dataset_from_moons_distribution(
pos_source=0.1,
pos_target=0.9,
n_samples_source=50,
n_samples_target=20,
random_state=0,
return_X_y=True,
)

# Test scalar sample_domain
X_target = X[sample_domain == -2]
returned_X_target, sample_domain_target = check_X_domain(X_target, sample_domain=-2)

assert sample_domain_target.shape[0] == X_target.shape[0]
assert np.array_equal(returned_X_target, X_target)
assert np.array_equal(sample_domain_target, -2 * np.ones(X_target.shape[0]))


def test_check_X_y_domain_scalar():
X, y, sample_domain = make_dataset_from_moons_distribution(
pos_source=0.1,
pos_target=0.9,
n_samples_source=50,
n_samples_target=20,
random_state=0,
return_X_y=True,
)

# Test scalar sample_domain
X_target = X[sample_domain == -2]
y_target = y[sample_domain == -2]
returned_X_target, returned_y_target, sample_domain_target = check_X_y_domain(
X_target, y_target, sample_domain=-2
)

assert sample_domain_target.shape[0] == X_target.shape[0]
assert np.array_equal(returned_X_target, X_target)
assert np.array_equal(returned_y_target, y_target)
assert np.array_equal(sample_domain_target, -2 * np.ones(X_target.shape[0]))


def test_source_target_split():
n_samples_source = 50
n_samples_target = 20
Expand Down
16 changes: 12 additions & 4 deletions skada/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def check_X_y_domain(
Input features
y : array-like of shape (n_samples,)
Target variable
sample_domain : array-like or None, optional (default=None)
Array specifying the domain labels for each sample.
sample_domain : array-like, scalar, or None, optional (default=None)
Array specifying the domain labels for each sample. A scalar value
can be provided to assign all samples to the same domain.
allow_source : bool, optional (default=True)
Allow the presence of source domains.
allow_multi_source : bool, optional (default=True)
Expand Down Expand Up @@ -91,6 +92,9 @@ def check_X_y_domain(
mask = (np.isnan(y))
sample_domain[mask] = _DEFAULT_TARGET_DOMAIN_LABEL

if np.isscalar(sample_domain):
sample_domain = sample_domain*np.ones_like(y)

source_idx = extract_source_indices(sample_domain)

# xxx(okachaiev): this needs to be re-written to accommodate for a
Expand Down Expand Up @@ -136,8 +140,9 @@ def check_X_domain(
----------
X : array-like of shape (n_samples, n_features)
Input features.
sample_domain : array-like of shape (n_samples,)
Domain labels for each sample.
sample_domain : array-like, scalar, or None, optional (default=None)
Array specifying the domain labels for each sample. A scalar value
can be provided to assign all samples to the same domain.
allow_domains : set of int, optional (default=None)
Set of allowed domain labels. If provided, only these domain labels are allowed.
allow_source : bool, optional (default=True)
Expand Down Expand Up @@ -173,6 +178,9 @@ def check_X_domain(
_DEFAULT_TARGET_DOMAIN_ONLY_LABEL * np.ones(X.shape[0], dtype=np.int32)
)

if np.isscalar(sample_domain):
sample_domain = sample_domain * np.ones(X.shape[0], dtype=np.int32)

source_idx = extract_source_indices(sample_domain)
check_consistent_length(X, sample_domain)

Expand Down

0 comments on commit 65f1659

Please sign in to comment.