Skip to content

Commit

Permalink
Fix mypy errors attributed to `pytorch_lightning. strategies.sharded_…
Browse files Browse the repository at this point in the history
…spawn` (#14102)


Co-authored-by: rohitgr7 <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
4 people authored Aug 11, 2022
1 parent 31ecf9b commit e53c4e8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ module = [
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
trainer = pl_module._trainer

if trainer is not None:
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
Expand Down
14 changes: 9 additions & 5 deletions src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
Expand All @@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):

def configure_ddp(self) -> None:
# set up optimizers after the wrapped module has been moved to the device
assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
)
Expand All @@ -69,12 +72,13 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
assert self.lightning_module
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
Expand All @@ -93,7 +97,7 @@ def block_backward_sync(self) -> Generator:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer):
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
Expand All @@ -112,7 +116,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
def pre_backward(self, closure_loss: Tensor) -> None:
pass

def post_training_step(self):
def post_training_step(self) -> None:
pass

@classmethod
Expand Down

0 comments on commit e53c4e8

Please sign in to comment.