Skip to content

Commit 67f59a6

Browse files
authored
[mlir][xegpu] Improve scatter attribute definition (#126540)
Refactors XeGPU scatter attribute introducing following: - improved docs formatting - default initialized parameters - invariant checks in attribute verifier - removal of additional parsing error The attribute's getters now provide default values simplifying their usage and scattered tensor descriptor handling. Related descriptor verifier is updated to avoid check duplication.
1 parent 701223a commit 67f59a6

File tree

6 files changed

+53
-23
lines changed

6 files changed

+53
-23
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,29 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
5959

6060
def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
6161
let summary = [{a composite attribute for `TensorDescType`}];
62-
let description = [{`ScatterTensorDesc` (or `scatter_tdesc_attr`) is a composite
63-
attribute defined for `TensorDescType` for describing following
64-
properties of a `TensorDesc`.
62+
let description = [{
63+
`ScatterTensorDesc` is a composite attribute defined for `TensorDescType`
64+
for describing following properties of a `TensorDesc`:
65+
6566
1. `memory_space`: It describes where the data block described by the
6667
TensorDesc is located, `Global` device memory or `Shared` local memory.
6768
It is default to `Global`.
68-
2. `chunk_size`: indicates number of continious elements accessed for each
69+
70+
2. `chunk_size`: indicates number of contiguous elements accessed for each
6971
offset, default is 1. It is used with `scattered` attr only.
7072
}];
7173

7274
let parameters = (ins
73-
OptionalParameter<"MemorySpaceAttr">: $memory_space,
74-
OptionalParameter<"IntegerAttr", "1">: $chunk_size
75+
DefaultValuedParameter<
76+
"MemorySpaceAttr",
77+
"MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
78+
"Data memory location"
79+
>: $memory_space,
80+
DefaultValuedParameter<
81+
"IntegerAttr",
82+
"IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
83+
"Number of contiguous elements"
84+
>: $chunk_size
7585
);
7686

7787
let builders = [
@@ -80,6 +90,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
8090
CArg<"int", "1">: $chunk_size
8191
)>
8292
];
93+
94+
let genVerifyDecl = 1;
8395
}
8496

8597
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
172172
auto attr = getEncoding();
173173
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
174174
assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
175-
if (scatter_attr && scatter_attr.getChunkSize())
175+
if (scatter_attr)
176176
return scatter_attr.getChunkSize().getInt();
177177
return 1;
178178
}

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
5555
return Base::get(context, scopeAttr, chunkSizeAttr);
5656
}
5757

58+
LogicalResult ScatterTensorDescAttr::verify(
59+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
60+
MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
61+
int64_t chunkSize = chunk_size.getInt();
62+
SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
63+
16, 32, 64, 128, 256};
64+
if (!llvm::is_contained(supportedChunkSizes, chunkSize))
65+
return emitError() << "invalid chunk size";
66+
67+
return success();
68+
}
69+
5870
//===----------------------------------------------------------------------===//
5971
// XeGPU_SGMapAttr
6072
//===----------------------------------------------------------------------===//
@@ -166,8 +178,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
166178
continue;
167179
}
168180
}
169-
parser.emitError(parser.getCurrentLocation(),
170-
"Failed to parse the attribute.\n");
171181
return {};
172182
}
173183

@@ -237,8 +247,7 @@ LogicalResult TensorDescType::verify(
237247
// Expected tensor ranks for scattered data:
238248
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
239249
// - 2D tensor for scattered blocks (chunk size > 1)
240-
IntegerAttr chunkAttr = scatterAttr.getChunkSize();
241-
unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
250+
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
242251
if (rank == 1 && chunkSize != 1)
243252
return emitError() << "expected non-contiguous elements for 1D tensor";
244253
if (rank == 2 && chunkSize < 2)
@@ -273,8 +282,7 @@ LogicalResult TensorDescType::verify(
273282
return emitError()
274283
<< "cannot map over non-contiguous scattered row elements";
275284

276-
IntegerAttr chunkAttr = scatterAttr.getChunkSize();
277-
unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
285+
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
278286
if (wiData[1] != chunkSize)
279287
return emitError() << "work item data mapping must match the number of "
280288
"contiguous elements";

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,16 +419,8 @@ LogicalResult CreateDescOp::verify() {
419419
<< " Source: " << srcMemorySpace
420420
<< ", TensorDesc: " << tdescMemorySpace;
421421

422-
auto chunkSize = tdescTy.getChunkSize();
423-
424-
// check chunk_size
425-
llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
426-
16, 32, 64, 128, 256};
427-
if (!llvm::is_contained(supportedChunkSizes, chunkSize))
428-
return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
429-
"8, 16, 32, 64, 128, or 256.");
430-
431422
// check total size
423+
auto chunkSize = tdescTy.getChunkSize();
432424
auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
433425
auto bitsPerLane = elemBits * chunkSize;
434426
if (chunkSize > 1 && bitsPerLane % 32) {

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
181181
gpu.return
182182
}
183183

184+
// CHECK: gpu.func @test_create_tdesc_vc_2(%[[arg0:.*]]: memref<?xf32>) {
185+
gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
186+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
187+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
188+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
189+
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
190+
gpu.return
191+
}
192+
184193
// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
185194
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
186195
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,23 @@ func.func @test_create_tdesc_vc_2(%src: ui64) {
190190
}
191191

192192
// -----
193-
func.func @test_create_tdesc_vc_1(%src: memref<?xf32>) {
193+
func.func @test_create_tdesc_vc_3(%src: memref<?xf32>) {
194194
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
195195
// expected-error@+1 {{Memory space mismatch}}
196196
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
197197
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>>
198198
return
199199
}
200200

201+
// -----
202+
func.func @test_create_tdesc_vc_4(%src: memref<?xf32>) {
203+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
204+
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
205+
// expected-error@+1 {{invalid chunk size}}
206+
-> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 5>>
207+
return
208+
}
209+
201210
// -----
202211
func.func @test_prefetch_vc_1(%src: memref<24x32xf16>) {
203212
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>

0 commit comments

Comments
 (0)