Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change: refactor skorch for more consistency when adding custom modules etc. #751

Merged

Conversation

BenjaminBossan
Copy link
Collaborator

@BenjaminBossan BenjaminBossan commented Apr 1, 2021

This PR is still WIP but as discussed with @ottonemo, I will create a draft PR early to discuss the changes and implementation.

Motivation

The initial reason why I wanted to work on this is that I'm currently implementing a gpytorch integration (see this branch). For this, a big part is adding a new custom module called "likelihood". Doing this correctly was actually not trivial and required a lot of more or less duplicated code. Putting such a burden on a user with less experience with skorch would not be possible.

The main reason for this difficulty is that module, criterion and optimizer are treated "special" so far. We assume that they are already there and build everything else around this. If a custom module is added, the user needs to be aware of all the places where this is relevant, which is too error prone.

Previous work

Some changes to facilitate adding custom modules were already implemented in #597. However, they don't go far enough.

Changes

Main changes

With this PR, we remove the special status of module, criterion and optimizer. Instead, all the work that needs to be done when adding any of them to the net is now implemented in a generic manner. This way, custom modules etc. can re-use the same functionality and can therefore expect to be treated the same as these "first class" components.

Here is a list of changed that were added to that effect:

  • Until now, custom module parameters were not trained by the optimizer, now they are
  • Until now, custom modules/criteria were not automatically moved to the indicated device, now they are
  • Until now, custom modules/criteria were not automatically set into train/eval mode, now they are
  • Simplified implementation of initialize_module et al. - they contained a lot of stuff that was irrelevant for the user, like messaging about why something was re-initialized; now this stuff is done inside the newly added methods _initialize_module etc., which are called by initialize() and shouldn't be a bother to the user
  • Adding a custom module no longer requires the attribute name to contain the substring "module" (which was really not a nice solution), same for criterion and optimizer
  • Re-initialization logic was changed: When any module is changed (via set_params), this triggers re-initialization of all modules, criteria and optimizers; when any criterion is changed, this triggers re-initialization of all optimizers (but not modules); this is a bit defensive since it could trigger unnecessary inits but it's better than missing any inits

Additions

  • There is now a get_learnable_params method on the net to retrieve all learnable parameters (instead of just those of module_); it is meant to be overridable by the user (e.g. when they have two optimizers for two modules)
  • Added attributes modules_, criteria_ and optimizers_ to the net to keep track of those; first started as OrderedDicts to mirror nn.Modules, but that was flaky, as the values in the dict would often get out of sync with the attributes on the net
  • If the criterion/criteria have learnable params, those are now passed to the optimizer as well (think GANs)

Minor changes

  • net.set_params(...) no longer initializes the net if it's not yet initialized - this was simply unnecessary and could lead to some unexpected behavior
  • custom module instances now need to be set inside initialize_module (and the name must end on an underscore), else the user will get an appropriate error message; same logic for criterion and optimizer
  • added a bunch of unit tests for the custom modules etc. that cover the cases not covered so far
  • checking of kwargs is now done during initialize, not during __init__ anymore, since at that point, we don't know yet what custom modules could exist
  • run a _check_kwargs during set_params - previously, this was a loophole that allowed users to set params with typos etc.
  • two unconditional print statements are now conditioned on verbosity level

Notes

I took extra effort to write the code as clearly as possible and add lots of comments, since this touches some complicated parts of the code base. But if something is not obvious, please tell me so that I can improve the code since now it's still fresh in my mind.

You will see that a few of the existing tests have been changed to now call initialize on the net when previously they didn't. The reason is that some work like checking kwargs is now moved to initialize.

Also, you will see that some tests now use mocked modules to check for device calls. I found this preferable to actually moving to 'cuda' since this will also work without a cuda device (e.g. during CI).

TODO

  • cover a few remaining edge cases
  • docstrings
  • documentation
  • update CHANGES
  • possibly add a more elaborate example

BenjaminBossan added 14 commits March 28, 2021 13:23
Previously, when a parameter on, say, the module was changed via
set_params (e.g. net.set_params(module__hidden_units=123)), set_params
would always trigger (re-)initialization of the module. However, when
the net was not initialized in the first place, this is unnecessary. It
is sufficient to set the new attribute and wait for the net to be
initialized later.

Fortunately, this change doesn't seem to have any further impact, i.e.
we didn't implicitly rely on this behavior anywhere. The only exceptions
are 2 tests in test_cli.py, but those can easily be adjusted and this
shouldn't have any user impact.
These methods started to become complicated because they did the
following:

1. Check if there is anything to initialize et all
2. Print message about reason for potential re-initialization
3. Moving to device

That made it quite difficult to override them without forgetting about
some aspect. With this change, there are now corresponding _intialize_*
methods that are called by net.initialize() and net.set_params. These
new methods now take care of the points above and call the initialize_*
methods inside.

Now, we can more easily make sure that the user can override
initialize_* without anything important being forgotten.
Add a test to check that set_params doesn't initialize the net if it's
not yet initialized at that time.
There were two instances of printing regardless of verbosity.
Is not relevant at the moment.
Removed code for states that could not be reached because of virtual
params. This simplifies the logic considerably.
Check optimizer-related messages for an initialized net with set_params
applied on module.
This is partly WIP because there is more to come, even though this
change per se is already an improvement on the status quo.

So far, the logic for creating custom modules or optimizers was separate
from the logic that created the default module, criterion and optimizer.
E.g., the "prefixes_" attribute was prefilled with 'module_',
'criterion_' and 'optimizer_'. This makes dealing with custom
modules/optimizers (e.g. creating a second module called 'mymodule_')
more difficult, because the logic for treating those was completely
disjoint from the logic of how the default modules/optimizer were
treated.

This change actually removes most of the "special status" of
module/criterion/optimizer. Therefore, the logic to treat those is now
the same as for any custom module. So for instance, they are no longer
pre-registered but instead are only registered later during their
initialize_* methods.

The this is implemented is to move the registration to the respective
initialize_* methods. This is because during __init__, we don't actually
know if we deal with a module or optimizer yet (passed argument for
'module' can, for instance, be a function, so we cannot type check). But
during 'initialize', when the actual instances are created, we can check
if we deal with a nn.Module or optim.Optimizer. If we do, we register
them.

So overall, the logic and role of 'initialize' have changed. Users will
be expected to set custom modules/optimizers during their respective
'initialize_*' methods from now on (stricter checks and doc updates will
be added). This affords us to no longer rely on the name to infer the
function (remember that previously, a custom module needed to contain
the substring 'module', which is an ugly restriction).

As more of a side effect to these changes, the '_check_kwargs' call was
moved to 'initialize' as well, since we cannot really check for faulty
kwargs as long as we don't know what modules and optimizers will be
registered.
These are only the tests, which will currently fail, hence WIP.

Right now, there is a big hole in the treatment of custom
modules/optimizers that distinguishes them from the assumed
ones ('module', 'criterion', 'optimizer'). This battery of unit tests
covers behaviors that will fail but really shouldn't:

- custom module parameters should be passed to the optimizer
- set_params on a custom module should trigger re-initialization of
  criterion and optimizer
- set_params on a custom criterion should trigger re-initialization of
  optimizer
- custom modules and criteria are not automatically moved to cuda
Since custom components are no longer matched by name, this became
obsolete.
@BenjaminBossan BenjaminBossan self-assigned this Apr 1, 2021
@BenjaminBossan BenjaminBossan marked this pull request as draft April 1, 2021 18:51
@BenjaminBossan
Copy link
Collaborator Author

I updated to documentation to reflect the changes being made. This should help with understanding how this PR will affect the user.

Before this, only the default "optimizer_" was used and all others were
being ignored. With this change, "zero_grad" and "step" are called on
all optimizers automatically.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some early thoughts of this draft.

Comment on lines 214 to 215
def initialize_optimizer(self, *args, **kwargs):
# first initialize the normal optimizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example in the docstring for get_learnable_params says to override get_learnable_params and initialize_optimizer to do this. I think what we have here is simpler.

If a user adds their own module2_, they can still access them through self.module2_.named_parameters. What get_learnable_params provides is a nice wrapper to connect the optimizer and the module together, but it require a user to implement this connection.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, it is probably better to encourage users to directly use module_.named_parameters(). The get_learnable_params method could be relegated to basically being a convenience method to return all named parameters at once. We could therefore remove the optimizer_name argument, I'm on the fence what the best design would be here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would opt for removing the optimizer_name argument. Maybe even rename the function to get_all_learnable_params to avoid confusion about its purpose?

skorch/net.py Outdated
@@ -230,6 +222,12 @@ class NeuralNet:
listed attributes are mapped to CPU. Expand this list if you
want to add other cuda-dependent attributes.

modules_ : TODO

criteria_ : TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To stay consistent?

Suggested change
criteria_ : TODO
criterias_ : TODO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"criteria" is already plural :)

skorch/net.py Outdated
@@ -567,56 +540,214 @@ def _apply_virtual_params(self, virtual_kwargs):
def initialize_virtual_params(self):
self.virtual_params_ = {}

def initialize_optimizer(self, triggered_directly=True):
def initialize_optimizer(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes public API, but should be okay?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. In my memory, this argument was just added recently and thus I thought it's unlikely that anyone uses it. But it's actually been there for 2.5 years, so I will add a deprecation.

if prefixes:
self.prefixes_ = self.prefixes_[:] + [name]

if cuda_dependent_attributes:
self.cuda_dependent_attributes_ = (
self.cuda_dependent_attributes_[:] + [name + '_'])

if self.init_context_ == 'module':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell we need the context manager because of how we are using __setattr__ to update the state of modules_ and friends. Is this correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. This is my proposal how to "know" what kind of attribute we're dealing with, the old one being to infer from the attribute name, which I find inferior. A small disadvantage is that the context is not triggered if a user were to call initialize_module et al directly, though I don't believe there is a need for that.

skorch/net.py Outdated
return self

def _initialize_history(self):
with self._current_init_context('callbacks'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we actually explicitly use of the 'callbacks' or virtual_params contexts. Should we still have them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. My reasoning was to stay consistent, but I don't see a use case at the moment. I could leave the context managers there but add a comment that they're not used at the moment.

BenjaminBossan added 6 commits April 5, 2021 13:13
- Update example in docstring of get_learnable_params
- Add comments about unused init contexts
- Add deprecation for triggered_directly argument

Also: Improved docstring for named_parameters
This case had to be covered yet: When the module/criterion is already
initialized and none of it's parameters changed,
initialize_module/criterion was not called. However, what if a custom
module/criterion does need to be initialized? In that case, not calling
initialize_module/criterion is bad.

With this fix, this bad behavior no longer occurs. Tests were added to
cover this.

In order to achieve this change, we had to unfortunately push down the
checking whether module/criterion is already initialized from
_initialize_module/criterion to initialize_module/criterion. There was
no other way of checking this, since at first, we cannot know which
attributes are modules/criteria.

For the user, this means a little more work if they want to implement
initialize_module/criterion absolutely correctly. However, that's not so
bad because it is only important if the user wants to work with
pre-initialized modules/criteria and with custom modules/criteria, which
should happen very rarely in practice. And even if it does, the user can
just copy the default skorch code and will end up with a correct
implementation.
Until now, only module_ and criterion_ were automatically set into
training/evaluation mode, now custom modules are also set automatically.
This was implemented through a new method, net._set_training. It is
private for now, maybe consider adding a public one. Also, the name
could be changed to "train" as in PyTorch, but that name could be
confusing.
@BenjaminBossan BenjaminBossan marked this pull request as ready for review April 10, 2021 15:10
BenjaminBossan added 6 commits April 24, 2021 13:14
I did not correctly handle virtual params with custom optimizers. This
has been fixed now. The ambiguous 'lr' parameter is only associated with
the default 'optimizer', not any custom optimizer, which need to be
addressed by 'myoptimizer__lr'.

Also, removed some unnecessary code from _initialize_optimizer.
It's not tantamount for the "initialize_*" methods to return self, since
their corresponding "_initialize_*" methods already do so. Their
signature is left as is, but the docs no longer mention that necessity.
Make it clear what happens by default in the customization docs.
Clarify what a user should do when _not_ calling super.
Instead of having get_learnable_params(optimizer_name), just have
get_all_learnable_params(), since there is not really a use case for the
former.
There is now a helper method that abstracts away the logic of
determining if a module/criterion is already initialized, and just
returns the instantiated instance.
@BenjaminBossan
Copy link
Collaborator Author

I would opt for removing the optimizer_name argument. Maybe even rename the function to get_all_learnable_params to avoid confusion about its purpose?

GH doesn't let me reply directly, so here is the answer: Yes, I was thinking the same. Changed it accordingly.

I am a little worried about the statefulness of __setattr__ / initialize_{optimizer,module,criterion} but the risk seems negligible.

Well, these methods have always been stateful ;) but I know what you mean. Yes, it's important that we get the book keeping right, but I believe the current state is unfortunate, since it's so half-baked, and IMO this is a strict improvement (provided no bugs are introduced).

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __setattr__ to dynamically define modules_ and friends is quite magical, but it works.


if 'lr' not in kwargs:
kwargs['lr'] = self.lr
named_parameters = self.get_all_learnable_params()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one were to define self.module2_, get_all_learnable_params also return the named parameters of module2_ and then we connect the optimizer every module. Is this the expected behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea. As is, if a user adds self.module2_, it's never updated. I have the feeling that this is not expected and it could indeed be very hard to spot. If a user wanted to update self.module2_, they would need to override initialize_module and add the parameters to the existing optimizer or create a new one.

In contrast, let's assume that the user does not want self.module2_ to be updated by self.optimizer_. In that case, they would most likely have to touch initialize_module anyway. If they do, it's trivial to have self.optimizer_ only update parameters of self.module_.

WDYT? Is the proposed behavior unexpected?


class CheckTrainingCallback(Callback):
def on_batch_end(self, net, batch, training, **kwargs):
assert_net_training_mode(net, training=training)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the easiest way to assert that on_batch_end is actually being called is to use the nonlocal trick from above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we test more than just if the method is called. Therefore, I don't see how nonlocal would help to simplify the test here. Maybe it's just too late in the evening for me :)

Anyway, I could simplify the test by removing the callback completely and instead override on_batch_end of the net itself.

return 0.5 * (self.module_(x) + self.module2_(x))

net = MyNet(module_cls, max_epochs=1, lr=0.5).initialize()
# params1_before = [copy.deepcopy(p) for p in net.module_.parameters()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an old commit?

Suggested change
# params1_before = [copy.deepcopy(p) for p in net.module_.parameters()]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -2813,7 +3503,7 @@ def forward(self, x0, x1):
class MyNet(NeuralNet):
"""Override train_step_single and validation_step"""
def train_step_single(self, batch, **fit_params):
self.module_.train()
self._set_training(True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.module_ should still work when this PR is merged?

Given that _set_training is private API, I would not expect third-party developers to know about it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. So your point is that I should reverse the changes to this test, since it should not depend on using the private API? If so, I agree, I will reverse the changes.


cuda_dependent_attributes_ = ['module_', 'optimizer_', 'criterion_']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not backward compatible but should be okay.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean pickle compatibility?

skorch/net.py Outdated
Comment on lines 263 to 265
modules_ = []
criteria_ = []
optimizers_ = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, this would not work with grid search, but the way these attributes are updated with __setattr__ allows it to work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's not the most elegant of solutions, but I couldn't come up with anything better.

return self

def _initialize_criterion(self, reason=None):
# _initialize_criterion and _initialize_module share the same logic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are ways to try to DRY this code, but I am okay to leave it as is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only a single repetition and extracting the common part could make it more difficult to understand than it already is. So yes, let's leave it as is for now :)

skorch/net.py Outdated
Comment on lines 945 to 946
for name in self.optimizers_:
optimizer = getattr(self, name + '_')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using string mangling to get the optimizers or modules is an okay approach.

I am thinking of using WeakValueDictionary to hold a weak reference could be better. It could even keep track of itself to avoid the whole `_unregister_attribute:

from weakref import WeakValueDictionary

class A:
    pass

a_obj = A()

my_dict = WeakValueDictionary() 
my_dict["hello"] = a_obj

# object is in dict
assert "hello" in my_dict

# Remove object
del a_obj

# object not in dict
assert "hello" not in my_dict

If we use this WeakValueDictionary, then modules_ can be a @property returning self.modules_weak_dict_.keys().

I do not have a strong opinion. (WeakValueDictionary is not a well know python object)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea because it would allow us to safe a bit of code. And since this is a part of the code that the user should never touch, it would be okay to use some of the less well known Python features.

In practice, however, this seems to be too brittle. It relies on the number of references held to, say, the module_. If a user references it somewhere else, a simple deletion of net.module_ will not clean it up from the net.modules_.

The way I discovered this was actually quite funny. One of the tests contains this code:

class MyNet(net_cls):
    def initialize_module(self):
        super().initialize_module()
        self.module_ = module_cls()  # same attribute name
        return self

The WeakValueDictionary would be empty after the module_ is re-assigned. Except if I added a breakpoint, then everything would work as expected. That's when I decided it's not worth it :D

- simplify one of the tests
- remove uncommented line of code
- reverse changes in one test to not use a private method
@BenjaminBossan
Copy link
Collaborator Author

No idea why all the tests are failing/being canceled. The offending tests seem to be:

FAILED skorch/tests/callbacks/test_logging.py::TestWandb::test_fit_with_real_experiment
FAILED skorch/tests/callbacks/test_logging.py::TestProgressBar::test_pickle
ERROR skorch/tests/callbacks/test_logging.py::TestTensorBoard::test_writer_closed_automatically

Those seem to be unrelated to my latest changes, which were exclusively inside test_net.py (and the tests also seem to be unrelated to each other?). I can't reproduce the error locally. Therefore, I will wait for now and run the tests again later, in case this issue fixes itself. Meanwhile, I believe my latest changes are safe to be reviewed.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an argument for making modules_ private because it is more of an implementation detail for NeuralNet. Users and subclass developers do not need to interact with it.

If we keep it public, we need to be super clear that modules_ is set automatically and should not be touched directly.

Comment on lines +246 to +247
collected dynamically when the net is initialized. Typically, there is no
reason for a user to modify this list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason for a developer to directly interact with criteria_? Currently, I think its auto populated when one does self.module2_ = ....

BenjaminBossan added 3 commits May 22, 2021 15:41
There should rarely be a need for a user to touch those attributes.
The corresponding methods have no *args or **kwargs anyway.
@BenjaminBossan
Copy link
Collaborator Author

There is an argument for making modules_ private because it is more of an implementation detail for NeuralNet. Users and subclass developers do not need to interact with it.

If we keep it public, we need to be super clear that modules_ is set automatically and should not be touched directly.

You are right, Thomas, there should rarely, if ever, be the necessity for a user to touch these attributes. I renamed modules_, criteria_ and optimizers_ to _modules, _criteria and _optimizers to mark them as private.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is ready. LGTM

@BenjaminBossan BenjaminBossan merged commit 812f54d into master Jun 13, 2021
@BenjaminBossan
Copy link
Collaborator Author

After personal communication with @ottonemo he gave his thumbs up, so I went ahead and merged.

@BenjaminBossan BenjaminBossan deleted the changed/refactor-init-more-consistency-custom-modules branch October 3, 2021 19:40
BenjaminBossan added a commit that referenced this pull request Oct 31, 2021
We are happy to announce the new skorch 0.11 release:

Two basic but very useful features have been added to our collection of callbacks. First, by setting `load_best=True` on the  [`Checkpoint` callback](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.Checkpoint), the snapshot of the network with the best score will be loaded automatically when training ends. Second, we added a callback [`InputShapeSetter`](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.InputShapeSetter) that automatically adjusts your input layer to have the size of your input data (useful e.g. when that size is not known beforehand).

When it comes to integrations, the [`MlflowLogger`](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.MlflowLogger) now allows to automatically log to [MLflow](https://mlflow.org/). Thanks to a contributor, some regressions in `net.history` have been fixed and it even runs faster now.

On top of that, skorch now offers a new module, `skorch.probabilistic`. It contains new classes to work with **Gaussian Processes** using the familiar skorch API. This is made possible by the fantastic [GPyTorch](https://github.com/cornellius-gp/gpytorch) library, which skorch uses for this. So if you want to get started with Gaussian Processes in skorch, check out the [documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) and this [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). Since we're still learning, it's possible that we will change the API in the future, so please be aware of that.

Morever, we introduced some changes to make skorch more customizable. First of all, we changed the signature of some methods so that they no longer assume the dataset to always return exactly 2 values. This way, it's easier to work with custom datasets that return e.g. 3 values. Normal users should not notice any difference, but if you often create custom nets, take a look at the [migration guide](https://skorch.readthedocs.io/en/latest/user/FAQ.html#migration-from-0-10-to-0-11).

And finally, we made a change to how custom modules, criteria, and optimizers are handled. They are now "first class citizens" in skorch land, which means: If you add a second module to your custom net, it is treated exactly the same as the normal module. E.g., skorch takes care of moving it to CUDA if needed and of switching it to train or eval mode. This way, customizing your networks architectures with skorch is easier than ever. Check the [docs](https://skorch.readthedocs.io/en/latest/user/customization.html#initialization-and-custom-modules) for more details.

Since these are some big changes, it's possible that you encounter issues. If that's the case, please check our [issue](https://github.com/skorch-dev/skorch/issues) page or create a new one.

As always, this release was made possible by outside contributors. Many thanks to:

- Autumnii
- Cebtenzzre
- Charles Cabergs
- Immanuel Bayer
- Jake Gardner
- Matthias Pfenninger
- Prabhat Kumar Sahu 

Find below the list of all changes:

Added

- Added `load_best` attribute to `Checkpoint` callback to automatically load state of the best result at the end of training
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
- Added `MlflowLogger` callback for logging to Mlflow (#769)
- Added `InputShapeSetter` callback for automatically setting the input dimension of the PyTorch module
- Added a new module to support Gaussian Processes through [GPyTorch](https://gpytorch.ai/). To learn more about it, read the [GP documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) or take a look at the [GP notebook](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). This feature is experimental, i.e. the API could be changed in the future in a backwards incompatible way (#782)

Changed

- Changed the signature of `validation_step`, `train_step_single`, `train_step`, `evaluation_step`, `on_batch_begin`, and `on_batch_end` such that instead of receiving `X` and `y`, they receive the whole batch; this makes it easier to deal with datasets that don't strictly return an `(X, y)` tuple, which is true for quite a few PyTorch datasets; please refer to the [migration guide](https://skorch.readthedocs.io/en/latest/user/FAQ.html#migration-from-0-10-to-0-11) if you encounter problems (#699)
- Checking of arguments to `NeuralNet` is now during `.initialize()`, not during `__init__`, to avoid raising false positives for yet unknown module or optimizer attributes
- Modules, criteria, and optimizers that are added to a net by the user are now first class: skorch takes care of setting train/eval mode, moving to the indicated device, and updating all learnable parameters during training (check the [docs](https://skorch.readthedocs.io/en/latest/user/customization.html#initialization-and-custom-modules) for more details, #751)
- `CVSplit` is renamed to `ValidSplit` to avoid confusion (#752)

Fixed

- Fixed a few bugs in the `net.history` implementation (#776)
- Fixed a bug in `TrainEndCheckpoint` that prevented it from being unpickled (#773)
BenjaminBossan added a commit that referenced this pull request Jul 28, 2022
This bug occurs when trying to load a net that was trained on CUDA on a
CPU machine. However, it only occurred when there were CUDA-dependent
attribute set via set_params. This bug is now fixed.

The problem started occurring after PR #751, which introduced storing
parameters set via set_params in the private attribute _kwargs.
Normally, for attributes, we make sure that they can be loaded without
CUDA, but attributes within _kwargs were not checked. Thus, loading
those without CUDA failed. Unfortunately, this was not caught by CI
because CI is not CUDA-enabled.

The solution to this problem is that CUDA-dependent attributes are now
removed from _kwargs during __getstate__ and re-added later during
__setstate__ (a test was introduced for checking that). A little problem
occurring there was that the cuda_dependent_attributes_ are not part of
state. Therefore, I add them to state during __getstate__ and remove
them later in __setstate__.
BenjaminBossan added a commit that referenced this pull request Aug 1, 2022
Supersedes #877

This bug could occur if a user has set a parameter with a CUDA
dependency and then tries to load the net without CUDA. Now, this
works (again) as expected.

Underlying reason

The problem started occurring after PR #751, which introduced storing
parameters set via set_params in the private attribute _kwargs.
Normally, for attributes, we make sure that they can be loaded without
CUDA, but attributes within _kwargs were not checked. Thus, loading
those without CUDA failed. Unfortunately, this was not caught by CI
because CI is not CUDA-enabled.

The bugfix consists of making sure that we don't store any values in
_kwargs. Since values are not needed, only the keys (parameter names),
this is more efficient anyway. Thus, there are no more possibly
CUDA-dependent values that can "slip through".

After discussion, we decided to also rename the attribute, as _kwargs
was not very specific. The new attribute is called _params_to_validate
and it is a set instead of a dict. Also, the _check_kwargs method was
renamed to _validate_params and it doesn't take a kwargs argument
anymore. And on top of that, I changed the raised error from TypeError
to ValueError.

The reason for making this change is that it now is similar to sklearn's
_validate_params method on BaseEstimator (same signature and same error
type). However, we don't make use of the actual sklearn machinery since
our validation does a few things differently (e.g. proposing possible
fixes when the name is wrong).

As the attribute was renamed, we would normally get an error when
unpickling nets stored with the old attribute. To prevent this, we catch
the old attribute _kwargs and convert it to the new attribute
_params_to_validate.

Coincidental changes

- Moved an entry in CHANGES.md to a different section
- Added a reference to an existing entry in CHANGES.md
- I adapted the code in hf.py to use the same new scheme
ottonemo pushed a commit that referenced this pull request Sep 5, 2022
* Loading extra arguments w/ cuda dependency on CPU

Supersedes #877

This bug could occur if a user has set a parameter with a CUDA
dependency and then tries to load the net without CUDA. Now, this
works (again) as expected.

Underlying reason

The problem started occurring after PR #751, which introduced storing
parameters set via set_params in the private attribute _kwargs.
Normally, for attributes, we make sure that they can be loaded without
CUDA, but attributes within _kwargs were not checked. Thus, loading
those without CUDA failed. Unfortunately, this was not caught by CI
because CI is not CUDA-enabled.

The bugfix consists of making sure that we don't store any values in
_kwargs. Since values are not needed, only the keys (parameter names),
this is more efficient anyway. Thus, there are no more possibly
CUDA-dependent values that can "slip through".

After discussion, we decided to also rename the attribute, as _kwargs
was not very specific. The new attribute is called _params_to_validate
and it is a set instead of a dict. Also, the _check_kwargs method was
renamed to _validate_params and it doesn't take a kwargs argument
anymore. And on top of that, I changed the raised error from TypeError
to ValueError.

The reason for making this change is that it now is similar to sklearn's
_validate_params method on BaseEstimator (same signature and same error
type). However, we don't make use of the actual sklearn machinery since
our validation does a few things differently (e.g. proposing possible
fixes when the name is wrong).

As the attribute was renamed, we would normally get an error when
unpickling nets stored with the old attribute. To prevent this, we catch
the old attribute _kwargs and convert it to the new attribute
_params_to_validate.

Coincidental changes

- Moved an entry in CHANGES.md to a different section
- Added a reference to an existing entry in CHANGES.md
- I adapted the code in hf.py to use the same new scheme

* Add TODO comment for removing transition code

Give a 1 year grace period to still enable loading old skorch models
with new version.
BenjaminBossan added a commit that referenced this pull request Oct 13, 2022
Resolves #907

The problem initially occurred because the warning did not check for the
value, just the presence of the key.

Even though there is a test for this, the test didn't detect the error.
This is because during a refactor (#751), the parameter validation was
moved to initialize() from __init__() but the test was not adjusted to
take the change into account.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants