Skip to content

Add WandbLogger callback for customizing checkpoint artifact logging #17913

Open
@schmidt-ai

Description

@schmidt-ai

Outline & Motivation

It could be useful to add a callback to WandbLogger to allow custom handling of checkpoint artifacts. Examples of use cases:

  1. I'm already writing checkpoints to persistent storage (e.g. using a ModelCheckpoint writing to S3), so I just want WandbLogger to log reference artifacts to them.
  2. I want to add additional files or metadata to my WandB checkpoint artifacts.

We could refactor WandbLogger slightly:

class WandbLogger:
    def on_log_checkpoint_artifact(self, artifact, checkpoint_timestamp, path, score, tag):
        artifact.add_file(path, name="model.ckpt")
        return artifact
    
    def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
        # get checkpoints to be saved with associated score
        checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)

        # log iteratively all new checkpoints
        for t, p, s, tag in checkpoints:
            metadata = (
                {
                    "score": s.item() if isinstance(s, Tensor) else s,
                    "original_filename": Path(p).name,
                    checkpoint_callback.__class__.__name__: {
                        k: getattr(checkpoint_callback, k)
                        for k in [
                            "monitor",
                            "mode",
                            "save_last",
                            "save_top_k",
                            "save_weights_only",
                            "_every_n_train_steps",
                        ]
                        # ensure it does not break if `ModelCheckpoint` args change
                        if hasattr(checkpoint_callback, k)
                    },
                }
                if _WANDB_GREATER_EQUAL_0_10_22
                else None
            )
            if not self._checkpoint_name:
                self._checkpoint_name = f"model-{self.experiment.id}"
            artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)

            # Handle artifact logic here
            artifact = self.on_log_checkpoint_artifact(artifact, t, p, s, tag)

            aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
            self.experiment.log_artifact(artifact, aliases=aliases)
            # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
            self._logged_model_time[p] = t

Then, if users want custom artifact logging, they can subclass WandbLogger and override on_log_checkpoint_artifact:

class ReferenceArtifactLogger(WandbLogger):
    def on_log_checkpoint_artifact(self, artifact, checkpoint_timestamp, path, score, tag):
        artifact.add_reference(path)
        return artifact

### Pitch

_No response_

### Additional context

_No response_

cc @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions