Skip to content

Commit 5cfe24e

Browse files
authored
[mlir][Vector] Add nontemporal attribute, mirroring memref (#76752)
Since vector loads and stores from scalar memrefs translate to llvm.load/store, add the ability to tag said loads and stores as nontemporal. This mirrors functionality available in memref.load/store.
1 parent dc03382 commit 5cfe24e

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,7 +1626,8 @@ def Vector_LoadOp : Vector_Op<"load"> {
16261626

16271627
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
16281628
[MemRead]>:$base,
1629-
Variadic<Index>:$indices);
1629+
Variadic<Index>:$indices,
1630+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
16301631
let results = (outs AnyVectorOfAnyRank:$result);
16311632

16321633
let extraClassDeclaration = [{
@@ -1710,7 +1711,8 @@ def Vector_StoreOp : Vector_Op<"store"> {
17101711
AnyVectorOfAnyRank:$valueToStore,
17111712
Arg<AnyMemRef, "the reference to store to",
17121713
[MemWrite]>:$base,
1713-
Variadic<Index>:$indices
1714+
Variadic<Index>:$indices,
1715+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
17141716
);
17151717

17161718
let extraClassDeclaration = [{

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
192192
vector::LoadOpAdaptor adaptor,
193193
VectorType vectorTy, Value ptr, unsigned align,
194194
ConversionPatternRewriter &rewriter) {
195-
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
195+
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
196+
/*volatile_=*/false,
197+
loadOp.getNontemporal());
196198
}
197199

198200
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
@@ -208,7 +210,8 @@ static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
208210
VectorType vectorTy, Value ptr, unsigned align,
209211
ConversionPatternRewriter &rewriter) {
210212
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
211-
ptr, align);
213+
ptr, align, /*volatile_=*/false,
214+
storeOp.getNontemporal());
212215
}
213216

214217
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,20 @@ func.func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index)
20232023

20242024
// -----
20252025

2026+
func.func @vector_load_op_nontemporal(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
2027+
%0 = vector.load %memref[%i, %j] {nontemporal = true} : memref<200x100xf32>, vector<8xf32>
2028+
return %0 : vector<8xf32>
2029+
}
2030+
2031+
// CHECK-LABEL: func @vector_load_op_nontemporal
2032+
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
2033+
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
2034+
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
2035+
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2036+
// CHECK: llvm.load %[[gep]] {alignment = 4 : i64, nontemporal} : !llvm.ptr -> vector<8xf32>
2037+
2038+
// -----
2039+
20262040
func.func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) -> vector<8xindex> {
20272041
%0 = vector.load %memref[%i, %j] : memref<200x100xindex>, vector<8xindex>
20282042
return %0 : vector<8xindex>
@@ -2049,6 +2063,21 @@ func.func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index
20492063

20502064
// -----
20512065

2066+
func.func @vector_store_op_nontemporal(%memref : memref<200x100xf32>, %i : index, %j : index) {
2067+
%val = arith.constant dense<11.0> : vector<4xf32>
2068+
vector.store %val, %memref[%i, %j] {nontemporal = true} : memref<200x100xf32>, vector<4xf32>
2069+
return
2070+
}
2071+
2072+
// CHECK-LABEL: func @vector_store_op_nontemporal
2073+
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
2074+
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
2075+
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
2076+
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2077+
// CHECK: llvm.store %{{.*}}, %[[gep]] {alignment = 4 : i64, nontemporal} : vector<4xf32>, !llvm.ptr
2078+
2079+
// -----
2080+
20522081
func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) {
20532082
%val = arith.constant dense<11> : vector<4xindex>
20542083
vector.store %val, %memref[%i, %j] : memref<200x100xindex>, vector<4xindex>

0 commit comments

Comments
 (0)