Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TO_REVIEW] Add utilities functions to the doc #227

Merged
merged 12 commits into from
Sep 5, 2024
25 changes: 22 additions & 3 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ DA pipeline
Utilities
^^^^^^^^^

.. autosummary::
:toctree: gen_modules/
:template: function.rst




source_target_split
per_domain_split



Expand Down Expand Up @@ -221,3 +223,20 @@ Datasets :py:mod:`skada.datasets`
make_variable_frequency_dataset


Utilities :py:mod:`skada.utils`
--------------------------------

.. currentmodule:: skada.utils

.. automodule:: skada.utils
:no-members:
:no-inherited-members:

.. autosummary::
:toctree: gen_modules/
:template: function.rst

check_X_y_domain
extract_source_indices
extract_domains_indices
source_target_merge
4 changes: 2 additions & 2 deletions skada/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def _find_y_type(y):
# Check if the target is a classification or regression target.
y_type = type_of_target(y)

if y_type == "continuous":
if y_type in ["continuous", "continuous-multioutput"]:
return Y_Type.CONTINUOUS
elif y_type == "binary" or y_type == "multiclass":
elif y_type in ["binary", "multiclass"]:
return Y_Type.DISCRETE
else:
# Here y_type is 'multilabel-indicator', 'continuous-multioutput',
Expand Down
34 changes: 33 additions & 1 deletion skada/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import pytest
from sklearn.utils import check_random_state

from skada._utils import _check_y_masking, _merge_domain_outputs
from skada._utils import (
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
_check_y_masking,
_merge_domain_outputs,
)
from skada.datasets import make_dataset_from_moons_distribution
from skada.utils import (
check_X_domain,
Expand Down Expand Up @@ -416,6 +420,34 @@ def test_extract_domains_indices():


def test_source_target_merge():
# Test simple source-target merge with 2 domains
X_source = np.array([[1, 2], [3, 4], [5, 6]])
X_target = np.array([[7, 8], [9, 10]])
y_source = np.array([0, 1, 1])
y_target = None
sample_domain = np.array([1, 1, 1, -2, -2])

X, y, _ = source_target_merge(
X_source, X_target, y_source, y_target, sample_domain=sample_domain
)

np.testing.assert_array_equal(
X, np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
)
np.testing.assert_array_equal(
y,
np.array(
[
0,
1,
1,
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
]
),
)

# Test moons dataset with 2 domains
n_samples_source = 50
n_samples_target = 20
X, y, sample_domain = make_dataset_from_moons_distribution(
Expand Down
51 changes: 28 additions & 23 deletions skada/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
from scipy.optimize import LinearConstraint, minimize
from sklearn.utils import check_array, check_consistent_length
from sklearn.utils.multiclass import type_of_target

from skada._utils import (
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
Expand All @@ -20,6 +19,7 @@
_DEFAULT_TARGET_DOMAIN_ONLY_LABEL,
_check_y_masking,
Y_Type,
_find_y_type
)


Expand Down Expand Up @@ -355,7 +355,7 @@ def source_target_merge(
*arrays,
sample_domain: Optional[np.ndarray] = None
) -> Sequence[np.ndarray]:
f""" Merge source and target domain data based on sample domain labels.
"""Merge source and target domain data based on sample domain labels.

Parameters
----------
Expand Down Expand Up @@ -387,33 +387,29 @@ def source_target_merge(
--------
>>> X_source = np.array([[1, 2], [3, 4], [5, 6]])
>>> X_target = np.array([[7, 8], [9, 10]])
>>> sample_domain = np.array([0, 0, 1, 1])
>>> X, _ = source_target_merge(X_source, X_target, sample_domain = sample_domain)
>>> X, sample_domain = source_target_merge(X_source, X_target)
>>> X
np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
array([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10]])
>>> sample_domain
array([ 1., 1., 1., -2., -2.])

>>> X_source = np.array([[1, 2], [3, 4], [5, 6]])
>>> X_target = np.array([[7, 8], [9, 10]])
>>> y_source = np.array([0, 1, 1])
>>> y_target = None
>>> sample_domain = np.array([0, 0, 1, 1])
>>> X, y, _ = source_target_merge(
X_source,
X_target,
y_source,
y_target,
sample_domain = sample_domain
)
>>> X, y, _ = source_target_merge(X_source, X_target, y_source, y_target)
>>> X
np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])

array([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10]])
>>> y
np.array([0,
1,
1,
{_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL},
{_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL}
])
array([ 0, 1, 1, {_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL}, {_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL}])
"""

arrays = list(arrays) # Convert to list to be able to modify it
Expand Down Expand Up @@ -490,10 +486,10 @@ def source_target_merge(

pair_index = i+1 if index_is_empty == i else i

y_type = type_of_target(arrays[pair_index])
y_type = _find_y_type(arrays[pair_index])
if y_type == Y_Type.DISCRETE:
default_masked_label = _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL
else:
elif y_type == Y_Type.CONTINUOUS:
default_masked_label = _DEFAULT_MASKED_TARGET_REGRESSION_LABEL

arrays[index_is_empty] = (
Expand Down Expand Up @@ -523,6 +519,15 @@ def source_target_merge(
return (*merges, sample_domain)


# Update the docstring to replace placeholders with actual values
source_target_merge.__doc__ = source_target_merge.__doc__.format(
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL=_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL=_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
_DEFAULT_SOURCE_DOMAIN_LABEL=_DEFAULT_SOURCE_DOMAIN_LABEL,
_DEFAULT_TARGET_DOMAIN_LABEL=_DEFAULT_TARGET_DOMAIN_LABEL
)


def _merge_arrays(
array_source,
array_target,
Expand Down
Loading