-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Make sure all API methods accept sample_domain as None #53
Conversation
…ing of y is well done in the check_X_y_domain function
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #53 +/- ##
==========================================
+ Coverage 84.43% 86.16% +1.72%
==========================================
Files 35 37 +2
Lines 2191 2334 +143
==========================================
+ Hits 1850 2011 +161
+ Misses 341 323 -18 ☔ View full report in Codecov by Sentry. |
skada/tests/test_utils.py
Outdated
|
||
def test_check_y_masking_regression(): | ||
y_properly_masked = np.array([np.nan, 1, 2.5, -1, np.nan, 0, -1.5]) | ||
y_wrongfuly_masked = np.array([-1, -2, 2.5, -1, 2, 0, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the type of this array is 'float', right? In this case we should assume it's a regression task with no labels being masked 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it, isn't it possible to have masked arrays for regression tasks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible to have masks for regression using nan
s. The variable that's called y_wrongfuly_masked
doesn't have a 'wrongly' masked array, it does have 'non masked' array. That's why I was confused about the name.
skada/_utils.py
Outdated
@@ -59,11 +59,18 @@ def check_X_y_domain( | |||
X = check_array(X, input_name='X', allow_nd=allow_nd) | |||
y = check_array(y, force_all_finite=True, ensure_2d=False, input_name='y') | |||
check_consistent_length(X, y) | |||
if sample_domain is None and allow_auto_sample_domain: | |||
if sample_domain is None and not allow_auto_sample_domain: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to make sure that all of those clauses are properly covered with unit tests. I guess just updating check_X_y_domain_exceptions
with a bunch of correct and incorrect inputs would do.
skada/_utils.py
Outdated
sample_domain = np.ones_like(y) | ||
# labels masked with -1 are recognized as targets, | ||
# the rest is treated as a source | ||
sample_domain[y == -1] = -2 | ||
if y_type == 'classification': | ||
sample_domain[y == -1] = -2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I always try to avoid using 'magic' numbers in the code, as they can be extremely difficult to understand and modify later on. Let's create a constant named _DEFAULT_TARGET_DOMAIN_LABEL
= -2 (or something like this) in the current module namespace.
skada/_utils.py
Outdated
if y_type == 'classification': | ||
sample_domain[y == -1] = -2 | ||
else: | ||
sample_domain[np.isnan(y)] = -2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same constant here.
skada/_utils.py
Outdated
@@ -114,13 +121,17 @@ def check_X_domain( | |||
return_indices: bool = False, | |||
# xxx(okachaiev): most likely this needs to be removed as it doesn't fit new API | |||
return_joint: bool = True, | |||
allow_auto_sample_domain: bool = False, | |||
allow_auto_sample_domain: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put a docstring for this function.
@YanisLalou I can't seem to push the update from my local branch here (not sure why, github previously allowed me to do so). Would you please check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments maybe we should discuss that in the chat
skada/_utils.py
Outdated
@@ -40,30 +51,87 @@ def _estimate_covariance(X, shrinkage): | |||
def check_X_y_domain( | |||
X, | |||
y, | |||
sample_domain, | |||
sample_domain=None, | |||
allow_source: bool = True, | |||
allow_multi_source: bool = True, | |||
allow_target: bool = True, | |||
allow_multi_target: bool = True, | |||
return_indices: bool = False, | |||
# xxx(okachaiev): most likely this needs to be removed as it doesn't fit new API |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we shoudl remove something aybe do it while we are working on ythis function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I bet there's a separate task for this)
⬆️ Issue #58 |
…it_source_target_X()
…rays + output the data the same way as sklearn
Issue #17
allow_auto_sample_domain
incheck_X_y_domain
check_X_y_domain