Description
Bug description
Usually, I would expect Fabric.load(ckpt_path, state)
to operate in-place, i.e. for it to call the load_state_dict
methods of the loadable objects in state
. This does indeed happen when state
is a flat dictionary. However, when state
is nested (e.g. because one handles multiple models with corresponding optimizers), this changes all of a sudden. Specifically, the modules in the state
dictionary will reference different objects after the call to load
. I have provided a minimal reproducible example below.
To me this is definitely unexpected behavior and will lead to problems when we are using the references to the old modules when resuming training (e.g. imagine in the example below we use model
during training and not state["generator"]["model"]
). Currently one needs to explicitly update the old references, which is cumbersome and error-prone, especially when users are not aware of this behavior.
What version are you seeing the problem on?
v2.3
How to reproduce the bug
from lightning import Fabric
from torch import nn
if __name__ == "__main__":
fabric = Fabric(accelerator="cpu")
model = nn.Linear(2, 2)
model = fabric.setup(model)
state_flat = {"model": model}
state_nested = {"generator": {"model": model, "optim": None}}
fabric.save("flat.pt", state_flat)
fabric.save("nested.pt", state_nested)
fabric.load("flat.pt", state_flat, strict=True)
assert model is state_flat["model"] # This is fine
fabric.load("nested.pt", state_nested, strict=True)
assert model is state_nested["generator"]["model"] # This will fail
Error messages and logs
No response
Environment
Current environment
- CUDA:
- GPU:
- Quadro RTX 5000
- Quadro RTX 5000
- available: True
- version: 12.1 - Lightning:
- graph-transformer-pytorch: 0.1.1
- lightning: 2.3.0
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.2.2
- rotary-embedding-torch: 0.6.2
- torch: 2.3.0
- torch-geometric: 2.5.3
- torchmetrics: 1.4.0.post0 - Packages:
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- async-timeout: 4.0.3
- attrs: 23.2.0
- autoregressive-graph-generation: 0.0.0
- beartype: 0.18.5
- biopandas: 0.4.1
- brotli: 1.1.0
- certifi: 2024.6.2
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- colorama: 0.4.6
- contourpy: 1.2.1
- cycler: 0.12.1
- deepspeed: 0.14.4
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- docopt: 0.6.2
- einops: 0.8.0
- et-xmlfile: 1.1.0
- fastavro: 1.9.5
- filelock: 3.14.0
- fonttools: 4.53.0
- freesasa: 2.2.1
- frozenlist: 1.4.1
- fsspec: 2024.5.0
- ftpretty: 0.4.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- goatools: 1.4.12
- graph-transformer-pytorch: 0.1.1
- gudhi: 3.9.0
- h5py: 3.11.0
- heapdict: 1.0.1
- hjson: 3.1.0
- hydra-core: 1.3.2
- identify: 2.5.36
- idna: 3.7
- imageio: 2.34.1
- jinja2: 3.1.4
- joblib: 1.4.2
- kiwisolver: 1.4.5
- lightning: 2.3.0
- lightning-utilities: 0.11.2
- loguru: 0.7.2
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.8.4
- mdurl: 0.1.2
- mpmath: 1.3.0
- multidict: 6.0.5
- munkres: 1.1.4
- networkx: 3.3
- ninja: 1.11.1.1
- nodeenv: 1.9.0
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-ml-py: 12.555.43
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.5.40
- nvidia-nvtx-cu12: 12.1.105
- omegaconf: 2.3.0
- openpyxl: 3.1.5
- packaging: 24.0
- pandas: 2.2.2
- patsy: 0.5.6
- pillow: 10.3.0
- pip: 24.0
- platformdirs: 4.2.2
- pot: 0.9.3
- pre-commit: 3.7.1
- proteinshake: 0.3.14
- protobuf: 4.25.3
- psutil: 5.9.8
- py-cpuinfo: 9.0.0
- pycairo: 1.25.0
- pycparser: 2.22
- pydantic: 2.7.4
- pydantic-core: 2.18.4
- pydot: 3.0.1
- pyemd: 1.0.0
- pygments: 2.18.0
- pygobject: 3.46.0
- pygsp: 0.5.1
- pyparsing: 3.1.2
- pysocks: 1.7.1
- python-dateutil: 2.9.0
- pytorch-lightning: 2.2.2
- pytz: 2024.1
- pyyaml: 6.0.1
- rdkit: 2024.3.3
- rdkit-pypi: 2022.9.5
- requests: 2.32.3
- rich: 13.7.1
- rotary-embedding-torch: 0.6.2
- rustworkx: 0.14.2
- scikit-learn: 1.5.0
- scipy: 1.13.1
- sentry-sdk: 2.3.1
- setproctitle: 1.3.3
- setuptools: 69.5.1
- six: 1.16.0
- smmap: 5.0.1
- statsmodels: 0.14.2
- sympy: 1.12.1
- threadpoolctl: 3.5.0
- torch: 2.3.0
- torch-geometric: 2.5.3
- torchmetrics: 1.4.0.post0
- tqdm: 4.66.4
- triton: 2.3.0
- typing-extensions: 4.12.1
- tzdata: 2024.1
- unicodedata2: 15.1.0
- urllib3: 2.2.1
- virtualenv: 20.26.2
- wandb: 0.17.0
- wheel: 0.43.0
- xlsxwriter: 3.2.0
- yarl: 1.9.4
- zstandard: 0.22.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 5.14.21-150400.24.97-default
- version: Proposal for help #1 SMP PREEMPT_DYNAMIC Fri Oct 27 10:29:06 UTC 2023 (8546fda)
More info
No response