From 9eb7875da2637a659e95723039ea9c2d0288b56b Mon Sep 17 00:00:00 2001 From: Enayat Ullah Date: Tue, 11 Feb 2025 13:46:50 -0800 Subject: [PATCH] Fix freezing modules in Ghost Clipping (#729) Summary: Freezing modules with ghost clipping throws an error as corresponding per-sample norms are (not) calculated. Fix: keep in memory the list of all parameters and checking if corresponding requires_grad is True when calculating norms. Further, unfreezing modules (with and without ghost clipping) wasn't supported because the hooks aren't present for the corresponding modules. Fix: rewrite `requires_grad_' to add the hook. Facebook We initially used a `trainable_parameters(module)` to traverse the list of trainable modules upon norm computation. It was slow because `trainable_parameters(module)` is a generator and it traverses the neural network graph overtime. We replaced it with a list of trainable parameters fixed during model creation time. This is what lead to issues with freezing modules as this list is not updated. Fix: Use **all parameters** **list** -- not a generator, so no traversal happens. Further, we check `requires_grad` when calculating per-sample norm to ascertain whether to compute it or not. This is how this check is done in (non-private) [optimizer](https://github.com/pytorch/pytorch/blob/5725462cd8679dd1dea8a469b1bf2e71f226b664/torch/optim/optimizer.py#L963) to determine which parameters are frozen or not. Differential Revision: D68656459 --- opacus/grad_sample/grad_sample_module.py | 14 ++++++++++++++ .../grad_sample_module_fast_gradient_clipping.py | 9 +++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index 19b5ffa6..12111531 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -145,6 +145,20 @@ def __init__( force_functorch=force_functorch, ) + def requires_grad_(self, requires_grad: bool = True) -> nn.Module: + "Rewrite requires_grad_ to add/remove hooks based on requires_grad value" + if requires_grad: + # Attack hook to the module + self.add_hooks( + loss_reduction=self.loss_reduction, + batch_first=self.batch_first, + force_functorch=self.force_functorch, + ) + else: + # Remove hooks + self.remove_hooks() + return super().requires_grad_(requires_grad) + def forward(self, *args, **kwargs): return self._module(*args, **kwargs) diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index 5a9adbb9..3214843f 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -117,7 +117,7 @@ def __init__( strict=strict, force_functorch=force_functorch, ) - self.trainable_parameters = [p for _, p in trainable_parameters(self._module)] + self.all_parameters = [p for p in self.parameters()] self.max_grad_norm = max_grad_norm self.use_ghost_clipping = use_ghost_clipping self._per_sample_gradient_norms = None @@ -130,7 +130,12 @@ def get_clipping_coef(self) -> torch.Tensor: def get_norm_sample(self) -> torch.Tensor: """Get per-example gradient norms.""" norm_sample = torch.stack( - [param._norm_sample for param in self.trainable_parameters], dim=0 + [ + param._norm_sample + for param in self.all_parameters + if param.requires_grad + ], + dim=0, ).norm(2, dim=0) self.per_sample_gradient_norms = norm_sample return norm_sample