-
Notifications
You must be signed in to change notification settings - Fork 395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Initial implementation of AMP support #707
Conversation
This feature is about adding support for automatic mixed precision (issue 611). The current state should already be working but is untested so far; it is still missing docs and tests.
Just as a flavor of how a facade optimizer could look like: class AmpOptimizerFacade(torch.optim.Optimizer):
def __init__(self, optimizer, grad_scaler, amp_enabled=True):
self.optimizer = optimizer
self.grad_scaler = grad_scaler
self.amp_enabled = amp_enabled
def step(self, closure):
if not self.amp_enabled:
# just use optimizer as normal
return self.optimizer.step(closure)
loss = None
if closure is not None:
with torch.enable_grad():
# Closure is responsible for scaling the loss, no
# scaling here; the reason is that loss.backward() is
# called within the closure, at which point the
# scaling already needs to be applied, so applying it
# here would be too late.
loss = closure()
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
return loss
def __getattr__(self, attr, default=None):
return getattr(self.optimizer, attr, default=default)
def __repr__(self):
# something useful The idea would be to set |
We can make the assumption that Here are some ideas:
|
Thanks for taking a look @thomasjpfan
That's what I meant with extra checks. It's not the end of the world, but still annoying. But probably better than setting a dummy object, which might confuse people when they
Yes, e.g. when users override
This seems to be the most conservative approach. We'd probably also need an upgrade guide for people who already overrode affected methods, to remind them that certain methods now need to perform extra duty. We could put this into CHANGES.md, but it's hard for me to tell if this is getting read.
Could you elaborate on this suggestion? What would you change to those?
(btw. trying to understand fastai code was as delightful as ever, with those strange decorators, patches everywhere, and lots of def before_batch(self): self.autocast.__enter__()
def after_loss(self): self.autocast.__exit__() So in fastai, Also, I wonder if fastai's Moreover, I wonder if the implementation using a callback cannot lead to bugs introduced by using callbacks in the wrong order. I'd feel more comfortable having more control over that by leaving the logic inside self.learn._step,self.learn._backward = self._step,self._backward This looks dangerous to me. Finally, I don't see that fastai adjusts their gradient clipping to AMP. Okay, I'm ranting a bit, but my conclusion for now is that we shouldn't lean too much on the way fastai implements AMP. |
Since unscaling is not an idempotent operation, PyTorch raises a RuntimeError if an unscaled optimizer is unscaled again. Since we could have multiple callbacks or other components that might want to unscale, we have to protect against this possibility. Therefore, we check if the optimizer has already been unscaled before unscaling. Unfortunately, we have to use a privat attribute in PyTorch for this, so this is prone to break in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have spent some time evaluating different options and this one is still the one that seems most appropriate (with the proposed modification). There is no transparent way of implementing this using callbacks and hiding the optimizer/criterion behind a facade makes the whole process too opaque for my taste.
else: | ||
scaler.scale(loss).backward() | ||
scaler.step(optimizer) | ||
scaler.update() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since PyTorch XLA requires to wrap the step of the optimizer as well (and this is a feature me might want to support in the future as well once TPUs become more accessible for smaller companies) I suggest that we introduce something akin to self.optimizer_step(optimizer)
which sorts stuff like AMP scaling and XLA optimizations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to give this a try in the context of this PR to see if it makes things more ergonomic.
Closed in favor of #826 |
Intro
This feature is about adding support for automatic mixed precision (AMP, solves #611). The current state should already be working but is untested so far; it is still missing docs and tests.
Testing AMP
Unfortunately, I cannot test whether this works or not. Looking at colab, they mention that one gets one of
but it seems AMP requires Turing, Volta, or Ampere. Could someone else run the test? I updated the
examples/benchmarks/mnist.py
script with a new--amp_enabled
argument. Running it once with and without AMP, the acceleration for skorch should be similar to the acceleration for pure PyTorch (not sure if we can expect a big acceleration with the simple architecture being used).Progress
See here
NeuralNet
to enable amp (amp_enabled
)autocast
insideinfer
andget_loss
methodscaler.scale
for backward calls (outside ofautocast
context)scaler.step(optimizer)
andscaler.update()
scaler.state_dict
/scaler.load_state_dict
wherever you save/restore checkpoints.GradScaler
and to pass arguments using dunder notation:net.set_params(grad_scaler__growth_factor=4)
.GradientNormClipping
to work with scaled gradientsDesign
My implementation so far uses the most straightforward way of implementing AMP support. However, I want to keep this a draft until we finalize the design. Parts of the code that I don't really like:
There is now a
net.grad_scaler_
attribute, analogue tonet.module_
etc. But it isNone
if notamp_enabled
, which is unlike the other such attributes, which are neverNone
. This requires some checks down the line (if val is not None
).Similarly,
f_grad_scaler
onCheckpoint
et al. is alsoNone
, unlike the other parameters.Inside
train_step
, there is now:This is ugly and also requires awareness of everyone who overrides
train_step
. If there is custom code out there that already overridestrain_step
, it will not work withamp_enabled
unless adjusted (in fact, it will fail silently).get_loss
:Both
get_loss
andtrain_step
should ideally be easy to override, this change makes it harder.train_step_single
and no longer usesinfer
andget_loss
, AMP will not be applied correctly.I believe that none of these issues are show stoppers, AMP support is important enough that we should accept some increased complexity. But maybe we can come up with a superior design that doesn't sacrifice as much. E.g., we could think about using a facade pattern to hide some of the
if ... else
ugliness above and to not havenet.grad_scaler_
asNone
, but that would also make the code more opaque. Anyway, I'm up for suggestions.