From 072a5cf978dd5c8f5912a47b19206aa1932e5068 Mon Sep 17 00:00:00 2001 From: harryankers Date: Tue, 11 Feb 2025 13:44:47 +0000 Subject: [PATCH 1/6] fix(mlflow): Enabling multiple callbacks for checkpoint reporting --- src/lightning/pytorch/loggers/mlflow.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index e3d99987b7f58..5946d8bec82c1 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -142,7 +142,7 @@ def __init__( self.tags = tags self._log_model = log_model self._logged_model_time: dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._checkpoint_callbacks: list[ModelCheckpoint] = [] self._prefix = prefix self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} @@ -283,8 +283,9 @@ def finalize(self, status: str = "success") -> None: status = "FINISHED" # log checkpoints as artifacts - if self._checkpoint_callback: - self._scan_and_log_checkpoints(self._checkpoint_callback) + if self._checkpoint_callbacks: + for callback in self._checkpoint_callbacks: + self._scan_and_log_checkpoints(callback) if self.experiment.get_run(self.run_id): self.experiment.set_terminated(self.run_id, status) @@ -331,7 +332,8 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: - self._checkpoint_callback = checkpoint_callback + if checkpoint_callback not in self._checkpoint_callbacks: + self._checkpoint_callbacks.append(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: # get checkpoints to be saved with associated score From 253986675f110c63ebcb5757e45addd9aa5b538d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:50:00 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/loggers/mlflow.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 5946d8bec82c1..c876712c7271f 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -331,9 +331,8 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) - elif self._log_model is True: - if checkpoint_callback not in self._checkpoint_callbacks: - self._checkpoint_callbacks.append(checkpoint_callback) + elif self._log_model is True and checkpoint_callback not in self._checkpoint_callbacks: + self._checkpoint_callbacks.append(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: # get checkpoints to be saved with associated score From 9b315f1e4199447d8f9fd397cb9fc2fcd081d909 Mon Sep 17 00:00:00 2001 From: harryankers Date: Tue, 11 Feb 2025 14:13:50 +0000 Subject: [PATCH 3/6] test(mlflow): Added test to test that multiple callbacks are picked up --- tests/tests_pytorch/loggers/test_mlflow.py | 79 ++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..7e2abd6e8f5ee 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Any from unittest import mock from unittest.mock import MagicMock, Mock import pytest from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( _MLFLOW_AVAILABLE, @@ -427,3 +430,79 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path): + """Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger. + + This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, + both callbacks function correctly and save the expected number of checkpoints when using + MLFlowLogger with log_model=True. + """ + + class CustomBoringModel(BoringModel): + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.log("train_loss", loss) + return {"loss": loss} + + def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.log("val_loss", loss) + return {"loss": loss} + + client = mlflow_mock.tracking.MlflowClient + + model = CustomBoringModel() + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True) + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") + + # Create two ModelCheckpoint callbacks monitoring different metrics + train_ckpt = ModelCheckpoint( + dirpath=str(tmp_path / "train_checkpoints"), + monitor="train_loss", + filename="best_train_model-{epoch:02d}-{train_loss:.2f}", + save_top_k=2, + mode="min", + ) + val_ckpt = ModelCheckpoint( + dirpath=str(tmp_path / "val_checkpoints"), + monitor="val_loss", + filename="best_val_model-{epoch:02d}-{val_loss:.2f}", + save_top_k=2, + mode="min", + ) + + # Create trainer with both callbacks + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + callbacks=[train_ckpt, val_ckpt], + max_epochs=5, + limit_train_batches=3, + limit_val_batches=3, + ) + trainer.fit(model) + + # Verify both callbacks saved their checkpoints + assert len(train_ckpt.best_k_models) > 0, "Train checkpoint callback did not save any models" + assert len(val_ckpt.best_k_models) > 0, "Validation checkpoint callback did not save any models" + + # Get all artifact paths that were logged + logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list] + + # Verify MLFlow logged artifacts from both callbacks + # Get all artifact paths that were logged + logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list] + + # Verify MLFlow logged artifacts from both callbacks + train_artifacts = [path for path in logged_artifacts if "train_checkpoints" in path] + val_artifacts = [path for path in logged_artifacts if "val_checkpoints" in path] + + assert len(train_artifacts) > 0, "MLFlow did not log any train checkpoint artifacts" + assert len(val_artifacts) > 0, "MLFlow did not log any validation checkpoint artifacts" + + # Verify the number of logged artifacts matches the save_top_k for each callback + assert len(train_artifacts) == train_ckpt.save_top_k, "Number of logged train artifacts doesn't match save_top_k" + assert len(val_artifacts) == val_ckpt.save_top_k, "Number of logged val artifacts doesn't match save_top_k" From 3fc69829111dc206c58a4008f582128356a2aab6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:17:28 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/loggers/test_mlflow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 7e2abd6e8f5ee..92440ac1357ee 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -20,13 +20,13 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( _MLFLOW_AVAILABLE, MLFlowLogger, _get_resolve_tags, ) +from lightning.pytorch.utilities.types import STEP_OUTPUT def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): @@ -436,9 +436,9 @@ def test_set_tracking_uri(mlflow_mock): def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path): """Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger. - This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, - both callbacks function correctly and save the expected number of checkpoints when using - MLFlowLogger with log_model=True. + This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, both callbacks function + correctly and save the expected number of checkpoints when using MLFlowLogger with log_model=True. + """ class CustomBoringModel(BoringModel): From 06de33b588f9f3b04baba817397e65fb12f7bb6f Mon Sep 17 00:00:00 2001 From: harryankers Date: Tue, 11 Feb 2025 14:18:18 +0000 Subject: [PATCH 5/6] fix: Ammended typing --- src/lightning/pytorch/loggers/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index c876712c7271f..f25cb21c597b7 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -142,7 +142,7 @@ def __init__( self.tags = tags self._log_model = log_model self._logged_model_time: dict[str, float] = {} - self._checkpoint_callbacks: list[ModelCheckpoint] = [] + self._checkpoint_callbacks: Optional[list[ModelCheckpoint]] = [] self._prefix = prefix self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} From c46b47dade94b1f03c79e196bdbcde62964da398 Mon Sep 17 00:00:00 2001 From: harryankers Date: Thu, 13 Feb 2025 19:26:45 +0000 Subject: [PATCH 6/6] fixed mypy error --- src/lightning/pytorch/loggers/mlflow.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index f25cb21c597b7..6e881776f685a 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -331,7 +331,11 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) - elif self._log_model is True and checkpoint_callback not in self._checkpoint_callbacks: + elif ( + self._log_model is True + and self._checkpoint_callbacks + and checkpoint_callback not in self._checkpoint_callbacks + ): self._checkpoint_callbacks.append(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: