Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed custom BatchSamplers when instantiated in *_dataloader hook #13640

Merged
merged 38 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
848e82b
generalize replace_init_method
Jul 13, 2022
8c2c45a
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 13, 2022
6e01b07
custom batch sampler support
Jul 13, 2022
d5cee5a
changelog
Jul 13, 2022
834e98c
apply suggestions from code review
Jul 14, 2022
e7831c1
change docstring
Jul 14, 2022
4268323
code review suggestions
Jul 14, 2022
0d2475a
types
Jul 14, 2022
589e49f
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 14, 2022
cb50725
merge master
Jul 14, 2022
b37e638
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 14, 2022
db0f534
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 14, 2022
ed44a76
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 14, 2022
9e106e1
code review suggestion
Jul 14, 2022
23c703a
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 18, 2022
7557870
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 19, 2022
812d69b
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 20, 2022
740471d
Merge branch 'master' into bugfix/custom_batch_sampler
otaj Jul 22, 2022
5e37868
Apply suggestions from code review
otaj Jul 22, 2022
309298a
Apply suggestions from code review
otaj Jul 25, 2022
230dda8
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 25, 2022
04f1921
make test not fail
Jul 25, 2022
88764c0
comment the tests
Jul 25, 2022
b12a7c5
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 25, 2022
67bb560
Merge branch 'master' into bugfix/custom_batch_sampler
otaj Jul 26, 2022
e996117
pass mode in ipu
Jul 26, 2022
5a6513b
batch_size is None when batch_sampler is set
Jul 26, 2022
4314fa3
don't pass anything more than necessary
Jul 26, 2022
a0bb6ff
test ipu failing test
Jul 26, 2022
7dc1cf4
return to poptorch dataloader
Jul 26, 2022
7d41ba9
Merge branch 'master' into bugfix/custom_batch_sampler
Jul 27, 2022
3c6f730
merge master
Jul 27, 2022
305730d
added a bit of docstring
Jul 27, 2022
5a51c5a
pprint debugging
Jul 27, 2022
ab20f5f
pprint debugging
Jul 27, 2022
ff14329
Default kwargs handling
Jul 27, 2022
5fffd33
Revert "pprint debugging"
Jul 27, 2022
b3d6075
Revert "pprint debugging"
Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))

- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640))


- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))

Expand Down
8 changes: 5 additions & 3 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
Expand All @@ -35,7 +35,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_replace_init_method,
_update_dataloader,
has_iterable_dataset,
)
Expand Down Expand Up @@ -409,7 +409,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
BatchSampler
):
return run_method(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
from weakref import proxy

from torch.utils.data import DataLoader, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
Expand All @@ -31,7 +31,7 @@
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
_replace_dataloader_init_method,
_replace_init_method,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
Expand Down Expand Up @@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
"""
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")

with _replace_dataloader_init_method():
with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
dataloader = source.dataloader()
Expand Down
14 changes: 1 addition & 13 deletions src/pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union

from torch.utils.data import (
BatchSampler,
Dataset,
DistributedSampler,
get_worker_info,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
Expand Down Expand Up @@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None:
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")

batch_sampler = getattr(dataloader, "batch_sampler", None)
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")

if type(sampler) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler) is RandomSampler:
Expand Down
145 changes: 110 additions & 35 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -217,11 +217,11 @@ def _get_dataloader_init_args_and_kwargs(
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

was_wrapped = hasattr(dataloader, "__pl_dl_args")
was_wrapped = hasattr(dataloader, "__pl_saved_args")
if was_wrapped:
dl_args = dataloader.__pl_dl_args
dl_kwargs = dataloader.__pl_dl_kwargs
arg_names = dataloader.__pl_dl_arg_names
dl_args = dataloader.__pl_saved_args
dl_kwargs = dataloader.__pl_saved_kwargs
arg_names = dataloader.__pl_saved_arg_names
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
else:
# get the dataloader instance attributes
Expand Down Expand Up @@ -322,12 +322,54 @@ def _dataloader_init_kwargs_resolve_sampler(
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
batch_sampler_cls = type(batch_sampler)
if batch_sampler is not None and (batch_sampler_cls is not BatchSampler or is_predicting):
if hasattr(batch_sampler, "__pl_saved_args"):
args = list(batch_sampler.__pl_saved_args)
kwargs = batch_sampler.__pl_saved_kwargs
arg_names = batch_sampler.__pl_saved_arg_names

if is_predicting:
success, args, kwargs = _replace_value_in_saved_args("drop_last", False, args, kwargs, arg_names)
if not success:
rank_zero_warn(
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however it "
f"seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in the `__init__` "
"method of your custom class."
)

success, args, kwargs = _replace_value_in_saved_args("sampler", sampler, args, kwargs, arg_names)
if not success:
raise MisconfigurationException(
"Trying to inject modified sampler into batch sampler, however it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not support argument called sampler. To mitigate this, "
"expose argument `sampler` in the `__init__` method of your custom class."
)

batch_sampler = batch_sampler_cls(*args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
except TypeError as e:
import re

match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise

# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be thrown from the original one.
raise MisconfigurationException(
"We tried to reinstantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from e
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

Expand All @@ -350,17 +392,36 @@ def _dataloader_init_kwargs_resolve_sampler(
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}


def _replace_value_in_saved_args(
replace_key: str, replace_value: Any, args: List[Any], kwargs: Dict[str, Any], arg_names: List[str]
) -> Tuple[bool, List[Any], Dict[str, Any]]:
"""Tries to replace an argument value in a saved list of args and kwargs.

Returns a tuple indicating success of the operation and modified saved args and kwargs
"""

if replace_key in arg_names:
replace_index = arg_names.index(replace_key)
args = args[:replace_index] + [replace_value] + args[replace_index + 1 :]
return True, args, kwargs
elif replace_key in kwargs:
kwargs[replace_key] = replace_value
return True, args, kwargs

return False, args, kwargs


def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _wrap_dataloader_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
of custom subclasses."""
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
params = inspect.signature(init).parameters
Expand All @@ -371,18 +432,30 @@ def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
)
param_names = param_names[: len(args)]

if not hasattr(obj, "__pl_dl_args"):
obj.__pl_dl_args = args
obj.__pl_dl_kwargs = kwargs
obj.__pl_dl_arg_names = param_names
default_kwargs = {
param.name: param.default
for param in params.values()
if param.name != "self"
and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
and param.default != param.empty
and (param.name not in kwargs and param.name not in param_names)
}

kwargs = {**kwargs, **default_kwargs}

if not hasattr(obj, "__pl_saved_args"):
obj.__pl_saved_args = args
obj.__pl_saved_kwargs = kwargs
obj.__pl_saved_arg_names = param_names

# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if "dataset" in param_names:
setattr(obj, "__dataset", args[param_names.index("dataset")])
elif "dataset" in kwargs:
setattr(obj, "__dataset", kwargs["dataset"])
if store_explicit_arg is not None:
if store_explicit_arg in param_names:
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
elif store_explicit_arg in kwargs:
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])

init(obj, *args, **kwargs)

Expand All @@ -404,15 +477,17 @@ def recurse(cl: Type[Any]) -> None:


@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
classes = _get_all_subclasses(DataLoader) | {DataLoader}
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.

It patches the ``__init__`` method.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
cls._old_init = cls.__init__
cls.__init__ = _wrap_dataloader_init(cls.__init__)
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
wrapped.add(cls.__init__)
yield
for cls in classes:
Expand Down Expand Up @@ -457,13 +532,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(


def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "__pl_dl_kwargs"):
if hasattr(dataloader, "__pl_saved_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_dataloader_init_method` context manager
if "shuffle" in dataloader.__pl_dl_kwargs:
return dataloader.__pl_dl_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_dl_arg_names:
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
# our `_replace_init_method` context manager
if "shuffle" in dataloader.__pl_saved_kwargs:
return dataloader.__pl_saved_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_saved_arg_names:
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
if isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False
Expand Down
7 changes: 4 additions & 3 deletions tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,17 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""

class Lite(LightningLite):
def run(self):
ctx_manager().__enter__.assert_called_once()
# One for BatchSampler, another for DataLoader
assert ctx_manager().__enter__.call_count == 2

Lite().run()
ctx_manager().__exit__.assert_called_once()
assert ctx_manager().__exit__.call_count == 2


def test_setup_dataloaders_raises_for_unknown_custom_args():
Expand Down
10 changes: 0 additions & 10 deletions tests/tests_pytorch/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torch.utils.data._utils.worker import _generate_state, get_worker_info
from torch.utils.data.dataloader import DataLoader, default_collate
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.sampler import Sampler

import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
Expand Down Expand Up @@ -1186,15 +1185,6 @@ class CustomRandomSampler(RandomSampler):
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)

class CustomBatchSampler(BatchSampler):
pass

sampler = Sampler(data())
batch_sampler = CustomBatchSampler(sampler, 2, False)
dl = DataLoader(data(), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match="BatchSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)

class CustomIterable(IterableDataset):
pass

Expand Down
Loading