Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor train loop for easier customization #699

Merged
merged 15 commits into from
Mar 26, 2021

Conversation

BenjaminBossan
Copy link
Collaborator

Small refactoring of some training/evaluation methods

The main change is to "push down" the unpacking of the batch
data (using Xi, yi = unpack_data(batch)) from run_single_epoch to
functions lower down the chain. The reason for this is that if we put
this unpacking rather "high up" the chain, we force the output of
dataset to be a tuple of 2 elements. If a user wants to change this,
they either have to modify their dataset to always return 2
elements (which can be annoying) or they would need to override
run_single_epoch, but we don't want to encourage this, as this
method performs a lot of book keeping.

With the new changes, small methods further down the chain are now
responsible for dealing with the data, e.g. train_step_single and
validation_step. This affords more flexibility overall, and a user
can easily override these methods without having to deal with book
keeping.

Obviously, this change results in some API changes, so some existing
code might break as a consequence (most notably code that relies on
on_batch_begin or on_batch_end now simply get a batch which
they need to unpack, instead of separate X and y).

Another "victim" of this change is the checking for whether y is a
placeholder and then replacing it with None. Tbh. I found this to be
an ugly solution anyway. Now it can happen that when a user uses the
skorch Dataset, they will receive a 0s tensor for y when they set y to
None. I think this is an acceptable loss overall.

To illustrate the gains of the change, I added a test,
test_customize_net_with_custom_dataset_that_returns_3_values, which
shows what is now possible when it comes to customizing skorch. With
the old state, achieving the same would have been more cumbersome.

Finally, I added a section to the docs called "Customization". The
intent is to help the user decide how best to customize NeuralNet et
al. It explains what methods best to override for what purpose and
what methods to best leave untouched. In the future, we could expand
this section to cover more aspects.

Additional minor changes

  • Clean ups to satisfy pylint
  • Add more to .gitignore
  • Remove duplication from .coveragerc

BenjaminBossan added 3 commits September 27, 2020 17:21
The main change is to "push down" the unpacking of the batch
data (using "Xi, yi = unpack_data(batch)") from "run_single_epoch" to
functions lower down the chain. The reason for this is that if we put
this unpacking rather "high up" the chain, we force the output of
dataset to be a tuple of 2 elements. If a user wants to change this,
they either have to modify their dataset to alwas return 2
elments (which can be annoying) or they would need to override
"run_single_epoch", but we don't want to encourage this, as this
method performs a lot of book keeping.

With the new changes, small methods further down the chain are now
responsible for dealing with the data, e.g. "train_step_single" and
"validation_step". This affords more flexibility overall, and a user
can easily override these methods without having to deal with book
keeping.

Obviously, this change results in some API changes, so some existing
code might break as a consequence (most notably code that relies on
"on_batch_begin" or "on_batch_end").

Another "victim" of this change is the checking for whether y is a
placeholder and then replacing it with None. Tbh. I found this to be
an ugly solution anyway. Now it can happen that when a user uses the
skorch Dataset, they will receive a 0s tensor for y when they set y to
None. I think this is an acceptable loss overall.

To illustrate the gains of the change, I added a test,
"test_customize_net_with_custom_dataset_that_returns_3_values", which
shows what is now possible when it comes to customizing skorch. With
the old state, achieving the same would have been more cumbersome.

Finally, I added a section to the docs called "Customization". The
intent is to help the user decide how best to customize "NeuralNet" et
al. It explains what methods best to override for what purpose and
what methods to best leave untouched. In the future, we could expand
this section to cover more aspects.
@BenjaminBossan
Copy link
Collaborator Author

Forgot to mention that I moved the notification for on_grad_computed from train_step_single to train_step. The reasoning for this was that a user is more likely to modify train_step_single, so if we move this book keeping out, users don't need to remember it.

And even if a user changes train_step, they most likely won't call train_step_single anymore, so if the notification was inside train_step_single, they still need to remember to include it into train_step (actually, remembering this might be easier now since it's right there in the code of train_step).

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick look at this PR.

Overall, I like pushing the unpack_data down a level. My biggest concern is the backward compatibility. I do not think that batch level callbacks are too common, so it is most likely a non issue.

"""Called at the end of each epoch."""

def on_batch_begin(self, net,
X=None, y=None, training=None, **kwargs):
def on_batch_begin(self, net, batch=None, training=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Copy link
Collaborator

@githubnemo githubnemo Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Agreed. Not every user might follow the changelog closely, we could wrap the notify call with an exception handler and in case of on_batch_begin we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a stab at this but it turns out not to be so simple. The reason is that we always have **kwargs in the signature, so even if someone writes a wrong method like def on_batch_begin(net, X=None, y=None, **kwargs), it would not immediately lead to an error. The error occurs only when they do something with X or y, because these will be None instead of the expected values. But this something can be anything.

In my attempt, I just caught any kind of Exception and raised a TypeError from it that pointed at the possible error source. However, this broke other tests, which checked for specific errors (for instance, using scoring with a non existing function name). We don't want that to happen.

The problem is thus that we don't know what error will occur but we also don't want to catch all kinds of errors.

We could theoretically inspect the arguments inside the notify-wrapper. But there I would be afraid that the performance penalty could be too big. Furthermore, we want to allow the users to be very liberal in how they call their on_* methods (e.g. only passing the arguments that they need, which might not even include batch), so it's not possible to do a strict check.

This leaves me with no idea how to implement the suggestion. Do you have any ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK this is fair, maybe we can issue a warning on the first call of notify('on_batch_begin') for callbacks that aren't ours that indicates that things might have changed? We can deprecate this warning in the next or a following release and you could always filter the warning if you are getting annoyed by it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically on the very first call to notify, you'd like to go through all callbacks and check whether they are an instance of any built-in callback, and if at least one of them isn't, issue a warning? I'm not sure if this is not overkill, given that the warning will mostly be a false alarm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no, more like if not cb.__module__.startswith('skorch'): issueWarning().

skorch/net.py Outdated
Comment on lines 699 to 701
step : dict
A dictionary containing a value ``loss`` for the loss scalar
and ``y_pred`` for the prediction.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I think it returns the loss as a float.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't happy with the docstring and adjusted it. Mentioned that loss should be a float.

@BenjaminBossan
Copy link
Collaborator Author

My biggest concern is the backward compatibility. I do not think that batch level callbacks are too common, so it is most likely a non issue.

Yes, but I think we're still at a stage where we can afford to make such a breaking change. Honestly, I was surprised how little skorch code needed to be changed, which makes me believe, as you say, that the two methods are not used all too frequently.

It would probably help as a "transition guide" to mention:

# change from this:
def on_batch_begin(self, net, X, y, ...):
    ...

# to this:
def on_batch_begin(self, net, batch, ...):
    X, y = batch
    ...

# same for on_batch_end

Not sure where we could put this. But overall, there is really not much to it, so the gain outweighs the inconvenience IMO.

Also, add documentation of return value of train_step_single, which
was missing (same docstring).
@BenjaminBossan BenjaminBossan marked this pull request as ready for review October 3, 2020 15:15
@BenjaminBossan
Copy link
Collaborator Author

@ottonemo are you fine with the changes?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am in favor.

Not sure where we could put this. But overall, there is really not much to it, so the gain outweighs the inconvenience IMO.

In the FAQ, with a link in the change log?

@BenjaminBossan
Copy link
Collaborator Author

In the FAQ, with a link in the change log?

Great idea, I added a migration guide section to the FAQ.

named_parameters=TeeGenerator(self.module_.named_parameters()),
X=Xi,
y=yi
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a line that can be classified as bookkeeping. The intention of having it in a place where the user might implement their own training code was that, in case the user has two or more models (or set of parameters) the process of communicating gradient changes to callbacks is kept transparent.

With this change we make that a bit less transparent (at the benefit of recognizability - this looks way more standard now). Does this change therefore introduce the assumption that a user who adds additional modules (or parameter sets that are updated independently) will also override train_step_single?

Copy link
Collaborator Author

@BenjaminBossan BenjaminBossan Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case the user has two or more models (or set of parameters) the process of communicating gradient changes to callbacks is kept transparent.

What would be a use case where the user wants to call on_grad_computed explicitly here? E.g. when they only want to clip gradients on a subset of parameters, so that they would only pass self.module_.SUBMODULE.named_parameters()? If so, yes, that would require them to override train_step_single, which is not nice. I would argue, however, that this filtering should happen on the other side (i.e. in the callback), not here.

this looks way more standard now

That was the idea: Make this look as much like "normal" PyTorch training code as possible. The current code, for a skorch novice, would look quite esoteric: what is notify? what is TeeGenerator??

"""Called at the end of each epoch."""

def on_batch_begin(self, net,
X=None, y=None, training=None, **kwargs):
def on_batch_begin(self, net, batch=None, training=None, **kwargs):
Copy link
Collaborator

@githubnemo githubnemo Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Agreed. Not every user might follow the changelog closely, we could wrap the notify call with an exception handler and in case of on_batch_begin we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.

@BenjaminBossan
Copy link
Collaborator Author

@ottonemo I'd like to refine #707 but I think there might be some conflicts with this PR, so we should probably decide on whether to go forward with this one or not. I think the biggest hurdle was the question whether we can provide a useful exception when a user needs to change the signature. Did you see my comment on this?

BenjaminBossan and others added 3 commits February 7, 2021 16:31
Check if the user uses any custom callback that overrides on_batch_begin
or on_batch_end. If they do, warn about the signature change and explain
how to recover. Also, indicate how to suppress the warning.
@BenjaminBossan
Copy link
Collaborator Author

@ottonemo I implemented your suggestion to check for user-defined callbacks and issue a warning containing the required fix and a note on how to turn off the warning. On top of your suggestion, I also check if any of the user-defined callbacks even overrides on_batch_{begin,end} and don't issue the warning if they don't.

Forgot to check this in when the warning was added.
@ottonemo ottonemo merged commit afa697e into master Mar 26, 2021
@BenjaminBossan BenjaminBossan deleted the refactor-train-loop-for-easier-customization branch March 26, 2021 18:06
BenjaminBossan added a commit that referenced this pull request Oct 31, 2021
We are happy to announce the new skorch 0.11 release:

Two basic but very useful features have been added to our collection of callbacks. First, by setting `load_best=True` on the  [`Checkpoint` callback](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.Checkpoint), the snapshot of the network with the best score will be loaded automatically when training ends. Second, we added a callback [`InputShapeSetter`](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.InputShapeSetter) that automatically adjusts your input layer to have the size of your input data (useful e.g. when that size is not known beforehand).

When it comes to integrations, the [`MlflowLogger`](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.MlflowLogger) now allows to automatically log to [MLflow](https://mlflow.org/). Thanks to a contributor, some regressions in `net.history` have been fixed and it even runs faster now.

On top of that, skorch now offers a new module, `skorch.probabilistic`. It contains new classes to work with **Gaussian Processes** using the familiar skorch API. This is made possible by the fantastic [GPyTorch](https://github.com/cornellius-gp/gpytorch) library, which skorch uses for this. So if you want to get started with Gaussian Processes in skorch, check out the [documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) and this [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). Since we're still learning, it's possible that we will change the API in the future, so please be aware of that.

Morever, we introduced some changes to make skorch more customizable. First of all, we changed the signature of some methods so that they no longer assume the dataset to always return exactly 2 values. This way, it's easier to work with custom datasets that return e.g. 3 values. Normal users should not notice any difference, but if you often create custom nets, take a look at the [migration guide](https://skorch.readthedocs.io/en/latest/user/FAQ.html#migration-from-0-10-to-0-11).

And finally, we made a change to how custom modules, criteria, and optimizers are handled. They are now "first class citizens" in skorch land, which means: If you add a second module to your custom net, it is treated exactly the same as the normal module. E.g., skorch takes care of moving it to CUDA if needed and of switching it to train or eval mode. This way, customizing your networks architectures with skorch is easier than ever. Check the [docs](https://skorch.readthedocs.io/en/latest/user/customization.html#initialization-and-custom-modules) for more details.

Since these are some big changes, it's possible that you encounter issues. If that's the case, please check our [issue](https://github.com/skorch-dev/skorch/issues) page or create a new one.

As always, this release was made possible by outside contributors. Many thanks to:

- Autumnii
- Cebtenzzre
- Charles Cabergs
- Immanuel Bayer
- Jake Gardner
- Matthias Pfenninger
- Prabhat Kumar Sahu 

Find below the list of all changes:

Added

- Added `load_best` attribute to `Checkpoint` callback to automatically load state of the best result at the end of training
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
- Added `MlflowLogger` callback for logging to Mlflow (#769)
- Added `InputShapeSetter` callback for automatically setting the input dimension of the PyTorch module
- Added a new module to support Gaussian Processes through [GPyTorch](https://gpytorch.ai/). To learn more about it, read the [GP documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) or take a look at the [GP notebook](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). This feature is experimental, i.e. the API could be changed in the future in a backwards incompatible way (#782)

Changed

- Changed the signature of `validation_step`, `train_step_single`, `train_step`, `evaluation_step`, `on_batch_begin`, and `on_batch_end` such that instead of receiving `X` and `y`, they receive the whole batch; this makes it easier to deal with datasets that don't strictly return an `(X, y)` tuple, which is true for quite a few PyTorch datasets; please refer to the [migration guide](https://skorch.readthedocs.io/en/latest/user/FAQ.html#migration-from-0-10-to-0-11) if you encounter problems (#699)
- Checking of arguments to `NeuralNet` is now during `.initialize()`, not during `__init__`, to avoid raising false positives for yet unknown module or optimizer attributes
- Modules, criteria, and optimizers that are added to a net by the user are now first class: skorch takes care of setting train/eval mode, moving to the indicated device, and updating all learnable parameters during training (check the [docs](https://skorch.readthedocs.io/en/latest/user/customization.html#initialization-and-custom-modules) for more details, #751)
- `CVSplit` is renamed to `ValidSplit` to avoid confusion (#752)

Fixed

- Fixed a few bugs in the `net.history` implementation (#776)
- Fixed a bug in `TrainEndCheckpoint` that prevented it from being unpickled (#773)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants