diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 387cbc3bd7482..d6d9cb8fb0ae7 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -137,6 +137,56 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us To change this, after initializing the metric, the method ``.persistent(mode)`` can be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. +******************* +Metrics and devices +******************* + +Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave +similar to buffers and parameters of modules. This means that metrics states should +be moved to the same device as the input of the metric: + +.. code-block:: python + + import torch + from pytorch_lightning.metrics import Accuracy + + target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0)) + preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0)) + + # Metric states are always initialized on cpu, and needs to be moved to + # the correct device + confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0)) + out = confmat(preds, target) + print(out.device) # cuda:0 + +However, when **properly defined** inside a :class:`~pytorch_lightning.core.lightning.LightningModule` +, Lightning will automatically move the metrics to the same device as the data. Being +**properly defined** means that the metric is correctly identified as a child module of the +model (check ``.children()`` attribute of the model). Therefore, metrics cannot be placed +in native python ``list`` and ``dict``, as they will not be correctly identified +as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and instead of +``dict`` use :class:`~torch.nn.ModuleDict`. + +.. testcode:: + + class MyModule(LightningModule): + def __init__(self): + ... + # valid ways metrics will be identified as child modules + self.metric1 = pl.metrics.Accuracy() + self.metric2 = torch.nn.ModuleList(pl.metrics.Accuracy()) + self.metric3 = torch.nn.ModuleDict({'accuracy': Accuracy()}) + + def training_step(self, batch, batch_idx): + # all metrics will be on the same device as the input batch + data, target = batch + preds = self(data) + ... + val1 = self.metric1(preds, target) + val2 = self.metric2[0](preds, target) + val3 = self.metric3['accuracy'](preds, target) + + ********************* Implementing a Metric *********************