diff --git a/test/test_ops.py b/test/test_ops.py
index 88124f7ba17..dd184ddde2d 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -228,12 +228,12 @@ def func(z):
         ):
             gradcheck(func, (x,))
 
-    @needs_cuda
+    @pytest.mark.parametrize("device", cpu_and_cuda())
     @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
     @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
-    def test_autocast(self, x_dtype, rois_dtype):
-        with torch.cuda.amp.autocast():
-            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
+    def test_autocast(self, device, x_dtype, rois_dtype):
+        with torch.amp.autocast(device):
+            self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
 
     def _helper_boxes_shape(self, func):
         # test boxes as Tensor[N, 5]
@@ -490,32 +490,18 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois
             aligned=aligned,
         )
 
-    @needs_cuda
+    @pytest.mark.parametrize("device", cpu_and_cuda())
     @pytest.mark.parametrize("aligned", (True, False))
     @pytest.mark.parametrize("deterministic", (True, False))
-    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
-    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
+    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half, torch.bfloat16))
+    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half, torch.bfloat16))
     @pytest.mark.opcheck_only_one()
-    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
-        with torch.cuda.amp.autocast():
-            self.test_forward(
-                torch.device("cuda"),
-                contiguous=False,
-                deterministic=deterministic,
-                aligned=aligned,
-                x_dtype=x_dtype,
-                rois_dtype=rois_dtype,
-            )
-
-    @pytest.mark.skip(reason="1/5000 flaky failure")
-    @pytest.mark.parametrize("aligned", (True, False))
-    @pytest.mark.parametrize("deterministic", (True, False))
-    @pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
-    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
-    def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
-        with torch.cpu.amp.autocast():
+    def test_autocast(self, device, aligned, deterministic, x_dtype, rois_dtype):
+        if device == "cpu" and x_dtype is torch.bfloat16:
+            pytest.skip("1/5000 flaky failure")
+        with torch.amp.autocast(device):
             self.test_forward(
-                torch.device("cpu"),
+                torch.device(device),
                 contiguous=False,
                 deterministic=deterministic,
                 aligned=aligned,
@@ -856,14 +842,14 @@ def test_nms_gpu(self, iou, device, dtype=torch.float64):
     @pytest.mark.parametrize("dtype", (torch.float, torch.half))
     @pytest.mark.opcheck_only_one()
     def test_autocast(self, iou, dtype):
-        with torch.cuda.amp.autocast():
+        with torch.amp.autocast("cuda"):
             self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
 
     @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
     @pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
     def test_autocast_cpu(self, iou, dtype):
         boxes, scores = self._create_tensors_with_iou(1000, iou)
-        with torch.cpu.amp.autocast():
+        with torch.amp.autocast("cpu"):
             keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
             keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
         torch.testing.assert_close(keep_ref_float, keep_dtype)
@@ -1188,13 +1174,13 @@ def test_compare_cpu_cuda_grads(self, contiguous):
                 res_grads = init_weight.grad.to("cpu")
                 torch.testing.assert_close(true_cpu_grads, res_grads)
 
-    @needs_cuda
+    @pytest.mark.parametrize("device", cpu_and_cuda())
     @pytest.mark.parametrize("batch_sz", (0, 33))
     @pytest.mark.parametrize("dtype", (torch.float, torch.half))
     @pytest.mark.opcheck_only_one()
-    def test_autocast(self, batch_sz, dtype):
-        with torch.cuda.amp.autocast():
-            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
+    def test_autocast(self, device, batch_sz, dtype):
+        with torch.amp.autocast(device):
+            self.test_forward(torch.device(device), contiguous=False, batch_sz=batch_sz, dtype=dtype)
 
     def test_forward_scriptability(self):
         # Non-regression test for https://github.com/pytorch/vision/issues/4078
diff --git a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp
index 4f082fa0006..fb7b953cd2d 100644
--- a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp
+++ b/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp
@@ -9,6 +9,7 @@ namespace ops {
 
 namespace {
 
+template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
 at::Tensor deform_conv2d_autocast(
     const at::Tensor& input,
     const at::Tensor& weight,
@@ -24,13 +25,13 @@ at::Tensor deform_conv2d_autocast(
     int64_t groups,
     int64_t offset_groups,
     bool use_mask) {
-  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
+  c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
   return deform_conv2d(
-             at::autocast::cached_cast(at::kFloat, input),
-             at::autocast::cached_cast(at::kFloat, weight),
-             at::autocast::cached_cast(at::kFloat, offset),
-             at::autocast::cached_cast(at::kFloat, mask),
-             at::autocast::cached_cast(at::kFloat, bias),
+             at::autocast::cached_cast(at::kFloat, input, device_type),
+             at::autocast::cached_cast(at::kFloat, weight, device_type),
+             at::autocast::cached_cast(at::kFloat, offset, device_type),
+             at::autocast::cached_cast(at::kFloat, mask, device_type),
+             at::autocast::cached_cast(at::kFloat, bias, device_type),
              stride_h,
              stride_w,
              pad_h,
@@ -48,7 +49,25 @@ at::Tensor deform_conv2d_autocast(
 TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
   m.impl(
       TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
-      TORCH_FN(deform_conv2d_autocast));
+      TORCH_FN((deform_conv2d_autocast<
+                c10::DispatchKey::Autocast,
+                c10::DeviceType::CUDA>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
+      TORCH_FN((deform_conv2d_autocast<
+                c10::DispatchKey::AutocastCPU,
+                c10::DeviceType::CPU>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
+      TORCH_FN((deform_conv2d_autocast<
+                c10::DispatchKey::AutocastXPU,
+                c10::DeviceType::XPU>)));
 }
 
 } // namespace ops
diff --git a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp
index bce987b0f71..d6cd0c471d1 100644
--- a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp
+++ b/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp
@@ -9,6 +9,7 @@ namespace ops {
 
 namespace {
 
+template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
 std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
     const at::Tensor& input,
     const at::Tensor& rois,
@@ -16,10 +17,10 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
     int64_t pooled_height,
     int64_t pooled_width,
     int64_t sampling_ratio) {
-  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
+  c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
   auto result = ps_roi_align(
-      at::autocast::cached_cast(at::kFloat, input),
-      at::autocast::cached_cast(at::kFloat, rois),
+      at::autocast::cached_cast(at::kFloat, input, device_type),
+      at::autocast::cached_cast(at::kFloat, rois, device_type),
       spatial_scale,
       pooled_height,
       pooled_width,
@@ -35,7 +36,25 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
 TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
   m.impl(
       TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
-      TORCH_FN(ps_roi_align_autocast));
+      TORCH_FN((ps_roi_align_autocast<
+               c10::DispatchKey::Autocast,
+               c10::DeviceType::CUDA>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
+      TORCH_FN((ps_roi_align_autocast<
+               c10::DispatchKey::AutocastCPU,
+               c10::DeviceType::CPU>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
+      TORCH_FN((ps_roi_align_autocast<
+               c10::DispatchKey::AutocastXPU,
+               c10::DeviceType::XPU>)));
 }
 
 } // namespace ops
diff --git a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp
index 3cf1e7f80d7..a623c42312b 100644
--- a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp
+++ b/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp
@@ -9,16 +9,17 @@ namespace ops {
 
 namespace {
 
+template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
 std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
     const at::Tensor& input,
     const at::Tensor& rois,
     double spatial_scale,
     int64_t pooled_height,
     int64_t pooled_width) {
-  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
+  c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
   auto result = ps_roi_pool(
-      at::autocast::cached_cast(at::kFloat, input),
-      at::autocast::cached_cast(at::kFloat, rois),
+      at::autocast::cached_cast(at::kFloat, input, device_type),
+      at::autocast::cached_cast(at::kFloat, rois, device_type),
       spatial_scale,
       pooled_height,
       pooled_width);
@@ -33,7 +34,25 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
 TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
   m.impl(
       TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
-      TORCH_FN(ps_roi_pool_autocast));
+      TORCH_FN((ps_roi_pool_autocast<
+                c10::DispatchKey::Autocast,
+                c10::DeviceType::CUDA>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
+      TORCH_FN((ps_roi_pool_autocast<
+                c10::DispatchKey::AutocastCPU,
+                c10::DeviceType::CPU>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
+      TORCH_FN((ps_roi_pool_autocast<
+                c10::DispatchKey::AutocastXPU,
+                c10::DeviceType::XPU>)));
 }
 
 } // namespace ops
diff --git a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp
index 3aaa038a9b4..936ce1dc5f5 100644
--- a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp
+++ b/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp
@@ -9,16 +9,17 @@ namespace ops {
 
 namespace {
 
+template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
 std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
     const at::Tensor& input,
     const at::Tensor& rois,
     double spatial_scale,
     int64_t pooled_height,
     int64_t pooled_width) {
-  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
+  c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
   auto result = roi_pool(
-      at::autocast::cached_cast(at::kFloat, input),
-      at::autocast::cached_cast(at::kFloat, rois),
+      at::autocast::cached_cast(at::kFloat, input, device_type),
+      at::autocast::cached_cast(at::kFloat, rois, device_type),
       spatial_scale,
       pooled_height,
       pooled_width);
@@ -33,7 +34,25 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
 TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
   m.impl(
       TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
-      TORCH_FN(roi_pool_autocast));
+      TORCH_FN((roi_pool_autocast<
+                c10::DispatchKey::Autocast,
+                c10::DeviceType::CUDA>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
+      TORCH_FN((roi_pool_autocast<
+                c10::DispatchKey::AutocastCPU,
+                c10::DeviceType::CPU>)));
+}
+
+TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
+      TORCH_FN((roi_pool_autocast<
+                c10::DispatchKey::AutocastXPU,
+                c10::DeviceType::XPU>)));
 }
 
 } // namespace ops