Description
Description & Motivation
Hey,
As far as I can tell, if you want to continuously save .ckpt files to an MLflow server during a training run, the best approach is to use the MLFlowLogger flag log_model="all".
However, this comes with two issues:
This significantly increases storage use in the artifact store.
It contradicts the ModelCheckpoint flag save_top_k.
If save_top_k is set, it only retains these checkpoints locally.
If the experiment crashes mid-run, these checkpoints are lost.
Without guaranteed local persistence, this isn't ideal for long-running or cloud-based training workflows.
Pitch
A new feature allowing top-k checkpoints to be upserted in MLflow and other loggers would be incredibly useful.
Proposed behavior:
save_top_k=2 is set.
A new checkpoint is created → It is upserted in the logger.
Another checkpoint is created → It is upserted in the logger.
A new, better checkpoint replaces an old one → It is upserted, and a previous one is deleted.
Proposed Implementation
This logic could be integrated here, where files are already being removed locally. The change would involve adding logger-specific removal functionality in addition to the local deletion.
Would love to hear thoughts on this!
Alternatives
No response
Additional context
Here is a script to demonstrate the pain:
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Change these as wanted
os.environ["MLFLOW_TRACKING_USERNAME"] = ""
os.environ["MLFLOW_TRACKING_PASSWORD"] = ""
class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
self.log("train_loss", loss)
if self.current_epoch == 34:
raise Exception("Forced failure at epoch 35")
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
self.log("val_loss", loss)
def configure_optimizers(self):
return optim.SGD(self.parameters(), lr=0.01)
def train():
# If local use this uri
URI = "http://localhost:5000"
mlflow_logger = MLFlowLogger(
experiment_name="harry-test",
tracking_uri=URI,
log_model=True,
)
x_train, y_train = torch.randn(100, 10), torch.randn(100, 1)
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32)
x_val, y_val = torch.randn(20, 10), torch.randn(20, 1)
val_dataset = TensorDataset(x_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=32)
model = SimpleModel()
checkpoint_callback_train = ModelCheckpoint(
monitor="train_loss",
filename="best_train_model-{epoch:02d}-{train_loss:.2f}",
save_top_k=2,
mode="min",
)
checkpoint_callback_val = ModelCheckpoint(
monitor="val_loss",
filename="best_val_model-{epoch:02d}-{val_loss:.2f}",
save_top_k=2,
mode="min",
)
trainer = pl.Trainer(
max_epochs=40,
logger=mlflow_logger,
callbacks=[checkpoint_callback_train, checkpoint_callback_val],
val_check_interval=3,
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
train()