Skip to content

Commit df7e064

Browse files
ref: inner train loop (intermediate step) 8/n" (Lightning-AI#3367)
* ref: inner train loop (intermediate step) 7/n * ref: inner train loop (intermediate step) 8/n
1 parent dcbfd09 commit df7e064

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

pytorch_lightning/trainer/training_loop.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -829,9 +829,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
829829
# gradient update with accumulated gradients
830830
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
831831
or (self.batch_idx + 1) == self.num_training_batches):
832+
# hook
833+
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)
834+
835+
# optimizer step
836+
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)
837+
838+
# hook
839+
self.train_loop.on_before_zero_grad(optimizer)
832840

833-
# backward
834-
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
841+
# clear gradients
842+
self.train_loop.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
835843

836844
# calculate running loss for display
837845
self.running_loss.append(self.batch_loss_value.mean() * self.accumulate_grad_batches)
@@ -854,15 +862,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
854862
)
855863
return result
856864

857-
def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
858-
# hook
859-
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)
860-
861-
# optimizer step (TODO: decouple zero grad)
862-
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)
863-
864-
return grad_norm_dic
865-
866865
def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
867866
"""
868867
wrap the forward step in a closure so second order methods work

pytorch_lightning/trainer/training_loop_temp.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
217217
def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
218218
# calls .step(), .zero_grad()
219219
# override function to modify this behavior
220-
model = self.trainer.get_model()
221220

222221
with self.trainer.profiler.profile('optimizer_step'):
223222
lambda_closure = lambda: self.trainer.optimizer_closure(
@@ -231,11 +230,12 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
231230
# optimizer step lightningModule hook
232231
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure)
233232

234-
# hook
235-
model.on_before_zero_grad(optimizer)
233+
def on_before_zero_grad(self, optimizer):
234+
model = self.trainer.get_model()
235+
model.on_before_zero_grad(optimizer)
236236

237-
# clear gradients
238-
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
237+
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
238+
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
239239

240240
def on_before_backward(self, batch_idx, optimizer):
241241
# track gradient norms

0 commit comments

Comments
 (0)