Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PassthroughScoring #595

Merged
merged 8 commits into from
Apr 10, 2020
Merged

add PassthroughScoring #595

merged 8 commits into from
Apr 10, 2020

Conversation

BenjaminBossan
Copy link
Collaborator

As mentioned here, I factored out the addition of PassthroughScoring from the PR #593

Here is a copy of the reasoning:


Currently, when adding new losses on a batch level, it's very fiddly to log them on an epoch level. You have to define a custom function that picks the right value from history and passes it through without modification, then add a BatchScoring callback that uses this mock scoring function, defines the name of the score again, and pass noop.

Here is the code before and after:

# somewhere in the net code
    ...
    self.history.record_batch('my_loss', my_loss)
    ...

# before
from skorch.callbacks import BatchScoring
from skorch.utils import noop

def my_loss_score(net, X=None, y=None):
    return net.history[-1, 'batches', -1, 'my_loss']

callbacks = [
    BatchScoring(
        my_loss_score,
        name='my_loss',
        target_extractor=noop,
    ),
]

# after
from skorch.callbacks import PassthroughScoring

callbacks = [
    PassthroughScoring(name='my_loss'),
]

I replaced BatchScoring with PassthroughScoring in the existing code for train and valid loss. Even though skorch.utils.train_loss_score and skorch.utils.valid_loss_score would not be needed anymore with this change, I still left them in place. This is for 2 reasons:

  • someone might use them directly (after all, they're public functions)
  • if they are removed, old pickled nets cannot be unpickled anymore

The latter could be fixed by introducing a special case in __getstate__ and replacing the old BatchScoring with the new PassthroughScoring which doesn't require these functions, but this seems to be far too much work for very little gain. Instead, I propose to leave those functions for now, maybe deprecate them at some future time.

I'm not quite certain what to do with test_net_initialized_with_initalized_dataset. I'll think of something, but if anyone has an idea, I'll be glad to hear it.

BenjaminBossan added 3 commits February 21, 2020 22:33
A scorer that just passes through a score that is already calculated
and logged during the batches (as, e.g., train_loss). Replaces use of
BatchScoring for most use cases.

This makes it a lot easier to add a custom loss during a batch. Just
log it in the history at each batch and add PassthroughScoring with
the name of the key of that loss. No need to add a fake scoring
function that just returns the value.
Also fixed minor bugs:

* missing __all__ declaration in __init__
* forgot to return self in initialize
Copy link
Member

@ottonemo ottonemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

On top of the two described scoring callbacks, skorch also provides
:class:`.PassthroughScoring`. This callback does not actually calculate
any new scores. Instead it uses an existing score that is calculated
for each batch and determines the average of this score, which is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for each batch and determines the average of this score, which is
for each batch (the train loss, for example) and determines the average of this score, which is

Maybe this makes it easier to relate how and when this callback is useful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@pytest.mark.xfail
def test_net_initialized_with_initalized_dataset(
self, net_cls, module_cls, data, dataset_cls):
# TODO: What to do of this test now??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an issue to track this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I don't remember why I put this here but the test xfails with the expected error, so I removed the comment and skip

@BenjaminBossan BenjaminBossan requested a review from ottonemo April 9, 2020 19:52
@ottonemo ottonemo merged commit 8665375 into master Apr 10, 2020
@BenjaminBossan BenjaminBossan deleted the feature/pass-through-scoring branch July 30, 2020 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants