Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve _optim_ckpt_wrapper so it is a drop in replacement of optimizer #2052

Open
felipemello1 opened this issue Nov 22, 2024 · 4 comments
Open
Labels
best practice Things we should be doing but aren't

Comments

@felipemello1
Copy link
Contributor

felipemello1 commented Nov 22, 2024

Our recipes are cluttered with logic that checks "if optim_in_bwd".

With a bit of engineering, we can make it a drop in replacement of optimizer, and avoid code like this:

if not self._optimizer_in_bwd:
    self._optimizer.zero_grad()
else:
    for opt in self._optim_ckpt_wrapper.optim_map.values():
        opt.zero_grad()

That can be replaced with:

class MyOptWrapper:
	def __init__(self, optimizers):
		self.optimizers = optimizers

	def zero_grad():
		for opt in self.optimizers.optim_map.values():
        	opt.zero_grad()

optimizer = MyOptWrapper(optimizers)
optimizer.zero_grad()

It may break things from time to time, but good testing should avoid errors hitting prod. For overly complex situations, e.g. checkpointing, we can still do if/else, but we definitely don't need every if/else that we have today: A total of 8.

@felipemello1 felipemello1 added best practice Things we should be doing but aren't community help wanted We would love the community's help completing this issue labels Nov 22, 2024
@felipemello1 felipemello1 changed the title improve _optim_ckpt_wrapper to its a drop in replacement of optimizer improve _optim_ckpt_wrapper so it is a drop in replacement of optimizer Nov 22, 2024
@RdoubleA
Copy link
Contributor

I've thought about this approach and I do like that it cleans up the recipe. But it adds some indirection and forces users to have to learn what the wrapper even does, and wouldn't it require all optimizers to be wrapped in this regardless of if they're using optimizer in bwd or not?

@felipemello1
Copy link
Contributor Author

felipemello1 commented Nov 22, 2024

wouldn't it require all optimizers to be wrapped in this regardless of if they're using optimizer in bwd or not?

I dont think so. The implementation would be something like this:

optimizer = config.instantiate(my_opt)
if opt_in_bwd:
	optimizer = MyOptWrapper(optimizer)

The idea is that, for example, when you call optimizer.zero_grad(), it doesn't need to know if its the optimizer or the wrapper, because the wrapper behaves like the optimizer.

@krammnic
Copy link
Contributor

So the point is to avoid all conditions about opt_in_bwd except checkpointing one?

@felipemello1
Copy link
Contributor Author

@krammnic, yes! It litters our recipe in many places. It would make the code cleaner and easier to maintain.

@felipemello1 felipemello1 removed the community help wanted We would love the community's help completing this issue label Feb 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
best practice Things we should be doing but aren't
Projects
None yet
Development

No branches or pull requests

3 participants