Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Modification source_target_merge function behaviours #71

Merged
merged 26 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
26ccd48
source_target_merge accepts now N dimenstion array + name change + te…
YanisLalou Feb 2, 2024
222a773
Merge branch 'main' into source_target_merge_branch
YanisLalou Feb 2, 2024
f069d7e
wrong import
YanisLalou Feb 2, 2024
d1b18e8
Fix + new test cases
YanisLalou Feb 2, 2024
8a69879
Fix typo
YanisLalou Feb 2, 2024
231df81
Merge branch 'main' into source_target_merge_branch
YanisLalou Feb 5, 2024
a79acd8
_merge_arrays() now accepts multiple arrays as input. The len(arrays)…
YanisLalou Feb 6, 2024
5e5134a
Removed a modified file from pull request
YanisLalou Feb 6, 2024
3451e93
Update flake8.yaml to actuallky run! (#72)
rflamary Feb 5, 2024
1afde5b
Fix flake8 for utils and tests (#73)
kachayev Feb 5, 2024
11921c7
source_target_merge accepts now N dimenstion array + name change + te…
YanisLalou Feb 2, 2024
08e6c20
wrong import
YanisLalou Feb 2, 2024
cd8ecfd
_merge_arrays() now accepts multiple arrays as input. The len(arrays)…
YanisLalou Feb 6, 2024
6cc5f9e
Merge branch 'main' into source_target_merge_branch
YanisLalou Feb 6, 2024
2e909fc
Removed a modified file from pull request
YanisLalou Feb 7, 2024
b70a925
flake8
YanisLalou Feb 7, 2024
cf892fa
Loop optimisation + Remove redundant raise errors
YanisLalou Feb 7, 2024
c6a5378
Merge branch 'main' into source_target_merge_branch
kachayev Feb 8, 2024
bcbf177
Add test case
YanisLalou Feb 9, 2024
542eaea
confict merge
YanisLalou Feb 9, 2024
536bcb2
flake8
YanisLalou Feb 9, 2024
0eb2113
Merge branch 'main' into source_target_merge_branch
rflamary Feb 12, 2024
dd711d1
Merge branch 'main' into source_target_merge_branch
rflamary Feb 16, 2024
e7403d4
Merge branch 'main' into source_target_merge_branch
kachayev Feb 17, 2024
6010c81
Change "duo" for "pair" + simplify small things
YanisLalou Feb 20, 2024
4445f07
Change source_target_merge return type
YanisLalou Feb 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions skada/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from .utils import (
check_X_domain,
check_X_y_domain,
source_target_split
source_target_split,
source_target_merge
)
from ._utils import (
_estimate_covariance,
_merge_source_target,
_estimate_covariance
)

from ._pipeline import make_da_pipeline
Expand Down Expand Up @@ -90,7 +90,9 @@ def adapt(self, X, y=None, sample_domain=None):
# thus there's no need to perform any transformations
if X_source.shape[0] > 0:
X_source = self.ot_transport_.transform(Xs=X_source)
X_adapt = _merge_source_target(X_source, X_target, sample_domain)
X_adapt, _ = source_target_merge(
X_source, X_target, sample_domain=sample_domain
)
return X_adapt

@abstractmethod
Expand Down Expand Up @@ -648,7 +650,9 @@ def adapt(self, X, y=None, sample_domain=None):

X_source_adapt = np.dot(X_source, self.cov_source_inv_sqrt_)
X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_)
X_adapt = _merge_source_target(X_source_adapt, X_target, sample_domain)
X_adapt, _ = source_target_merge(
X_source_adapt, X_target, sample_domain=sample_domain
)
return X_adapt


Expand Down
6 changes: 4 additions & 2 deletions skada/_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .base import BaseAdapter
from .utils import check_X_domain, source_target_split
from ._utils import _merge_source_target
from .utils import source_target_merge
from ._pipeline import make_da_pipeline


Expand Down Expand Up @@ -93,7 +93,9 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs):
if X_target.shape[0]:
X_target = np.dot(self.pca_target_.transform(X_target), self.M_)
# xxx(okachaiev): this could be done through a more high-level API
X_adapt = _merge_source_target(X_source, X_target, sample_domain)
X_adapt, _ = source_target_merge(
X_source, X_target, sample_domain=sample_domain
)
return X_adapt

def fit(self, X, y=None, sample_domain=None, **kwargs):
Expand Down
12 changes: 0 additions & 12 deletions skada/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,6 @@ def _estimate_covariance(X, shrinkage):
return s


def _merge_source_target(X_source, X_target, sample_domain) -> np.ndarray:
n_samples = X_source.shape[0] + X_target.shape[0]
assert n_samples > 0
if X_source.shape[0] > 0:
output = np.zeros((n_samples, X_source.shape[1]), dtype=X_source.dtype)
output[sample_domain >= 0] = X_source
else:
output = np.zeros((n_samples, X_target.shape[1]), dtype=X_target.dtype)
output[sample_domain < 0] = X_target
return output


def _check_y_masking(y):
"""Check that labels are properly masked
ie. labels are either -1 or >= 0
Expand Down
134 changes: 131 additions & 3 deletions skada/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@

import numpy as np

from skada.datasets import make_dataset_from_moons_distribution
from skada.datasets import (
make_dataset_from_moons_distribution
)
from skada.utils import (
check_X_domain,
check_X_y_domain,
check_X_domain,
extract_source_indices,
source_target_split
source_target_split,
source_target_merge

)
from skada._utils import _check_y_masking

Expand Down Expand Up @@ -290,3 +294,127 @@ def test_extract_source_indices():
assert len(source_idx) == (len(sample_domain)), "source_idx shape mismatch"
assert np.sum(source_idx) == 2 * n_samples_source, "source_idx sum mismatch"
assert np.sum(~source_idx) == 2 * n_samples_target, "target_idx sum mismatch"


def test_source_target_merge():
n_samples_source = 50
n_samples_target = 20
X, y, sample_domain = make_dataset_from_moons_distribution(
pos_source=0.1,
pos_target=0.9,
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
random_state=0,
return_X_y=True,
)

X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

# Test that no Error is raised for a 2D array
samples, _ = source_target_merge(X_source, X_target, sample_domain=sample_domain)
assert (
samples.shape[0] == X_source.shape[0] + X_target.shape[0]
), "target_samples shape mismatch"

# Test that no Error is raised for a 1D array
labels, _ = source_target_merge(y_source, y_target, sample_domain=sample_domain)
assert (
labels.shape[0] == y_source.shape[0] + y_target.shape[0]
), "labels shape mismatch"

# Test with empty samples
with pytest.raises(ValueError,
match="Only one array can be None or empty in each pair"
):
_ = source_target_merge(np.array([]), np.array([]), sample_domain=np.array([]))

# Test that no Error with one empty sample
_ = source_target_merge(
X_source,
np.array([]),
sample_domain=np.array([1]*X_source.shape[0])
)
_ = source_target_merge(
np.array([]),
X_target,
sample_domain=np.array([-1]*X_target.shape[0])
)

# Test consistent length
with pytest.raises(ValueError,
match="Inconsistent number of samples in source-target arrays "
"and the number infered in the sample_domain"
):
_ = source_target_merge(X_source[0], X_target[1], sample_domain=sample_domain)

# Test no sample domain
_ = source_target_merge(X_source, X_target, sample_domain=None)

# Test odd number of array
with pytest.raises(ValueError,
match="Even number of arrays required as input"
):
_ = source_target_merge(
X_source,
X_target,
y_source,
sample_domain=sample_domain
)

# Test >2 arrays
_ = source_target_merge(
X_source,
X_target,
y_source,
y_target,
sample_domain=sample_domain
)

# Test one array
with pytest.raises(ValueError,
match="At least two array required as input"
):
_ = source_target_merge(X_source, sample_domain=sample_domain)

# Test y_target = None + Inconsistent number of samples in source-target
with pytest.raises(ValueError,
match="Inconsistent number of samples in source-target arrays "
"and the number infered in the sample_domain"
):
_ = source_target_merge(
X_source,
X_target,
np.ones_like(sample_domain),
None,
sample_domain=sample_domain
)

# Test y_target = None + Consistent number of samples in source-target
_ = source_target_merge(
X_source,
X_target,
y_source,
None,
sample_domain=sample_domain
)

# Test 2 None in a pair of arrays
with pytest.raises(ValueError,
match="Only one array can be None or empty in each pair"
):
_ = source_target_merge(None, None, sample_domain=sample_domain)

# Test 1 None in 2 pair of arrays
_ = source_target_merge(X_source, None, y_source, None, sample_domain=sample_domain)

# Test inconsistent number of features
with pytest.raises(ValueError,
match="Inconsistent number of features in source-target arrays"
):
_ = source_target_merge(
X_source[:, :-1],
X_target,
sample_domain=sample_domain
)
Loading
Loading