Skip to content

Unexpected Behavior: Fabric.load operates out-of-place on nested states #20208

Open
@Markus28

Description

@Markus28

Bug description

Usually, I would expect Fabric.load(ckpt_path, state) to operate in-place, i.e. for it to call the load_state_dict methods of the loadable objects in state. This does indeed happen when state is a flat dictionary. However, when state is nested (e.g. because one handles multiple models with corresponding optimizers), this changes all of a sudden. Specifically, the modules in the state dictionary will reference different objects after the call to load. I have provided a minimal reproducible example below.

To me this is definitely unexpected behavior and will lead to problems when we are using the references to the old modules when resuming training (e.g. imagine in the example below we use model during training and not state["generator"]["model"]). Currently one needs to explicitly update the old references, which is cumbersome and error-prone, especially when users are not aware of this behavior.

What version are you seeing the problem on?

v2.3

How to reproduce the bug

from lightning import Fabric
from torch import nn

if __name__ == "__main__":
    fabric = Fabric(accelerator="cpu")
    model = nn.Linear(2, 2)
    model = fabric.setup(model)
    state_flat = {"model": model}
    state_nested = {"generator": {"model": model, "optim": None}}
    fabric.save("flat.pt", state_flat)
    fabric.save("nested.pt", state_nested)

    fabric.load("flat.pt", state_flat, strict=True)
    assert model is state_flat["model"]     # This is fine

    fabric.load("nested.pt", state_nested, strict=True)
    assert model is state_nested["generator"]["model"]      # This will fail

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - Quadro RTX 5000
    - Quadro RTX 5000
    - available: True
    - version: 12.1
  • Lightning:
    - graph-transformer-pytorch: 0.1.1
    - lightning: 2.3.0
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.2
    - rotary-embedding-torch: 0.6.2
    - torch: 2.3.0
    - torch-geometric: 2.5.3
    - torchmetrics: 1.4.0.post0
  • Packages:
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - autoregressive-graph-generation: 0.0.0
    - beartype: 0.18.5
    - biopandas: 0.4.1
    - brotli: 1.1.0
    - certifi: 2024.6.2
    - cffi: 1.16.0
    - cfgv: 3.4.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - colorama: 0.4.6
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - deepspeed: 0.14.4
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - docopt: 0.6.2
    - einops: 0.8.0
    - et-xmlfile: 1.1.0
    - fastavro: 1.9.5
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - freesasa: 2.2.1
    - frozenlist: 1.4.1
    - fsspec: 2024.5.0
    - ftpretty: 0.4.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - goatools: 1.4.12
    - graph-transformer-pytorch: 0.1.1
    - gudhi: 3.9.0
    - h5py: 3.11.0
    - heapdict: 1.0.1
    - hjson: 3.1.0
    - hydra-core: 1.3.2
    - identify: 2.5.36
    - idna: 3.7
    - imageio: 2.34.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - kiwisolver: 1.4.5
    - lightning: 2.3.0
    - lightning-utilities: 0.11.2
    - loguru: 0.7.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.8.4
    - mdurl: 0.1.2
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - munkres: 1.1.4
    - networkx: 3.3
    - ninja: 1.11.1.1
    - nodeenv: 1.9.0
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-ml-py: 12.555.43
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.5.40
    - nvidia-nvtx-cu12: 12.1.105
    - omegaconf: 2.3.0
    - openpyxl: 3.1.5
    - packaging: 24.0
    - pandas: 2.2.2
    - patsy: 0.5.6
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 4.2.2
    - pot: 0.9.3
    - pre-commit: 3.7.1
    - proteinshake: 0.3.14
    - protobuf: 4.25.3
    - psutil: 5.9.8
    - py-cpuinfo: 9.0.0
    - pycairo: 1.25.0
    - pycparser: 2.22
    - pydantic: 2.7.4
    - pydantic-core: 2.18.4
    - pydot: 3.0.1
    - pyemd: 1.0.0
    - pygments: 2.18.0
    - pygobject: 3.46.0
    - pygsp: 0.5.1
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - python-dateutil: 2.9.0
    - pytorch-lightning: 2.2.2
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - rdkit: 2024.3.3
    - rdkit-pypi: 2022.9.5
    - requests: 2.32.3
    - rich: 13.7.1
    - rotary-embedding-torch: 0.6.2
    - rustworkx: 0.14.2
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - sentry-sdk: 2.3.1
    - setproctitle: 1.3.3
    - setuptools: 69.5.1
    - six: 1.16.0
    - smmap: 5.0.1
    - statsmodels: 0.14.2
    - sympy: 1.12.1
    - threadpoolctl: 3.5.0
    - torch: 2.3.0
    - torch-geometric: 2.5.3
    - torchmetrics: 1.4.0.post0
    - tqdm: 4.66.4
    - triton: 2.3.0
    - typing-extensions: 4.12.1
    - tzdata: 2024.1
    - unicodedata2: 15.1.0
    - urllib3: 2.2.1
    - virtualenv: 20.26.2
    - wandb: 0.17.0
    - wheel: 0.43.0
    - xlsxwriter: 3.2.0
    - yarl: 1.9.4
    - zstandard: 0.22.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 5.14.21-150400.24.97-default
    - version: Proposal for help #1 SMP PREEMPT_DYNAMIC Fri Oct 27 10:29:06 UTC 2023 (8546fda)

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.3.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions