-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support empty batches #530
Changes from all commits
b8a2ff7
2e1b9d7
df9d1ab
0268fa1
b952c2a
5c7fc6f
64f08ad
df7c355
338097c
46922c9
8b11967
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
# limitations under the License. | ||
|
||
import logging | ||
from typing import Any, Optional, Sequence | ||
from typing import Any, Optional, Sequence, Tuple, Type, Union | ||
|
||
import torch | ||
from opacus.utils.uniform_sampler import ( | ||
|
@@ -29,7 +29,10 @@ | |
|
||
|
||
def wrap_collate_with_empty( | ||
collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence | ||
*, | ||
collate_fn: Optional[_collate_fn_t], | ||
sample_empty_shapes: Sequence[Tuple], | ||
dtypes: Sequence[Union[torch.dtype, Type]], | ||
): | ||
""" | ||
Wraps given collate function to handle empty batches. | ||
|
@@ -49,12 +52,15 @@ def collate(batch): | |
if len(batch) > 0: | ||
return collate_fn(batch) | ||
else: | ||
return [torch.zeros(x) for x in sample_empty_shapes] | ||
return [ | ||
torch.zeros(shape, dtype=dtype) | ||
for shape, dtype in zip(sample_empty_shapes, dtypes) | ||
] | ||
|
||
return collate | ||
|
||
|
||
def shape_safe(x: Any): | ||
def shape_safe(x: Any) -> Tuple: | ||
""" | ||
Exception-safe getter for ``shape`` attribute | ||
|
||
|
@@ -67,6 +73,19 @@ def shape_safe(x: Any): | |
return x.shape if hasattr(x, "shape") else () | ||
|
||
|
||
def dtype_safe(x: Any) -> Union[torch.dtype, Type]: | ||
""" | ||
Exception-safe getter for ``dtype`` attribute | ||
|
||
Args: | ||
x: any object | ||
|
||
Returns: | ||
``x.dtype`` if attribute exists, type of x otherwise | ||
""" | ||
return x.dtype if hasattr(x, "dtype") else type(x) | ||
|
||
|
||
class DPDataLoader(DataLoader): | ||
""" | ||
DataLoader subclass that always does Poisson sampling and supports empty batches | ||
|
@@ -143,7 +162,8 @@ def __init__( | |
sample_rate=sample_rate, | ||
generator=generator, | ||
) | ||
sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]] | ||
sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]] | ||
dtypes = [dtype_safe(x) for x in dataset[0]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is consumed by wrap_collate_with_empty, we need the dtype to be an actual There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.zeros support normal python types (int/float/bool) just fine:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi there! I've been following the discussion since I ran into the empty batches problem when using poisson sampling with small sampling rate. When I tested this implementation with my own project, it didn't work for me, since it does not support more complex labels. If labels are just numbers ( I don't know if you expect this kind of things to be supported, but as they work in standard PyTorch (and in fact in Opacus when batches are not empty), maybe it is something worth to be aware of. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an interesting point, thanks Just thinking out loud, what would be a good way to support this? In case the original label is a tuple, how does collate function handles it - would it output multiple tensors per label? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have a dataset in which each sample is of type, say, This is the code snippet that manages this cases:
I guess a proper solution for supporting empty batches would be to recycle this code but returning tuples/dicts/lists/... of empty tensors with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes perfect sense, thanks! You're absolutely right this is the right way forward. However, I don't see this as a blocker for landing this PR. This PR does solve the problem for the subset of cases and could be considered an atomic improvement. Not to delay merging it, I created an issue to track the proposed improvement (#534), hopefully someone will pick it up soon |
||
if collate_fn is None: | ||
collate_fn = default_collate | ||
|
||
|
@@ -156,7 +176,11 @@ def __init__( | |
dataset=dataset, | ||
batch_sampler=batch_sampler, | ||
num_workers=num_workers, | ||
collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes), | ||
collate_fn=wrap_collate_with_empty( | ||
collate_fn=collate_fn, | ||
sample_empty_shapes=sample_empty_shapes, | ||
dtypes=dtypes, | ||
), | ||
pin_memory=pin_memory, | ||
timeout=timeout, | ||
worker_init_fn=worker_init_fn, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,14 @@ def compute_conv_grad_sample( | |
""" | ||
activations = activations[0] | ||
n = activations.shape[0] | ||
if n == 0: | ||
# Empty batch | ||
ret = {} | ||
ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0) | ||
if layer.bias is not None and layer.bias.requires_grad: | ||
ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0) | ||
return ret | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does conv need a special treatment? What happens with other layers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's down to specific grad sampler implementation.
That said, it is a good point that we want to be sure all of the layers can handle it - which current tests only partially do. We have a PrivacyEngine test for an empty batch, but nothing on the grad sampler level - I'll check if there's an easy way to do it and update the PR |
||
|
||
# get activations and backprops in shape depending on the Conv layer | ||
if type(layer) == nn.Conv2d: | ||
activations = unfold2d( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.