Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix label masking for training dataset #40

Merged
merged 7 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions skada/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,16 @@ def adapt(self, X, y=None, sample_domain=None):
weights : None
No weights are returned here.
"""
X_adapt = np.dot(X, self.cov_source_inv_sqrt_)
X_adapt = np.dot(X_adapt, self.cov_target_sqrt_)
# xxx(okachaiev): i feel this is straight up incorrect,
# in the previous version of the code only source data
# was transformed, and the target was never updated. i
# guess it should just 'passthrough' for a target space
X_source, X_target = check_X_domain(
X,
sample_domain,
allow_multi_source=True,
allow_multi_target=True,
return_joint=False,
)
X_source_adapt = np.dot(X_source, self.cov_source_inv_sqrt_)
X_source_adapt = np.dot(X_source_adapt, self.cov_target_sqrt_)
X_adapt = _merge_source_target(X_source_adapt, X_target, sample_domain)
return X_adapt


Expand Down
1 change: 1 addition & 0 deletions skada/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def make_da_pipeline(
# note that we generate names before wrapping estimators into the selector
# xxx(okachaiev): unwrap from the selector when passed explicitly
steps = _wrap_with_selectors(_name_estimators(steps), default_selector)
steps[-1][1]._mark_as_final()
return Pipeline(steps, memory=memory, verbose=verbose)


Expand Down
34 changes: 34 additions & 0 deletions skada/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, base_estimator: BaseEstimator, **kwargs):
super().__init__()
self.base_estimator = base_estimator
self.base_estimator.set_params(**kwargs)
self._is_final = False

# xxx(okachaiev): should this be a metadata routing object instead of request?
def get_metadata_routing(self):
Expand Down Expand Up @@ -206,6 +207,37 @@ def decision_function(self, X, **params):
def score(self, X, y, **params):
return self._route_to_estimator('score', X, y=y, **params)

def _mark_as_final(self) -> 'BaseSelector':
"""Internal API for keeping track of which estimator is final
in the Pipeline.
"""
self._is_final = True
return self

def _remove_masked(self, X, y, routed_params):
"""Internal API for removing masked samples before passing them
to the final estimator. Only applicable for the final estimator
within the Pipeline.
"""
if not self._is_final:
return X, y, routed_params
# in case the estimator is marked as final in the pipeline,
# the selector is responsible for removing masked labels
# from the targets
if y.dtype in (np.float32, np.float64):
unmasked_idx = ~np.isfinite(y)
else:
unmasked_idx = (y != -1)
X = X[unmasked_idx]
y = y[unmasked_idx]
routed_params = {
# this is somewhat crude way to test is `v` is indexable
k: v[unmasked_idx] if hasattr(v, "__len__") else v
for k, v
in routed_params.items()
}
return X, y, routed_params


class Shared(BaseSelector):

Expand All @@ -228,6 +260,7 @@ def fit(self, X, y, **params):
if k != 'X' and k in routed_params:
routed_params[k] = v
X = X['X']
X, y, routed_params = self._remove_masked(X, y, routed_params)
estimator = clone(self.base_estimator)
estimator.fit(X, y, **routed_params)
self.base_estimator_ = estimator
Expand Down Expand Up @@ -283,6 +316,7 @@ def fit(self, X, y, **params):
if k != 'X' and k in routed_params:
routed_params[k] = v
X = X['X']
X, y, routed_params = self._remove_masked(X, y, routed_params)
estimators = {}
# xxx(okachaiev): maybe return_index?
for domain_label in np.unique(sample_domain):
Expand Down
17 changes: 10 additions & 7 deletions skada/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,27 +191,30 @@ def pack(
sample_domains.append(np.ones_like(y)*domain_id)
domain_labels[domain_name] = domain_id
# xxx(okachaiev): code duplication, re-write when API is fixed
dtype = None
for domain_name in as_targets:
domain_id = self.domain_names_[domain_name]
target = self.get_domain(domain_name)
if len(target) == 1:
X, = target
# xxx(okachaiev): for what it's worth, we should likely to
# move the decision about dtype to the very end of the list
y = -np.ones(X.shape[0], dtype=np.int32)
elif len(target) == 2:
X, y = target
else:
raise ValueError("Invalid definition for domain data")
if train:
if mask is not None:
y = np.array([mask] * X.shape[0])
elif y.dtype == np.int32:
y = -np.ones(X.shape[0], dtype=np.int32)
y = np.array([mask] * X.shape[0], dtype=dtype)
elif y.dtype in (np.int32, np.int64):
y = -np.ones(X.shape[0], dtype=y.dtype)
# make sure that the mask is reused on the next iteration
mask = -1
elif y.dtype == np.float32:
y = np.array([np.nan] * X.shape[0])
mask, dtype = -1, y.dtype
elif y.dtype in (np.float32, np.float64):
y = np.array([np.nan] * X.shape[0], dtype=y.dtype)
# make sure that the mask is reused on the next iteration
mask = np.nan
mask, dtype = np.nan, y.dtype
# xxx(okachaiev): this is horribly inefficient, rewrite when API is fixed
Xs.append(X)
ys.append(y)
Expand Down
38 changes: 38 additions & 0 deletions skada/datasets/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Author: Oleksii Kachaiev <[email protected]>
#
# License: BSD 3-Clause

import numpy as np
from numpy.testing import assert_array_equal

from skada.datasets import DomainAwareDataset


def test_dataset_train_label_masking():
dataset = DomainAwareDataset()
dataset.add_domain(np.array([1., 2.]), np.array([1, 2]), 's1')
dataset.add_domain(np.array([10., 20., 30.]), np.array([10, 20, 30]), 't1')
X, y, sample_domain = dataset.pack_for_train(as_sources=['s1'], as_targets=['t1'])

# test shape of the output
assert X.shape == (5,)
assert y.shape == (5,)
assert sample_domain.shape == (5,)
assert (sample_domain > 0).sum() == 2
assert (sample_domain < 0).sum() == 3

# test label masking
assert_array_equal(y[sample_domain < 0], np.array([-1, -1, -1]))
assert np.all(y[sample_domain > 0] > 0)

# custom mask
X, y, sample_domain = dataset.pack_for_train(
as_sources=['s1'], as_targets=['t1'], mask=-10)
assert_array_equal(y[sample_domain < 0], np.array([-10, -10, -10]))

# test packing does not perform masking
X, y, sample_domain = dataset.pack_for_test(as_targets=['t1'])
assert X.shape == (3,)
assert y.shape == (3,)
assert sample_domain.shape == (3,)
assert np.all(y > 0)