Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Get rid of optim_bwd checks via wrapper. #2370

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

krammnic
Copy link
Contributor

@krammnic krammnic commented Feb 9, 2025

#2052

We have lots of similar codes related to theoptim_bwd optimizer setup in our recipes. Basically, we can create a wrapper/abstraction of the basic optimizer, where we will put all checks, making the recipes cleaner. One of the possible ways of doing such a cleanup is presented in this RFC. Basically, we should discuss the balance between abstraction and simplicity of the given solution. Would love to hear some comments on it.

Copy link

pytorch-bot bot commented Feb 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2370

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 9, 2025
@krammnic
Copy link
Contributor Author

krammnic commented Feb 9, 2025

@felipemello1 Hey! Require review. I haven't found a better design at the point of the wrapper.

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Feb 9, 2025

Hey @krammnic. We're currently seperately working through an issue (#2360) with our optimizer-in-backward feature. I wonder if it'd be best to investigate those fixes first? cc @ebsmothers

@krammnic
Copy link
Contributor Author

krammnic commented Feb 9, 2025

@SalmanMohammadi Hey! Sure, will check it out.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey Mark, thanks for this PR! At first glance, i love it :)

There are a few things that worry me. Let me know your thoughts about them:

  1. Initially when i thought about the opt_in_bwd wrapper, was for it to mimic the expected behavior of the optimizer. But in this PR, we are also wrapping the regular optimizer, for example:
def state_dict(self) -> Dict[str, Any]:
        if self._optimizer_in_bwd:
            return self._optim_ckpt_wrapper.state_dict()
        else:
            return self._optimizer.state_dict()

My fear is that we are adding layers on something that may not need layers, hurting hackability.

On the other hand, i like how you removed a lot of utilities/boilerplate from the recipe and added here.

Q: should we wrap the optimizer? Its fine if the answer is yes, just want to make sure it makes sense.

  1. I believe that in the training loop, the optimizer is called differently when its opt_in_bwd and when it isnt. This happens because of grad_clipping and normalizing loss by number of tokens.

check lines 814, 825, 868: https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py#L814

Q: How should we handle it? Probably still check if self._optimizer_in_bwd?

  1. Some recipes dont have optimizer_in_bwd, like RL and lora recipes. It may be worth sanity checking that using this implementation will also work for them. But I am ok with doing it in a follow up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this live in memory with the other memory optimizations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

else:
self._optimizer.step()

def get_lr(self) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of slightly more descriptive names like get_learning_rate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, lr: 5e-8 in the recipes. Yeah, sounds reasonable!

@krammnic
Copy link
Contributor Author

hey Mark, thanks for this PR! At first glance, i love it :)

There are a few things that worry me. Let me know your thoughts about them:

  1. Initially when i thought about the opt_in_bwd wrapper, was for it to mimic the expected behavior of the optimizer. But in this PR, we are also wrapping the regular optimizer, for example:
def state_dict(self) -> Dict[str, Any]:
        if self._optimizer_in_bwd:
            return self._optim_ckpt_wrapper.state_dict()
        else:
            return self._optimizer.state_dict()

My fear is that we are adding layers on something that may not need layers, hurting hackability.

On the other hand, i like how you removed a lot of utilities/boilerplate from the recipe and added here.

Q: should we wrap the optimizer?

  1. I believe that in the training loop, the optimizer is called differently when its opt_in_bwd and when it isnt. This happens because of grad_clipping and normalizing loss by number of tokens.

check lines 814, 825, 868: https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py#L814

Q: How should we handle it? Probably still check if self._optimizer_in_bwd?

  1. Some recipes dont have optimizer_in_bwd, like RL and lora recipes. It may be worth sanity checking that using this implementation will also work for them. But I am ok with doing it in a follow up PR.

@felipemello1 Thanks for the answer! My opinion that there is no problem to make point 3 done during this PR, it is actually pretty small procedure, isn't it? (Correct me if I'm missing something, please). Speaking about the design overal... To be honest, for me it is hard to choose either. For instance, I love how some parts of the recipes has gone in this optimizer abstraction. But it is abstraction! Therefore, it is a complication. Will think more about it.

@felipemello1
Copy link
Contributor

lets get the approval for one (or two recipes, maybe a LoRA one). And when everyone is ok with it, we can look towards updating the others, so it minimizes the work for you. Thanks again!

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments and then realized that we already have OptimizerInBackwardWrapper. I should have started with that. But can't we just extend OptimizerInBackwardWrapper to have all the safety checks that are currently inside of the recipe instead of wrapping it in an additional layer?

for param in self._model.parameters()
}
training.register_optim_in_bwd_hooks(model=self._model, optim_dict=self._optim_dict)
self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should replace create_optim_in_bwd_wrapper so we don't have wrapper inception

def __init__(
self,
model: nn.Module,
cfg_optimizer: DictConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rule of thumb we try to keep DictConfig (and all config code like instantiate) for the recipe only and all utils just work directly with python parameters. Instead you can partially initialize the optimizer (using functools) inside the recipe and then pass that into this object.

This could also be handled by extending training.create_optim_in_bwd_wrapper to instead return this object.

Copy link
Contributor

@felipemello1 felipemello1 Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use it like this:

param: config.instantiate(cfg_optimizer, [param]) for param in model.parameters()
or
param: config.instantiate(cfg_optimizer, model.parameters())

we also pass a DictConfig to setup_optimizer

Are you thinking about something like below? Not sure if this is more intuitive / debuggable

self._optimizer = OptimizerWrapper(
            model=self._model,
            cfg_optimizer=partial(cfg_optimizer),
            optimizer_in_bwd=self._optimizer_in_bwd,
            opt_state_dict=(
                checkpoint_dict[training.OPT_KEY]
                if self._resume_from_checkpoint
                else None
            ),
        )

from torchtune.training.lr_schedulers import get_lr


class OptimizerWrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this inherit from torch optim so it maintains the same api as the the wrapped optimizer? It would also mean that opt_state_dict wouldn't be passed into the init.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think that inheriting would help. Even for regular optimizer, we still need it when resuming_from_checkpoint so we can call:

training.load_from_full_optimizer_state_dict(
                    model,
                    optimizer,
                    opt_state_dict,
                    self._device,
                )

from torchtune.training.lr_schedulers import get_lr


class OptimizerWrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name should be more specific. Maybe MemoryEfficientOptimizer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a wrapper for both optimizer and opt_in_bwd. Basically it is removing all the boilerplate from the recipe and putting into this function. So it is not for memory only. Its more of an abstraction

else:
return self._optimizer.state_dict()

def get_optimizer(self) -> tuple:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used?

else:
return get_lr(self._optimizer)

def set_learning_rate_scheduler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used?

else:
return self._optimizer

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be the same code as in init, should we just call it instead?

@felipemello1
Copy link
Contributor

felipemello1 commented Feb 13, 2025

Here is what i am understanding:

a) There are optimizer methods (zero grad, step, state_dict)
b) There are utilities / boiler plate that needs to call the optimizer differently based on opt_in_bwd (get_lr, load_state_dict, set_learning_rate_scheduler)

For (a), i feel like these should belong to training.create_optim_in_bwd_wrapper, so that it has the same API as other optimizers.

For (b), IMO your OptimizerWrapper makes sense: It removes from the recipe the burden of if/else, and puts it in a separate class that handles it, working for both regular optimizer and opt_in_bws

I just dont know if we should do something like: optimizer_wrapper.optimizer.zero_grad() or optimizer_wrapper.zero_grad()

@krammnic @pbontrager do you agree?

@krammnic , when you have time, can you also address the other parts of the recipe (lines 814, 825, 868: main/recipes/full_finetune_distributed.py#L814)

@krammnic
Copy link
Contributor Author

@felipemello1 Yep, reasonable comments! Let me fix then

@krammnic
Copy link
Contributor Author

Speaking about optimizer_wrapper.optimizer.zero_grad() or optimizer_wrapper.zero_grad() I assume that common .zero_grad() without .optimizer is less confusing.

@krammnic krammnic changed the title [WIP]: Get rid of optim_bwd checks via wrapper. [RFC]: Get rid of optim_bwd checks via wrapper. Feb 21, 2025
@krammnic
Copy link
Contributor Author

@pbontrager Hey Philip! I reformat this as an RFC and would love to hear your comments on desired design for this (except the changes the @felipemello1 already mentioned). I assume that such clean up should be done, but it mostly depends on chosen design.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants