Skip to content

Commit 5fc8e87

Browse files
authored
[MLIR][XeGPU] Retain anchor op layouts for XeGPU nD ops (#170934)
This PR adds support to retain the anchor op layouts (after dropping what's not required) for xegpu nD ops during workgroup to subgroup & unroll transformation
1 parent 8fe38c4 commit 5fc8e87

File tree

8 files changed

+64
-22
lines changed

8 files changed

+64
-22
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
329329
"ArrayRef<OpFoldResult>": $offsets,
330330
"xegpu::CachePolicyAttr": $l1_hint,
331331
"xegpu::CachePolicyAttr": $l2_hint,
332-
"xegpu::CachePolicyAttr": $l3_hint)>
332+
"xegpu::CachePolicyAttr": $l3_hint,
333+
"xegpu::DistributeLayoutAttr": $layout)>
333334
];
334335

335336
let hasVerifier = 1;
@@ -453,7 +454,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
453454
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
454455
"xegpu::CachePolicyAttr": $l1_hint,
455456
"xegpu::CachePolicyAttr": $l2_hint,
456-
"xegpu::CachePolicyAttr": $l3_hint)>
457+
"xegpu::CachePolicyAttr": $l3_hint,
458+
"xegpu::DistributeLayoutAttr": $layout)>
457459
];
458460

459461
let hasVerifier = 1;
@@ -564,7 +566,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
564566
"ArrayRef<OpFoldResult>": $offsets,
565567
"xegpu::CachePolicyAttr": $l1_hint,
566568
"xegpu::CachePolicyAttr": $l2_hint,
567-
"xegpu::CachePolicyAttr": $l3_hint)>
569+
"xegpu::CachePolicyAttr": $l3_hint,
570+
"xegpu::DistributeLayoutAttr": $layout)>
568571
];
569572

570573

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
567567
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
568568
/*packed=*/nullptr, transposeAttr,
569569
/*l1_hint=*/hint,
570-
/*l2_hint=*/hint, /*l3_hint=*/hint);
570+
/*l2_hint=*/hint, /*l3_hint=*/hint,
571+
/*layout=*/nullptr);
571572
rewriter.replaceOp(readOp, loadOp);
572573

573574
return success();
@@ -621,7 +622,8 @@ struct TransferWriteLowering
621622
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
622623
ndDesc, indices,
623624
/*l1_hint=*/hint,
624-
/*l2_hint=*/hint, /*l3_hint=*/hint);
625+
/*l2_hint=*/hint, /*l3_hint=*/hint,
626+
/*layout=*/nullptr);
625627
rewriter.replaceOp(writeOp, storeOp);
626628

627629
return success();
@@ -725,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
725727
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
726728
/*packed=*/nullptr, /*transpose=*/nullptr,
727729
/*l1_hint=*/hint,
728-
/*l2_hint=*/hint, /*l3_hint=*/hint);
730+
/*l2_hint=*/hint, /*l3_hint=*/hint,
731+
/*layout=*/nullptr);
729732
rewriter.replaceOp(loadOp, loadNdOp);
730733

731734
return success();
@@ -763,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
763766
auto storeNdOp =
764767
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
765768
/*l1_hint=*/hint,
766-
/*l2_hint=*/hint, /*l3_hint=*/hint);
769+
/*l2_hint=*/hint, /*l3_hint=*/hint,
770+
/*layout=*/nullptr);
767771

768772
rewriter.replaceOp(storeOp, storeNdOp);
769773

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,16 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
472472
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
473473
xegpu::CachePolicyAttr l1_hint,
474474
xegpu::CachePolicyAttr l2_hint,
475-
xegpu::CachePolicyAttr l3_hint) {
475+
xegpu::CachePolicyAttr l3_hint,
476+
xegpu::DistributeLayoutAttr layout) {
476477
SmallVector<Value> dynamicOffsets;
477478
SmallVector<int64_t> staticOffsets;
478479
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
479480

480481
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
481482

482483
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
483-
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
484+
l2_hint, l3_hint, /*anchor_layout=*/layout);
484485
}
485486

486487
LogicalResult PrefetchNdOp::verify() {
@@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
527528
UnitAttr packed, DenseI64ArrayAttr transpose,
528529
xegpu::CachePolicyAttr l1_hint,
529530
xegpu::CachePolicyAttr l2_hint,
530-
xegpu::CachePolicyAttr l3_hint) {
531+
xegpu::CachePolicyAttr l3_hint,
532+
xegpu::DistributeLayoutAttr layout) {
531533
SmallVector<Value> dynamicOffsets;
532534
SmallVector<int64_t> staticOffsets;
533535
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -536,7 +538,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
536538

537539
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
538540
packed, transpose, l1_hint, l2_hint, l3_hint,
539-
/*anchor_layout=*/nullptr);
541+
/*anchor_layout=*/layout);
540542
}
541543

542544
LogicalResult LoadNdOp::verify() {
@@ -647,15 +649,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
647649
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
648650
xegpu::CachePolicyAttr l1_hint,
649651
xegpu::CachePolicyAttr l2_hint,
650-
xegpu::CachePolicyAttr l3_hint) {
652+
xegpu::CachePolicyAttr l3_hint,
653+
xegpu::DistributeLayoutAttr layout) {
651654
SmallVector<Value> dynamicOffsets;
652655
SmallVector<int64_t> staticOffsets;
653656
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
654657

655658
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
656659

657660
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
658-
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
661+
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
659662
}
660663

661664
LogicalResult StoreNdOp::verify() {

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
528528
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
529529
newDescOp.getResult(),
530530
getPrefetchOffsets(initForOp.getInductionVar()),
531-
readCacheHint, readCacheHint, readCacheHint);
531+
readCacheHint, readCacheHint, readCacheHint,
532+
/*layout=*/nullptr);
532533

533534
// Insert prefetch op in main loop.
534535
// Calculate prefetch offset after the init prefetches have been issued.
@@ -539,7 +540,7 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
539540
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
540541
newDescOp.getResult(),
541542
getPrefetchOffsets(prefetchOffset), readCacheHint,
542-
readCacheHint, readCacheHint);
543+
readCacheHint, readCacheHint, /*layout=*/nullptr);
543544

544545
// Unroll the init loop.
545546
if (failed(loopUnrollFull(initForOp)))

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
214214
newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
215215
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
216216
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
217-
origLoadOp.getL3HintAttr());
217+
origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
218218
// Set the layout for the loadOp.
219219
auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
220220
xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
238238
if (!targetShape)
239239
return failure();
240240

241+
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
242+
if (layout)
243+
layout = layout.dropInstData();
241244
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
242245
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
243246

@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
255258
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
256259
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
257260
op.getL1HintAttr(), op.getL2HintAttr(),
258-
op.getL3HintAttr());
261+
op.getL3HintAttr(), layout);
259262
// return dummy Value to satisfy function's signature
260263
return nullptr;
261264
};
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
282285
if (!targetShape)
283286
return failure();
284287

288+
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
289+
if (layout)
290+
layout = layout.dropInstData();
285291
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
286292
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
287293

@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
306312
return xegpu::LoadNdOp::create(
307313
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
308314
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
309-
op.getL2HintAttr(), op.getL3HintAttr());
315+
op.getL2HintAttr(), op.getL3HintAttr(), layout);
310316
};
311317
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
312318
*targetShape, createLoad, loc, rewriter);
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
331337
if (!targetShape)
332338
return failure();
333339

340+
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
341+
if (layout)
342+
layout = layout.dropInstData();
334343
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
335344
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
336345

@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
354363
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
355364
convertedTdescs[0], offsets,
356365
op.getL1HintAttr(), op.getL2HintAttr(),
357-
op.getL3HintAttr());
366+
op.getL3HintAttr(), layout);
358367
// return dummy Value to satisfy function's signature
359368
return nullptr;
360369
};

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
317317
if (failed(genOffsetsList(rewriter, op, offsetsList)))
318318
return failure();
319319

320+
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
321+
if (layout)
322+
layout = layout.dropSgLayoutAndData();
320323
SmallVector<Value> newOps;
321324
for (auto [tdesc, offsets] :
322325
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
@@ -326,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
326329
auto newOp = xegpu::LoadNdOp::create(
327330
rewriter, op.getLoc(), newResTy, tdesc, offsets,
328331
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
329-
op.getL2HintAttr(), op.getL3HintAttr());
332+
op.getL2HintAttr(), op.getL3HintAttr(), layout);
330333
newOps.push_back(newOp);
331334
}
332335
rewriter.replaceOpWithMultiple(op, {newOps});
@@ -347,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
347350
if (failed(genOffsetsList(rewriter, op, offsetsList)))
348351
return failure();
349352

353+
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
354+
if (layout)
355+
layout = layout.dropSgLayoutAndData();
350356
for (auto [v, tdesc, offsets] :
351357
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
352358
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
353359
op.getL1HintAttr(), op.getL2HintAttr(),
354-
op.getL3HintAttr());
360+
op.getL3HintAttr(), layout);
355361
}
356362
rewriter.eraseOp(op);
357363

@@ -371,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
371377
if (failed(genOffsetsList(rewriter, op, offsetsList)))
372378
return failure();
373379

380+
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
381+
if (layout)
382+
layout = layout.dropSgLayoutAndData();
374383
for (auto [tdesc, offsets] :
375384
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
376385
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
377386
op.getL1HintAttr(), op.getL2HintAttr(),
378-
op.getL3HintAttr());
387+
op.getL3HintAttr(), layout);
379388
}
380389
rewriter.eraseOp(op);
381390

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,4 +633,17 @@ gpu.module @test_distribution {
633633
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
634634
gpu.return
635635
}
636+
637+
// CHECK-LABEL: load_nd_tdesc_with_anchor_layout
638+
gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) {
639+
//CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
640+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
641+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
642+
// CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}>
643+
// CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
644+
%load = xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}>
645+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
646+
-> vector<256x128xf32>
647+
gpu.return
648+
}
636649
}

0 commit comments

Comments
 (0)