@@ -829,9 +829,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
829
829
# gradient update with accumulated gradients
830
830
if ((self .batch_idx + 1 ) % self .accumulate_grad_batches == 0
831
831
or (self .batch_idx + 1 ) == self .num_training_batches ):
832
+ # hook
833
+ grad_norm_dic = self .train_loop .on_before_backward (batch_idx , optimizer )
834
+
835
+ # optimizer step
836
+ self .train_loop .optimizer_step (optimizer , opt_idx , batch_idx , split_batch )
837
+
838
+ # hook
839
+ self .train_loop .on_before_zero_grad (optimizer )
832
840
833
- # backward
834
- grad_norm_dic = self .run_batch_backward_pass ( split_batch , batch_idx , opt_idx , optimizer )
841
+ # clear gradients
842
+ self .train_loop . optimizer_zero_grad ( batch_idx , optimizer , opt_idx )
835
843
836
844
# calculate running loss for display
837
845
self .running_loss .append (self .batch_loss_value .mean () * self .accumulate_grad_batches )
@@ -854,15 +862,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
854
862
)
855
863
return result
856
864
857
- def run_batch_backward_pass (self , split_batch , batch_idx , opt_idx , optimizer ):
858
- # hook
859
- grad_norm_dic = self .train_loop .on_before_backward (batch_idx , optimizer )
860
-
861
- # optimizer step (TODO: decouple zero grad)
862
- self .train_loop .optimizer_step (optimizer , opt_idx , batch_idx , split_batch )
863
-
864
- return grad_norm_dic
865
-
866
865
def optimizer_closure (self , split_batch , batch_idx , opt_idx , optimizer , hiddens ):
867
866
"""
868
867
wrap the forward step in a closure so second order methods work
0 commit comments