diff --git a/opacus/grad_sample/gsm_base.py b/opacus/grad_sample/gsm_base.py index f6947fae..ec137789 100644 --- a/opacus/grad_sample/gsm_base.py +++ b/opacus/grad_sample/gsm_base.py @@ -24,7 +24,12 @@ logger = logging.getLogger(__name__) -OPACUS_PARAM_MONKEYPATCH_ATTRS = ["_forward_counter", "_current_grad_sample"] +OPACUS_PARAM_MONKEYPATCH_ATTRS = [ + "grad_sample", + "_forward_counter", + "_current_grad_sample", + "_norm_sample", +] class AbstractGradSampleModule(nn.Module, ABC): @@ -131,18 +136,15 @@ def to_standard_module(self) -> nn.Module: return self._module def _close(self): - self.del_grad_sample() - self._clean_up_attributes() - - def __repr__(self): - return f"{type(self).__name__}({self._module.__repr__()})" - - def _clean_up_attributes(self): + # Clean up attributes for attr in OPACUS_PARAM_MONKEYPATCH_ATTRS: for p in self.parameters(): if hasattr(p, attr): delattr(p, attr) + def __repr__(self): + return f"{type(self).__name__}({self._module.__repr__()})" + def forbid_grad_accumulation(self): """ Sets a flag to detect gradient accumulation (multiple forward/backward passes diff --git a/opacus/tests/grad_sample_module_test.py b/opacus/tests/grad_sample_module_test.py index bb2d3911..85c4ead4 100644 --- a/opacus/tests/grad_sample_module_test.py +++ b/opacus/tests/grad_sample_module_test.py @@ -130,7 +130,7 @@ def test_to_standard_module(self): self.original_model.state_dict(), strict=True, ) - new_grad_sample_module = GradSampleModule( + new_grad_sample_module = self.CLS( copy_of_original_model, batch_first=True, loss_reduction="mean" )