From f0802453b212fe81ff4ba830219723602278f2b7 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 18 Dec 2023 12:02:51 +0100 Subject: [PATCH 1/6] Fix label masking for the dataset pack function --- skada/datasets/_base.py | 17 ++++++++++------- skada/datasets/tests/test_base.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) create mode 100644 skada/datasets/tests/test_base.py diff --git a/skada/datasets/_base.py b/skada/datasets/_base.py index 57cb7f33..4524de00 100644 --- a/skada/datasets/_base.py +++ b/skada/datasets/_base.py @@ -191,11 +191,14 @@ def pack( sample_domains.append(np.ones_like(y)*domain_id) domain_labels[domain_name] = domain_id # xxx(okachaiev): code duplication, re-write when API is fixed + dtype = None for domain_name in as_targets: domain_id = self.domain_names_[domain_name] target = self.get_domain(domain_name) if len(target) == 1: X, = target + # xxx(okachaiev): for what it's worth, we should likely to + # move the decision about dtype to the very end of the list y = -np.ones(X.shape[0], dtype=np.int32) elif len(target) == 2: X, y = target @@ -203,15 +206,15 @@ def pack( raise ValueError("Invalid definition for domain data") if train: if mask is not None: - y = np.array([mask] * X.shape[0]) - elif y.dtype == np.int32: - y = -np.ones(X.shape[0], dtype=np.int32) + y = np.array([mask] * X.shape[0], dtype=dtype) + elif y.dtype in (np.int32, np.int64): + y = -np.ones(X.shape[0], dtype=y.dtype) # make sure that the mask is reused on the next iteration - mask = -1 - elif y.dtype == np.float32: - y = np.array([np.nan] * X.shape[0]) + mask, dtype = -1, y.dtype + elif y.dtype in (np.float32, np.float64): + y = np.array([np.nan] * X.shape[0], dtype=y.dtype) # make sure that the mask is reused on the next iteration - mask = np.nan + mask, dtype = np.nan, y.dtype # xxx(okachaiev): this is horribly inefficient, rewrite when API is fixed Xs.append(X) ys.append(y) diff --git a/skada/datasets/tests/test_base.py b/skada/datasets/tests/test_base.py new file mode 100644 index 00000000..065f1d47 --- /dev/null +++ b/skada/datasets/tests/test_base.py @@ -0,0 +1,30 @@ +# Author: Oleksii Kachaiev +# +# License: BSD 3-Clause + +import numpy as np +from numpy.testing import assert_array_equal + +from skada.datasets import DomainAwareDataset + + +def test_dataset_train_label_masking(): + dataset = DomainAwareDataset() + dataset.add_domain(np.array([1., 2.]), np.array([1, 2]), 's1') + dataset.add_domain(np.array([10., 20., 30.]), np.array([10, 20, 30]), 't1') + X, y, sample_domain = dataset.pack_for_train(as_sources=['s1'], as_targets=['t1']) + + # test shape of the output + assert X.shape == (5,) + assert y.shape == (5,) + assert sample_domain.shape == (5,) + assert (sample_domain > 0).sum() == 2 + assert (sample_domain < 0).sum() == 3 + + # test label masking + assert_array_equal(y[sample_domain < 0], np.array([-1, -1, -1])) + assert np.all(y[sample_domain > 0]) + + # custom mask + X, y, sample_domain = dataset.pack_for_train(as_sources=['s1'], as_targets=['t1'], mask=-10) + assert_array_equal(y[sample_domain < 0], np.array([-10, -10, -10])) \ No newline at end of file From ed21873417977f2b9332e26217ef87a4bc2223df Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 18 Dec 2023 22:54:07 +0100 Subject: [PATCH 2/6] Formatting --- skada/datasets/tests/test_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/skada/datasets/tests/test_base.py b/skada/datasets/tests/test_base.py index 065f1d47..a8c810a2 100644 --- a/skada/datasets/tests/test_base.py +++ b/skada/datasets/tests/test_base.py @@ -24,7 +24,8 @@ def test_dataset_train_label_masking(): # test label masking assert_array_equal(y[sample_domain < 0], np.array([-1, -1, -1])) assert np.all(y[sample_domain > 0]) - + # custom mask - X, y, sample_domain = dataset.pack_for_train(as_sources=['s1'], as_targets=['t1'], mask=-10) - assert_array_equal(y[sample_domain < 0], np.array([-10, -10, -10])) \ No newline at end of file + X, y, sample_domain = dataset.pack_for_train( + as_sources=['s1'], as_targets=['t1'], mask=-10) + assert_array_equal(y[sample_domain < 0], np.array([-10, -10, -10])) From 76710a64df34c108c2739a9906107363145e795b Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 18 Dec 2023 23:21:23 +0100 Subject: [PATCH 3/6] Remove masked samples before fitting --- skada/_pipeline.py | 1 + skada/base.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/skada/_pipeline.py b/skada/_pipeline.py index c318c7f1..82af25b4 100644 --- a/skada/_pipeline.py +++ b/skada/_pipeline.py @@ -80,6 +80,7 @@ def make_da_pipeline( # note that we generate names before wrapping estimators into the selector # xxx(okachaiev): unwrap from the selector when passed explicitly steps = _wrap_with_selectors(_name_estimators(steps), default_selector) + steps[-1][1]._mark_as_final() return Pipeline(steps, memory=memory, verbose=verbose) diff --git a/skada/base.py b/skada/base.py index f443f6ba..91b926ad 100644 --- a/skada/base.py +++ b/skada/base.py @@ -111,6 +111,7 @@ def __init__(self, base_estimator: BaseEstimator, **kwargs): super().__init__() self.base_estimator = base_estimator self.base_estimator.set_params(**kwargs) + self._is_final = False # xxx(okachaiev): should this be a metadata routing object instead of request? def get_metadata_routing(self): @@ -206,6 +207,23 @@ def decision_function(self, X, **params): def score(self, X, y, **params): return self._route_to_estimator('score', X, y=y, **params) + def _mark_as_final(self) -> 'BaseSelector': + self._is_final = True + return self + + def _remove_masked(self, X, y, routed_params): + if not self._is_final: + return X, y, routed_params + # in case the estimator is marked as final in the pipeline, + # the selector is responsible for removing masked labels + # from the targets + # xxx(okachaiev): make sure to recognize all labels + idx = (y != -1).copy() + X = X[idx] + y = y[idx] + routed_params = {k: v[idx] for k, v in routed_params.items()} + return X, y, routed_params + class Shared(BaseSelector): @@ -228,6 +246,7 @@ def fit(self, X, y, **params): if k != 'X' and k in routed_params: routed_params[k] = v X = X['X'] + X, y, routed_params = self._remove_masked(X, y, routed_params) estimator = clone(self.base_estimator) estimator.fit(X, y, **routed_params) self.base_estimator_ = estimator @@ -283,6 +302,7 @@ def fit(self, X, y, **params): if k != 'X' and k in routed_params: routed_params[k] = v X = X['X'] + X, y, routed_params = self._remove_masked(X, y, routed_params) estimators = {} # xxx(okachaiev): maybe return_index? for domain_label in np.unique(sample_domain): From 2eb2750717a45ca9105911c39c1512334b69b1f7 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 19 Dec 2023 09:45:51 +0100 Subject: [PATCH 4/6] Extend test to include pack_test as well --- skada/datasets/tests/test_base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/skada/datasets/tests/test_base.py b/skada/datasets/tests/test_base.py index a8c810a2..385fae0e 100644 --- a/skada/datasets/tests/test_base.py +++ b/skada/datasets/tests/test_base.py @@ -23,9 +23,16 @@ def test_dataset_train_label_masking(): # test label masking assert_array_equal(y[sample_domain < 0], np.array([-1, -1, -1])) - assert np.all(y[sample_domain > 0]) + assert np.all(y[sample_domain > 0] > 0) # custom mask X, y, sample_domain = dataset.pack_for_train( as_sources=['s1'], as_targets=['t1'], mask=-10) assert_array_equal(y[sample_domain < 0], np.array([-10, -10, -10])) + + # test packing does not perform masking + X, y, sample_domain = dataset.pack_for_test(as_targets=['t1']) + assert X.shape == (3,) + assert y.shape == (3,) + assert sample_domain.shape == (3,) + assert np.all(y > 0) From 9a076ada3c9fe0a6aa31e0371619c15f37331c37 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 19 Dec 2023 12:15:58 +0100 Subject: [PATCH 5/6] Fix CORAL --- skada/_mapping.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/skada/_mapping.py b/skada/_mapping.py index a4bf3241..046b78f3 100644 --- a/skada/_mapping.py +++ b/skada/_mapping.py @@ -638,12 +638,16 @@ def adapt(self, X, y=None, sample_domain=None): weights : None No weights are returned here. """ - X_adapt = np.dot(X, self.cov_source_inv_sqrt_) - X_adapt = np.dot(X_adapt, self.cov_target_sqrt_) - # xxx(okachaiev): i feel this is straight up incorrect, - # in the previous version of the code only source data - # was transformed, and the target was never updated. i - # guess it should just 'passthrough' for a target space + X_source, X_target = check_X_domain( + X, + sample_domain, + allow_multi_source=True, + allow_multi_target=True, + return_joint=False, + ) + X_source_adapt = np.dot(X_source, self.cov_source_inv_sqrt_) + X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_) + X_adapt = _merge_source_target(X_source_adapt, X_target, sample_domain) return X_adapt From becddd00129618a62ee5e892a2039779951a7625 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 19 Dec 2023 12:36:32 +0100 Subject: [PATCH 6/6] Make sure we can deal with different masks for regression vs. classification --- skada/base.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/skada/base.py b/skada/base.py index 91b926ad..11229536 100644 --- a/skada/base.py +++ b/skada/base.py @@ -208,20 +208,34 @@ def score(self, X, y, **params): return self._route_to_estimator('score', X, y=y, **params) def _mark_as_final(self) -> 'BaseSelector': + """Internal API for keeping track of which estimator is final + in the Pipeline. + """ self._is_final = True return self def _remove_masked(self, X, y, routed_params): + """Internal API for removing masked samples before passing them + to the final estimator. Only applicable for the final estimator + within the Pipeline. + """ if not self._is_final: return X, y, routed_params # in case the estimator is marked as final in the pipeline, # the selector is responsible for removing masked labels # from the targets - # xxx(okachaiev): make sure to recognize all labels - idx = (y != -1).copy() - X = X[idx] - y = y[idx] - routed_params = {k: v[idx] for k, v in routed_params.items()} + if y.dtype in (np.float32, np.float64): + unmasked_idx = ~np.isfinite(y) + else: + unmasked_idx = (y != -1) + X = X[unmasked_idx] + y = y[unmasked_idx] + routed_params = { + # this is somewhat crude way to test is `v` is indexable + k: v[unmasked_idx] if hasattr(v, "__len__") else v + for k, v + in routed_params.items() + } return X, y, routed_params