diff --git a/skada/base.py b/skada/base.py index b4a7dd99..869200bf 100644 --- a/skada/base.py +++ b/skada/base.py @@ -239,7 +239,7 @@ def _remove_masked(self, X, y, routed_params): if y_type == 'classification': unmasked_idx = (y != _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL) elif y_type == 'continuous': - unmasked_idx = ~np.isfinite(y) + unmasked_idx = np.isfinite(y) X = X[unmasked_idx] y = y[unmasked_idx] diff --git a/skada/tests/test_selector.py b/skada/tests/test_selector.py index f1a4caf4..ddf36574 100644 --- a/skada/tests/test_selector.py +++ b/skada/tests/test_selector.py @@ -108,8 +108,11 @@ def test_base_selector_remove_masked_continuous(): source_idx = rng.choice([False, True], size=n_samples) # mask target labels y[~source_idx] = _DEFAULT_MASKED_TARGET_REGRESSION_LABEL + assert np.any(~np.isfinite(y)), 'at least one label is masked' + X_output, y_output, _ = selector._remove_masked(X, y, {}) + assert np.all(np.isfinite(y_output)), 'masks are removed' - n_target_samples = X.shape[0] - np.sum(source_idx) - assert X_output.shape[0] == n_target_samples, "X output shape mismatch" + n_source_samples = np.sum(source_idx) + assert X_output.shape[0] == n_source_samples, 'X output shape mismatch' assert X_output.shape[0] == y_output.shape[0]