Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support empty batches #530

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions opacus/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, Tuple, Type, Union

import torch
from opacus.utils.uniform_sampler import (
Expand All @@ -29,7 +29,10 @@


def wrap_collate_with_empty(
collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence
*,
collate_fn: Optional[_collate_fn_t],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
collate_fn: Optional[_collate_fn_t],
*,
collate_fn: Optional[_collate_fn_t],

sample_empty_shapes: Sequence[Tuple],
dtypes: Sequence[Union[torch.dtype, Type]],
):
"""
Wraps given collate function to handle empty batches.
Expand All @@ -49,12 +52,15 @@ def collate(batch):
if len(batch) > 0:
return collate_fn(batch)
else:
return [torch.zeros(x) for x in sample_empty_shapes]
return [
torch.zeros(shape, dtype=dtype)
for shape, dtype in zip(sample_empty_shapes, dtypes)
]

return collate


def shape_safe(x: Any):
def shape_safe(x: Any) -> Tuple:
"""
Exception-safe getter for ``shape`` attribute

Expand All @@ -67,6 +73,19 @@ def shape_safe(x: Any):
return x.shape if hasattr(x, "shape") else ()


def dtype_safe(x: Any) -> Union[torch.dtype, Type]:
"""
Exception-safe getter for ``dtype`` attribute

Args:
x: any object

Returns:
``x.dtype`` if attribute exists, type of x otherwise
"""
return x.dtype if hasattr(x, "dtype") else type(x)


class DPDataLoader(DataLoader):
"""
DataLoader subclass that always does Poisson sampling and supports empty batches
Expand Down Expand Up @@ -143,7 +162,8 @@ def __init__(
sample_rate=sample_rate,
generator=generator,
)
sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]]
sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]]
dtypes = [dtype_safe(x) for x in dataset[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is consumed by wrap_collate_with_empty, we need the dtype to be an actual dtype and not a type right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.zeros support normal python types (int/float/bool) just fine:

> torch.zeros((2,2), dtype=int)
tensor([[0, 0],
        [0, 0]])

> torch.zeros((2,2), dtype=float)  
tensor([[0., 0.],
        [0., 0.]], dtype=torch.float64)

> torch.zeros((2,2), dtype=bool)   
tensor([[False, False],
        [False, False]])

Copy link

@joserapa98 joserapa98 Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there! I've been following the discussion since I ran into the empty batches problem when using poisson sampling with small sampling rate. When I tested this implementation with my own project, it didn't work for me, since it does not support more complex labels. If labels are just numbers (int, float, bool), it's ok, but if you have maybe a label formed by a tuple of numbers, its type will be tuple, thus causing an error in torch.zeros.

I don't know if you expect this kind of things to be supported, but as they work in standard PyTorch (and in fact in Opacus when batches are not empty), maybe it is something worth to be aware of.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting point, thanks

Just thinking out loud, what would be a good way to support this? In case the original label is a tuple, how does collate function handles it - would it output multiple tensors per label?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a dataset in which each sample is of type, say, tuple(torch.Tensor, tuple(int, int)), the dataloader would return tuple(torch.Tensor, tuple(torch.Tensor, torch.Tensor)), where now each tensor has an extra dimension for the batch. Something similar would happen if labels are given as lists, dicts, etc.

This is the code snippet that manages this cases:

if isinstance(elem, collections.abc.Mapping):
    try:
        return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
    except TypeError:
        # The mapping type may not support `__init__(iterable)`.
        return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
    return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.


    if isinstance(elem, tuple):
        return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
    else:
        try:
            return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
        except TypeError:
            # The sequence type may not support `__init__(iterable)` (e.g., `range`).
            return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]

https://github.com/pytorch/pytorch/blob/8349bf1cd1d5df7be73b194940bcf96209159f40/torch/utils/data/_utils/collate.py#L126-L149

I guess a proper solution for supporting empty batches would be to recycle this code but returning tuples/dicts/lists/... of empty tensors with torch.zeros, so that the types are still preserved, though filled with empty tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes perfect sense, thanks! You're absolutely right this is the right way forward.

However, I don't see this as a blocker for landing this PR. This PR does solve the problem for the subset of cases and could be considered an atomic improvement. Not to delay merging it, I created an issue to track the proposed improvement (#534), hopefully someone will pick it up soon

if collate_fn is None:
collate_fn = default_collate

Expand All @@ -156,7 +176,11 @@ def __init__(
dataset=dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes),
collate_fn=wrap_collate_with_empty(
collate_fn=collate_fn,
sample_empty_shapes=sample_empty_shapes,
dtypes=dtypes,
),
pin_memory=pin_memory,
timeout=timeout,
worker_init_fn=worker_init_fn,
Expand Down
1 change: 1 addition & 0 deletions opacus/grad_sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Please note that these are known limitations and we plan to improve Expanded Wei
| `batch_first=False` | ✅ Supported | Not supported | ✅ Supported |
| Recurrent networks | ✅ Supported | Not supported | ✅ Supported |
| Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported |
| Empty poisson batches | ✅ Supported | Not supported | Not supported |

† Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size.
Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers.
Expand Down
8 changes: 8 additions & 0 deletions opacus/grad_sample/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def compute_conv_grad_sample(
"""
activations = activations[0]
n = activations.shape[0]
if n == 0:
# Empty batch
ret = {}
ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0)
return ret
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does conv need a special treatment? What happens with other layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's down to specific grad sampler implementation.
Most grad samplers we have rely on einsums, which are generally pretty good with handling 0-sized vectors.
With conv in particular the culprit is the following line:

backprops = backprops.reshape(n, -1, activations.shape[-1])

That said, it is a good point that we want to be sure all of the layers can handle it - which current tests only partially do. We have a PrivacyEngine test for an empty batch, but nothing on the grad sampler level - I'll check if there's an easy way to do it and update the PR


# get activations and backprops in shape depending on the Conv layer
if type(layer) == nn.Conv2d:
activations = unfold2d(
Expand Down
4 changes: 4 additions & 0 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def compute_embedding_grad_sample(
torch.backends.cudnn.deterministic = True

batch_size = activations.shape[0]
if batch_size == 0:
ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)
return ret

index = (
activations.unsqueeze(-1)
.expand(*activations.shape, layer.embedding_dim)
Expand Down
18 changes: 11 additions & 7 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,17 @@ def clip_and_accumulate(self):
Stores clipped and aggregated gradients into `p.summed_grad```
"""

per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
max=1.0
)
if len(self.grad_samples[0]) == 0:
# Empty batch
per_sample_clip_factor = torch.zeros((0,))
else:
per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (
self.max_grad_norm / (per_sample_norms + 1e-6)
).clamp(max=1.0)

for p in self.params:
_check_processed_flag(p.grad_sample)
Expand Down
88 changes: 78 additions & 10 deletions opacus/tests/batch_memory_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,37 +37,42 @@ class BatchMemoryManagerTest(unittest.TestCase):
GSM_MODE = "hooks"

def setUp(self) -> None:
self.data_size = 100
self.batch_size = 10
self.data_size = 256
self.inps = torch.randn(self.data_size, 5)
self.tgts = torch.randn(
self.data_size,
)

self.dataset = TensorDataset(self.inps, self.tgts)

def _init_training(self, **data_loader_kwargs):
def _init_training(self, batch_size=10, **data_loader_kwargs):
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data_loader = DataLoader(
self.dataset, batch_size=self.batch_size, **data_loader_kwargs
self.dataset, batch_size=batch_size, **data_loader_kwargs
)

return model, optimizer, data_loader

@given(
num_workers=st.integers(0, 4),
pin_memory=st.booleans(),
batch_size=st.sampled_from([8, 16, 64]),
max_physical_batch_size=st.sampled_from([4, 8]),
)
@settings(deadline=10000)
def test_basic(
self,
num_workers: int,
pin_memory: bool,
batch_size: int,
max_physical_batch_size: int,
):
batches_per_step = max(1, batch_size // max_physical_batch_size)
model, optimizer, data_loader = self._init_training(
num_workers=num_workers,
pin_memory=pin_memory,
batch_size=batch_size,
)

privacy_engine = PrivacyEngine()
Expand All @@ -80,22 +85,19 @@ def test_basic(
poisson_sampling=False,
grad_sample_mode=self.GSM_MODE,
)
max_physical_batch_size = 3
with BatchMemoryManager(
data_loader=data_loader,
max_physical_batch_size=max_physical_batch_size,
optimizer=optimizer,
) as new_data_loader:
self.assertEqual(
len(data_loader), len(data_loader.dataset) // self.batch_size
)
self.assertEqual(len(data_loader), len(data_loader.dataset) // batch_size)
self.assertEqual(
len(new_data_loader),
len(data_loader.dataset) // max_physical_batch_size,
)
weights_before = torch.clone(model._module.fc.weight)
for i, (x, y) in enumerate(new_data_loader):
self.assertTrue(x.shape[0] <= 3)
self.assertTrue(x.shape[0] <= max_physical_batch_size)

out = model(x)
loss = (y - out).mean()
Expand All @@ -104,7 +106,63 @@ def test_basic(
optimizer.step()
optimizer.zero_grad()

if i % 4 < 3:
if (i + 1) % batches_per_step > 0:
self.assertTrue(
torch.allclose(model._module.fc.weight, weights_before)
)
else:
self.assertFalse(
torch.allclose(model._module.fc.weight, weights_before)
)
weights_before = torch.clone(model._module.fc.weight)

@given(
num_workers=st.integers(0, 4),
pin_memory=st.booleans(),
)
@settings(deadline=10000)
def test_empty_batch(
self,
num_workers: int,
pin_memory: bool,
):
batch_size = 2
max_physical_batch_size = 10
torch.manual_seed(30)

model, optimizer, data_loader = self._init_training(
num_workers=num_workers,
pin_memory=pin_memory,
batch_size=batch_size,
)

privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=data_loader,
noise_multiplier=0.0,
max_grad_norm=1e5,
poisson_sampling=True,
grad_sample_mode=self.GSM_MODE,
)
with BatchMemoryManager(
data_loader=data_loader,
max_physical_batch_size=max_physical_batch_size,
optimizer=optimizer,
) as new_data_loader:
weights_before = torch.clone(model._module.fc.weight)
for i, (x, y) in enumerate(new_data_loader):
self.assertTrue(x.shape[0] <= max_physical_batch_size)

out = model(x)
loss = (y - out).mean()

loss.backward()
optimizer.step()
optimizer.zero_grad()

if len(x) == 0:
self.assertTrue(
torch.allclose(model._module.fc.weight, weights_before)
)
Expand Down Expand Up @@ -174,3 +232,13 @@ def test_equivalent_to_one_batch(self):
)
class BatchMemoryManagerTestWithExpandedWeights(BatchMemoryManagerTest):
GSM_MODE = "ew"

def test_empty_batch(self):
pass


@unittest.skipIf(
torch.__version__ >= API_CUTOFF_VERSION, "not supported in this torch version"
)
class BatchMemoryManagerTestWithFunctorch(BatchMemoryManagerTest):
GSM_MODE = "functorch"
36 changes: 24 additions & 12 deletions opacus/tests/grad_samples/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import io
import unittest
from typing import Dict, List, Union
from typing import Dict, Iterable, List, Tuple, Union

import numpy as np
import torch
Expand All @@ -36,6 +36,13 @@ def shrinker(x, factor: int = 2):
return max(1, x // factor) # if avoid returning 0 for x == 1


def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]):
if type(batch) is torch.Tensor:
return batch.numel() == 0
else:
return batch[0].numel() == 0


class ModelWithLoss(nn.Module):
"""
To test the gradients of a module, we need to have a loss.
Expand Down Expand Up @@ -221,7 +228,7 @@ def compute_opacus_grad_sample(

def run_test(
self,
x: Union[torch.Tensor, PackedSequence],
x: Union[torch.Tensor, PackedSequence, Tuple],
module: nn.Module,
batch_first=True,
atol=10e-6,
Expand All @@ -235,7 +242,9 @@ def run_test(
except ImportError:
grad_sample_modes = ["hooks"]

if type(module) is nn.EmbeddingBag:
if type(module) is nn.EmbeddingBag or (
type(x) is not PackedSequence and is_batch_empty(x)
):
grad_sample_modes = ["hooks"]

for grad_sample_mode in grad_sample_modes:
Expand Down Expand Up @@ -277,6 +286,14 @@ def run_test_with_reduction(
grad_sample_mode="hooks",
chunk_method=iter,
):
opacus_grad_samples = self.compute_opacus_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)

if type(x) is PackedSequence:
x_unpacked = _unpack_packedsequences(x)
microbatch_grad_samples = self.compute_microbatch_grad_sample(
Expand All @@ -285,22 +302,17 @@ def run_test_with_reduction(
batch_first=batch_first,
loss_reduction=loss_reduction,
)
else:
elif not is_batch_empty(x):
microbatch_grad_samples = self.compute_microbatch_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
chunk_method=chunk_method,
)

opacus_grad_samples = self.compute_opacus_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)
else:
# We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense
return

if microbatch_grad_samples.keys() != opacus_grad_samples.keys():
raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions opacus/tests/grad_samples/conv1d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class Conv1d_test(GradSampleHooks_test):
@given(
N=st.integers(1, 4),
N=st.integers(0, 4),
C=st.sampled_from([1, 3, 32]),
W=st.integers(6, 10),
out_channels_mapper=st.sampled_from([expander, shrinker]),
Expand Down Expand Up @@ -67,4 +67,6 @@ def test_conv1d(
dilation=dilation,
groups=groups,
)
self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
self.run_test(
x, conv, batch_first=True, atol=10e-5, rtol=10e-4, ew_compatible=N > 0
)
Loading