Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] [Docs] Add section about device placement #5280

Merged
merged 15 commits into from
Dec 29, 2020
50 changes: 50 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,56 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*******************
Metrics and devices
*******************

Metrics are simple subclasses of ``torch.nn.Module`` and their metric states behave
similar to buffers and parameters of modules. This means that metrics states should
be moved to the same device as the input of the metric:

.. code-block:: python

import torch
from pytorch_lightning.metrics import Accuracy

target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))

# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0

However, when **proper defined** inside an ``LightningModule``, lightning will
automatically move the metrics to the same device as the data. By **proper defined**
it is meant that the metric is correctly identified as a child module of the model
(check ``.children()`` attribute of the model). Therefore, metrics cannot be placed
in native python ``list`` and ``dict``, as they will not be correctly identified
as a child modules. Instead of ``list`` use ``torch.nn.ModuleList`` and instead of
``dict`` use ``torch.nn.ModuleDict``.

.. code-block::

class MyModule(pl.LightningModule):
def __init__(self):
...
# valid ways metrics will be identified as child modules
self.metric1 = pl.metrics.Accuracy()
self.metric2 = torch.nn.ModuleList(pl.metrics.Accuracy())
self.metric3 = torch.nn.ModuleDict({'accuracy': Accuracy()})

def training_step(self, batch, batch_idx):
# all metrics will be on the same device as the input batch
data, target = batch
preds = self(data)
...
val1 = self.metric1(preds, target)
val2 = self.metric2[0](preds, target)
val3 = self.metric3['accuracy'](preds, target)


*********************
Implementing a Metric
*********************
Expand Down