Skip to content

Added some more potentially robust ways to do learning rate tuning #19867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/lightning/pytorch/callbacks/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Finds optimal learning rate
"""

from typing import Optional
from typing import Literal, Optional

from typing_extensions import override

Expand Down Expand Up @@ -50,6 +50,7 @@ class LearningRateFinder(Callback):
update_attr: Whether to update the learning rate attribute or not.
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
automatically detected. Otherwise, set the name here.
opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``.

Example::

Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = True,
attr_name: str = "",
opt_method: Literal["gradient", "slide", "valley"] = "gradient",
) -> None:
mode = mode.lower()
if mode not in self.SUPPORTED_MODES:
Expand All @@ -104,7 +106,7 @@ def __init__(
self._early_stop_threshold = early_stop_threshold
self._update_attr = update_attr
self._attr_name = attr_name

self._opt_method = opt_method
self._early_exit = False
self.lr_finder: Optional[_LRFinder] = None

Expand All @@ -120,6 +122,7 @@ def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Non
early_stop_threshold=self._early_stop_threshold,
update_attr=self._update_attr,
attr_name=self._attr_name,
opt_method=self._opt_method,
)

if self._early_exit:
Expand Down
93 changes: 81 additions & 12 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import uuid
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

import torch
from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -78,6 +78,8 @@ class _LRFinder:

num_training: number of steps to take between lr_min and lr_max

opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``.

Example::
# Run lr finder
lr_finder = trainer.lr_find(model)
Expand All @@ -93,17 +95,29 @@ class _LRFinder:

"""

def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
def __init__(
self,
mode: str,
lr_min: float,
lr_max: float,
num_training: int,
opt_method: Literal["gradient", "slide", "valley", "valley_grad"] = "gradient",
opt_parameters: Dict[str, float | int] = None,
) -> None:
assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`"

self.mode = mode
self.lr_min = lr_min
self.lr_max = lr_max
self.num_training = num_training

self.opt_method = opt_method
self.results: Dict[str, Any] = {}
self._total_batch_idx = 0 # for debug purpose

self.opt_parameters = opt_parameters
if self.opt_parameters is None:
self.opt_parameters = {}

def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
# TODO: update docs here
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
Expand Down Expand Up @@ -167,6 +181,8 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] =
_ = self.suggestion()
if self._optimal_idx:
ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker="o", color="red")
elif self._optimal_lr:
ax.axvline(self._optimal_lr, linestyle="--")

if show:
plt.show()
Expand All @@ -188,8 +204,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
"""
losses = torch.tensor(self.results["loss"][skip_begin:-skip_end])
losses = losses[torch.isfinite(losses)]
lrs = self.results["lr"][skip_begin:-skip_end]

if len(losses) < 2:
self._optimal_lr = None
if self.opt_method == "gradient" and len(losses) < 2:
# computing np.gradient requires at least 2 points
log.error(
"Failed to compute suggestion for learning rate because there are not enough points. Increase the loop"
Expand All @@ -198,13 +216,62 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
self._optimal_idx = None
return None

# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
# incorrectly shifted by an offset
gradients = torch.gradient(losses)[0] # Unpack the tuple
min_grad = torch.argmin(gradients).item()

self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
if self.opt_method == "gradient":
# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
# incorrectly shifted by an offset
gradients = torch.gradient(losses)[0] # Unpack the tuple
min_grad = torch.argmin(gradients).item()

self._optimal_idx = min_grad + skip_begin
opt_lr = self.results["lr"][self._optimal_idx]
elif self.opt_method == "slide":
# See https://forums.fast.ai/t/automated-learning-rate-suggester/44199 "slide" method
loss_t = self.opt_parameters.get("loss_threshold", 0.5)
lr_diff = self.opt_parameters.get("lr_diff", 15)
adjust_value = self.opt_parameters.get("adjust_value", 1.0)
r_idx = -1
l_idx = r_idx - lr_diff
gradients = torch.gradient(losses)[0] # Unpack the tuple

while (l_idx >= -len(losses)) and (abs(gradients[r_idx] - gradients[l_idx]) > loss_t):
local_min_lr = lrs[l_idx]
r_idx -= 1
l_idx -= 1
opt_lr = local_min_lr * adjust_value
elif self.opt_method in ["valley", "valley_grad"]:
# See https://forums.fast.ai/t/automated-learning-rate-suggester/44199 "valley" method
n = len(losses)
max_start = 0
max_end = 0

# finding the longest valley.
lds = [1] * n

for i in range(1, n):
for j in range(0, i):
if losses[i] < losses[j] and lds[i] < lds[j] + 1:
lds[i] = lds[j] + 1
if lds[max_end] < lds[i]:
max_end = i
max_start = max_end - lds[max_end]

sections = (max_end - max_start) / 3
valley_lip_idx = (
max_start + int(sections) + int(sections / 2)
) + skip_begin # pick something midway, or 2/3rd of the way to be more aggressive
if self.opt_method == "valley":
self._optimal_idx = valley_lip_idx
# Look for grad minimum inside the feasible region
else:
feasible_region = slice(valley_lip_idx, valley_lip_idx + losses[valley_lip_idx:].argmin())
gradients = torch.gradient(losses)[0] # Unpack the tuple
self._optimal_idx = gradients[feasible_region].argmin() + valley_lip_idx

opt_lr = self.results["lr"][self._optimal_idx]

self._optimal_lr = opt_lr

return opt_lr


def _lr_find(
Expand All @@ -217,6 +284,7 @@ def _lr_find(
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = False,
attr_name: str = "",
opt_method: Literal["gradient", "slide", "valley", "valley_grad"] = "gradient",
) -> Optional[_LRFinder]:
"""Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking
a good starting learning rate.
Expand All @@ -238,6 +306,7 @@ def _lr_find(
update_attr: Whether to update the learning rate attribute or not.
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
automatically detected. Otherwise, set the name here.
opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``.

"""
if trainer.fast_dev_run:
Expand Down Expand Up @@ -266,7 +335,7 @@ def _lr_find(
trainer.progress_bar_callback.disable()

# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training, opt_method=opt_method)

# Configure optimizer and scheduler
lr_finder._exchange_scheduler(trainer)
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def lr_find(
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = True,
attr_name: str = "",
opt_method: Literal["gradient", "slide", "valley"] = "gradient",
) -> Optional["_LRFinder"]:
"""Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
picking a good starting learning rate.
Expand Down Expand Up @@ -148,7 +149,7 @@ def lr_find(
update_attr: Whether to update the learning rate attribute or not.
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
automatically detected. Otherwise, set the name here.

opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``.
Raises:
MisconfigurationException:
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden,
Expand All @@ -172,6 +173,7 @@ def lr_find(
early_stop_threshold=early_stop_threshold,
update_attr=update_attr,
attr_name=attr_name,
opt_method=opt_method,
)

lr_finder_callback._early_exit = True
Expand Down