Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial implementation of AMP support #707

Closed

Conversation

BenjaminBossan
Copy link
Collaborator

@BenjaminBossan BenjaminBossan commented Oct 3, 2020

Intro

This feature is about adding support for automatic mixed precision (AMP, solves #611). The current state should already be working but is untested so far; it is still missing docs and tests.

Testing AMP

Unfortunately, I cannot test whether this works or not. Looking at colab, they mention that one gets one of

Nvidia K80s, T4s, P4s and P100s

but it seems AMP requires Turing, Volta, or Ampere. Could someone else run the test? I updated the examples/benchmarks/mnist.py script with a new --amp_enabled argument. Running it once with and without AMP, the acceleration for skorch should be similar to the acceleration for pure PyTorch (not sure if we can expect a big acceleration with the simple architecture being used).

Progress

See here

  • Add a global option on NeuralNet to enable amp (amp_enabled)
  • use autocast inside infer and get_loss method
  • add the scaler to the parameters that need to be persisted
  • use scaler.scale for backward calls (outside of autocast context)
  • use scaler.step(optimizer) and scaler.update()
  • use scaler.state_dict / scaler.load_state_dict wherever you save/restore checkpoints.
  • possibility to change the used GradScaler and to pass arguments using dunder notation: net.set_params(grad_scaler__growth_factor=4).
  • keep the code backwards compatible with older PyTorch versions, warn users when they enable AMP but it has no effect
  • update the docs
  • add tests
  • possibly adjust GradientNormClipping to work with scaled gradients
  • possibly add a benchmark (similar to or based on this script)

Design

My implementation so far uses the most straightforward way of implementing AMP support. However, I want to keep this a draft until we finalize the design. Parts of the code that I don't really like:

  1. There is now a net.grad_scaler_ attribute, analogue to net.module_ etc. But it is None if not amp_enabled, which is unlike the other such attributes, which are never None. This requires some checks down the line (if val is not None).

  2. Similarly, f_grad_scaler on Checkpoint et al. is also None, unlike the other parameters.

  3. Inside train_step, there is now:

        if not self.amp_enabled:
            self.optimizer_.step(step_fn)
        else:
            step_fn()  # closure not (yet) supported with AMP
            self.grad_scaler_.step(self.optimizer_)
            self.grad_scaler_.update()

        return step_accumulator.get_step()

This is ugly and also requires awareness of everyone who overrides train_step. If there is custom code out there that already overrides train_step, it will not work with amp_enabled unless adjusted (in fact, it will fail silently).

  1. Similar argument for get_loss:
        with self.autocast():
            loss = self.criterion_(y_pred, y_true)
        return self.grad_scaler_.scale(loss) if self.amp_enabled else loss

Both get_loss and train_step should ideally be easy to override, this change makes it harder.

  1. When someone overrides train_step_single and no longer uses infer and get_loss, AMP will not be applied correctly.

I believe that none of these issues are show stoppers, AMP support is important enough that we should accept some increased complexity. But maybe we can come up with a superior design that doesn't sacrifice as much. E.g., we could think about using a facade pattern to hide some of the if ... else ugliness above and to not have net.grad_scaler_ as None, but that would also make the code more opaque. Anyway, I'm up for suggestions.

This feature is about adding support for automatic mixed
precision (issue 611). The current state should already be working but
is untested so far; it is still missing docs and tests.
@BenjaminBossan BenjaminBossan self-assigned this Oct 3, 2020
@BenjaminBossan BenjaminBossan removed their assignment Oct 3, 2020
@BenjaminBossan BenjaminBossan self-assigned this Oct 3, 2020
@BenjaminBossan BenjaminBossan linked an issue Oct 3, 2020 that may be closed by this pull request
@BenjaminBossan
Copy link
Collaborator Author

Just as a flavor of how a facade optimizer could look like:

class AmpOptimizerFacade(torch.optim.Optimizer):
    def __init__(self, optimizer, grad_scaler, amp_enabled=True):
        self.optimizer = optimizer
        self.grad_scaler = grad_scaler
        self.amp_enabled = amp_enabled

    def step(self, closure):
        if not self.amp_enabled:
            # just use optimizer as normal
            return self.optimizer.step(closure)

        loss = None
        if closure is not None:
            with torch.enable_grad():
                # Closure is responsible for scaling the loss, no
                # scaling here; the reason is that loss.backward() is
                # called within the closure, at which point the
                # scaling already needs to be applied, so applying it
                # here would be too late.
                loss = closure()

        self.grad_scaler.step(self.optimizer)
        self.grad_scaler.update()
        return loss

    def __getattr__(self, attr, default=None):
        return getattr(self.optimizer, attr, default=default)

    def __repr__(self):
        # something useful

The idea would be to set net.optimizer_ to this object in case that AMP is enabled (otherwise leave it as is). Then we can get rid of issue 3 mentioned above. Unfortunately, we cannot perform the loss scaling inside this wrapper (issue 4), because loss.backward() is called inside the closure, but the scaling should happen before loss.backward() is called, thus scaling inside the wrapper would be too late.

@thomasjpfan
Copy link
Member

There is now a net.grad_scaler_ attribute, analogue to net.module_ etc. But it is None if not amp_enabled, which is unlike the other such attributes, which are never None. This requires some checks down the line (if val is not None).

We can make the assumption that amp_enabled==False => net.grad_scaler_ is None and use self.amp_enabled everywhere.

Here are some ideas:

  1. We can extend the callback api to have a "before_backward" and "after_backward" callback, something like fastai's MixedPrecision Callback. But this will still require subclasses to call trigger the callbacks in the right places.

  2. We keep the ideas in this PR and document places that need to be adjusted to make amp work for subclasses.

  3. If we want to extend the facade idea, we can also facade the criterion_ and module_ to do the correct thing when amp is activated. This would make things extra opaque, but it make allow subclasses to not worry about amp and get it "for free".

@BenjaminBossan
Copy link
Collaborator Author

Thanks for taking a look @thomasjpfan

We can make the assumption that amp_enabled==False => net.grad_scaler_ is None and use self.amp_enabled everywhere.

That's what I meant with extra checks. It's not the end of the world, but still annoying. But probably better than setting a dummy object, which might confuse people when they save_params and suddenly there is a grad scaler file despite not using AMP.

  1. We can extend the callback api to have a "before_backward" and "after_backward" callback... But this will still require subclasses to call trigger the callbacks in the right places.

Yes, e.g. when users override train_step_single, they need to remember to call those. Also, I wonder if it's good to "break up" the training loop further and further by invoking more and more methods -- it could make the overall flow hard to understand.

  1. We keep the ideas in this PR and document places that need to be adjusted to make amp work for subclasses.

This seems to be the most conservative approach. We'd probably also need an upgrade guide for people who already overrode affected methods, to remind them that certain methods now need to perform extra duty. We could put this into CHANGES.md, but it's hard for me to tell if this is getting read.

  1. If we want to extend the facade idea, we can also facade the criterion_ and module_ to do the correct thing when amp is activated.

Could you elaborate on this suggestion? What would you change to those?

something like fastai's MixedPrecision Callback

MixedPrecision seems to predate the PyTorch native AMP support, so instead I looked at NativeMixedPrecision:

https://github.com/fastai/fastai/blob/72590db2e66af6dd0eaa8c8874a80f11a4b8cbc2/fastai/callback/fp16.py#L154-L167

(btw. trying to understand fastai code was as delightful as ever, with those strange decorators, patches everywhere, and lots of import * ^^ )

    def before_batch(self): self.autocast.__enter__()
    def after_loss(self): self.autocast.__exit__()

So in fastai, autocast is entered on_batch_begin and exited after get_loss (skorch nomenclature). This is thus a "broader" approach, i.e. a lot can happen within the autocast context, whereas I chose to only apply it precisely where it's needed (when module_ and criterion_ are called). What is better?

Also, I wonder if fastai's after_loss is invoked after a prediction is made. I tried to understand get_preds but failed unfortunately. If it's not called, that seems to be wrong: autocast is entered when the prediction starts and never exited.

Moreover, I wonder if the implementation using a callback cannot lead to bugs introduced by using callbacks in the wrong order. I'd feel more comfortable having more control over that by leaving the logic inside NeuralNet tbh.

        self.learn._step,self.learn._backward = self._step,self._backward

This looks dangerous to me. self.learn seems to be the trainer object (like NeuralNet) and just overriding methods ad hoc could lead to surprising results.

Finally, I don't see that fastai adjusts their gradient clipping to AMP.

Okay, I'm ranting a bit, but my conclusion for now is that we shouldn't lean too much on the way fastai implements AMP.

Since unscaling is not an idempotent operation, PyTorch raises a
RuntimeError if an unscaled optimizer is unscaled again. Since we
could have multiple callbacks or other components that might want to
unscale, we have to protect against this possibility. Therefore, we
check if the optimizer has already been unscaled before unscaling.

Unfortunately, we have to use a privat attribute in PyTorch for this,
so this is prone to break in the future.
@BenjaminBossan BenjaminBossan marked this pull request as ready for review November 21, 2020 15:55
Copy link
Member

@ottonemo ottonemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have spent some time evaluating different options and this one is still the one that seems most appropriate (with the proposed modification). There is no transparent way of implementing this using callbacks and hiding the optimizer/criterion behind a facade makes the whole process too opaque for my taste.

else:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since PyTorch XLA requires to wrap the step of the optimizer as well (and this is a feature me might want to support in the future as well once TPUs become more accessible for smaller companies) I suggest that we introduce something akin to self.optimizer_step(optimizer) which sorts stuff like AMP scaling and XLA optimizations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to give this a try in the context of this PR to see if it makes things more ergonomic.

@BenjaminBossan
Copy link
Collaborator Author

Closed in favor of #826

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Native automatic mixed precision for Skorch
3 participants