Skip to content

MPS Mixed-precision Autocast #20497

Open
@laclouis5

Description

@laclouis5

Description & Motivation

Support for MPS autocasting has recently be added in PyTorch 2.5.0 here and there is an ongoing effort to implement gradient scaling here.

PyTorch Lightning does not currently support mixed-precision on MPS device but it could be added in a near future when gradient scaling is finalized.

Is this feature considered? This would allow reducing memory usage and improving training time for some models.

Pitch

Currently PyTorch Lightning falls back to FP32 when trying to use mixed-precision and issues a warning mentioning CUDA.

I think that considering adding a path for MPS mixed-precision would be great.

Alternatives

Stick to FP32 training when using MPS device.

Additional context

thanks for your work!

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions