Skip to content

Commit

Permalink
[warning] Add warning when values are not being reduced (#6417)
Browse files Browse the repository at this point in the history
* add warning non reduced

* add test

* update test

* update changelog

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Justus Schock <[email protected]>

* update

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
3 people authored Mar 26, 2021
1 parent 21fc5eb commit 0e45220
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))


- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,12 @@ def rename_keys(self, map_dict: dict):
meta[dest] = meta[source]
del meta[source]

def get_non_metrics_keys(self):
"""
This function is used to filter metric keys for which the value isn't a Metric
"""
return [k for k, v in self.items() if not isinstance(v, Metric)]


def choose_last(x):
if isinstance(x, (torch.Tensor, list)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from weakref import proxy

import torch
Expand All @@ -21,6 +22,19 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import DistributedType, LightningEnum
from pytorch_lightning.utilities.warnings import WarningCache

log = logging.getLogger(__name__)


class MetricWarningCache(WarningCache):

def __init__(self):
super().__init__()
self.warned_metrics = []


warning_cache = MetricWarningCache()


class ResultStoreType(LightningEnum):
Expand Down Expand Up @@ -52,8 +66,10 @@ class HookResultStore:
Those data structures enables us to reduce properly Result object when batch loop is finished.
"""

def __init__(self, fx_name: str) -> None:
def __init__(self, fx_name: str, all_gather_fn: Callable, should_warn: bool) -> None:
self._fx_name = fx_name
self._all_gather_fn = all_gather_fn
self._should_warn = should_warn
self._internals = {}
self._internals_reduced = {}
self._internal_type = None
Expand Down Expand Up @@ -109,6 +125,20 @@ def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> Non

func = getattr(opt_metric, func_name)
metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs)
if self._should_warn:
for non_metric_key in opt_metric.get_non_metrics_keys():
if non_metric_key in metrics_to_log and non_metric_key not in warning_cache.warned_metrics:
metric = self._all_gather_fn(metrics_to_log[non_metric_key])
if any(metric[0] != m for m in metric[1:]):
warning_cache.warn(
f"The value associated to the key {non_metric_key}: {metric.cpu().tolist()} "
"doesn't appear to be the same accross all processes. "
"HINT: One could either do: `self.log(..., sync_dist=True, sync_fn=torch.mean)`"
" to force mean reduction across processes which can be inaccurate or implement"
" a `torchmetrics.Metric`"
)
warning_cache.warned_metrics.append(non_metric_key)

results.append(metrics_to_log)

def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
Expand Down Expand Up @@ -227,6 +257,12 @@ class EpochResultStore:

def __init__(self, trainer: 'pl.Trainer') -> None:
self.trainer = proxy(trainer)

# Add warning only for distributed (expect rpc as main worker is running the code).
_should_warn = trainer.accelerator_connector.is_distributed
_should_warn &= not trainer.training_type_plugin.rpc_enabled
self._should_warn = _should_warn

self.reset()

def __getitem__(self, key: str) -> Any:
Expand Down Expand Up @@ -278,7 +314,8 @@ def cache_result(self) -> None:
info = self.info
fx_name = info["fx_name"]

self._internals.setdefault(fx_name, HookResultStore(fx_name))
all_gather_fn = self.trainer.lightning_module.all_gather
self._internals.setdefault(fx_name, HookResultStore(fx_name, all_gather_fn, self._should_warn))

# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)
Expand Down
7 changes: 6 additions & 1 deletion tests/trainer/logging_/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ class TestLoggingSyncDistModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
self.log('cho', acc, on_step=False, on_epoch=True)
return acc

def validation_step(self, batch, batch_idx):
Expand All @@ -763,8 +764,12 @@ def validation_step(self, batch, batch_idx):
gpus=2,
profiler="pytorch"
)
trainer.fit(model)

if os.getenv("LOCAL_RANK") == '0':
with pytest.warns(UserWarning, match="The value associated to the key cho:"):
trainer.fit(model)
else:
trainer.fit(model)
assert trainer.logged_metrics['foo'] == 2
assert trainer.logged_metrics['bar'] == 2

Expand Down

0 comments on commit 0e45220

Please sign in to comment.