-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Get rid of optim_bwd checks via wrapper. #2370
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2370
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@felipemello1 Hey! Require review. I haven't found a better design at the point of the wrapper. |
Hey @krammnic. We're currently seperately working through an issue (#2360) with our optimizer-in-backward feature. I wonder if it'd be best to investigate those fixes first? cc @ebsmothers |
@SalmanMohammadi Hey! Sure, will check it out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey Mark, thanks for this PR! At first glance, i love it :)
There are a few things that worry me. Let me know your thoughts about them:
- Initially when i thought about the opt_in_bwd wrapper, was for it to mimic the expected behavior of the optimizer. But in this PR, we are also wrapping the regular optimizer, for example:
def state_dict(self) -> Dict[str, Any]:
if self._optimizer_in_bwd:
return self._optim_ckpt_wrapper.state_dict()
else:
return self._optimizer.state_dict()
My fear is that we are adding layers on something that may not need layers, hurting hackability.
On the other hand, i like how you removed a lot of utilities/boilerplate from the recipe and added here.
Q: should we wrap the optimizer? Its fine if the answer is yes, just want to make sure it makes sense.
- I believe that in the training loop, the optimizer is called differently when its opt_in_bwd and when it isnt. This happens because of grad_clipping and normalizing loss by number of tokens.
check lines 814, 825, 868: https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py#L814
Q: How should we handle it? Probably still check if self._optimizer_in_bwd
?
- Some recipes dont have optimizer_in_bwd, like RL and lora recipes. It may be worth sanity checking that using this implementation will also work for them. But I am ok with doing it in a follow up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have this live in memory with the other memory optimizations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
else: | ||
self._optimizer.step() | ||
|
||
def get_lr(self) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in favor of slightly more descriptive names like get_learning_rate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meanwhile, Yeah, sounds reasonable!lr: 5e-8
in the recipes.
@felipemello1 Thanks for the answer! My opinion that there is no problem to make point 3 done during this PR, it is actually pretty small procedure, isn't it? (Correct me if I'm missing something, please). Speaking about the design overal... To be honest, for me it is hard to choose either. For instance, I love how some parts of the recipes has gone in this optimizer abstraction. But it is abstraction! Therefore, it is a complication. Will think more about it. |
Co-authored-by: Joe Cummings <[email protected]>
lets get the approval for one (or two recipes, maybe a LoRA one). And when everyone is ok with it, we can look towards updating the others, so it minimizes the work for you. Thanks again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments and then realized that we already have OptimizerInBackwardWrapper. I should have started with that. But can't we just extend OptimizerInBackwardWrapper to have all the safety checks that are currently inside of the recipe instead of wrapping it in an additional layer?
for param in self._model.parameters() | ||
} | ||
training.register_optim_in_bwd_hooks(model=self._model, optim_dict=self._optim_dict) | ||
self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class should replace create_optim_in_bwd_wrapper so we don't have wrapper inception
def __init__( | ||
self, | ||
model: nn.Module, | ||
cfg_optimizer: DictConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a rule of thumb we try to keep DictConfig (and all config code like instantiate) for the recipe only and all utils just work directly with python parameters. Instead you can partially initialize the optimizer (using functools) inside the recipe and then pass that into this object.
This could also be handled by extending training.create_optim_in_bwd_wrapper to instead return this object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use it like this:
param: config.instantiate(cfg_optimizer, [param]) for param in model.parameters()
or
param: config.instantiate(cfg_optimizer, model.parameters())
we also pass a DictConfig to setup_optimizer
Are you thinking about something like below? Not sure if this is more intuitive / debuggable
self._optimizer = OptimizerWrapper(
model=self._model,
cfg_optimizer=partial(cfg_optimizer),
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
)
from torchtune.training.lr_schedulers import get_lr | ||
|
||
|
||
class OptimizerWrapper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this inherit from torch optim so it maintains the same api as the the wrapped optimizer? It would also mean that opt_state_dict wouldn't be passed into the init.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think that inheriting would help. Even for regular optimizer, we still need it when resuming_from_checkpoint
so we can call:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
)
from torchtune.training.lr_schedulers import get_lr | ||
|
||
|
||
class OptimizerWrapper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the name should be more specific. Maybe MemoryEfficientOptimizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a wrapper for both optimizer and opt_in_bwd. Basically it is removing all the boilerplate from the recipe and putting into this function. So it is not for memory only. Its more of an abstraction
else: | ||
return self._optimizer.state_dict() | ||
|
||
def get_optimizer(self) -> tuple: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is this used?
else: | ||
return get_lr(self._optimizer) | ||
|
||
def set_learning_rate_scheduler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is this used?
else: | ||
return self._optimizer | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be the same code as in init, should we just call it instead?
Here is what i am understanding: a) There are optimizer methods (zero grad, step, state_dict) For (a), i feel like these should belong to For (b), IMO your I just dont know if we should do something like: @krammnic @pbontrager do you agree? @krammnic , when you have time, can you also address the other parts of the recipe (lines 814, 825, 868: main/recipes/full_finetune_distributed.py#L814) |
@felipemello1 Yep, reasonable comments! Let me fix then |
Speaking about |
@pbontrager Hey Philip! I reformat this as an RFC and would love to hear your comments on desired design for this (except the changes the @felipemello1 already mentioned). I assume that such clean up should be done, but it mostly depends on chosen design. |
#2052
We have lots of similar codes related to theoptim_bwd optimizer setup in our recipes. Basically, we can create a wrapper/abstraction of the basic optimizer, where we will put all checks, making the recipes cleaner. One of the possible ways of doing such a cleanup is presented in this RFC. Basically, we should discuss the balance between abstraction and simplicity of the given solution. Would love to hear some comments on it.