Open
Description
Outline & Motivation
It could be useful to add a callback to WandbLogger
to allow custom handling of checkpoint artifacts. Examples of use cases:
- I'm already writing checkpoints to persistent storage (e.g. using a
ModelCheckpoint
writing to S3), so I just wantWandbLogger
to log reference artifacts to them. - I want to add additional files or metadata to my WandB checkpoint artifacts.
We could refactor WandbLogger
slightly:
class WandbLogger:
def on_log_checkpoint_artifact(self, artifact, checkpoint_timestamp, path, score, tag):
artifact.add_file(path, name="model.ckpt")
return artifact
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
# get checkpoints to be saved with associated score
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
# log iteratively all new checkpoints
for t, p, s, tag in checkpoints:
metadata = (
{
"score": s.item() if isinstance(s, Tensor) else s,
"original_filename": Path(p).name,
checkpoint_callback.__class__.__name__: {
k: getattr(checkpoint_callback, k)
for k in [
"monitor",
"mode",
"save_last",
"save_top_k",
"save_weights_only",
"_every_n_train_steps",
]
# ensure it does not break if `ModelCheckpoint` args change
if hasattr(checkpoint_callback, k)
},
}
if _WANDB_GREATER_EQUAL_0_10_22
else None
)
if not self._checkpoint_name:
self._checkpoint_name = f"model-{self.experiment.id}"
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
# Handle artifact logic here
artifact = self.on_log_checkpoint_artifact(artifact, t, p, s, tag)
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t
Then, if users want custom artifact logging, they can subclass WandbLogger
and override on_log_checkpoint_artifact
:
class ReferenceArtifactLogger(WandbLogger):
def on_log_checkpoint_artifact(self, artifact, checkpoint_timestamp, path, score, tag):
artifact.add_reference(path)
return artifact
### Pitch
_No response_
### Additional context
_No response_
cc @justusschock @awaelchli