diff --git a/test/test_ops.py b/test/test_ops.py
index 3f0d8312c01..c7c415e0ab3 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -929,6 +929,7 @@ def test_batched_nms_implementations(self, seed):
 
 class TestDeformConv:
     dtype = torch.float64
+    mps_dtype = torch.float32
 
     def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
         stride_h, stride_w = _pair(stride)
@@ -1050,12 +1051,11 @@ def test_is_leaf_node(self, device):
         assert len(graph_node_names[0]) == len(graph_node_names[1])
         assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
 
-    @pytest.mark.parametrize("device", cpu_and_cuda())
+    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
     @pytest.mark.parametrize("contiguous", (True, False))
     @pytest.mark.parametrize("batch_sz", (0, 33))
-    @pytest.mark.opcheck_only_one()
     def test_forward(self, device, contiguous, batch_sz, dtype=None):
-        dtype = dtype or self.dtype
+        dtype = self.mps_dtype if device == "mps" else dtype or self.dtype
         x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
         in_channels = 6
         out_channels = 2
@@ -1201,13 +1201,50 @@ def test_forward_scriptability(self):
         torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
 
 
-optests.generate_opcheck_tests(
-    testcase=TestDeformConv,
-    namespaces=["torchvision"],
-    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
-    additional_decorators=[],
-    test_utils=OPTESTS,
-)
+@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
+@pytest.mark.parametrize("device", cpu_and_cuda())
+@pytest.mark.parametrize("requires_grad", (True, False))
+def test_deform_conv2d_opcheck(dtype, device, requires_grad):
+    batch_size, channels_in, height, width = 1, 6, 10, 10
+    kernel_size = (3, 3)
+    stride = (1, 1)
+    padding = (1, 1)
+    dilation = (1, 1)
+    groups = 2
+    out_channels = 4
+    out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
+    out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
+    x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad)
+    offset = torch.randn(batch_size, 2 * kernel_size[0] * kernel_size[1], out_h, out_w,
+                         dtype=dtype, device=device, requires_grad=requires_grad)
+    weight = torch.randn(out_channels, channels_in // groups, kernel_size[0], kernel_size[1],
+                         dtype=dtype, device=device, requires_grad=requires_grad)
+    bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad)
+    use_mask = True
+    mask = torch.sigmoid(torch.randn(
+        batch_size,
+        kernel_size[0] * kernel_size[1],
+        out_h,
+        out_w,
+        dtype=dtype, device=device, requires_grad=requires_grad
+    ))
+    kwargs = {
+        "offset": offset,
+        "weight": weight,
+        "bias": bias,
+        "stride_h": stride[0],
+        "stride_w": stride[1],
+        "pad_h": padding[0],
+        "pad_w": padding[1],
+        "dilation_h": dilation[0],
+        "dilation_w": dilation[1],
+        "groups": groups,
+        "offset_groups": 1,
+        "use_mask": use_mask,
+        "mask": mask,  # no modulation in this test
+    }
+    optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs)
+
 
 
 class TestFrozenBNT:
diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm
new file mode 100644
index 00000000000..1d390a37f43
--- /dev/null
+++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm
@@ -0,0 +1,134 @@
+#include <ATen/ATen.h>
+#include <ATen/mps/MPSProfiler.h>
+#include <ATen/native/mps/OperationUtils.h>
+#include "mps_kernels.h"
+
+namespace vision {
+namespace ops {
+
+namespace {
+
+at::Tensor deform_conv2d_forward_kernel(
+    const at::Tensor& input,
+    const at::Tensor& weight,
+    const at::Tensor& offset,
+    const at::Tensor& mask,
+    const at::Tensor& bias,
+    int64_t stride_h,
+    int64_t stride_w,
+    int64_t pad_h,
+    int64_t pad_w,
+    int64_t dilation_h,
+    int64_t dilation_w,
+    int64_t n_weight_grps,
+    int64_t n_offset_grps,
+    bool use_mask) {
+  using namespace at::native::mps;
+  at::Tensor input_c = input.contiguous();
+  at::Tensor weight_c = weight.contiguous();
+  at::Tensor offset_c = offset.contiguous();
+  at::Tensor mask_c = mask.contiguous();
+  at::Tensor bias_c = bias.contiguous();
+
+  TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D");
+  TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D");
+  TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D");
+  TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true");
+  TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor");
+
+  at::DeviceGuard guard(input_c.device());
+
+  int batch = input_c.size(0);
+  int in_channels = input_c.size(1);
+  int in_h = input_c.size(2);
+  int in_w = input_c.size(3);
+  int weight_h = weight_c.size(2);
+  int weight_w = weight_c.size(3);
+  int out_channels = weight_c.size(0);
+  int ker_h = dilation_h * (weight_h - 1) + 1;
+  int ker_w = dilation_w * (weight_w - 1) + 1;
+  int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
+  int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
+
+  TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels,
+    "Input channels (", in_channels, 
+    ") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")");
+  TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0,
+    "Weight tensor's out channels (", weight_c.size(0), 
+    ") must be divisible by n_weight_grps (", n_weight_grps, ")");
+  TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w,
+    "Offset tensor shape[1] is invalid: got ", offset_c.size(1), 
+    ", expected ", n_offset_grps * 2 * weight_h * weight_w);
+  TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w,
+    "Mask tensor shape[1] is invalid: got ", mask_c.size(1), 
+    ", expected ", n_offset_grps * weight_h * weight_w);
+  TORCH_CHECK(in_channels % n_offset_grps == 0,
+    "Input tensor channels (", in_channels, 
+    ") must be divisible by n_offset_grps (", n_offset_grps, ")");
+  TORCH_CHECK(offset_c.size(0) == batch,
+    "Offset tensor batch size (", offset_c.size(0),
+    ") must match input tensor batch size (", batch, ")");
+  TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w,
+    "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3), 
+    ") must match calculated output dimensions (", out_h, ", ", out_w, ")");
+  TORCH_CHECK(!use_mask || mask_c.size(0) == batch,
+    "Mask tensor batch size (", mask_c.size(0),
+    ") must match input tensor batch size (", batch, ")");
+  TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w),
+    "Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3),
+    ") must match calculated output dimensions (", out_h, ", ", out_w, ")");
+  TORCH_CHECK(out_h > 0 && out_w > 0,
+    "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w);
+
+  auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options());
+
+  id<MTLBuffer> inputBuffer  = getMTLBufferStorage(input_c);
+  id<MTLBuffer> offsetBuffer = getMTLBufferStorage(offset_c);
+  id<MTLBuffer> maskBuffer   = use_mask ? getMTLBufferStorage(mask_c) : nil;
+  id<MTLBuffer> outputBuffer = getMTLBufferStorage(columns);
+
+  id<MTLDevice> device = MPSDevice::getInstance()->device();
+  std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type());
+  id<MTLComputePipelineState> pipelineState = mps::visionPipelineState(device, kernelName);
+
+  int num_kernels = in_channels * out_h * out_w * batch;
+  NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup;
+  NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup;
+  MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1);
+  MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1);
+
+  MPSStream* mpsStream = getCurrentMPSStream();
+  dispatch_sync(mpsStream->queue(), ^{
+    @autoreleasepool {
+      id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
+      [computeEncoder setComputePipelineState:pipelineState];
+      at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer,
+                                   in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, 
+                                   dilation_h, dilation_w, batch, in_channels, n_offset_grps, out_h, out_w,
+                                   use_mask, outputBuffer);
+      [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
+    }
+  });
+  int in_channels_per_grp = in_channels / n_weight_grps;
+  int out_channels_per_grp = out_channels / n_weight_grps;
+  auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w});
+  auto columns_grouped = columns.view({n_weight_grps,
+                                      (in_channels * weight_h * weight_w) / n_weight_grps,
+                                      batch * out_h * out_w});
+  auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1});
+  auto out_grouped = at::bmm(weight_reshaped, columns_grouped);
+  auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w})
+              .transpose(0, 1);
+  return out + bias_c.view({1, out_channels, 1, 1});
+}
+
+} // namespace
+
+TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
+  m.impl(
+      TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
+      TORCH_FN(deform_conv2d_forward_kernel));
+}
+
+} // namespace ops
+} // namespace vision
\ No newline at end of file
diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h
index f85546a6c41..2f24c86c6bf 100644
--- a/torchvision/csrc/ops/mps/mps_kernels.h
+++ b/torchvision/csrc/ops/mps/mps_kernels.h
@@ -91,6 +91,52 @@ inline T bilinear_interpolate(
   return val;
 }
 
+template <typename T, typename integer_t>
+inline T bilinear_interpolate_deformable_conv2d(
+    constant T* input,
+    integer_t height,
+    integer_t width,
+    T y,
+    T x,
+    uint index /* index for debug only*/) {
+  if (y <= -1.0 || y >= height || x <= -1.0 || x >= width) {
+    return 0;
+  }
+  integer_t y_low = static_cast<integer_t>(floor(y));
+  integer_t x_low = static_cast<integer_t>(floor(x));
+  integer_t y_high = y_low + 1;
+  integer_t x_high = x_low + 1;
+
+  T ly = y - static_cast<T>(y_low);
+  T lx = x - static_cast<T>(x_low);
+  T hh = 1.0 - ly;
+  T hw = 1.0 - lx;
+
+  T v1 = 0;
+  if (y_low >= 0 && x_low >= 0)
+    v1 = input[y_low * width + x_low];
+  
+  T v2 = 0;
+  if (y_low >= 0 && x_high <= width - 1)
+    v2 = input[y_low * width + x_high];
+  
+  T v3 = 0;
+  if (y_high <= height - 1 && x_low >= 0)
+    v3 = input[y_high * width + x_low];
+  
+  T v4 = 0;
+  if (y_high <= height - 1 && x_high <= width - 1)
+    v4 = input[y_high * width + x_high];
+
+  T w1 = hh * hw;
+  T w2 = hh * lx;
+  T w3 = ly * hw;
+  T w4 = ly * lx;
+
+  T val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
+  return val;
+}
+
 template <typename T, typename integer_t>
 inline void bilinear_interpolate_gradient(
     integer_t height,
@@ -225,6 +271,117 @@ kernel void nms<DTYPE ## 4, DTYPE>(                        \
   uint2    tgid   [[threadgroup_position_in_grid]],        \
   uint2    tid2   [[thread_position_in_threadgroup]]);
 
+
+template<typename T>
+kernel void deformable_im2col_kernel(
+    constant T*           input_ptr     [[ buffer(0) ]],
+    constant T*           offset_ptr    [[ buffer(1) ]],
+    constant T*           mask_ptr      [[ buffer(2) ]],
+    constant int&         height        [[ buffer(3) ]],
+    constant int&         width         [[ buffer(4) ]],
+    constant int&         weight_h      [[ buffer(5) ]],
+    constant int&         weight_w      [[ buffer(6) ]],
+    constant int&         pad_h         [[ buffer(7) ]],
+    constant int&         pad_w         [[ buffer(8) ]],
+    constant int&         stride_h      [[ buffer(9) ]],
+    constant int&         stride_w      [[ buffer(10)]],
+    constant int&         dilation_h    [[ buffer(11)]],
+    constant int&         dilation_w    [[ buffer(12)]],
+    constant int&         batch_size      [[ buffer(13)]],
+    constant int&         n_in_channels [[ buffer(14)]],
+    constant int&         n_offset_grps [[ buffer(15)]],
+    constant int&         out_h         [[ buffer(16)]],
+    constant int&         out_w         [[ buffer(17)]],
+    constant bool&        use_mask      [[ buffer(18)]],
+    device T*             columns_ptr   [[ buffer(19)]],
+    uint                  tid           [[ thread_position_in_grid ]],
+    uint                  tpg           [[ threads_per_grid ]])
+{
+    int total = out_w * out_h * batch_size * n_in_channels;
+    int gridSize = tpg;
+    if (tid >= total) {
+        return;
+    }
+
+    int out_x = tid % out_w;
+    int out_y = (tid / out_w) % out_h;
+    int out_b = (tid / (out_w * out_h)) % batch_size;
+    int in_c  = tid / (out_w * out_h * batch_size);
+    int out_c = in_c * weight_h * weight_w;
+    
+    int c_per_offset_grp = n_in_channels / n_offset_grps;
+    int grp_idx = in_c / c_per_offset_grp;
+    
+    int col_offset = out_c * (batch_size * out_h * out_w)
+                      + out_b * (out_h * out_w)
+                      + out_y * out_w + out_x;
+    device T* local_columns_ptr = columns_ptr + col_offset;
+    
+    int input_offset = out_b * (n_in_channels * height * width)
+                        + in_c * (height * width);
+    constant T* local_input_ptr = input_ptr + input_offset;
+    
+    int offset_offset = (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w;
+    constant T* local_offset_ptr = offset_ptr + offset_offset;
+    
+    constant T* local_mask_ptr = nullptr;
+    if (use_mask) {
+        int mask_offset = (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w;
+        local_mask_ptr = mask_ptr + mask_offset;
+    }
+    
+    for (int i = 0; i < weight_h; ++i) {
+        for (int j = 0; j < weight_w; ++j) {
+            int mask_index = i * weight_w + j;
+            int offset_index = 2 * mask_index;
+            
+            T mask_value = 1;
+            if (use_mask) {
+                mask_value = local_mask_ptr[mask_index * (out_h * out_w) + out_y * out_w + out_x];
+            }
+            
+            T offset_h_val = local_offset_ptr[offset_index * (out_h * out_w) + out_y * out_w + out_x];
+            T offset_w_val = local_offset_ptr[(offset_index + 1) * (out_h * out_w) + out_y * out_w + out_x];
+            
+            T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h_val;
+            T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w_val;
+            
+            T interp = bilinear_interpolate_deformable_conv2d(local_input_ptr, height, width, y, x, tid);
+            
+            *local_columns_ptr = mask_value * interp;
+            
+            local_columns_ptr += batch_size * out_h * out_w;
+        }
+    }
+}
+
+#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE)                                  \
+template                                                                      \
+[[host_name("deformable_im2col_" #DTYPE)]]                                    \
+kernel void deformable_im2col_kernel<DTYPE>(                                  \
+    constant DTYPE*           input_ptr    [[ buffer(0) ]],                   \
+    constant DTYPE*           offset_ptr   [[ buffer(1) ]],                   \
+    constant DTYPE*           mask_ptr     [[ buffer(2) ]],                   \
+    constant int&                 height       [[ buffer(3) ]],               \
+    constant int&                 width        [[ buffer(4) ]],               \
+    constant int&                 weight_h     [[ buffer(5) ]],               \
+    constant int&                 weight_w     [[ buffer(6) ]],               \
+    constant int&                 pad_h        [[ buffer(7) ]],               \
+    constant int&                 pad_w        [[ buffer(8) ]],               \
+    constant int&                 stride_h     [[ buffer(9) ]],               \
+    constant int&                 stride_w     [[ buffer(10)]],               \
+    constant int&                 dilation_h   [[ buffer(11)]],               \
+    constant int&                 dilation_w   [[ buffer(12)]],               \
+    constant int&                 batch_sz     [[ buffer(13)]],               \
+    constant int&                 n_in_channels[[ buffer(14)]],               \
+    constant int&                 n_offset_grps[[ buffer(15)]],               \
+    constant int&                 out_h        [[ buffer(16)]],               \
+    constant int&                 out_w        [[ buffer(17)]],               \
+    constant bool&                use_mask     [[ buffer(18)]],               \
+    device DTYPE*                 columns_ptr  [[ buffer(19)]],               \
+    uint                          tid          [[ thread_position_in_grid ]], \
+    uint                          tpg           [[ threads_per_grid ]]);
+
 template<typename T, typename integer_t>
 kernel void roi_align(
     constant T       * input          [[buffer(0)]],
@@ -1013,6 +1170,8 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>(          \
 
 REGISTER_NMS_OP(float);
 REGISTER_NMS_OP(half);
+REGISTER_DEFORMABLE_IM2COL_OP(float);
+REGISTER_DEFORMABLE_IM2COL_OP(half);
 REGISTER_ROI_ALIGN_OP(float, int64_t);
 REGISTER_ROI_ALIGN_OP(half, int64_t);
 REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);