Skip to content

MLFlow logger - save top k to server on N epochs #20584

Open
@HarryAnkers

Description

@HarryAnkers

Description & Motivation

Hey,
As far as I can tell, if you want to continuously save .ckpt files to an MLflow server during a training run, the best approach is to use the MLFlowLogger flag log_model="all".

However, this comes with two issues:

This significantly increases storage use in the artifact store.
It contradicts the ModelCheckpoint flag save_top_k.
If save_top_k is set, it only retains these checkpoints locally.
If the experiment crashes mid-run, these checkpoints are lost.
Without guaranteed local persistence, this isn't ideal for long-running or cloud-based training workflows.

Pitch

A new feature allowing top-k checkpoints to be upserted in MLflow and other loggers would be incredibly useful.

Proposed behavior:

save_top_k=2 is set.
A new checkpoint is created → It is upserted in the logger.
Another checkpoint is created → It is upserted in the logger.
A new, better checkpoint replaces an old one → It is upserted, and a previous one is deleted.
Proposed Implementation
This logic could be integrated here, where files are already being removed locally. The change would involve adding logger-specific removal functionality in addition to the local deletion.

Would love to hear thoughts on this!

Alternatives

No response

Additional context

Here is a script to demonstrate the pain:

import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Change these as wanted
os.environ["MLFLOW_TRACKING_USERNAME"] = ""
os.environ["MLFLOW_TRACKING_PASSWORD"] = ""


class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss)
        if self.current_epoch == 34:
            raise Exception("Forced failure at epoch 35")
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=0.01)


def train():
    # If local use this uri
    URI = "http://localhost:5000"

    mlflow_logger = MLFlowLogger(
        experiment_name="harry-test",
        tracking_uri=URI,
        log_model=True,
    )

    x_train, y_train = torch.randn(100, 10), torch.randn(100, 1)
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=32)

    x_val, y_val = torch.randn(20, 10), torch.randn(20, 1)
    val_dataset = TensorDataset(x_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=32)

    model = SimpleModel()

    checkpoint_callback_train = ModelCheckpoint(
        monitor="train_loss",
        filename="best_train_model-{epoch:02d}-{train_loss:.2f}",
        save_top_k=2,
        mode="min",
    )

    checkpoint_callback_val = ModelCheckpoint(
        monitor="val_loss",
        filename="best_val_model-{epoch:02d}-{val_loss:.2f}",
        save_top_k=2,
        mode="min",
    )

    trainer = pl.Trainer(
        max_epochs=40,
        logger=mlflow_logger,
        callbacks=[checkpoint_callback_train, checkpoint_callback_val],
        val_check_interval=3,
    )

    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    train()

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementneeds triageWaiting to be triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions