-
Notifications
You must be signed in to change notification settings - Fork 395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add PassthroughScoring #595
Conversation
A scorer that just passes through a score that is already calculated and logged during the batches (as, e.g., train_loss). Replaces use of BatchScoring for most use cases. This makes it a lot easier to add a custom loss during a batch. Just log it in the history at each batch and add PassthroughScoring with the name of the key of that loss. No need to add a fake scoring function that just returns the value.
Also fixed minor bugs: * missing __all__ declaration in __init__ * forgot to return self in initialize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
docs/user/callbacks.rst
Outdated
On top of the two described scoring callbacks, skorch also provides | ||
:class:`.PassthroughScoring`. This callback does not actually calculate | ||
any new scores. Instead it uses an existing score that is calculated | ||
for each batch and determines the average of this score, which is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for each batch and determines the average of this score, which is | |
for each batch (the train loss, for example) and determines the average of this score, which is |
Maybe this makes it easier to relate how and when this callback is useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
skorch/tests/test_net.py
Outdated
@pytest.mark.xfail | ||
def test_net_initialized_with_initalized_dataset( | ||
self, net_cls, module_cls, data, dataset_cls): | ||
# TODO: What to do of this test now?? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need an issue to track this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh I don't remember why I put this here but the test xfails with the expected error, so I removed the comment and skip
Addresses reviewer comment
…rch-dev/skorch into feature/pass-through-scoring
As mentioned here, I factored out the addition of
PassthroughScoring
from the PR #593Here is a copy of the reasoning:
Currently, when adding new losses on a batch level, it's very fiddly to log them on an epoch level. You have to define a custom function that picks the right value from history and passes it through without modification, then add a BatchScoring callback that uses this mock scoring function, defines the name of the score again, and pass noop.
Here is the code before and after:
I replaced
BatchScoring
withPassthroughScoring
in the existing code for train and valid loss. Even thoughskorch.utils.train_loss_score
andskorch.utils.valid_loss_score
would not be needed anymore with this change, I still left them in place. This is for 2 reasons:The latter could be fixed by introducing a special case in
__getstate__
and replacing the oldBatchScoring
with the newPassthroughScoring
which doesn't require these functions, but this seems to be far too much work for very little gain. Instead, I propose to leave those functions for now, maybe deprecate them at some future time.I'm not quite certain what to do with
test_net_initialized_with_initalized_dataset
. I'll think of something, but if anyone has an idea, I'll be glad to hear it.