Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove lr_scheduler requirement in lora_dpo_single_device #1991

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def setup(self, cfg: DictConfig) -> None:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be less indirection to just check directly if lr_scheduler exists and set to None here, instead of calling the setup method only to return None:

cfg_lr_scheduler = cfg.get("lr_scheduler", None)
self._lr_scheduler = self._setup_lr_scheduler(...) if cfg_lr_scheduler is not None else None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me, Rafi, but upon second thought, i like the idea of handling everything inside of the setup_lr_scheduler, including the log_info. What do you think?

Copy link
Contributor

@RdoubleA RdoubleA Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, either works, no strong opinions

num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
Expand Down Expand Up @@ -325,10 +325,16 @@ def _setup_optimizer(

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
) -> Optional[Optimizer]:
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
Expand Down Expand Up @@ -543,7 +549,9 @@ def train(self) -> None:
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()

if self._lr_scheduler is not None:
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1

Expand Down
Loading