-
Notifications
You must be signed in to change notification settings - Fork 395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor train loop for easier customization #699
Refactor train loop for easier customization #699
Conversation
The main change is to "push down" the unpacking of the batch data (using "Xi, yi = unpack_data(batch)") from "run_single_epoch" to functions lower down the chain. The reason for this is that if we put this unpacking rather "high up" the chain, we force the output of dataset to be a tuple of 2 elements. If a user wants to change this, they either have to modify their dataset to alwas return 2 elments (which can be annoying) or they would need to override "run_single_epoch", but we don't want to encourage this, as this method performs a lot of book keeping. With the new changes, small methods further down the chain are now responsible for dealing with the data, e.g. "train_step_single" and "validation_step". This affords more flexibility overall, and a user can easily override these methods without having to deal with book keeping. Obviously, this change results in some API changes, so some existing code might break as a consequence (most notably code that relies on "on_batch_begin" or "on_batch_end"). Another "victim" of this change is the checking for whether y is a placeholder and then replacing it with None. Tbh. I found this to be an ugly solution anyway. Now it can happen that when a user uses the skorch Dataset, they will receive a 0s tensor for y when they set y to None. I think this is an acceptable loss overall. To illustrate the gains of the change, I added a test, "test_customize_net_with_custom_dataset_that_returns_3_values", which shows what is now possible when it comes to customizing skorch. With the old state, achieving the same would have been more cumbersome. Finally, I added a section to the docs called "Customization". The intent is to help the user decide how best to customize "NeuralNet" et al. It explains what methods best to override for what purpose and what methods to best leave untouched. In the future, we could expand this section to cover more aspects.
Forgot to mention that I moved the notification for And even if a user changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick look at this PR.
Overall, I like pushing the unpack_data
down a level. My biggest concern is the backward compatibility. I do not think that batch level callbacks are too common, so it is most likely a non issue.
"""Called at the end of each epoch.""" | ||
|
||
def on_batch_begin(self, net, | ||
X=None, y=None, training=None, **kwargs): | ||
def on_batch_begin(self, net, batch=None, training=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that changing the signature of callbacks would break's users code the most.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that changing the signature of callbacks would break's users code the most.
Agreed. Not every user might follow the changelog closely, we could wrap the notify
call with an exception handler and in case of on_batch_begin
we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a stab at this but it turns out not to be so simple. The reason is that we always have **kwargs
in the signature, so even if someone writes a wrong method like def on_batch_begin(net, X=None, y=None, **kwargs)
, it would not immediately lead to an error. The error occurs only when they do something with X or y, because these will be None instead of the expected values. But this something can be anything.
In my attempt, I just caught any kind of Exception
and raised a TypeError
from it that pointed at the possible error source. However, this broke other tests, which checked for specific errors (for instance, using scoring with a non existing function name). We don't want that to happen.
The problem is thus that we don't know what error will occur but we also don't want to catch all kinds of errors.
We could theoretically inspect the arguments inside the notify
-wrapper. But there I would be afraid that the performance penalty could be too big. Furthermore, we want to allow the users to be very liberal in how they call their on_*
methods (e.g. only passing the arguments that they need, which might not even include batch
), so it's not possible to do a strict check.
This leaves me with no idea how to implement the suggestion. Do you have any ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK this is fair, maybe we can issue a warning on the first call of notify('on_batch_begin')
for callbacks that aren't ours that indicates that things might have changed? We can deprecate this warning in the next or a following release and you could always filter the warning if you are getting annoyed by it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically on the very first call to notify
, you'd like to go through all callbacks and check whether they are an instance of any built-in callback, and if at least one of them isn't, issue a warning? I'm not sure if this is not overkill, given that the warning will mostly be a false alarm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, no, more like if not cb.__module__.startswith('skorch'): issueWarning()
.
skorch/net.py
Outdated
step : dict | ||
A dictionary containing a value ``loss`` for the loss scalar | ||
and ``y_pred`` for the prediction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I think it returns the loss as a float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't happy with the docstring and adjusted it. Mentioned that loss should be a float.
Yes, but I think we're still at a stage where we can afford to make such a breaking change. Honestly, I was surprised how little skorch code needed to be changed, which makes me believe, as you say, that the two methods are not used all too frequently. It would probably help as a "transition guide" to mention: # change from this:
def on_batch_begin(self, net, X, y, ...):
...
# to this:
def on_batch_begin(self, net, batch, ...):
X, y = batch
...
# same for on_batch_end Not sure where we could put this. But overall, there is really not much to it, so the gain outweighs the inconvenience IMO. |
Also, add documentation of return value of train_step_single, which was missing (same docstring).
@ottonemo are you fine with the changes? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am in favor.
Not sure where we could put this. But overall, there is really not much to it, so the gain outweighs the inconvenience IMO.
In the FAQ, with a link in the change log?
Co-authored-by: Thomas J. Fan <[email protected]>
As suggested by reviewer.
…//github.com/skorch-dev/skorch into refactor-train-loop-for-easier-customization
Great idea, I added a migration guide section to the FAQ. |
named_parameters=TeeGenerator(self.module_.named_parameters()), | ||
X=Xi, | ||
y=yi | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably a line that can be classified as bookkeeping. The intention of having it in a place where the user might implement their own training code was that, in case the user has two or more models (or set of parameters) the process of communicating gradient changes to callbacks is kept transparent.
With this change we make that a bit less transparent (at the benefit of recognizability - this looks way more standard now). Does this change therefore introduce the assumption that a user who adds additional modules (or parameter sets that are updated independently) will also override train_step_single
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in case the user has two or more models (or set of parameters) the process of communicating gradient changes to callbacks is kept transparent.
What would be a use case where the user wants to call on_grad_computed
explicitly here? E.g. when they only want to clip gradients on a subset of parameters, so that they would only pass self.module_.SUBMODULE.named_parameters()
? If so, yes, that would require them to override train_step_single
, which is not nice. I would argue, however, that this filtering should happen on the other side (i.e. in the callback), not here.
this looks way more standard now
That was the idea: Make this look as much like "normal" PyTorch training code as possible. The current code, for a skorch novice, would look quite esoteric: what is notify
? what is TeeGenerator
??
"""Called at the end of each epoch.""" | ||
|
||
def on_batch_begin(self, net, | ||
X=None, y=None, training=None, **kwargs): | ||
def on_batch_begin(self, net, batch=None, training=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that changing the signature of callbacks would break's users code the most.
Agreed. Not every user might follow the changelog closely, we could wrap the notify
call with an exception handler and in case of on_batch_begin
we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.
Co-authored-by: githubnemo <[email protected]>
Co-authored-by: githubnemo <[email protected]>
@ottonemo I'd like to refine #707 but I think there might be some conflicts with this PR, so we should probably decide on whether to go forward with this one or not. I think the biggest hurdle was the question whether we can provide a useful exception when a user needs to change the signature. Did you see my comment on this? |
Check if the user uses any custom callback that overrides on_batch_begin or on_batch_end. If they do, warn about the signature change and explain how to recover. Also, indicate how to suppress the warning.
…//github.com/skorch-dev/skorch into refactor-train-loop-for-easier-customization
@ottonemo I implemented your suggestion to check for user-defined callbacks and issue a warning containing the required fix and a note on how to turn off the warning. On top of your suggestion, I also check if any of the user-defined callbacks even overrides |
Forgot to check this in when the warning was added.
We are happy to announce the new skorch 0.11 release: Two basic but very useful features have been added to our collection of callbacks. First, by setting `load_best=True` on the [`Checkpoint` callback](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.Checkpoint), the snapshot of the network with the best score will be loaded automatically when training ends. Second, we added a callback [`InputShapeSetter`](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.InputShapeSetter) that automatically adjusts your input layer to have the size of your input data (useful e.g. when that size is not known beforehand). When it comes to integrations, the [`MlflowLogger`](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.MlflowLogger) now allows to automatically log to [MLflow](https://mlflow.org/). Thanks to a contributor, some regressions in `net.history` have been fixed and it even runs faster now. On top of that, skorch now offers a new module, `skorch.probabilistic`. It contains new classes to work with **Gaussian Processes** using the familiar skorch API. This is made possible by the fantastic [GPyTorch](https://github.com/cornellius-gp/gpytorch) library, which skorch uses for this. So if you want to get started with Gaussian Processes in skorch, check out the [documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) and this [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). Since we're still learning, it's possible that we will change the API in the future, so please be aware of that. Morever, we introduced some changes to make skorch more customizable. First of all, we changed the signature of some methods so that they no longer assume the dataset to always return exactly 2 values. This way, it's easier to work with custom datasets that return e.g. 3 values. Normal users should not notice any difference, but if you often create custom nets, take a look at the [migration guide](https://skorch.readthedocs.io/en/latest/user/FAQ.html#migration-from-0-10-to-0-11). And finally, we made a change to how custom modules, criteria, and optimizers are handled. They are now "first class citizens" in skorch land, which means: If you add a second module to your custom net, it is treated exactly the same as the normal module. E.g., skorch takes care of moving it to CUDA if needed and of switching it to train or eval mode. This way, customizing your networks architectures with skorch is easier than ever. Check the [docs](https://skorch.readthedocs.io/en/latest/user/customization.html#initialization-and-custom-modules) for more details. Since these are some big changes, it's possible that you encounter issues. If that's the case, please check our [issue](https://github.com/skorch-dev/skorch/issues) page or create a new one. As always, this release was made possible by outside contributors. Many thanks to: - Autumnii - Cebtenzzre - Charles Cabergs - Immanuel Bayer - Jake Gardner - Matthias Pfenninger - Prabhat Kumar Sahu Find below the list of all changes: Added - Added `load_best` attribute to `Checkpoint` callback to automatically load state of the best result at the end of training - Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable - Added `MlflowLogger` callback for logging to Mlflow (#769) - Added `InputShapeSetter` callback for automatically setting the input dimension of the PyTorch module - Added a new module to support Gaussian Processes through [GPyTorch](https://gpytorch.ai/). To learn more about it, read the [GP documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) or take a look at the [GP notebook](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). This feature is experimental, i.e. the API could be changed in the future in a backwards incompatible way (#782) Changed - Changed the signature of `validation_step`, `train_step_single`, `train_step`, `evaluation_step`, `on_batch_begin`, and `on_batch_end` such that instead of receiving `X` and `y`, they receive the whole batch; this makes it easier to deal with datasets that don't strictly return an `(X, y)` tuple, which is true for quite a few PyTorch datasets; please refer to the [migration guide](https://skorch.readthedocs.io/en/latest/user/FAQ.html#migration-from-0-10-to-0-11) if you encounter problems (#699) - Checking of arguments to `NeuralNet` is now during `.initialize()`, not during `__init__`, to avoid raising false positives for yet unknown module or optimizer attributes - Modules, criteria, and optimizers that are added to a net by the user are now first class: skorch takes care of setting train/eval mode, moving to the indicated device, and updating all learnable parameters during training (check the [docs](https://skorch.readthedocs.io/en/latest/user/customization.html#initialization-and-custom-modules) for more details, #751) - `CVSplit` is renamed to `ValidSplit` to avoid confusion (#752) Fixed - Fixed a few bugs in the `net.history` implementation (#776) - Fixed a bug in `TrainEndCheckpoint` that prevented it from being unpickled (#773)
Small refactoring of some training/evaluation methods
The main change is to "push down" the unpacking of the batch
data (using
Xi, yi = unpack_data(batch)
) fromrun_single_epoch
tofunctions lower down the chain. The reason for this is that if we put
this unpacking rather "high up" the chain, we force the output of
dataset to be a tuple of 2 elements. If a user wants to change this,
they either have to modify their dataset to always return 2
elements (which can be annoying) or they would need to override
run_single_epoch
, but we don't want to encourage this, as thismethod performs a lot of book keeping.
With the new changes, small methods further down the chain are now
responsible for dealing with the data, e.g.
train_step_single
andvalidation_step
. This affords more flexibility overall, and a usercan easily override these methods without having to deal with book
keeping.
Obviously, this change results in some API changes, so some existing
code might break as a consequence (most notably code that relies on
on_batch_begin
oron_batch_end
now simply get abatch
whichthey need to unpack, instead of separate
X
andy
).Another "victim" of this change is the checking for whether y is a
placeholder and then replacing it with None. Tbh. I found this to be
an ugly solution anyway. Now it can happen that when a user uses the
skorch
Dataset
, they will receive a 0s tensor for y when they set y toNone. I think this is an acceptable loss overall.
To illustrate the gains of the change, I added a test,
test_customize_net_with_custom_dataset_that_returns_3_values
, whichshows what is now possible when it comes to customizing skorch. With
the old state, achieving the same would have been more cumbersome.
Finally, I added a section to the docs called "Customization". The
intent is to help the user decide how best to customize
NeuralNet
etal. It explains what methods best to override for what purpose and
what methods to best leave untouched. In the future, we could expand
this section to cover more aspects.
Additional minor changes