Open
Description
Description & Motivation
Support for MPS autocasting has recently be added in PyTorch 2.5.0 here and there is an ongoing effort to implement gradient scaling here.
PyTorch Lightning does not currently support mixed-precision on MPS device but it could be added in a near future when gradient scaling is finalized.
Is this feature considered? This would allow reducing memory usage and improving training time for some models.
Pitch
Currently PyTorch Lightning falls back to FP32 when trying to use mixed-precision and issues a warning mentioning CUDA.
I think that considering adding a path for MPS mixed-precision would be great.
Alternatives
Stick to FP32 training when using MPS device.
Additional context
thanks for your work!
cc @Borda