Skip to content

Upgrade to PyTorch 2.3. #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 27, 2024
Merged

Upgrade to PyTorch 2.3. #546

merged 18 commits into from
Mar 27, 2024

Conversation

stellaraccident
Copy link
Contributor

@stellaraccident stellaraccident commented Mar 22, 2024

As discussed on Discord, this is a significant upgrade because it is the first stable release that has a fully functional torch.export.export with the preferred dynamic shapes support. It is also just prior to nightlies that completely remove support for the old constraints based API, so is therefore a good point to stop for a moment and support both styles.

This patch makes a number of API changes:

  • Issues deprecation warnings if the constraints= keyword for jittable is used, otherwise not passing it to PyTorch. This should make jittable not immediately incompatible with later nightlies unless if that feature is used.
  • Adds the ability for a CompiledModule to directly have an attribute of a torch.export.ExportedProgram, allowing the user to pre-export with Torch and then construct a compiled module from that (vs the jittable approach where the CompiledModule API was directly invoking Torch internals to do so). This defaults to exporting as public if given a name not starting with an underscore and private otherwise. Private ExportedPrograms can be called from procedures just as with jittable.
  • shark_turbine.aot.export() now accepts either an CompiledModule, nn.Module, a or a torch.export.ExportedProgram. For the last two, a new external_params= bool is available to control whether parameters are inlined or externalized. For an nn.Module arguments corresponding to torch.export.export are added. Internally, for an nn.Module, it simply calls torch.export.export. jittable is no longer used internally.

Some attempt has been made to be backwards compatible with Torch 2.1.0. New features will not work, but we should be able to support a short buffer window where older pinned systems are not completely broken. The repository prior to this patch will be branched to torch_2.1.

Breaking changes:

  • ops.iree.trace_tensors (plural) had to be removed because the PyTorch auto functionalization thing has a TODO around lists of tensors. We can add a wrapper that takes a list and invokves trace_tensors multiple times and/or ass a functional_trace_tensors which works a bit better with the infra.
  • stateless_llama_test.py::test_rerotated_torch_comparison marked as expectedFailure. Filed stateless_llama test_rerotated_torch_comparison test broken with PyTorch 2.3 #560

@stellaraccident stellaraccident changed the title Pytorch 2.3 prep Upgrade to PyTorch 2.3. Mar 26, 2024
@stellaraccident stellaraccident marked this pull request as ready for review March 26, 2024 01:50
Copy link
Member

@dan-garvey dan-garvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small nits but wow this is amazing. 👏 👏

@stellaraccident stellaraccident merged commit b73c5c3 into main Mar 27, 2024
@stellaraccident stellaraccident deleted the pytorch_2.3_prep branch March 27, 2024 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants