Skip to content

Add support to saving.py for loading GPU-trained models on CPU-only machines #19024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed exporting `__version__` in `__init__` ([#19221](https://github.com/Lightning-AI/lightning/pull/19221))


- Fixed an issue preventing the user to `model.load_from_checkpoint()` a GPU-trained model on a CPU-only machine with a CPU-only PyTorch installation ([#19024](https://github.com/Lightning-AI/lightning/pull/19024))


## [2.1.3] - 2023-12-21

### Changed
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def _load_from_checkpoint(

device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
assert isinstance(model, pl.LightningModule)
if device.type == "cpu" and model.device.type == "cpu":
return model
return model.to(device)

raise NotImplementedError(f"Unsupported {cls}")
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def root_device(self) -> torch.device:
@override
def model_to_device(self) -> None:
assert self.model is not None, "self.model must be set before self.model.to()"
self.model.to(self.root_device)
if self.model.device.type != self.root_device.type:
self.model.to(self.root_device)

@property
@override
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def teardown(self) -> None:

if self.lightning_module is not None:
log.debug(f"{self.__class__.__name__}: moving model to CPU")
self.lightning_module.cpu()
if self.lightning_module.device.type != "cpu":
self.lightning_module.cpu()
self.precision_plugin.teardown()
assert self.accelerator is not None
self.accelerator.teardown()
Expand Down
Loading