From 4b0cd9182d4ae60651496f3c8d0481b0b85fd9c3 Mon Sep 17 00:00:00 2001 From: Iden Kalemaj Date: Wed, 15 Jan 2025 10:22:56 -0800 Subject: [PATCH 1/8] Add **kwargs to all optimizer classes (#710) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/710 Purpose: To enable creating custom PrivacyEngines that extend the PrivacyEngine class and take in additional parameters. Fix prior diff: D67456352 Reviewed By: HuanyuZhang Differential Revision: D67953655 fbshipit-source-id: 70aef7571e012a370d6a0fd04948eccee06c9a0d --- opacus/optimizers/adaclipoptimizer.py | 1 + opacus/optimizers/ddp_perlayeroptimizer.py | 2 ++ opacus/optimizers/ddpoptimizer.py | 2 ++ opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py | 2 ++ opacus/optimizers/optimizer.py | 1 + opacus/optimizers/optimizer_fast_gradient_clipping.py | 2 ++ opacus/optimizers/perlayeroptimizer.py | 2 ++ opacus/privacy_engine.py | 1 + 8 files changed, 13 insertions(+) diff --git a/opacus/optimizers/adaclipoptimizer.py b/opacus/optimizers/adaclipoptimizer.py index 7144f06b..9498613d 100644 --- a/opacus/optimizers/adaclipoptimizer.py +++ b/opacus/optimizers/adaclipoptimizer.py @@ -53,6 +53,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): super().__init__( optimizer, diff --git a/opacus/optimizers/ddp_perlayeroptimizer.py b/opacus/optimizers/ddp_perlayeroptimizer.py index c9b9bdfa..30a50633 100644 --- a/opacus/optimizers/ddp_perlayeroptimizer.py +++ b/opacus/optimizers/ddp_perlayeroptimizer.py @@ -48,6 +48,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -79,6 +80,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() diff --git a/opacus/optimizers/ddpoptimizer.py b/opacus/optimizers/ddpoptimizer.py index 1b4e472d..06048a82 100644 --- a/opacus/optimizers/ddpoptimizer.py +++ b/opacus/optimizers/ddpoptimizer.py @@ -38,6 +38,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): super().__init__( optimizer, @@ -47,6 +48,7 @@ def __init__( loss_reduction=loss_reduction, generator=generator, secure_mode=secure_mode, + **kwargs, ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() diff --git a/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py b/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py index b2245303..9442380f 100644 --- a/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py +++ b/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py @@ -38,6 +38,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): super().__init__( optimizer, @@ -47,6 +48,7 @@ def __init__( loss_reduction=loss_reduction, generator=generator, secure_mode=secure_mode, + **kwargs, ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index 7a22eeec..58b4c990 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -205,6 +205,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): """ diff --git a/opacus/optimizers/optimizer_fast_gradient_clipping.py b/opacus/optimizers/optimizer_fast_gradient_clipping.py index 21489779..a5a7425e 100644 --- a/opacus/optimizers/optimizer_fast_gradient_clipping.py +++ b/opacus/optimizers/optimizer_fast_gradient_clipping.py @@ -63,6 +63,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): """ @@ -91,6 +92,7 @@ def __init__( loss_reduction=loss_reduction, generator=generator, secure_mode=secure_mode, + **kwargs, ) @property diff --git a/opacus/optimizers/perlayeroptimizer.py b/opacus/optimizers/perlayeroptimizer.py index 6d0029bf..6ebc9724 100644 --- a/opacus/optimizers/perlayeroptimizer.py +++ b/opacus/optimizers/perlayeroptimizer.py @@ -39,6 +39,7 @@ def __init__( loss_reduction: str = "mean", generator=None, secure_mode: bool = False, + **kwargs, ): assert len(max_grad_norm) == len(params(optimizer)) self.max_grad_norms = max_grad_norm @@ -51,6 +52,7 @@ def __init__( loss_reduction=loss_reduction, generator=generator, secure_mode=secure_mode, + **kwargs, ) def clip_and_accumulate(self): diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index bdddafe4..558c8f8e 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -136,6 +136,7 @@ def _prepare_optimizer( loss_reduction=loss_reduction, generator=generator, secure_mode=self.secure_mode, + **kwargs, ) def _prepare_data_loader( From b4c075de16558189bc9fc6aa9cb9e20afd53c887 Mon Sep 17 00:00:00 2001 From: Xinwei Date: Thu, 16 Jan 2025 18:48:59 -0800 Subject: [PATCH 2/8] Disk (#706) Summary: ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue It introduces a set of new optimizers called DiSK, which uses a simplified Kalman filter to improve optimizer performance. ## How Has This Been Tested (if it applies) It is tested with the mnist.py from the example folder (with modifications for DiSK) to ensure all the functions work. ## Checklist Not sure whether to add documents. - [ ] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: https://github.com/pytorch/opacus/pull/706 Reviewed By: HuanyuZhang Differential Revision: D67626897 Pulled By: iden-kalemaj fbshipit-source-id: 3ac3caf5212920afdae7b4a8ef71bd3868073731 --- research/disk_optimizer/KFprivacy_engine.py | 56 ++++++ research/disk_optimizer/ReadMe.md | 69 ++++++++ .../optimizers/KFadaclipoptimizer.py | 95 +++++++++++ .../optimizers/KFddp_perlayeroptimizer.py | 160 ++++++++++++++++++ .../optimizers/KFddpoptimizer.py | 97 +++++++++++ .../KFddpoptimizer_fast_gradient_clipping.py | 97 +++++++++++ .../disk_optimizer/optimizers/KFoptimizer.py | 147 ++++++++++++++++ .../KFoptimizer_fast_gradient_clipping.py | 89 ++++++++++ .../optimizers/KFperlayeroptimizer.py | 65 +++++++ .../disk_optimizer/optimizers/__init__.py | 129 ++++++++++++++ 10 files changed, 1004 insertions(+) create mode 100644 research/disk_optimizer/KFprivacy_engine.py create mode 100644 research/disk_optimizer/ReadMe.md create mode 100644 research/disk_optimizer/optimizers/KFadaclipoptimizer.py create mode 100644 research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py create mode 100644 research/disk_optimizer/optimizers/KFddpoptimizer.py create mode 100644 research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py create mode 100644 research/disk_optimizer/optimizers/KFoptimizer.py create mode 100644 research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py create mode 100644 research/disk_optimizer/optimizers/KFperlayeroptimizer.py create mode 100644 research/disk_optimizer/optimizers/__init__.py diff --git a/research/disk_optimizer/KFprivacy_engine.py b/research/disk_optimizer/KFprivacy_engine.py new file mode 100644 index 00000000..c890c224 --- /dev/null +++ b/research/disk_optimizer/KFprivacy_engine.py @@ -0,0 +1,56 @@ +from typing import List, Union + +from opacus.optimizers import DPOptimizer +from opacus.privacy_engine import PrivacyEngine +from torch import optim + +from .optimizers import KF_DPOptimizer, get_optimizer_class + + +class KF_PrivacyEngine(PrivacyEngine): + def __init__(self, *, accountant: str = "prv", secure_mode: bool = False): + super().__init__(accountant=accountant, secure_mode=secure_mode) + + def _prepare_optimizer( + self, + *, + optimizer: optim.Optimizer, + noise_multiplier: float, + max_grad_norm: Union[float, List[float]], + expected_batch_size: int, + loss_reduction: str = "mean", + distributed: bool = False, + clipping: str = "flat", + noise_generator=None, + grad_sample_mode="hooks", + kalman: bool = False, + **kwargs, + ) -> DPOptimizer: + if kalman and isinstance(optimizer, KF_DPOptimizer): + optimizer = optimizer.original_optimizer + elif not kalman and isinstance(optimizer, DPOptimizer): + optimizer = optimizer.original_optimizer + + generator = None + if self.secure_mode: + generator = self.secure_rng + elif noise_generator is not None: + generator = noise_generator + + optim_class = get_optimizer_class( + clipping=clipping, + distributed=distributed, + grad_sample_mode=grad_sample_mode, + kalman=kalman, + ) + + return optim_class( + optimizer=optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=self.secure_mode, + **kwargs, + ) diff --git a/research/disk_optimizer/ReadMe.md b/research/disk_optimizer/ReadMe.md new file mode 100644 index 00000000..7c25bda1 --- /dev/null +++ b/research/disk_optimizer/ReadMe.md @@ -0,0 +1,69 @@ +# DiSK: Differentially Private Optimizer with Simplified Kalman Filter for Noise Reduction + +## Introduction +This part of the code introduces a new component to the optimizer named DiSK. The code uses a simplifed Kalman to improve the privatized gradient estimate. Speficially, the privatized minibatch gradient is replaced with: + +$$\mathbb{g_{t+\frac{1}{2}}} = \frac{1}{B}\sum_{\xi \in \mathcal{B}_t} \mathrm{clip}_C\left(\frac{1-\kappa}{\kappa\gamma}\nabla f(x_t + \gamma(x_t-x_{t-1});\xi) + \Big(1- \frac{1-\kappa}{\kappa\gamma}\Big)\nabla f(x_t;\xi)\right) + w_t$$ +$$g_{t}= (1-\kappa)g_{t-1} + \kappa g_{t+\frac{1}{2}}$$ + +A detailed description of the algorithm can be found at [Here](https://arxiv.org/abs/2410.03883). + +## Usage +The code provides a modified privacy engine with three extra arguments: +* kamlan: bool=False +* kappa: float=0.7 +* gamma: float=0.5 + +To use DiSK, follow the steps: + +**Step I:** Import KF_PrivacyEngine from KFprivacy_engine.py and set ```kalman=True``` + +**Step II:** Define a closure (see [here](https://pytorch.org/docs/stable/optim.html#optimizer-step-closure) for example) to compute loss and backward **without** ```zero_grad()``` and perform ```optimizer.step(closure)``` + +Example of using the DiSK optimizers: + +```python +from KFprivacy_engine import KF_PrivacyEngine +# ... +# follow the same steps as original opacus training scripts +privacy_engine = KF_PrivacyEngine() +model, optimizer, train_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=train_loader, + noise_multiplier=args.sigma, + max_grad_norm=max_grad_norm, + clipping=clipping, + grad_sample_mode=args.grad_sample_mode, + kalman=True, # need this argument + kappa=0.7, # optional + gamma=0.5 # optional + ) + +# ... +# during training: +def closure(): # compute loss and backward, an example adapting the one used in examples/cifar10.py + output = model(images) + loss = criterion(output, target) + loss.backward() + return output, loss +output, loss = optimizer.step(closure) +optimizer.zero_grad() +# compute other matrices +# ... +``` + +## Citation +Consider citing the paper is you use DiSK in your papers, as follows: + +``` +@article{zhang2024disk, + title={{DiSK}: Differentially private optimizer with simplified kalman filter for noise reduction}, + author={Zhang, Xinwei and Bu, Zhiqi and Balle, Borja and Hong, Mingyi and Razaviyayn, Meisam and Mirrokni, Vahab}, + journal={arXiv preprint arXiv:2410.03883}, + year={2024} +} +``` + +Contributer: Xinwei Zhang. Email: [zhan6234@umn.edu](mailto:zhan6234@umn.edu) + diff --git a/research/disk_optimizer/optimizers/KFadaclipoptimizer.py b/research/disk_optimizer/optimizers/KFadaclipoptimizer.py new file mode 100644 index 00000000..b8721903 --- /dev/null +++ b/research/disk_optimizer/optimizers/KFadaclipoptimizer.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import logging +import math +from typing import Optional + +import torch +from opacus.optimizers.adaclipoptimizer import AdaClipDPOptimizer +from torch.optim import Optimizer +from torch.optim.optimizer import required + +from .KFoptimizer import KF_DPOptimizer + + +logger = logging.getLogger(__name__) + + +class KF_AdaClipDPOptimizer(AdaClipDPOptimizer, KF_DPOptimizer): + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + target_unclipped_quantile: float, + clipbound_learning_rate: float, + max_clipbound: float, + min_clipbound: float, + unclipped_num_std: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa: float = 0.7, + gamma: float = 0.5, + ): + if gamma == 0 or abs(gamma - (1 - kappa) / kappa) < 1e-3: + gamma = (1 - kappa) / kappa + self.kf_compute_grad_at_original = False + else: + self.scaling_factor = (1 - kappa) / ( + gamma * kappa + ) # (gamma*kappa+kappa-1)/(1-kappa) + self.kf_compute_grad_at_original = True + c = (1 - kappa) / (gamma * kappa) + norm_factor = math.sqrt(c**2 + (1 - c) ** 2) + noise_multiplier = noise_multiplier / norm_factor + super(AdaClipDPOptimizer).__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + target_unclipped_quantile=target_unclipped_quantile, + clipbound_learning_rate=clipbound_learning_rate, + max_clipbound=max_clipbound, + min_clipbound=min_clipbound, + unclipped_num_std=unclipped_num_std, + ) + self.kappa = kappa + self.gamma = gamma + + def step(self, closure=required) -> Optional[float]: + if self.kf_compute_grad_at_original: + loss = self._compute_two_closure(closure) + else: + loss = self._compute_one_closure(closure) + + if self.pre_step(): + tmp_states = [] + first_step = False + for p in self.params: + grad = p.grad + state = self.state[p] + if "kf_d_t" not in state: + state = dict() + first_step = True + state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) + state["kf_m_t"] = grad.clone().to(p.data) + state["kf_m_t"].lerp_(grad, weight=self.kappa) + p.grad = state["kf_m_t"].clone().to(p.data) + state["kf_d_t"] = -p.data.clone().to(p.data) + if first_step: + tmp_states.append(state) + self.original_optimizer.step() + for p in self.params: + if first_step: + tmp_state = tmp_states.pop(0) + self.state[p]["kf_d_t"] = tmp_state["kf_d_t"] + self.state[p]["kf_m_t"] = tmp_state["kf_m_t"] + del tmp_state + self.state[p]["kf_d_t"].add_(p.data, alpha=1.0) + return loss diff --git a/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py b/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py new file mode 100644 index 00000000..662e3612 --- /dev/null +++ b/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from functools import partial +from typing import Callable, List, Optional + +import torch +from opacus.optimizers.ddp_perlayeroptimizer import _clip_and_accumulate_parameter +from opacus.optimizers.optimizer import _generate_noise +from torch import nn +from torch.optim import Optimizer + +from .KFddpoptimizer import KF_DistributedDPOptimizer +from .KFoptimizer import KF_DPOptimizer +from .KFperlayeroptimizer import KF_DPPerLayerOptimizer + + +class KF_SimpleDistributedPerLayerOptimizer( + KF_DPPerLayerOptimizer, KF_DistributedDPOptimizer +): + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa: float = 0.7, + gamma: float = 0.5, + ): + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + kappa=kappa, + gamma=gamma, + ) + + +class KF_DistributedPerLayerOptimizer(KF_DPOptimizer): + """ + :class:`~opacus.optimizers.optimizer.DPOptimizer` that implements + per layer clipping strategy and is compatible with distributed data parallel + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: List[float], + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa: float = 0.7, + gamma: float = 0.5, + ): + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + self.max_grad_norms = max_grad_norm + max_grad_norm = torch.norm(torch.Tensor(self.max_grad_norms), p=2).item() + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + kappa=kappa, + gamma=gamma, + ) + self._register_hooks() + + def _add_noise_parameter(self, p: nn.Parameter): + """ + The reason why we need self is because of generator for secure_mode + """ + noise = _generate_noise( + std=self.noise_multiplier * self.max_grad_norm, + reference=p.summed_grad, + generator=None, + secure_mode=self.secure_mode, + ) + p.grad = p.summed_grad + noise + + @property + def accumulated_iterations(self) -> int: + return max([p.accumulated_iterations for p in self.params]) + + def _scale_grad_parameter(self, p: nn.Parameter): + if not hasattr(p, "accumulated_iterations"): + p.accumulated_iterations = 0 + p.accumulated_iterations += 1 + if self.loss_reduction == "mean": + p.grad /= ( + self.expected_batch_size * p.accumulated_iterations * self.world_size + ) + + def clip_and_accumulate(self): + raise NotImplementedError( + "Clip and accumulate is added per layer in DPDDP Per Layer." + ) + + def add_noise(self): + raise NotImplementedError("Noise is added per layer in DPDDP Per Layer.") + + def pre_step( + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[float]: + if self._check_skip_next_step(): + self._is_last_step_skipped = True + return False + + if self.step_hook: + self.step_hook(self) + + for p in self.params: + p.accumulated_iterations = 0 + + self._is_last_step_skipped = False + return True + + def _ddp_per_layer_hook( + self, p: nn.Parameter, max_grad_norm: float, _: torch.Tensor + ): + _clip_and_accumulate_parameter(p, max_grad_norm) + # Equivalent ot _check_skip_next_step but without popping because it has to be done for every parameter p + if self._check_skip_next_step(pop_next=False): + return + + if self.rank == 0: + self._add_noise_parameter(p) + else: + p.grad = p.summed_grad + self._scale_grad_parameter(p) + + return p.grad + + def _register_hooks(self): + for p, max_grad_norm in zip(self.params, self.max_grad_norms): + if not p.requires_grad: + continue + + if not hasattr(p, "ddp_hooks"): + p.ddp_hooks = [] + + p.ddp_hooks.append( + p.register_hook(partial(self._ddp_per_layer_hook, p, max_grad_norm)) + ) diff --git a/research/disk_optimizer/optimizers/KFddpoptimizer.py b/research/disk_optimizer/optimizers/KFddpoptimizer.py new file mode 100644 index 00000000..00c3bbcd --- /dev/null +++ b/research/disk_optimizer/optimizers/KFddpoptimizer.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from torch.optim import Optimizer +from torch.optim.optimizer import required + +from .KFoptimizer import KF_DPOptimizer + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class KF_DistributedDPOptimizer(KF_DPOptimizer): + """ + :class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with + distributed data processing + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa=0.7, + gamma=0.5, + ): + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + kappa=kappa, + gamma=gamma, + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_noise(self): + # Noise only gets added to the first worker + if self.rank == 0: + super().add_noise() + else: + for p in self.params: + p.grad = p.summed_grad.view_as(p) + + def reduce_gradients(self): + for p in self.params: + if not p.requires_grad: + continue + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM) + if self.loss_reduction == "mean": + p.grad /= self.world_size + + def step(self, closure=required) -> Optional[float]: + if self.kf_compute_grad_at_original: + loss = self._compute_two_closure(closure) + else: + loss = self._compute_one_closure(closure) + + if self.pre_step(): + tmp_states = [] + first_step = False + for p in self.params: + grad = p.grad + state = self.state[p] + if "kf_d_t" not in state: + state = dict() + first_step = True + state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) + state["kf_m_t"] = grad.clone().to(p.data) + state["kf_m_t"].lerp_(grad, weight=self.kappa) + p.grad = state["kf_m_t"].clone().to(p.data) + state["kf_d_t"] = -p.data.clone().to(p.data) + if first_step: + tmp_states.append(state) + self.reduce_gradients() + self.original_optimizer.step() + for p in self.params: + if first_step: + tmp_state = tmp_states.pop(0) + self.state[p]["kf_d_t"] = tmp_state["kf_d_t"] + self.state[p]["kf_m_t"] = tmp_state["kf_m_t"] + del tmp_state + self.state[p]["kf_d_t"].add_(p.data, alpha=1.0) + return loss diff --git a/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py b/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py new file mode 100644 index 00000000..f3d61b06 --- /dev/null +++ b/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from torch.optim import Optimizer +from torch.optim.optimizer import required + +from .KFoptimizer_fast_gradient_clipping import KF_DPOptimizerFastGradientClipping + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class KF_DistributedDPOptimizerFastGradientClipping(KF_DPOptimizerFastGradientClipping): + """ + :class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with + distributed data processing + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa=0.7, + gamma=0.5, + ): + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + kappa=kappa, + gamma=gamma, + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_noise(self): + # Noise only gets added to the first worker + if self.rank == 0: + super().add_noise() + else: + for p in self.params: + p.grad = p.summed_grad.view_as(p) + + def reduce_gradients(self): + for p in self.params: + if not p.requires_grad: + continue + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM) + if self.loss_reduction == "mean": + p.grad /= self.world_size + + def step(self, closure=required) -> Optional[float]: + if self.kf_compute_grad_at_original: + loss = self._compute_two_closure(closure) + else: + loss = self._compute_one_closure(closure) + + if self.pre_step(): + tmp_states = [] + first_step = False + for p in self.params: + grad = p.grad + state = self.state[p] + if "kf_d_t" not in state: + state = dict() + first_step = True + state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) + state["kf_m_t"] = grad.clone().to(p.data) + state["kf_m_t"].lerp_(grad, weight=self.kappa) + p.grad = state["kf_m_t"].clone().to(p.data) + state["kf_d_t"] = -p.data.clone().to(p.data) + if first_step: + tmp_states.append(state) + self.reduce_gradients() + self.original_optimizer.step() + for p in self.params: + if first_step: + tmp_state = tmp_states.pop(0) + self.state[p]["kf_d_t"] = tmp_state["kf_d_t"] + self.state[p]["kf_m_t"] = tmp_state["kf_m_t"] + del tmp_state + self.state[p]["kf_d_t"].add_(p.data, alpha=1.0) + return loss diff --git a/research/disk_optimizer/optimizers/KFoptimizer.py b/research/disk_optimizer/optimizers/KFoptimizer.py new file mode 100644 index 00000000..30597653 --- /dev/null +++ b/research/disk_optimizer/optimizers/KFoptimizer.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import logging +import math +from typing import Optional + +import torch +from opacus.optimizers.optimizer import DPOptimizer +from torch.optim import Optimizer +from torch.optim.optimizer import required + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class KF_DPOptimizer(DPOptimizer): + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa=0.7, + gamma=0.5, + ): + if gamma == 0 or abs(gamma - (1 - kappa) / kappa) < 1e-3: + gamma = (1 - kappa) / kappa + self.kf_compute_grad_at_original = False + else: + self.scaling_factor = (1 - kappa) / ( + gamma * kappa + ) # (gamma*kappa+kappa-1)/(1-kappa) + self.kf_compute_grad_at_original = True + c = (1 - kappa) / (gamma * kappa) + norm_factor = math.sqrt(c**2 + (1 - c) ** 2) + noise_multiplier = noise_multiplier / norm_factor + super().__init__( + optimizer=optimizer, + noise_multiplier=noise_multiplier, + expected_batch_size=expected_batch_size, + max_grad_norm=max_grad_norm, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + ) + self.kappa = kappa + self.gamma = gamma + + @DPOptimizer.grad_samples.setter + def grad_samples(self, value): + """ + Set the per sample gradient tensors to zero + """ + if value is not None: + for p, v in zip(self.params, value): + p.grad_sample = v + else: + for p in self.params: + if hasattr(p, "grad_sample"): + p.grad_sample = None + + def _compute_one_closure(self, closure=required): + loss = None + has_kf_d_t = True + for p in self.params: + state = self.state[p] + if "kf_d_t" not in state: + has_kf_d_t = False + continue + # perturb + p.data.add_(state["kf_d_t"], alpha=self.gamma) + with torch.enable_grad(): + loss = closure() + if has_kf_d_t: + for p in self.params: + state = self.state[p] + # perturb back + p.data.add_(state["kf_d_t"], alpha=-self.gamma) + return loss + + def _compute_two_closure(self, closure=required): + loss = None + has_kf_d_t = True + with torch.enable_grad(): + loss = closure() + for p in self.params: + state = self.state[p] + if "kf_d_t" not in state: + has_kf_d_t = False + continue + # perturb + p.data.add_(state["kf_d_t"], alpha=self.gamma) + # store first set of gradient + if has_kf_d_t: + if self.grad_samples is not None and len(self.grad_samples) != 0: + self.past_grad_samples = self.grad_samples + for grad in self.past_grad_samples: + grad.mul_(1.0 - self.scaling_factor) + self.grad_samples = None + with torch.enable_grad(): + loss = closure() + for p in self.params: + state = self.state[p] + # perturb back + p.data.add_(state["kf_d_t"], alpha=-self.gamma) + if self.grad_samples is not None and len(self.grad_samples) != 0: + for grad, past_grad in zip(self.grad_samples, self.past_grad_samples): + grad.mul_(self.scaling_factor).add_(past_grad) + self.past_grad_samples = None + return loss + + def step(self, closure=required) -> Optional[float]: + if self.kf_compute_grad_at_original: + loss = self._compute_two_closure(closure) + else: + loss = self._compute_one_closure(closure) + + if self.pre_step(): + tmp_states = [] + first_step = False + for p in self.params: + grad = p.grad + state = self.state[p] + if "kf_d_t" not in state: + state = dict() + first_step = True + state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) + state["kf_m_t"] = grad.clone().to(p.data) + state["kf_m_t"].lerp_(grad, weight=self.kappa) + p.grad = state["kf_m_t"].clone().to(p.data) + state["kf_d_t"] = -p.data.clone().to(p.data) + if first_step: + tmp_states.append(state) + self.original_optimizer.step() + for p in self.params: + if first_step: + tmp_state = tmp_states.pop(0) + self.state[p]["kf_d_t"] = tmp_state["kf_d_t"] + self.state[p]["kf_m_t"] = tmp_state["kf_m_t"] + del tmp_state + self.state[p]["kf_d_t"].add_(p.data, alpha=1.0) + return loss diff --git a/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py b/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py new file mode 100644 index 00000000..3b105418 --- /dev/null +++ b/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from torch.optim import Optimizer +from torch.optim.optimizer import required + +from .KFoptimizer import KF_DPOptimizer + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class KF_DPOptimizerFastGradientClipping(KF_DPOptimizer): + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa=0.7, + gamma=0.5, + ): + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + kappa=kappa, + gamma=gamma, + ) + + def _compute_one_closure(self, closure=required): + loss = None + has_kf_d_t = True + for p in self.params: + state = self.state[p] + if "kf_d_t" not in state: + has_kf_d_t = False + continue + # perturb + p.data.add_(state["kf_d_t"], alpha=self.gamma) + with torch.enable_grad(): + loss = closure() + if has_kf_d_t: + for p in self.params: + state = self.state[p] + # perturb back + p.data.add_(state["kf_d_t"], alpha=-self.gamma) + return loss + + def _compute_two_closure(self, closure=required): + loss = None + has_kf_d_t = True + with torch.enable_grad(): + loss = closure() + for p in self.params: + state = self.state[p] + if "kf_d_t" not in state: + has_kf_d_t = False + continue + # perturb + p.data.add_(state["kf_d_t"], alpha=self.gamma) + # store first set of gradient + if has_kf_d_t: + for p in self.params: + p.past_grad = p.grad + p.past_grad.mul_(1.0 - self.scaling_factor) + p.grad = None + with torch.enable_grad(): + loss = closure() + for p in self.params: + state = self.state[p] + # perturb back + p.data.add_(state["kf_d_t"], alpha=-self.gamma) + for p in self.params: + p.grad.mul_(self.scaling_factor).add_(p.past_grad) + p.past_grad = None + return loss diff --git a/research/disk_optimizer/optimizers/KFperlayeroptimizer.py b/research/disk_optimizer/optimizers/KFperlayeroptimizer.py new file mode 100644 index 00000000..fa48a246 --- /dev/null +++ b/research/disk_optimizer/optimizers/KFperlayeroptimizer.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import logging +from typing import List, Optional + +import torch +from opacus.optimizers.optimizer import _check_processed_flag, _mark_as_processed +from opacus.optimizers.utils import params +from torch.optim import Optimizer + +from .KFoptimizer import KF_DPOptimizer + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class KF_DPPerLayerOptimizer(KF_DPOptimizer): + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: List[float], + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + kappa=0.7, + gamma=0.5, + ): + assert len(max_grad_norm) == len(params(optimizer)) + self.max_grad_norms = max_grad_norm + max_grad_norm = torch.norm(torch.Tensor(self.max_grad_norms), p=2).item() + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + kappa=kappa, + gamma=gamma, + ) + + def clip_and_accumulate(self): + for p, max_grad_norm in zip(self.params, self.max_grad_norms): + _check_processed_flag(p.grad_sample) + + grad_sample = self._get_flat_grad_sample(p) + per_sample_norms = grad_sample.norm( + 2, dim=tuple(range(1, grad_sample.ndim)) + ) + per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp( + max=1.0 + ) + grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) + + if p.summed_grad is not None: + p.summed_grad += grad + else: + p.summed_grad = grad + + _mark_as_processed(p.grad_sample) diff --git a/research/disk_optimizer/optimizers/__init__.py b/research/disk_optimizer/optimizers/__init__.py new file mode 100644 index 00000000..cbf4f04f --- /dev/null +++ b/research/disk_optimizer/optimizers/__init__.py @@ -0,0 +1,129 @@ +from opacus.optimizers import ( + AdaClipDPOptimizer, + DistributedDPOptimizer, + DistributedDPOptimizerFastGradientClipping, + DistributedPerLayerOptimizer, + DPOptimizer, + DPOptimizerFastGradientClipping, + DPPerLayerOptimizer, + SimpleDistributedPerLayerOptimizer, +) + +from .KFadaclipoptimizer import KF_AdaClipDPOptimizer +from .KFddp_perlayeroptimizer import ( + KF_DistributedPerLayerOptimizer, + KF_SimpleDistributedPerLayerOptimizer, +) +from .KFddpoptimizer import KF_DistributedDPOptimizer +from .KFddpoptimizer_fast_gradient_clipping import ( + KF_DistributedDPOptimizerFastGradientClipping, +) +from .KFoptimizer import KF_DPOptimizer +from .KFoptimizer_fast_gradient_clipping import KF_DPOptimizerFastGradientClipping +from .KFperlayeroptimizer import KF_DPPerLayerOptimizer + + +__all__ = [ + "KF_AdaClipDPOptimizer", + "KF_DistributedPerLayerOptimizer", + "KF_DistributedDPOptimizer", + "KF_DPOptimizer", + "KF_DPOptimizerFastGradientClipping", + "KF_DistributedDPOptimizerFastGradientlipping", + "KF_DPPerLayerOptimizer", + "KF_SimpleDistributedPerLayerOptimizer", +] + + +def get_optimizer_key(clipping: str, distributed: bool, grad_sample_mode: str = None): + key = clipping + "_" + str(distributed) + if grad_sample_mode == "ghost" or (clipping == "per_layer" and distributed): + key += "_" + grad_sample_mode + return key + + +def get_optimizer_class( + clipping: str, distributed: bool, grad_sample_mode: str = None, kalman: bool = False +): + if kalman: + optimizer_dict = { + "flat_false_ghost": KF_DPOptimizerFastGradientClipping, + "flat_true_ghost": KF_DistributedDPOptimizerFastGradientClipping, + "flat_false": KF_DPOptimizer, + "flat_true": KF_DistributedDPOptimizer, + "per_layer_false": KF_DPPerLayerOptimizer, + "per_layer_true_hook": KF_DistributedPerLayerOptimizer, + "per_layer_true_ew": KF_SimpleDistributedPerLayerOptimizer, + "adaptive_false": KF_AdaClipDPOptimizer, + } + else: + optimizer_dict = { + "flat_false_ghost": DPOptimizerFastGradientClipping, + "flat_true_ghost": DistributedDPOptimizerFastGradientClipping, + "flat_false": DPOptimizer, + "flat_true": DistributedDPOptimizer, + "per_layer_false": DPPerLayerOptimizer, + "per_layer_true_hook": DistributedPerLayerOptimizer, + "per_layer_true_ew": SimpleDistributedPerLayerOptimizer, + "adaptive_false": AdaClipDPOptimizer, + } + optimizer_key = get_optimizer_key(clipping, distributed, grad_sample_mode) + if optimizer_key not in optimizer_dict: + err_str = "Unsupported combination of parameters." + err_str += f"Clipping: {clipping}, distributed: {str(distributed)} and grad_sample_mode: {grad_sample_mode}" + raise ValueError(err_str) + else: + return optimizer_dict[optimizer_key] + + # if grad_sample_mode == "ghost": + # if clipping == "flat" and distributed is False: + # return KF_DPOptimizerFastGradientClipping + # elif clipping == "flat" and distributed is True: + # return KF_DistributedDPOptimizerFastGradientClipping + # else: + # + # raise ValueError(err_str) + # elif clipping == "flat" and distributed is False: + # return KF_DPOptimizer + # elif clipping == "flat" and distributed is True: + # return KF_DistributedDPOptimizer + # elif clipping == "per_layer" and distributed is False: + # return KF_DPPerLayerOptimizer + # elif clipping == "per_layer" and distributed is True: + # if grad_sample_mode == "hooks": + # return KF_DistributedPerLayerOptimizer + # elif grad_sample_mode == "ew": + # return KF_SimpleDistributedPerLayerOptimizer + # else: + # raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}") + # elif clipping == "adaptive" and distributed is False: + # return KF_AdaClipDPOptimizer + # elif grad_sample_mode == "ghost": + # if clipping == "flat" and distributed is False: + # return DPOptimizerFastGradientClipping + # elif clipping == "flat" and distributed is True: + # return DistributedDPOptimizerFastGradientClipping + # else: + # err_str = "Unsupported combination of parameters." + # err_str+= f"Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}" + # raise ValueError( + # err_str + # ) + # elif clipping == "flat" and distributed is False: + # return DPOptimizer + # elif clipping == "flat" and distributed is True: + # return DistributedDPOptimizer + # elif clipping == "per_layer" and distributed is False: + # return DPPerLayerOptimizer + # elif clipping == "per_layer" and distributed is True: + # if grad_sample_mode == "hooks": + # return DistributedPerLayerOptimizer + # elif grad_sample_mode == "ew": + # return SimpleDistributedPerLayerOptimizer + # else: + # raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}") + # elif clipping == "adaptive" and distributed is False: + # return AdaClipDPOptimizer + # raise ValueError( + # f"Unexpected optimizer parameters. Clipping: {clipping}, distributed: {distributed}" + # ) From 3e9744787d6f2ac638dae39061963b72885d4bbb Mon Sep 17 00:00:00 2001 From: Iden Kalemaj Date: Tue, 21 Jan 2025 10:48:03 -0800 Subject: [PATCH 3/8] Edits to isort and contribution instructions (#714) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/714 1. Add command for earlier versions of isort 2. For contributions in the `research` folder, add instructions about code fomratting. 3. Specify version of isort in dev requirements 4. Remove -v (verbose) arg since this arg seems to alter the behavior of isort in some cases. Reviewed By: EnayatUllah Differential Revision: D68078254 fbshipit-source-id: 21b75ade5c651cb4b890a4906afca002394e7aee --- .github/workflows/ci_cpu.yml | 2 +- CONTRIBUTING.md | 11 ++++++++--- dev_requirements.txt | 2 +- research/README.md | 2 ++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci_cpu.yml b/.github/workflows/ci_cpu.yml index 428d9741..4196cd23 100644 --- a/.github/workflows/ci_cpu.yml +++ b/.github/workflows/ci_cpu.yml @@ -32,7 +32,7 @@ jobs: - name: Lint with black run: black --check --diff --color . - name: Check import order with isort - run: isort -v -l 88 -o opacus --lines-after-imports 2 -m 3 --trailing-comma --check-only . + run: isort -l 88 -o opacus --lines-after-imports 2 -m 3 --trailing-comma --check-only . ########### UNIT TESTS ############## unittest_py38_torch_release: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e5a35b8e..22b44ba9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,11 +31,16 @@ for advanced usage). Opacus also uses [isort](https://github.com/timothycrosley/isort) to sort imports alphabetically and separate into sections. isort is installed easily via -pip using `pip install isort`, and run locally by calling +pip using `pip install isort --upgrade`, and run locally by calling ```bash -isort -v -l 88 -o opacus --lines-after-imports 2 -m 3 --trailing-comma . +isort -l 88 -o opacus --lines-after-imports 2 -m 3 --trailing-comma . ``` from the repository root. Configuration for isort is located in .isort.cfg. +If using `isort` versions `<5.0.0` call +```bash +isort -l 88 -o opacus --lines-after-imports 2 -m 3 --trailing-comma --recursive +``` + We feel strongly that having a consistent code style is extremely important, so CircleCI will fail on your PR if it does not adhere to the black or flake8 formatting style or isort import ordering. @@ -96,7 +101,7 @@ Run following command from `website` folder. It will build the docs and serve th ``` You can also perform spell checks on documentation automatically (besides IDEs) using [```sphinxcontrib-spelling```](https://sphinxcontrib-spelling.readthedocs.io/en/latest/install.html) -Note that you will also need [```PyEnchant```](https://pyenchant.github.io/pyenchant/) to run ```sphinxcontrib-spelling```, and thus the Enchant C library. Use this guide for ```PyEnchant```. +Note that you will also need [```PyEnchant```](https://pyenchant.github.io/pyenchant/) to run ```sphinxcontrib-spelling```, and thus the Enchant C library. Use this guide for ```PyEnchant```. Steps: 1. Install the extension with pip: ```pip install sphinxcontrib-spelling``` diff --git a/dev_requirements.txt b/dev_requirements.txt index fe57472f..0ba47b6e 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -8,7 +8,7 @@ flake8 sphinx sphinx-autodoc-typehints mypy>=0.760 -isort +isort>=5.0.0 hypothesis tensorboard datasets diff --git a/research/README.md b/research/README.md index 952b0d40..806c7f83 100644 --- a/research/README.md +++ b/research/README.md @@ -9,6 +9,8 @@ We warmly welcome and encourage contributions of new methods! To contribute, ple 1. Fork the repo and create your branch from `main`. 2. Place the new method in a separate subfolder within the `research` directory. 3. The new folder should include a `README.md` that explains the method at a high level, demonstrates usage (e.g., introducing new parameters to the `PrivacyEngine`), and cites relevant sources. The subfolder name should aptly represent the method. +4. Format code using `black`, `flake8`, and `isort` following the instructions under `Code Style` [here](https://github.com/pytorch/opacus/blob/main/CONTRIBUTING.md). +5. Add copyright headers to each `.py` file contributed in the format `# Copyright (c) [copy-right holder]`. More detailed PR instructions can be found [here](https://github.com/pytorch/opacus/blob/main/CONTRIBUTING.md). From ea4cb955ab3fd18e88355fca3646965b3200930d Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Wed, 22 Jan 2025 09:31:06 -0800 Subject: [PATCH 4/8] Fix broken ``coveralls`` (#713) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/713 Fix the broken communication between CI and coveralls Reviewed By: EnayatUllah Differential Revision: D68002443 fbshipit-source-id: ee81fbdb57b059a1aad15b0c0d190a0eda6354f2 --- .github/workflows/ci_cpu.yml | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_cpu.yml b/.github/workflows/ci_cpu.yml index 4196cd23..1f480434 100644 --- a/.github/workflows/ci_cpu.yml +++ b/.github/workflows/ci_cpu.yml @@ -52,13 +52,21 @@ jobs: - name: Run unit tests run: | mkdir unittest-py38-release-reports - coverage run -m pytest --doctest-modules -p conftest --junitxml=unittest-py38-release-reports/junit.xml opacus + coverage run -m pytest --doctest-modules -p conftest opacus coverage report -i -m + # Format into xml to be used for coveralls + coverage xml -i - name: Store test results uses: actions/upload-artifact@v4 with: name: unittest-py38-release-reports path: unittest-py38-release-reports + - name: Send coverage to Coveralls (parallel) + uses: coverallsapp/github-action@v2 + with: + format: cobertura + parallel: true + flag-name: run-1 unittest_py39_torch_release: runs-on: ubuntu-latest @@ -77,13 +85,19 @@ jobs: - name: Run unit tests run: | mkdir unittest-py39-release-reports - coverage run -m pytest --doctest-modules -p conftest --junitxml=unittest-py39-release-reports/junit.xml opacus - coverage report -i -m + coverage run -m pytest --doctest-modules -p conftest opacus + coverage xml -i - name: Store test results uses: actions/upload-artifact@v4 with: name: unittest-py39-release-reports path: unittest-py39-release-reports + - name: Send coverage to Coveralls (parallel) + uses: coverallsapp/github-action@v2 + with: + format: cobertura + parallel: true + flag-name: run-2 prv_accountant_values: runs-on: ubuntu-latest @@ -150,11 +164,18 @@ jobs: coverage run examples/mnist.py --lr 0.25 --sigma 0.7 -c 1.5 --batch-size 64 --epochs 1 --data-root runs/mnist/data --n-runs 1 --device cpu python -c "import torch; accuracy = torch.load('run_results_mnist_0.25_0.7_1.5_64_1.pt'); exit(0) if (accuracy[0]>0.78 and accuracy[0]<0.95) else exit(1)" coverage report -i -m + coverage xml -i - name: Store test results uses: actions/upload-artifact@v4 with: name: mnist-cpu-reports path: runs/mnist/test-reports + - name: Send coverage to Coveralls (parallel) + uses: coverallsapp/github-action@v2 + with: + format: cobertura + parallel: true + flag-name: run-3 ######## FINISH COVERALLS ########## finish_coveralls_parallel: @@ -168,3 +189,4 @@ jobs: with: github_token: ${{ secrets.GITHUB_TOKEN }} parallel-finished: true + carryforward: "run-1,run-2,run-3" From 9741fe24ceeae1df1c4865f75e86f8747698d5fe Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Thu, 23 Jan 2025 15:19:25 -0800 Subject: [PATCH 5/8] Adding PyPI downloads information (#721) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/721 Leveraging query information from "pepy.tech". Reviewed By: iden-kalemaj Differential Revision: D68580105 fbshipit-source-id: f3b9ed1127918c153407cb2c5fffd1a06b4c62c7 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c46147a6..1bc2e0dd 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@
+[![PyPI Downloads](https://static.pepy.tech/badge/opacus)](https://pepy.tech/projects/opacus) [![GitHub Actions](https://github.com/pytorch/opacus/actions/workflows/ci_cpu.yml/badge.svg)](https://github.com/pytorch/opacus/actions/workflows/ci_cpu.yml) [![Coverage Status](https://coveralls.io/repos/github/pytorch/opacus/badge.svg?branch=main)](https://coveralls.io/github/pytorch/opacus?branch=main) [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](CONTRIBUTING.md) From c7d6144378392b9eab82ebd68a9051ba8bb1b4d9 Mon Sep 17 00:00:00 2001 From: Aparna Aketi Date: Sat, 25 Jan 2025 17:01:45 -0800 Subject: [PATCH 6/8] Modifying DPLossFastGradientClipping to add support for generative tasks with ghost clipping (#722) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/722 Generative tasks for NLP output predictions of shape (B,T,C) i.e., (batch_size, sequence_length, vocab_size). To compute the cross-entropy loss in this case, usually the predictions are reshaped to (BxT, C) and targets to (BxT). This creates an issue with Ghost Clipping per sample loss computation as BxT is seen as the batch_size. In particular, the current implementation of Ghost Clipping results in loss_per_sample, coeff variables to have a shape of BxT and B respectively. This causes a shape mismatch error. This diff fixes that error by collapsing the loss_per_sample variable to shape B i.e., the loss across the sequence_length dim is averaged/summed. Reviewed By: EnayatUllah Differential Revision: D68047256 fbshipit-source-id: ad7614e2cdba59869d762d810a14b96b465ee513 --- opacus/utils/fast_gradient_clipping_utils.py | 26 +++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/opacus/utils/fast_gradient_clipping_utils.py b/opacus/utils/fast_gradient_clipping_utils.py index 290341e1..42ecd94e 100644 --- a/opacus/utils/fast_gradient_clipping_utils.py +++ b/opacus/utils/fast_gradient_clipping_utils.py @@ -68,7 +68,9 @@ def backward(self): reduced_loss.backward(retain_graph=True) self.optimizer.zero_grad() coeff = self.module.get_clipping_coef() - second_loss_per_sample = coeff * self.loss_per_sample + second_loss_per_sample = ( + coeff.to(self.loss_per_sample.device) * self.loss_per_sample + ) second_loss = torch.sum(second_loss_per_sample) self.module.disable_hooks() second_loss.backward() @@ -104,15 +106,27 @@ def __init__( self.loss_reduction = loss_reduction self.criterion.reduction = "none" - def __call__(self, input, target) -> DPTensorFastGradientClipping: + def __call__(self, input, target, shape=None) -> DPTensorFastGradientClipping: """ Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping """ - loss_per_sample = self.criterion( - input, - target, - ) + loss_per_sample = self.criterion(input, target) + + if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]: + # Note that the privacy unit for generative NLP tasks is per sequence. + # The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size]. + # This variable is necessary for ghost clipping to work with generative NLP tasks. + loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT + if self.loss_reduction == "mean": + loss_per_sample = loss_per_sample.mean(dim=1) # B + elif self.loss_reduction == "sum": + loss_per_sample = loss_per_sample.sum(dim=1) # B + else: + raise ValueError( + f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + return DPTensorFastGradientClipping( self.module, self.optimizer, loss_per_sample, self.loss_reduction ) From 8e028f46b8f2f5c2b106498f617303a398a51e4b Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Mon, 27 Jan 2025 21:02:31 -0800 Subject: [PATCH 7/8] Improve documentation of Github and Website (#723) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/723 Improve the documentation of Github and Opacus website, specifically: 1. Added a "Latest updates" section in Github `readme`. 2. Updated outdated documentation, and highlighted new features like Ghost clipping. 3. Fixed the API library from the website which did not include some newly added files (e.g., `fast_gradient_clipping_utils`). Reviewed By: iden-kalemaj Differential Revision: D68637848 fbshipit-source-id: d8a46d88f13e68e858787dc0ff983adcb4cac39c --- README.md | 18 +++++++-------- docs/faq.md | 6 ++--- tutorials/README.md | 2 +- website/pages/tutorials/index.js | 23 ++++++++++--------- ...d_sample_module_fast_gradient_clipping.rst | 5 ++++ website/sphinx/source/index.rst | 1 + ...p_ddp_optimizer_fast_gradient_clipping.rst | 5 ++++ .../dp_optimizer_fast_gradient_clipping.rst | 5 ++++ website/sphinx/source/optim/optimizers.rst | 3 ++- .../utils/fast_gradient_clipping_utils.rst | 5 ++++ website/sphinx/source/utils/utils.rst | 1 + 11 files changed, 49 insertions(+), 25 deletions(-) create mode 100644 website/sphinx/source/grad_sample_module_fast_gradient_clipping.rst create mode 100644 website/sphinx/source/optim/dp_ddp_optimizer_fast_gradient_clipping.rst create mode 100644 website/sphinx/source/optim/dp_optimizer_fast_gradient_clipping.rst create mode 100644 website/sphinx/source/utils/fast_gradient_clipping_utils.rst diff --git a/README.md b/README.md index 1bc2e0dd..69866ef6 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,13 @@ This code release is aimed at two target audiences: 2. Differential Privacy researchers will find this easy to experiment and tinker with, allowing them to focus on what matters. + +## Latest updates + +2024-12-18: We updated this [tutorial](https://github.com/pytorch/opacus/blob/main/tutorials/building_text_classifier.ipynb) to show how [LoRA](https://arxiv.org/abs/2106.09685) and [peft](https://huggingface.co/docs/peft/en/index) library could be used in conjuncture with DP-SGD. + +2024-08-20: We introduced [Fast Gradient Clipping](https://arxiv.org/abs/2009.03106) and Ghost Clipping(https://arxiv.org/abs/2110.05679) to Opacus, significantly reducing the memory requirements of DP-SGD. Please refer to our [blogpost](https://pytorch.org/blog/clipping-in-opacus/) for more information. + ## Installation The latest release of Opacus can be installed via `pip`: @@ -76,13 +83,6 @@ shows an end-to-end run using Opacus. The [examples](https://github.com/pytorch/opacus/tree/main/examples/) folder contains more such examples. -### Migrating to 1.0 - -Opacus 1.0 introduced many improvements to the library, but also some breaking -changes. If you've been using Opacus 0.x and want to update to the latest -release, please use this -[Migration Guide](https://github.com/pytorch/opacus/blob/main/Migration_Guide.md) - ## Learn more ### Interactive tutorials @@ -90,9 +90,9 @@ release, please use this We've built a series of IPython-based tutorials as a gentle introduction to training models with privacy and using various Opacus features. +- [Building text classifier with Differential Privacy on BERT](https://github.com/pytorch/opacus/blob/main/tutorials/building_text_classifier.ipynb) - [Building an Image Classifier with Differential Privacy](https://github.com/pytorch/opacus/blob/main/tutorials/building_image_classifier.ipynb) - [Training a differentially private LSTM model for name classification](https://github.com/pytorch/opacus/blob/main/tutorials/building_lstm_name_classifier.ipynb) -- [Building text classifier with Differential Privacy on BERT](https://github.com/pytorch/opacus/blob/main/tutorials/building_text_classifier.ipynb) - [Opacus Guide: Introduction to advanced features](https://github.com/pytorch/opacus/blob/main/tutorials/intro_to_advanced_features.ipynb) - [Opacus Guide: Grad samplers](https://github.com/pytorch/opacus/blob/main/tutorials/guide_to_grad_sampler.ipynb) - [Opacus Guide: Module Validator and Fixer](https://github.com/pytorch/opacus/blob/main/tutorials/guide_to_module_validator.ipynb) @@ -119,12 +119,12 @@ Consider citing the report if you use Opacus in your papers, as follows: If you want to learn more about DP-SGD and related topics, check out our series of blogposts and talks: +- [Enabling Fast Gradient Clipping and Ghost Clipping in Opacus](https://pytorch.org/blog/clipping-in-opacus/) - [Differential Privacy Series Part 1 | DP-SGD Algorithm Explained](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3) - [Differential Privacy Series Part 2 | Efficient Per-Sample Gradient Computation in Opacus](https://medium.com/pytorch/differential-privacy-series-part-2-efficient-per-sample-gradient-computation-in-opacus-5bf4031d9e22) - [PriCon 2020 Tutorial: Differentially Private Model Training with Opacus](https://www.youtube.com/watch?v=MWPwofiQMdE&list=PLUNOsx6Az_ZGKQd_p4StdZRFQkCBwnaY6&index=52) - [Differential Privacy on PyTorch | PyTorch Developer Day 2020](https://www.youtube.com/watch?v=l6fbl2CBnq0) - [Opacus v1.0 Highlights | PyTorch Developer Day 2021](https://www.youtube.com/watch?v=U1mszp8lzUI) -- [Enabling Fast Gradient Clipping and Ghost Clipping in Opacus](https://pytorch.org/blog/clipping-in-opacus/) ## FAQ diff --git a/docs/faq.md b/docs/faq.md index ea12bebe..5c1cbdde 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -13,8 +13,8 @@ Yes! Opacus is open-source for public use, and it is licensed under the [Apache ## How can I report a bug or ask a question? -You can report bugs by submitting GitHub issues. To submit a GitHub issue, please [click here](https://github.com/pytorch/opacus/issues). -You can ask questions in our dedicated PyTorch [Discussion Forum](https://discuss.pytorch.org/c/opacus/29). We actively monitor questions in the PyTorch forums with the category `Opacus`. +You can report bugs or ask questions by submitting GitHub issues. To submit a GitHub issue, please [click here](https://github.com/pytorch/opacus/issues). + ## I'd like to contribute to Opacus. How can I do that? @@ -76,7 +76,7 @@ If these interventions don’t help (or the model starts to converge but its pri ## How to deal with out-of-memory errors? -Dealing with per-sample gradients will inevitably put more pressure on your memory: after all, if you want to train with batch size 64, you are looking to keep 64 copies of your parameter gradients. The first sanity check to do is to make sure that you don’t go out of memory with "standard" training (without DP). That should guarantee that you can train with batch size of 1 at least. Then, you can check your memory usage with e.g. `nvidia-smi` as usual, gradually increasing the batch size until you find your sweet spot. Note that this may mean that you still train with small batch size, which comes with its own training behavior (i.e. higher variance between batches). Training with larger batch sizes can be beneficial, and we built `virtual_step` to make this possible while still memory efficient (see *what is virtual batch size* in these FAQs). +Dealing with per-sample gradients will inevitably put more pressure on your memory: after all, if you want to train with batch size 64, you are looking to keep 64 copies of your parameter gradients. The first sanity check to do is to make sure that you don’t go out of memory with "standard" training (without DP). That should guarantee that you can train with batch size of 1 at least. Then, you can check your memory usage with e.g. `nvidia-smi` as usual, gradually increasing the batch size until you find your sweet spot. Note that this may mean that you still train with small batch size, which comes with its own training behavior (i.e. higher variance between batches). Training with larger batch sizes can be beneficial. To this end, we built [Fast Gradient Clipping](https://pytorch.org/blog/clipping-in-opacus/) and `virtual_step` (see *what is virtual batch size* in these FAQs) to make DP-SGD memory efficient. ## What does epsilon=1.1 really mean? How about delta? diff --git a/tutorials/README.md b/tutorials/README.md index 84f9b4f5..476f62fd 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -1,5 +1,5 @@ # Tutorials -This folder contains multiple tutorials to get you started on training differentially private models! +This folder contains multiple tutorials to get you started on training differentially private models! We recommend "building_text_classifier.ipynb" to experiment with latest Opacus features such as Fast Gradient Clipping, LoRA, and fine-tuning Hugging Face Transformers. Note that you may not have all the required packages. You can install opacus's dev version, which will bring in all the required packages in these tutorials: diff --git a/website/pages/tutorials/index.js b/website/pages/tutorials/index.js index a358b256..fdf9ea52 100644 --- a/website/pages/tutorials/index.js +++ b/website/pages/tutorials/index.js @@ -20,7 +20,9 @@ const React = require('react'); const CWD = process.cwd(); -const CompLibrary = require(`${CWD}/node_modules/docusaurus/lib/core/CompLibrary.js`); +const CompLibrary = require( + `${CWD}/node_modules/docusaurus/lib/core/CompLibrary.js`, +); const Container = CompLibrary.Container; const MarkdownBlock = CompLibrary.MarkdownBlock; @@ -69,7 +71,8 @@ class TutorialHome extends React.Component { - Efficient Per-Sample Gradient Computation for More Layers in Opacus + Efficient Per-Sample Gradient Computation for More Layers in + Opacus
  • @@ -81,13 +84,18 @@ class TutorialHome extends React.Component {
  • Videos*

    -

    * Note that Opacus API has changed over time and some of the code samples and demos in the videos may not work. The concepts presented in the videos though are concrete and still valid.

    +

    + * Note that Opacus API has changed over time and some of the code + samples and demos in the videos may not work. The concepts + presented in the videos though are concrete and still valid. +

    1. - PyTorch Developer Day 2021: Fast and Flexible Differential Privacy Framework for PyTorch + PyTorch Developer Day 2021: Fast and Flexible Differential + Privacy Framework for PyTorch
    2. @@ -114,13 +122,6 @@ class TutorialHome extends React.Component { Differentially Private Deep Learning In 20 Lines Of Code
    3. -
    4. - - PySyft + Opacus: Federated Learning With Differential Privacy - -
    diff --git a/website/sphinx/source/grad_sample_module_fast_gradient_clipping.rst b/website/sphinx/source/grad_sample_module_fast_gradient_clipping.rst new file mode 100644 index 00000000..f9d82109 --- /dev/null +++ b/website/sphinx/source/grad_sample_module_fast_gradient_clipping.rst @@ -0,0 +1,5 @@ +GradSampleModuleFastGradientClipping +================ + +.. automodule:: opacus.grad_sample.grad_sample_module_fast_gradient_clipping + :members: diff --git a/website/sphinx/source/index.rst b/website/sphinx/source/index.rst index 06ad71a7..821094cb 100755 --- a/website/sphinx/source/index.rst +++ b/website/sphinx/source/index.rst @@ -13,6 +13,7 @@ Opacus API Reference privacy_engine grad_sample_module + grad_sample_module_fast_gradient_clipping optim/optimizers data_loader accounting/accounting diff --git a/website/sphinx/source/optim/dp_ddp_optimizer_fast_gradient_clipping.rst b/website/sphinx/source/optim/dp_ddp_optimizer_fast_gradient_clipping.rst new file mode 100644 index 00000000..baaf9494 --- /dev/null +++ b/website/sphinx/source/optim/dp_ddp_optimizer_fast_gradient_clipping.rst @@ -0,0 +1,5 @@ +DistributedDPOptimizerFastGradientClipping +============== + +.. automodule:: opacus.optimizers.ddpoptimizer_fast_gradient_clipping + :members: diff --git a/website/sphinx/source/optim/dp_optimizer_fast_gradient_clipping.rst b/website/sphinx/source/optim/dp_optimizer_fast_gradient_clipping.rst new file mode 100644 index 00000000..8d3c9e7f --- /dev/null +++ b/website/sphinx/source/optim/dp_optimizer_fast_gradient_clipping.rst @@ -0,0 +1,5 @@ +DPOptimizerFastGradientClipping +============== + +.. automodule:: opacus.optimizers.optimizer_fast_gradient_clipping + :members: diff --git a/website/sphinx/source/optim/optimizers.rst b/website/sphinx/source/optim/optimizers.rst index 5a3c1fd3..7bda1032 100644 --- a/website/sphinx/source/optim/optimizers.rst +++ b/website/sphinx/source/optim/optimizers.rst @@ -3,7 +3,8 @@ Optimizers .. toctree:: dp_optimizer + dp_optimizer_fast_gradient_clipping dp_per_layer_optimizer dp_ddp_optimizer + dp_ddp_optimizer_fast_gradient_clipping dp_ddp_per_layer_optimizer - diff --git a/website/sphinx/source/utils/fast_gradient_clipping_utils.rst b/website/sphinx/source/utils/fast_gradient_clipping_utils.rst new file mode 100644 index 00000000..89670148 --- /dev/null +++ b/website/sphinx/source/utils/fast_gradient_clipping_utils.rst @@ -0,0 +1,5 @@ +Fast Gradient Clipping Utils +============= + +.. automodule:: opacus.utils.fast_gradient_clipping_utils + :members: diff --git a/website/sphinx/source/utils/utils.rst b/website/sphinx/source/utils/utils.rst index 9257a70f..1eb32a6e 100644 --- a/website/sphinx/source/utils/utils.rst +++ b/website/sphinx/source/utils/utils.rst @@ -6,3 +6,4 @@ Utils tensor_utils packed_sequences uniform_sampler + fast_gradient_clipping_utils From 57f03494ffb1ce9d3ed6d50ed7e7ad4b39496cc6 Mon Sep 17 00:00:00 2001 From: Iden Kalemaj Date: Fri, 31 Jan 2025 10:51:17 -0800 Subject: [PATCH 8/8] Adaptive Clipping (with Ghost Clipping) (#711) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/711 Added adaptive clipping as a capability for Opacus. Supported only with ghost-clipping. Distributed data parallel training is supported. Reviewed By: HuanyuZhang Differential Revision: D67522957 fbshipit-source-id: 1ea05d22cf9ee1e16f891fbe8dc9d46a48dfb92e --- opacus/tests/multigpu_adaptive_clipping.py | 155 ++++++++++ opacus/utils/adaptive_clipping/README.md | 53 ++++ opacus/utils/adaptive_clipping/__init__.py | 12 + .../adaptive_clipping_utils.py | 272 ++++++++++++++++++ 4 files changed, 492 insertions(+) create mode 100644 opacus/tests/multigpu_adaptive_clipping.py create mode 100644 opacus/utils/adaptive_clipping/README.md create mode 100644 opacus/utils/adaptive_clipping/__init__.py create mode 100644 opacus/utils/adaptive_clipping/adaptive_clipping_utils.py diff --git a/opacus/tests/multigpu_adaptive_clipping.py b/opacus/tests/multigpu_adaptive_clipping.py new file mode 100644 index 00000000..ffd7c6f1 --- /dev/null +++ b/opacus/tests/multigpu_adaptive_clipping.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import ( + DistributedDPOptimizerFastGradientClipping, +) +from opacus.utils.adaptive_clipping.adaptive_clipping_utils import ( + PrivacyEngineAdaptiveClipping, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.distributed import DistributedSampler + + +def setup(rank, world_size): + if sys.platform == "win32": + raise ValueError("Windows platform is not supported for this test") + else: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + torch.distributed.init_process_group( + init_method="env://", + backend="nccl", + ) + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def demo_basic(rank, weight, world_size, dp): + torch.manual_seed(world_size) + batch_size = 32 + setup(rank, world_size) + + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + model.net1.weight.data.zero_() + optimizer = optim.SGD(model.parameters(), lr=1) + + # create dataset + labels = torch.randn(2 * batch_size, 5).to(rank) + data = torch.randn(2 * batch_size, 10) + dataset = TensorDataset(data, labels) + + criterion = nn.CrossEntropyLoss(reduction="mean") + + max_grad_norm = 1e8 + + ddp_model = DDP(model, device_ids=[rank]) + + privacy_engine = PrivacyEngineAdaptiveClipping() + + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) + + if dp: + ddp_model, optimizer, criterion, data_loader = privacy_engine.make_private( + module=ddp_model, + optimizer=optimizer, + criterion=criterion, + data_loader=data_loader, + noise_multiplier=0, + max_grad_norm=max_grad_norm, + poisson_sampling=False, + grad_sample_mode="ghost", + target_unclipped_quantile=1.0, + ) + assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping) + + for x, y in data_loader: + outputs = ddp_model(x.to(rank)) + loss = criterion(outputs, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + break + + weight.copy_(model.net1.weight.data.cpu()) + cleanup() + + +def run_demo(demo_fn, weight, world_size, dp): + mp.spawn( + demo_fn, + args=(weight, world_size, dp), + nprocs=world_size, + join=True, + ) + + +class GradientComputationTestAdaptiveClipping(unittest.TestCase): + def test_gradient_correct_adaptive(self) -> None: + + # Tests that gradient is the same with DP or without DP in the distributed setting + n_gpus = torch.cuda.device_count() + self.assertTrue( + n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." + ) + + weight_dp, weight_nodp = torch.ones(10, 10), torch.ones(10, 10) + + run_demo( + demo_basic, + weight_nodp, + 2, + dp=False, + ) + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + ) + + self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)) diff --git a/opacus/utils/adaptive_clipping/README.md b/opacus/utils/adaptive_clipping/README.md new file mode 100644 index 00000000..1f7923eb --- /dev/null +++ b/opacus/utils/adaptive_clipping/README.md @@ -0,0 +1,53 @@ +# Adaptive Clipping (with Ghost Clipping) + +Adaptive clipping [1] adapts the clipping norm (and amount of noise) during training to a quantile of per-sample gradient norms. It can reduce hyper-parameter tuning efforts and improve model accuracy by injecting less noise. + +It is supported with: +- Ghost clipping +- Distributed data parallel training + +It is **not** currently supported with: +- Vanilla DP-SGD +- Virtual batch sizes via Batch Memory Manager + +## Overview + +`PrivacyEngineAdaptiveClipping` is the entry-point for adaptive clipping training. It extends `PrivacyEngine` with additional arguments for adaptive clipping: + +* `target_unclipped_quantile`: the quantile of per-sample gradient norms at which to clip (between 0 and 1) +* `min_clipbound`: the minimum allowed clipping norm +* `max_clipbound`: the maximum allowed clipping norm +* `clipbound_learning_rate`: the learning rate for tracking the true quantile +* `max_grad_norm`: the initial clipping norm (used at step 0) + +The main hyper-parameter to tune is `target_unclipped_quantile`, which replaces tuning the clipping norm (`max_grad_norm`) in constant clipping DP-SGD. This parameter can be easier to tune, since the search is over a smaller range of values. + + +## Example usage + +```python +from opacus.utils.adaptive_clipping.adaptive_clipping_utils import PrivacyEngineAdaptiveClipping + +# ... +privacy_engine = PrivacyEngineAdaptiveClipping() +model, optimizer, criterion, train_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=train_loader, + criterion=criterion, + noise_multiplier=args.sigma, + max_grad_norm=10, # initial clipping norm + grad_sample_mode="ghost", + target_unclipped_quantile=0.5, # key parameter, may need tuning + min_clipbound=1, # default value + max_clipbound=1e8, # default value + clipbound_learning_rate=0.2 # default value, tuning not recommended +) +# ... +``` + +Note that `grad_sample_mode` must be set to `"ghost"` for adaptive clipping to work. + +## References + +[1] Galen Andrew, Om Thakkar, H. Brendan McMahan, Swaroop Ramaswamy, "Differentially Private Learning with Adaptive Clipping", NeurIPS, 2021. diff --git a/opacus/utils/adaptive_clipping/__init__.py b/opacus/utils/adaptive_clipping/__init__.py new file mode 100644 index 00000000..9a5f1913 --- /dev/null +++ b/opacus/utils/adaptive_clipping/__init__.py @@ -0,0 +1,12 @@ +from .adaptive_clipping_utils import ( + DPLossFastGradientAdaptiveClipping, + DPTensorFastGradientAdaptiveClipping, + PrivacyEngineAdaptiveClipping, +) + + +__all__ = [ + "DPTensorFastGradientAdaptiveClipping", + "DPLossFastGradientAdaptiveClipping", + "PrivacyEngineAdaptiveClipping", +] diff --git a/opacus/utils/adaptive_clipping/adaptive_clipping_utils.py b/opacus/utils/adaptive_clipping/adaptive_clipping_utils.py new file mode 100644 index 00000000..0862c675 --- /dev/null +++ b/opacus/utils/adaptive_clipping/adaptive_clipping_utils.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP +from opacus.grad_sample import GradSampleModule +from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import ( + GradSampleModuleFastGradientClipping, +) +from opacus.optimizers import DPOptimizerFastGradientClipping +from opacus.privacy_engine import PrivacyEngine +from opacus.utils.fast_gradient_clipping_utils import ( + DPLossFastGradientClipping, + DPTensorFastGradientClipping, +) +from torch.nn.parallel import DistributedDataParallel as DDP + + +class DPTensorFastGradientAdaptiveClipping(DPTensorFastGradientClipping): + """ + Packages the training loop for Adaptive clipping (with Fast Gradient and Ghost Clipping) into loss.backward(). + Differently from DPTensorFastGradientClipping, the clipping norm is updated with each backward pass. + The clipping norm track a quantile of the per-sample gradient norms. + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + loss_per_sample: torch.Tensor, + loss_reduction: str = "mean", + target_unclipped_quantile: float = 0.5, + min_clipbound: float = 1, + max_clipbound: float = 1e8, + clipbound_learning_rate: float = 0.2, + initial_noise_multiplier: float = 1.0, + ): + """ + + Args: + module: the module to train + optimizer: the optimizer used to train the module + loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + target_unclipped_quantile: target quantile for unclipped gradients, between 0 and 1 + min_clipbound: minimum clipping norm allowed + max_clipbound: maximum clipping norm allowed + clipbound_learning_rate: learning rate for the descent algorithm that finds the target unclipped quantile + initial_noise_multiplier: initial noise multiplier provided at step 0 + + """ + + super().__init__(module, optimizer, loss_per_sample, loss_reduction) + + self.target_unclipped_quantile = target_unclipped_quantile + self.min_clipbound = min_clipbound + self.max_clipbound = max_clipbound + self.clipbound_learning_rate = clipbound_learning_rate + self.initial_clipping_norm = self.optimizer.max_grad_norm + self.initial_noise_multiplier = initial_noise_multiplier + + def backward(self): + """ + Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between. + In addition, the clipping norm is updated between the two backward passes according to a quantile of the per-sample gradient norms. + """ + + if self.loss_reduction == "mean": + reduced_loss = torch.mean(self.loss_per_sample, dim=0) + elif self.loss_reduction == "sum": + reduced_loss = torch.sum(self.loss_per_sample, dim=0) + else: + raise ValueError( + f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + reduced_loss.backward(retain_graph=True) + self.optimizer.zero_grad() + + # calc per_sample gradient norms + per_sample_norms = self.module.get_norm_sample() + + # calculate new max grad norm and noise multiplier + new_max_grad_norm, new_noise_multiplier = self._update_clip_and_noise( + per_sample_norms + ) + + # update max grad norm and noise multiplier + self.module.max_grad_norm = new_max_grad_norm + self.optimizer.max_grad_norm = new_max_grad_norm + self.optimizer.noise_multiplier = new_noise_multiplier + + # get the loss rescaling coefficients using the updated max_grad_norm + coeff = torch.where( + per_sample_norms <= self.module.max_grad_norm, + torch.ones_like(per_sample_norms), + self.module.max_grad_norm / per_sample_norms, + ) # per-sample coeff, shape = [batch_size] + + second_loss_per_sample = coeff * self.loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.module.disable_hooks() + second_loss.backward() + self.module.enable_hooks() + + def _is_distributed(self): + + return isinstance(self.module, (DPDDP, DDP)) + + def _update_clip_and_noise(self, per_sample_norms): + + assert ( + self.module.max_grad_norm == self.optimizer.max_grad_norm + ), "Max grad norm does not match between optimizer and model." + + # calculate new max_grad_norm + current_max_norm = self.module.max_grad_norm + local_batch_size = len(per_sample_norms) + local_unclipped_num = (per_sample_norms <= current_max_norm).sum() + + if self._is_distributed(): + # pair the two variables in one tensor to perform only one all_reduce call + global_unclipped_and_batch = torch.tensor( + [local_unclipped_num, local_batch_size] + ) + torch.distributed.all_reduce( + global_unclipped_and_batch, op=torch.distributed.ReduceOp.SUM + ) + unclipped_num = global_unclipped_and_batch[0].item() + batch_size = global_unclipped_and_batch[1].item() + else: + unclipped_num = local_unclipped_num + batch_size = local_batch_size + + unclipped_num_std = ( + batch_size / 20.0 + ) # use heuristic from [ATMR'22, https://arxiv.org/pdf/1905.03871] + unclipped_num = ( + unclipped_num + + torch.normal(mean=0.0, std=unclipped_num_std, size=(1,)).item() + ) + unclipped_frac = unclipped_num / batch_size + + new_max_grad_norm = current_max_norm * torch.exp( + -self.clipbound_learning_rate + * (unclipped_frac - self.target_unclipped_quantile) + ) + new_max_grad_norm = new_max_grad_norm.clamp( + min=self.min_clipbound, max=self.max_clipbound + ).item() + + # the following ensures that the updated noise multiplier is a real number + assert ( + batch_size > 10 * self.initial_noise_multiplier + ), "Batch size is too small. For adaptive clipping, please use a batch size larger than 10 * noise_multiplier." + if self.initial_noise_multiplier > 0: + # From Theorem 1 in [ATMR'22, https://arxiv.org/pdf/1905.03871] + # The factor of 2.0 comes from the recentering of a binary bit + # Privacy definition is for add/remove DP + # Note: For uniform batches, the batch size is public + # For Poisson batches, the batch size is private, and the computation below leaks some privacy, as it assumes a known batch size. + # We currently ignore the privacy leak due to the the private batch size for Poisson subsampling of batches + new_noise_multiplier = ( + self.initial_noise_multiplier ** (-2) + - (2.0 * unclipped_num_std) ** (-2) + ) ** (-1 / 2.0) + else: + new_noise_multiplier = self.initial_noise_multiplier + + return new_max_grad_norm, new_noise_multiplier + + +class DPLossFastGradientAdaptiveClipping(DPLossFastGradientClipping): + """ + Wrapper on the loss function to be used with Adaptive Clipping (together with Fast Gradient and Ghost Clipping). + It computes the per-sample loss, and wraps it in DPTensorFastGradientAdaptiveClipping. + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + criterion, + loss_reduction: str = "mean", + target_unclipped_quantile: float = 0.5, + min_clipbound: float = 1, + max_clipbound: float = 1e8, + clipbound_learning_rate: float = 0.2, + initial_noise_multiplier: float = 1.0, + ): + + super().__init__(module, optimizer, criterion, loss_reduction) + + self.target_unclipped_quantile = target_unclipped_quantile + self.min_clipbound = min_clipbound + self.max_clipbound = max_clipbound + self.clipbound_learning_rate = clipbound_learning_rate + self.initial_noise_multiplier = initial_noise_multiplier + + def __call__(self, input, target) -> DPTensorFastGradientAdaptiveClipping: + """ + Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientAdaptiveClipping + """ + + loss_per_sample = self.criterion( + input, + target, + ) + return DPTensorFastGradientAdaptiveClipping( + self.module, + self.optimizer, + loss_per_sample, + self.loss_reduction, + self.target_unclipped_quantile, + self.min_clipbound, + self.max_clipbound, + self.clipbound_learning_rate, + self.initial_noise_multiplier, + ) + + +class PrivacyEngineAdaptiveClipping(PrivacyEngine): + + def __init__(self, *, accountant: str = "prv", secure_mode: bool = False): + super().__init__(accountant=accountant, secure_mode=secure_mode) + + def _prepare_criterion( + self, + *, + module: GradSampleModule, + optimizer: DPOptimizerFastGradientClipping, + criterion=torch.nn.CrossEntropyLoss(), + loss_reduction: str = "mean", + target_unclipped_quantile: float = 0.5, + min_clipbound: float = 1, + max_clipbound: float = 1e8, + clipbound_learning_rate: float = 0.2, + **kwargs, + ) -> DPLossFastGradientAdaptiveClipping: + """ + Args: + module: the module to train + optimizer: the optimizer used to train the module + criterion: the loss function used to train the module + loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) + target_unclipped_quantile: target quantile for unclipped gradients, between 0 and 1 + min_clipbound: minimum clipping norm allowed + max_clipbound: maximum clipping norm allowed + clipbound_learning_rate: learning rate for the descent algorithm that finds the target unclipped quantile + """ + + return DPLossFastGradientAdaptiveClipping( + module, + optimizer, + criterion, + loss_reduction=loss_reduction, + target_unclipped_quantile=target_unclipped_quantile, + min_clipbound=min_clipbound, + max_clipbound=max_clipbound, + clipbound_learning_rate=clipbound_learning_rate, + initial_noise_multiplier=optimizer.noise_multiplier, + )