diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5dfd871f933f5..8852367a116f6 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041)) +- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938)) + + - Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051)) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 20a3dcc3f0f26..6650bb3f0c479 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,7 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Any, Callable, cast, List, Optional, Union +from typing import Any, Callable, cast, Dict, List, Optional, Union import torch from torch import nn, Tensor @@ -24,6 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig @@ -112,15 +113,22 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + self.n_averaged: Optional[torch.Tensor] = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._max_epochs: int - self._model_contains_batch_norm: bool + self._model_contains_batch_norm: Optional[bool] = None self._average_model: "pl.LightningModule" + self._initialized = False + self._swa_scheduler: Optional[_LRScheduler] = None + self._scheduler_state: Optional[Dict] = None + self._init_n_averaged = 0 + self._latest_update_epoch = -1 + self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None + self._max_epochs: int @property def swa_start(self) -> int: @@ -147,6 +155,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)): + raise MisconfigurationException("SWA does not currently support sharded models.") + if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) @@ -158,8 +169,13 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs += 1 + if self._scheduler_state is not None: + self._clear_schedulers(trainer) + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if trainer.current_epoch == self.swa_start: + if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): + self._initialized = True + # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) @@ -180,6 +196,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ), ) + if self._scheduler_state is not None: + # Restore scheduler state from checkpoint + self._swa_scheduler.load_state_dict(self._scheduler_state) + elif trainer.current_epoch != self.swa_start: + # Log a warning if we're initializing after start without any checkpoint data, + # as behaviour will be different compared to having checkpoint data. + rank_zero_warn( + "SWA is initializing after swa_start without any checkpoint data. " + "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." + ) + # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 @@ -196,14 +223,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo else: trainer.lr_scheduler_configs.append(default_scheduler_cfg) - self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) + if self.n_averaged is None: + self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) - if self.swa_start <= trainer.current_epoch <= self.swa_end: + if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( + trainer.current_epoch > self._latest_update_epoch + ): + assert self.n_averaged is not None self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) + self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: - # Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) @@ -265,6 +296,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" + assert self.momenta is not None for bn_module in self.momenta: bn_module.momentum = self.momenta[bn_module] @@ -285,3 +317,35 @@ def update_parameters( def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) + + def state_dict(self) -> Dict[str, Any]: + return { + "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), + "latest_update_epoch": self._latest_update_epoch, + "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), + "average_model_state": None if self._average_model is None else self._average_model.state_dict(), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._init_n_averaged = state_dict["n_averaged"] + self._latest_update_epoch = state_dict["latest_update_epoch"] + self._scheduler_state = state_dict["scheduler_state"] + self._load_average_model_state(state_dict["average_model_state"]) + + @staticmethod + def _clear_schedulers(trainer: "pl.Trainer") -> None: + # If we have scheduler state saved, clear the scheduler configs so that we don't try to + # load state into the wrong type of schedulers when restoring scheduler checkpoint state. + # We'll configure the scheduler and re-load its state in on_train_epoch_start. + # Note that this relies on the callback state being restored before the scheduler state is + # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of + # writing that is only True for deepspeed which is already not supported by SWA. + # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background. + if trainer.lr_scheduler_configs: + assert len(trainer.lr_scheduler_configs) == 1 + trainer.lr_scheduler_configs.clear() + + def _load_average_model_state(self, model_state: Any) -> None: + if self._average_model is None: + return + self._average_model.load_state_dict(model_state) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 859cf2fa98c0c..65a0fea2fb4a5 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os +from pathlib import Path +from typing import ContextManager, Optional from unittest import mock import pytest import torch from torch import nn +from torch.optim.lr_scheduler import LambdaLR from torch.optim.swa_utils import SWALR from torch.utils.data import DataLoader @@ -30,7 +34,9 @@ class SwaTestModel(BoringModel): - def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False): + def __init__( + self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None + ): super().__init__() layers = [nn.Linear(32, 32)] if batchnorm: @@ -39,17 +45,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat self.layer = nn.Sequential(*layers) self.interval = interval self.iterable_dataset = iterable_dataset + self.crash_on_epoch = crash_on_epoch def training_step(self, batch, batch_idx): + if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: + raise Exception("SWA crash test") output = self.forward(batch) loss = self.loss(batch, output) return {"loss": loss} def train_dataloader(self): - dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset dset = dset_cls(32, 64) - return DataLoader(dset, batch_size=2) def configure_optimizers(self): @@ -66,6 +73,8 @@ def configure_optimizers(self): class SwaTestCallback(StochasticWeightAveraging): update_parameters_calls: int = 0 transfer_weights_calls: int = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0 + first_epoch: Optional[int] = None def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 @@ -77,6 +86,11 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) + if self.first_epoch is None and not trainer.fit_loop.restarting: + # since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will + # not update the model and just call the epoch-level hooks, for that reason, we check that we are not + # restarting before choosing the first epoch + self.first_epoch = trainer.current_epoch assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR) @@ -88,6 +102,7 @@ def on_train_epoch_end(self, trainer, *args): if self.swa_start <= trainer.current_epoch <= self.swa_end: swa_epoch = trainer.current_epoch - self.swa_start assert self.n_averaged == swa_epoch + 1 + assert self._swa_scheduler is not None # Scheduler is stepped once on initialization and then at the end of each epoch assert self._swa_scheduler._step_count == swa_epoch + 2 elif trainer.current_epoch > self.swa_end: @@ -103,10 +118,13 @@ def on_train_end(self, trainer, pl_module): if not isinstance(trainer.strategy, DDPSpawnStrategy): # check backward call count. the batchnorm update epoch should not backward - assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches + assert trainer.strategy.backward.call_count == ( + (trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches + ) # check call counts - assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) + first_swa_epoch = max(self.first_epoch, self.swa_start) + assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch assert self.transfer_weights_calls == 1 @@ -140,7 +158,7 @@ def train_with_swa( devices=devices, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward): + with _backward_patch(trainer): trainer.fit(model) # check the model is the expected @@ -226,9 +244,10 @@ def test_swa_multiple_lrs(tmpdir): class TestModel(BoringModel): def __init__(self): - super(BoringModel, self).__init__() + super().__init__() self.layer1 = torch.nn.Linear(32, 32) self.layer2 = torch.nn.Linear(32, 2) + self.on_train_epoch_start_called = False def forward(self, x): x = self.layer1(x) @@ -255,3 +274,98 @@ def on_train_epoch_start(self): ) trainer.fit(model) assert model.on_train_epoch_start_called + + +def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False): + swa_start = 3 + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 5, + "accelerator": "cpu", + "strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None, + "devices": 2 if ddp else 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accumulate_grad_batches": 2, + "enable_progress_bar": False, + } + trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) + + with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"): + trainer.fit(model) + + checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" + checkpoint_files = os.listdir(checkpoint_dir) + assert len(checkpoint_files) == 1 + ckpt_path = str(checkpoint_dir / checkpoint_files[0]) + + trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) + + with _backward_patch(trainer): + trainer.fit(resume_model, ckpt_path=ckpt_path) + + +class CustomSchedulerModel(SwaTestModel): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def lr_lambda(current_step: int): + return 0.1 + + scheduler = LambdaLR(optimizer, lr_lambda, -1) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.interval, + }, + } + + +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch): + model = SwaTestModel(crash_on_epoch=crash_on_epoch) + resume_model = SwaTestModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) + + +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch): + # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 + model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch) + resume_model = CustomSchedulerModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) + + +@RunIf(skip_windows=True) +def test_swa_resume_training_from_checkpoint_ddp(tmpdir): + model = SwaTestModel(crash_on_epoch=3) + resume_model = SwaTestModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)), + pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), + ], +) +def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str): + model = SwaTestModel() + swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=5, + callbacks=[swa_callback], + strategy=strategy, + accelerator="gpu", + devices=1, + ) + with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"): + trainer.fit(model) + + +def _backward_patch(trainer: Trainer) -> ContextManager: + return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)