Description
The issue if forked from #16421. I had a long discussion with @matthias-springer about the bufferization issue. We are not able to bufferize scf.for ops when it returns three tensors. There is a write to %arg4
, but then there is a read from the original %arg4
later. Therefore, %24
must bufferize out-of-place. Meaning a new copy of %arg4
. Meaning that a different buffer (newly allocated) buffer is passed to the next iteration. A workaround is adding %y2 = bufferization.materialize_in_destination %24 in %arg4 : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>
and yielding %y2
instead of %24
. The bufferization.materialize_in_destination
op is not needed for other values. @matthias-springer is it something we can implement in the bufferization? If so, would you help on that?
To repro: iree-opt --iree-codegen-llvmcpu-bufferization-pipeline ~/repro.mlir
func.func @main_dispatch_0_attention_20x4096x64xf16() {
%c4096 = arith.constant 4096 : index
%c20 = arith.constant 20 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c20 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c4096 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>> -> tensor<20x64x64xf16>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x64x64xf16>
%10 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%12 = tensor.empty() : tensor<64x64xf32>
%extracted_slice = tensor.extract_slice %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
%cst = arith.constant 0.000000e+00 : f32
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<64x64xf32>) -> tensor<64x64xf32>
%cst_0 = arith.constant -1.000000e+30 : f32
%14 = tensor.empty() : tensor<64xf32>
%15 = linalg.fill ins(%cst_0 : f32) outs(%14 : tensor<64xf32>) -> tensor<64xf32>
%16 = tensor.empty() : tensor<64xf32>
%17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<64xf32>) -> tensor<64xf32>
%c0_1 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c4096_2 = arith.constant 4096 : index
%18:3 = scf.for %arg2 = %c0_1 to %c4096_2 step %c64 iter_args(%arg3 = %13, %arg4 = %15, %arg5 = %17) -> (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) {
%extracted_slice_3 = tensor.extract_slice %10[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
%extracted_slice_4 = tensor.extract_slice %11[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
%extracted_slice_5 = tensor.extract_slice %9[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
%cst_6 = arith.constant 0.000000e+00 : f32
%21 = tensor.empty() : tensor<64x64xf32>
%22 = linalg.fill ins(%cst_6 : f32) outs(%21 : tensor<64x64xf32>) -> tensor<64x64xf32>
%23 = linalg.matmul_transpose_b ins(%extracted_slice_5, %extracted_slice_3 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%22 : tensor<64x64xf32>) -> tensor<64x64xf32>
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%34 = arith.maximumf %in, %out : f32
linalg.yield %34 : f32
} -> tensor<64xf32>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<64xf32>) outs(%23 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
%34 = arith.subf %out, %in : f32
%35 = math.exp2 %34 : f32
linalg.yield %35 : f32
} -> tensor<64x64xf32>
%26 = tensor.empty() : tensor<64xf32>
%27 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24, %arg4 : tensor<64xf32>, tensor<64xf32>) outs(%26 : tensor<64xf32>) {
^bb0(%in: f32, %in_7: f32, %out: f32):
%34 = arith.subf %in_7, %in : f32
%35 = math.exp2 %34 : f32
linalg.yield %35 : f32
} -> tensor<64xf32>
%28 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%27 : tensor<64xf32>) outs(%arg5 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%34 = arith.mulf %in, %out : f32
linalg.yield %34 : f32
} -> tensor<64xf32>
%29 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%25 : tensor<64x64xf32>) outs(%28 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%34 = arith.addf %in, %out : f32
linalg.yield %34 : f32
} -> tensor<64xf32>
%30 = tensor.empty() : tensor<64x64xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%25 : tensor<64x64xf32>) outs(%30 : tensor<64x64xf16>) {
^bb0(%in: f32, %out: f16):
%34 = arith.truncf %in : f32 to f16
linalg.yield %34 : f16
} -> tensor<64x64xf16>
%32 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%27 : tensor<64xf32>) outs(%arg3 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
%34 = arith.mulf %in, %out : f32
linalg.yield %34 : f32
} -> tensor<64x64xf32>
%33 = linalg.matmul ins(%31, %extracted_slice_4 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%32 : tensor<64x64xf32>) -> tensor<64x64xf32>
scf.yield %33, %24, %29 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
}
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18#2 : tensor<64xf32>) outs(%18#0 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
%cst_3 = arith.constant 1.000000e+00 : f32
%21 = arith.divf %cst_3, %in : f32
%22 = arith.mulf %21, %out : f32
linalg.yield %22 : f32
} -> tensor<64x64xf32>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%19 : tensor<64x64xf32>) outs(%extracted_slice : tensor<64x64xf16>) {
^bb0(%in: f32, %out: f16):
%21 = arith.truncf %in : f32 to f16
linalg.yield %21 : f16
} -> tensor<64x64xf16>
%inserted_slice = tensor.insert_slice %20 into %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<64x64xf16> into tensor<20x64x64xf16>
flow.dispatch.tensor.store %inserted_slice, %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : tensor<20x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
}
}
return
}
The error:
/home/hanchung/repro.mlir:83:9: error: Yield operand #1 is not equivalent to the corresponding iter bbArg
scf.yield %33, %24, %29 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
^
/home/hanchung/repro.mlir:83:9: note: see current operation: "scf.yield"(%49, %40, %45) : (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) -> ()