Skip to content

Missing support in scf.for bufferization #16956

Open
llvm/llvm-project
#87594
@hanhanW

Description

@hanhanW

The issue if forked from #16421. I had a long discussion with @matthias-springer about the bufferization issue. We are not able to bufferize scf.for ops when it returns three tensors. There is a write to %arg4, but then there is a read from the original %arg4 later. Therefore, %24 must bufferize out-of-place. Meaning a new copy of %arg4. Meaning that a different buffer (newly allocated) buffer is passed to the next iteration. A workaround is adding %y2 = bufferization.materialize_in_destination %24 in %arg4 : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> and yielding %y2 instead of %24. The bufferization.materialize_in_destination op is not needed for other values. @matthias-springer is it something we can implement in the bufferization? If so, would you help on that?

To repro: iree-opt --iree-codegen-llvmcpu-bufferization-pipeline ~/repro.mlir

func.func @main_dispatch_0_attention_20x4096x64xf16() {
  %c4096 = arith.constant 4096 : index
  %c20 = arith.constant 20 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %4 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_id_y]
  %5 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_count_y]
  scf.for %arg0 = %4 to %c20 step %5 {
    %6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
    %7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
    scf.for %arg1 = %6 to %c4096 step %7 {
      %8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>> -> tensor<20x64x64xf16>
      %9 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x64x64xf16>
      %10 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
      %11 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
      %12 = tensor.empty() : tensor<64x64xf32>
      %extracted_slice = tensor.extract_slice %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
      %cst = arith.constant 0.000000e+00 : f32
      %13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<64x64xf32>) -> tensor<64x64xf32>
      %cst_0 = arith.constant -1.000000e+30 : f32
      %14 = tensor.empty() : tensor<64xf32>
      %15 = linalg.fill ins(%cst_0 : f32) outs(%14 : tensor<64xf32>) -> tensor<64xf32>
      %16 = tensor.empty() : tensor<64xf32>
      %17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<64xf32>) -> tensor<64xf32>
      %c0_1 = arith.constant 0 : index
      %c64 = arith.constant 64 : index
      %c4096_2 = arith.constant 4096 : index
      %18:3 = scf.for %arg2 = %c0_1 to %c4096_2 step %c64 iter_args(%arg3 = %13, %arg4 = %15, %arg5 = %17) -> (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) {
        %extracted_slice_3 = tensor.extract_slice %10[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
        %extracted_slice_4 = tensor.extract_slice %11[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
        %extracted_slice_5 = tensor.extract_slice %9[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
        %cst_6 = arith.constant 0.000000e+00 : f32
        %21 = tensor.empty() : tensor<64x64xf32>
        %22 = linalg.fill ins(%cst_6 : f32) outs(%21 : tensor<64x64xf32>) -> tensor<64x64xf32>
        %23 = linalg.matmul_transpose_b ins(%extracted_slice_5, %extracted_slice_3 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%22 : tensor<64x64xf32>) -> tensor<64x64xf32>
        %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %34 = arith.maximumf %in, %out : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>
        %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<64xf32>) outs(%23 : tensor<64x64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %34 = arith.subf %out, %in : f32
          %35 = math.exp2 %34 : f32
          linalg.yield %35 : f32
        } -> tensor<64x64xf32>
        %26 = tensor.empty() : tensor<64xf32>
        %27 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24, %arg4 : tensor<64xf32>, tensor<64xf32>) outs(%26 : tensor<64xf32>) {
        ^bb0(%in: f32, %in_7: f32, %out: f32):
          %34 = arith.subf %in_7, %in : f32
          %35 = math.exp2 %34 : f32
          linalg.yield %35 : f32
        } -> tensor<64xf32>
        %28 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%27 : tensor<64xf32>) outs(%arg5 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %34 = arith.mulf %in, %out : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>
        %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%25 : tensor<64x64xf32>) outs(%28 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %34 = arith.addf %in, %out : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>
        %30 = tensor.empty() : tensor<64x64xf16>
        %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%25 : tensor<64x64xf32>) outs(%30 : tensor<64x64xf16>) {
        ^bb0(%in: f32, %out: f16):
          %34 = arith.truncf %in : f32 to f16
          linalg.yield %34 : f16
        } -> tensor<64x64xf16>
        %32 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%27 : tensor<64xf32>) outs(%arg3 : tensor<64x64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %34 = arith.mulf %in, %out : f32
          linalg.yield %34 : f32
        } -> tensor<64x64xf32>
        %33 = linalg.matmul ins(%31, %extracted_slice_4 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%32 : tensor<64x64xf32>) -> tensor<64x64xf32>
        scf.yield %33, %24, %29 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
      }
      %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18#2 : tensor<64xf32>) outs(%18#0 : tensor<64x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        %cst_3 = arith.constant 1.000000e+00 : f32
        %21 = arith.divf %cst_3, %in : f32
        %22 = arith.mulf %21, %out : f32
        linalg.yield %22 : f32
      } -> tensor<64x64xf32>
      %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%19 : tensor<64x64xf32>) outs(%extracted_slice : tensor<64x64xf16>) {
      ^bb0(%in: f32, %out: f16):
        %21 = arith.truncf %in : f32 to f16
        linalg.yield %21 : f16
      } -> tensor<64x64xf16>
      %inserted_slice = tensor.insert_slice %20 into %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<64x64xf16> into tensor<20x64x64xf16>
      flow.dispatch.tensor.store %inserted_slice, %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : tensor<20x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
    }
  }
  return
}

The error:

/home/hanchung/repro.mlir:83:9: error: Yield operand #1 is not equivalent to the corresponding iter bbArg
        scf.yield %33, %24, %29 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
        ^
/home/hanchung/repro.mlir:83:9: note: see current operation: "scf.yield"(%49, %40, %45) : (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) -> ()

Metadata

Metadata

Labels

codegenShared code generation infrastructure and dialects

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions