@@ -373,11 +373,11 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
373373
374374 // Returns the pitch (stride in bytes) of \p ptr.
375375 Value getPitch (ConversionPatternRewriter &rewriter, Value ptr,
376- unsigned elemSizeInBits) const {
376+ unsigned elemSizeInBits, unsigned dim = 0 ) const {
377377 Location loc = ptr.getLoc ();
378378 auto b = TritonLLVMOpBuilder (loc, rewriter);
379379
380- int stride = getStride (ptr, 0 );
380+ int stride = getStride (ptr, dim );
381381 // If the stride is 0, we assume a minimum pitch of 64 bytes.
382382 constexpr int MIN_PITCH = 64 ;
383383 if (stride == 0 )
@@ -1884,17 +1884,6 @@ struct LoadOpToBlockIOConversion
18841884 // HW issue for vblock = 4
18851885 vBlocks = vBlocks == 4 ? 1 : vBlocks;
18861886
1887- // TODO: use the axis info to general the handling for both regular pointer
1888- // and block pointer.
1889- const bool memoryRowMajor = isMemoryRowMajor (op);
1890- unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
1891- const bool isTransposeRequired = contiguousDim != colDim;
1892-
1893- if (isTransposeRequired) {
1894- // TODO: support load column major data.
1895- return failure ();
1896- }
1897-
18981887 Location loc = op.getLoc ();
18991888 MLIRContext *ctx = op.getContext ();
19001889 auto b = TritonLLVMOpBuilder (loc, rewriter);
@@ -2012,13 +2001,59 @@ struct LoadOpToBlockIOConversion
20122001 otherElems = unpackLLElements (loc, llOther, rewriter);
20132002 }
20142003
2004+ // TODO: use the axis info to general the handling for both regular pointer
2005+ // and block pointer.
2006+ const bool memoryRowMajor = isMemoryRowMajor (op);
2007+ unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
2008+ const bool isTransposeRequired = contiguousDim != colDim;
2009+
2010+ if (isTransposeRequired) {
2011+ if (numPackedVals > 1 )
2012+ return failure ();
2013+ if (elemSizeInBits > 32 )
2014+ return failure ();
2015+ if (tileWidth > 32 )
2016+ return failure (); // tileWidth is limited to 32 for transpose 2d load.
2017+
2018+ vBlocks = 1 ;
2019+
2020+ // use the d32 for transpose 2d load.
2021+ packedElemSizeInBits = 32 ;
2022+ numPackedVals = packedElemSizeInBits / elemSizeInBits;
2023+ tileHeight = std::min (tileHeight / numPackedVals, 8 );
2024+
2025+ // transpose the width and height of the tile
2026+ std::swap (tileHeight, tileWidth);
2027+ if (tileHeight * tileWidth < threadsPerWarp)
2028+ return failure (); // The tile size is not large enough for IGC scalar
2029+ // backend vectorization.
2030+ // if (oneMatrixPerLoadForBT) {
2031+ // // Only load 1 operand per inst on row.
2032+ // numOperandsPer2DLoadM = 1;
2033+ // tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2034+ // } else {
2035+ // // We can decompose the matrix returned by transposed large 2d load
2036+ // // when threads per warp < column size. Otherwise we have to load one
2037+ // // operand per inst.
2038+ // // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2039+ // // now.
2040+ // numOperandsPer2DLoadM =
2041+ // (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2042+ // }
2043+ // // The transpose 2d load only support 1 operand per inst on column.
2044+ // // (vBlocks = 1)
2045+ // numOperandsPer2DloadN = 1;
2046+ // // TODO: support load column major data.
2047+ // return failure();
2048+ }
2049+
20152050 baseWidth = b.i32_val (
20162051 std::max (64u , vBlocks * tileWidth * (packedElemSizeInBits / 8 )));
20172052 // If the stride is 0, we want to load only the first row.
2018- int stride = getStride (ptr, 0 );
2053+ int stride = getStride (ptr, memoryRowMajor ? 0 : 1 );
20192054 baseHeightInt = (stride == 0 ? 1 : tileHeight);
20202055 baseHeight = b.i32_val (baseHeightInt);
2021- pitch = getPitch (rewriter, ptr, elemSizeInBits);
2056+ pitch = getPitch (rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1 );
20222057 if (!pitch)
20232058 return failure ();
20242059
@@ -2161,7 +2196,7 @@ struct LoadOpToBlockIOConversion
21612196 /* tile_width*/ tileWidth,
21622197 /* tile_height*/ tileHeight,
21632198 /* v_blocks*/ vBlocks,
2164- /* transpose*/ false ,
2199+ /* transpose*/ isTransposeRequired ,
21652200 /* vnni_transform*/ useVNNIFormat);
21662201
21672202 // When strides[0] is 0, we only want to load the first row, so we
0 commit comments