From 50872f3d27c97946dd4deed0c13e94a2e140e9a8 Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Mon, 21 Jul 2025 15:49:00 +0000 Subject: [PATCH] [LoadStoreOpToLLVM] Transposed 2d load. Signed-off-by: Lu,Chengjun --- python/test/unit/intel/test_block_io.py | 27 +- .../tensor-pointer-load-block-2d.mlir | 87 ++- .../LoadStoreOpToLLVM.cpp | 705 ++---------------- 3 files changed, 161 insertions(+), 658 deletions(-) diff --git a/python/test/unit/intel/test_block_io.py b/python/test/unit/intel/test_block_io.py index 6c6d5f1250..2f3a2f9fdb 100644 --- a/python/test/unit/intel/test_block_io.py +++ b/python/test/unit/intel/test_block_io.py @@ -120,8 +120,9 @@ def warps_per_cta(layout): @pytest.mark.parametrize("layout", layouts) @pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False), (False, True)]) +@pytest.mark.parametrize("transpose", [True, False]) @pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend") -def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path): +def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path): warps = warps_per_cta(layout) num_warps = int(np.prod(warps)) @@ -132,16 +133,20 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] + block_io = "\"column_major\"" if transpose else "\"row_major\"" + + strides = "[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]" + if load_block_ptr: load_ops = f""" - %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array}} : > - %store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array, padding = 1 : i32}} : !tt.ptr> + %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {strides}, [%c0_i32, %c0_i32] {{order = array}} : > + %store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array, padding = 1 : i32}} : !tt.ptr> """ else: load_ops = f""" %src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> - %src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> - %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + %src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> + %store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> """ if store_block_ptr: store_ops = f""" @@ -175,6 +180,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi %7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout> %row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout> + %stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout> + %col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout> + %8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout> + %9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout> + %col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout> + {load_ops} {store_ops} @@ -195,10 +206,14 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi temp_file.write_text(ir) kernel = triton.compile(str(temp_file)) + a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a + kernel[(1, 1, 1)](a, x) assert torch.equal(a, x) if support_block_io: if not load_block_ptr: - assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir'] + if not ((transpose and type(layout) in [SliceLayout]) or + (transpose and dtype_str in ["float16", "int8"])): # TODO: add support for these cases + assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir'] assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir'] diff --git a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir index ed6caad5cf..ed7023e2dd 100644 --- a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir +++ b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm +// RUN: env TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { @@ -566,3 +566,88 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.support_sg_2d_block} { + tt.func public @trans_block_load_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} { + %cst = arith.constant dense<64> : tensor<32x1xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %cst_0 = arith.constant dense<32> : tensor<1x64xi32, #blocked> + %8 = arith.muli %4, %cst_0 : tensor<1x64xi32, #blocked> + %9 = tt.broadcast %1 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked> + %10 = tt.broadcast %8 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %11 = arith.addi %9, %10 : tensor<32x64xi32, #blocked> + %12 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %13 = tt.addptr %12, %11 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + // COM: Transpose 2D block load with i32 type. + // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 2, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<1xi32> + %14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<32x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}> +module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { + tt.func public @trans_block_load_i16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} { + %cst = arith.constant dense<64> : tensor<32x1xi32, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma> + %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> + %cst_0 = arith.constant dense<32> : tensor<1x64xi32, #mma> + %8 = arith.muli %4, %cst_0 : tensor<1x64xi32, #mma> + %9 = tt.broadcast %1 : tensor<32x1xi32, #mma> -> tensor<32x64xi32, #mma> + %10 = tt.broadcast %8 : tensor<1x64xi32, #mma> -> tensor<32x64xi32, #mma> + %11 = arith.addi %9, %10 : tensor<32x64xi32, #mma> + %12 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #mma> + %13 = tt.addptr %12, %11 : tensor<32x64x!tt.ptr, #mma>, tensor<32x64xi32, #mma> + // COM: Transpose 2D block load with f16 type. Pack the loaded vector to the i32 type. Then transpose the loaded i32 vector with bitcast op. + // CHECK: %[[LOADED:.*]] = triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + // CHECK: %[[PACKED_I32:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [0, 1, 2, 3] : vector<8xi32> + // CHECK: llvm.bitcast %[[PACKED_I32]] : vector<4xi32> to vector<8xf16> + // CHECK-COUNT-3: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<32x64x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}> +module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { + tt.func public @trans_block_load_i8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} { + %cst = arith.constant dense<128> : tensor<128x1xi32, #mma> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %2 = arith.muli %1, %cst : tensor<128x1xi32, #mma> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> + %5 = tt.broadcast %2 : tensor<128x1xi32, #mma> -> tensor<128x128xi32, #mma> + %6 = tt.broadcast %4 : tensor<1x128xi32, #mma> -> tensor<128x128xi32, #mma> + %7 = arith.addi %5, %6 : tensor<128x128xi32, #mma> + %cst_0 = arith.constant dense<128> : tensor<1x128xi32, #mma> + %8 = arith.muli %4, %cst_0 : tensor<1x128xi32, #mma> + %9 = tt.broadcast %1 : tensor<128x1xi32, #mma> -> tensor<128x128xi32, #mma> + %10 = tt.broadcast %8 : tensor<1x128xi32, #mma> -> tensor<128x128xi32, #mma> + %11 = arith.addi %9, %10 : tensor<128x128xi32, #mma> + %12 = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr, #mma> + %13 = tt.addptr %12, %11 : tensor<128x128x!tt.ptr, #mma>, tensor<128x128xi32, #mma> + // COM: Transpose 2D block load with i8 type. Pack the loaded vector to the i32 type. Then transpose the loaded i32 vector with bitcast op. + // CHECK: %[[LOADED:.*]] = triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + // COM: We do the shuffle and then the bitcast. Maybe it is efficient to do bitcast first then shuffle? + // CHECK: %[[PACKED_1ST_HALF:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [0, 1] : vector<8xi32> + // CHECK: llvm.bitcast %[[PACKED_1ST_HALF]] : vector<2xi32> to vector<8xi8> + // CHECK: %[[PACKED_2ND_HALF:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [2, 3] : vector<8xi32> + // CHECK: llvm.bitcast %[[PACKED_2ND_HALF]] : vector<2xi32> to vector<8xi8> + // CHECK-COUNT-7: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<128x128x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index cfe07a7d5b..7c6950ebad 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -155,7 +155,17 @@ struct LoadStoreConversionBase { AxisInfo *axisInfo = const_cast(axisAnalysisPass) .getAxisInfo(ptr); - return axisInfo ? axisInfo->getStride(dim) : -1; + if (axisInfo) { + const SmallVector &stride = axisInfo->getStride(); + if (dim < stride.size()) { + return stride[dim]; + } + // There is only one case that the regular pointer is defined as the + // function args. + assert(stride.size() == 1 && stride[0] == -1 && + "get the stride of invalid dim from regular pointer"); + } + return -1; } unsigned getContiguity(Value ptr) const { @@ -1831,640 +1841,6 @@ struct LoadOpToBlockIOConversion return success(); } - // FIXME: Temp solution for supporting transpose load. - LogicalResult - matchAndRewriteTranspose(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value ptr = op.getPtr(); - Value mask = op.getMask(); - Type resultType = op.getType(); - auto tensorType = cast(resultType); - const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); - - Attribute encoding = tensorType.getEncoding(); - std::optional llEncoding = - cast(encoding).toLinearLayout( - tensorType.getShape()); - assert(llEncoding.has_value() && - "unexpected failure when getting linear layout"); - - Type eltTy = getTypeConverter()->convertType(tensorType.getElementType()); - unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - - auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); - SmallVector threadOrder(llAttr.getThreadOrder()); - size_t rank = threadOrder.size(); - const bool isTransposeRequired = true; - - // Step 2: Right now we only support DPAS related layout to simplify the - // lowering. - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); - const ArrayRef tensorShape = tensorType.getShape(); - unsigned numElems = getTotalElemsPerThread(resultType); - SmallVector repetitons = - dpasLayout.getDPASRepetitions(tensorShape, opIdx); - assert(repetitons.size() == 3 && - "getDPASRepetitions always return rank 3 size"); - assert(repetitons[0] == 1 && "Only supports rank of 2 for now"); - SmallVector numReps{repetitons[1], repetitons[2]}; - ArrayRef warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector dpasWarpsOrder = - getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); - unsigned threadsPerWarp = - product(getThreadsPerWarp(dpasLayout, tensorShape)); - - Value warpId = rewriter.create( - loc, i32_ty, - rewriter.create(loc, /*upperBound=*/nullptr)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); - - // By default, use the unpacked type for the 2D load result type. - Type loadResultElemType = typeConverter->convertType(eltTy); - bool usePackedType = false; - unsigned packedElemsNum = 1; - // The tensor values are distributed as DotOp layout of DPAS. - // If the element size of the tensor matches the DPAS packed layout, then - // use the packed type for the 2D load result type. For example, - // The intermediate ops generated by ConvertTritonGPUToLLVM: - // %0 = load_2d %ptr : vector<8 x i32> - // %1 = bitcast %0 : vector<8 x i32> -> vector<16 x f16> - // %2 = bitcast %1 : vector<16 x f16> -> vector<8 x i32> - // %3 = dpas %2 - // And the LLVM dialect optimization pass can eliminate the duplicated - // bitcast. Then there is a shortcut to use the load result directly as the - // input operands to DPAS. - // TODO: add support for int4 and int2. - - // OperandA: outer dim -> M, inner dim -> K. - // OperandB: outer dim -> N, inner dim -> K. - // OperandC: outer dim -> M, inner dim -> N. - // Round the warp id fit into the tensor shape. - unsigned dimOuter; - unsigned dimInner; - SmallVector repCluster(dpasLayout.getRepCluster()); - SmallVector warpShape; - SmallVector dpasInstShape; - - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - warpShape = std::move(dpasLayout.getShapeA()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeA()); - dimOuter = rank - 2; - dimInner = rank - 1; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = elemSizeInBits == 32 ? i32_ty : i16_ty; - packedElemsNum = opsPerChannel == 4 ? 2 : 1; - usePackedType = true; - } else if (opsPerChannel == 4) { - packedElemsNum = 2; - unsigned packedBitWidht = elemSizeInBits * packedElemsNum; - if (packedBitWidht > 64) { - // Be conservative to avoid the packed type exceeds 64 bits. - return failure(); - } - // Need to pack two column into one to work around vectorization - // limitation. - loadResultElemType = int_ty(packedBitWidht); - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - warpShape = std::move(dpasLayout.getShapeB()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeB()); - dimOuter = rank - 1; - dimInner = rank - 2; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = i32_ty; - packedElemsNum = opsPerChannel; - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandC: - warpShape = std::move(dpasLayout.getShapeC()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeC()); - dimOuter = rank - 2; - dimInner = rank - 1; - usePackedType = false; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - unsigned elemsPerLanePerDPASInst = - product(dpasInstShape) / threadsPerWarp; - LLVMTypeConverter *typeConverter = getTypeConverter(); - Type unpackedDPASOperandType = LLVM::getVectorType( - typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); - - unsigned packedElemsPerLanePerDPASInst = - elemsPerLanePerDPASInst / packedElemsNum; - Type packedDPASOperandType = - LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst); - - unsigned outerDimTileNum = - mlir::ceil(tensorShape[dimOuter], warpShape[dimOuter]); - unsigned outerDimWarpNum = - std::min(warpsPerCTA[dimOuter], outerDimTileNum); - Value outerDimWarpId = - b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); - unsigned innerDimRequiredWarpNum = - mlir::ceil(tensorShape[dimInner], warpShape[dimInner]); - unsigned innerDimWarpNum = - std::min(warpsPerCTA[dimInner], innerDimRequiredWarpNum); - - // Step 3: Get the tile size of load. - unsigned tileWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned tileHeight = dpasInstShape[threadOrder[rank - 1]]; - unsigned vBlocks = 1; - unsigned numOperandsOuterDimPerLoad = 1; - unsigned numOperandsInnerDimPerLoad = 1; - unsigned maskConstancyHor = 1, maskConstancyVer = 1; - unsigned instWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned instHeight = dpasInstShape[threadOrder[rank - 1]]; - - std::map, Value> ptrs; - std::map, Value> masks; - std::map, Value> others; - - Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - - SmallVector ptrElems, maskElems, otherElems; - // Get the LLVM values for pointers - ptrElems = unpackLLElements(loc, llPtr, rewriter); - assert(ptrElems.size() == numElems && - "the number of pointer values is not matched with the number of " - "elements"); - - // Get the LLVM values for mask - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(maskElems.size() == numElems && - "the number of mask values is not matched with the number of " - "elements"); - auto axisInfo = - const_cast(axisAnalysisPass) - .getAxisInfo(mask); - if (axisInfo) { - maskConstancyHor = axisInfo->getConstancy(rank - 1); - maskConstancyVer = axisInfo->getConstancy(rank - 2); - } else { - maskConstancyHor = 1; - maskConstancyVer = 1; - } - } else { - // no mask - maskConstancyHor = std::numeric_limits::max(); - maskConstancyVer = std::numeric_limits::max(); - } - - // Check the constancy of the mask support to load the memory in 2D block. - if (!(maskConstancyHor >= instWidth && maskConstancyVer >= instHeight)) - return failure(); - - // Get the LLVM values for `other` - Value other = op.getOther(); - Value llOther = adaptor.getOther(); - DenseElementsAttr constAttr; - if (other) { - if (matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) { - Type elemTy = constAttr.getElementType(); - auto handleSplatValue = [&](auto splatVal) { - if (!splatVal.isZero()) { - otherElems = SmallVector( - numElems, - rewriter.create(loc, elemTy, splatVal)); - } - }; - - TypeSwitch(elemTy) - .Case([&](FloatType) { - handleSplatValue(constAttr.getSplatValue()); - }) - .Case([&](IntegerType) { - handleSplatValue(constAttr.getSplatValue()); - }); - } else { - otherElems = unpackLLElements(loc, llOther, rewriter); - } - } - - // re-arrange the ptrs and masks to for large 2D block IO. - // Layout is unrelated to the scalar type. - SmallVector> offsets = - mlir::emitOffsetForLayout(encoding, tensorType); - for (size_t i = 0; i < ptrElems.size(); ++i) { - SmallVector offset = offsets[i]; - ptrs[offset] = ptrElems[i]; - if (llMask) - masks[offset] = maskElems[i]; - if (otherElems.size()) - others[offset] = otherElems[i]; - } - - unsigned numOperandsPer2DLoadM, numOperandsPer2DLoadN; - if (opIdx == DpasEncodingAttr::OpIdx::OperandA) - return failure(); - - if (!usePackedType) - return failure(); - - std::swap(tileHeight, tileWidth); - - // We can decompose the matrix returned by transposed large 2d load - // when threads per warp < column size. Otherwise we have to load one - // operand per inst. - // Note: the tileHeight and numOperandsPer2DLoadM are the column size - // now. - numOperandsPer2DLoadM = - (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; - // The transpose 2d load only support 1 operand per inst on column. - // (vBlocks = 1) - numOperandsPer2DLoadN = 1; - - // adjust the mask constancy to fit the 2D load. - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, maskConstancyHor / instWidth); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, maskConstancyVer / instHeight); - - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - constexpr unsigned MAX_TILE_HEIGHT = 32; - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, - static_cast(MAX_TILE_HEIGHT / tileHeight)); - - // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands - // by enlarging the vBlocks. - unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; - constexpr int MAX_WIDTH = 64; - if (totalBytesPerRowPerDPASOp > MAX_WIDTH) - return failure(); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, MAX_WIDTH / totalBytesPerRowPerDPASOp); - // vBlocks has HW limitation of 4. - numOperandsPer2DLoadN = std::min(numOperandsPer2DLoadN, 4u); - - tileHeight = instHeight * numOperandsPer2DLoadM; - tileWidth = instWidth; - vBlocks = numOperandsPer2DLoadN; - - numOperandsOuterDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadM - : numOperandsPer2DLoadN; - numOperandsInnerDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadN - : numOperandsPer2DLoadM; - - std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); - - unsigned numLoadPerOutRepCluster = - mlir::ceil(repCluster[dimOuter], numOperandsOuterDimPerLoad); - unsigned numLoadPerInnerRepCluster = - mlir::ceil(repCluster[dimInner], numOperandsInnerDimPerLoad); - - unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * - numOperandsOuterDimPerLoad * - numOperandsInnerDimPerLoad; - Type load2DGenXType = - LLVM::getVectorType(loadResultElemType, numValuesPerLoad); - - // Step 4: Generates the load instruction. - // The stride for the tile replicates. - unsigned numRepOuter; - unsigned numRepInner; - unsigned repOuterStride = warpShape[dimOuter] * outerDimWarpNum; - unsigned repInnerStride; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: - case DpasEncodingAttr::OpIdx::OperandB: - numRepOuter = numReps[dimOuter]; - numRepInner = - mlir::ceil(numReps[dimInner], numOperandsInnerDimPerLoad); - repInnerStride = warpShape[dimInner] * numOperandsInnerDimPerLoad; - break; - case DpasEncodingAttr::OpIdx::OperandC: - numRepOuter = numReps[dimOuter]; - numRepInner = numReps[dimInner]; - repInnerStride = warpShape[dimInner] * innerDimWarpNum; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - - Value pitch = - getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1); - if (!pitch) - return failure(); - - // If the stride is 0, we want to load only the first row. - int stride = getStride(ptr, memoryRowMajor ? 0 : 1); - unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight); - Value baseHeight = b.i32_val(baseHeightInt); - Value baseWidth = - b.i32_val(std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8))); - - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - const unsigned originalElemBits = elemSizeInBits; - - LDBG("Block io tile shape: [" - << tileHeight << ", " << tileWidth << "], vblocks: " << vBlocks - << ", numOperandsPerLoad: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad) - << "], number loads per repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerOutRepCluster - : numLoadPerInnerRepCluster) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerInnerRepCluster - : numLoadPerOutRepCluster) - << "], number repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepOuter - : numRepInner) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepInner - : numRepOuter) - << "]"); - - ValueTable loadVals; - for (int inner = 0; inner < numRepInner; ++inner) { - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int loadInner = 0; loadInner < numLoadPerInnerRepCluster; - ++loadInner) { - for (int loadOuter = 0; loadOuter < numLoadPerOutRepCluster; - ++loadOuter) { - unsigned offsetOuter = - outer * repOuterStride + loadOuter * dpasInstShape[dimOuter] * - numOperandsOuterDimPerLoad; - unsigned offsetInner = - inner * repInnerStride + loadInner * dpasInstShape[dimInner] * - numOperandsInnerDimPerLoad; - unsigned offsetM = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetOuter - : offsetInner); - unsigned offsetN = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetInner - : offsetOuter); - - LDBG("Block load iterator: inner: " - << inner << ", outer:" << outer << ", loadInner:" << loadInner - << ", loadOuter:" << loadOuter << " offset: [" << offsetM - << ", " << offsetN << "]"); - - Value offsetY = b.i32_val(0); - Value pred; - if (llMask) { - assert(masks.size() && "Invalid size of the masks."); - pred = targetInfo.shuffleIdx(rewriter, loc, - masks[{offsetM, offsetN}], 0); - // We leverage the GPU block I/O hardware out-of-bound protection - // feature by setting the offset to an invalid value when 'pred' - // is false (the HW will not read out-of-bounds values). Later on, - // after issuing the 2d block read operation, we will select the - // result of the load only if the mask evaluate to true, otherwise - // we will use 'other'. - offsetY = b.select(pred, offsetY, baseHeight); - } - - // Use the top-left address of the block to load the data. - Value addrElem = - b.bitcast(ptrs[{offsetM, offsetN}], ptr_ty(ctx, 1 /*global*/)); - addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); - - Value ret = rewriter.create( - loc, load2DGenXType, - /*ptr*/ addrElem, - /*base_width*/ baseWidth, - /*base_height*/ baseHeight, - /*base_pitch*/ pitch, - /*x*/ b.i32_val(0), - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ - (usePackedType && opIdx == DpasEncodingAttr::OpIdx::OperandB && - !isTransposeRequired && originalElemBits != 32)); - - // When strides[0] is 0, we only want to load the first row, so we - // set the base height to be 1. If tile height is bigger than 1, - // then only the first row contain valid data. To ensure the entire - // tile is filled with valid data, we must replicate the first row - // throughout the tile. - if (baseHeightInt < tileHeight && baseHeightInt == 1) { - unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; - SmallVector shuffleIndices(numValuesPerLoad); - - // Create a vector to store the data of the first index of each - // matrix. - VectorType vecTy = vec_ty(loadResultElemType, vBlocks); - Value firstIndexVec = b.undef(vecTy); - - for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; - ++valueIndex) { - unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; - // Handle case where an index spans two rows. - if (valueIndex % numIndicesPerMatrix == 0) { - Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); - Value newVal = oldVal; - if (tileWidth < threadsPerWarp) { - assert(tileWidth * 2 == threadsPerWarp && - "Expecting tileWidth to be 2x threadsPerWarp"); - Value threadId = getThreadId(rewriter, loc); - newVal = targetInfo.shuffleIdx( - rewriter, loc, oldVal, - b.urem(threadId, b.i32_val(tileWidth))); - } - firstIndexVec = - b.insert_element(firstIndexVec.getType(), firstIndexVec, - newVal, b.i32_val(firstIndexVecIdx)); - } - - shuffleIndices[valueIndex] = firstIndexVecIdx; - } - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(shuffleIndices); - ret = rewriter.create( - loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); - } - - if (others.size()) { - assert(masks.size() == others.size() && - "The mask value has to be provided when " - "the other value is provided."); - VectorType vecTy = - vec_ty(eltTy, numValuesPerLoad * packedElemsNum); - - Value v = b.undef(vecTy); - unsigned nWords = 0; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int i = 0; i < tileHeight; ++i) { - unsigned numColPerPackedValue = - opIdx == DpasEncodingAttr::OpIdx::OperandA - ? packedElemsNum - : 1; - unsigned numPackedValuesPerRow = mlir::ceil( - (tileWidth / numColPerPackedValue), threadsPerWarp); - for (int col = 0; col < numPackedValuesPerRow; ++col) { - for (int packedCol = 0; packedCol < numColPerPackedValue; - ++packedCol) { - unsigned N = packedCol + - col * threadsPerWarp * numColPerPackedValue + - vblk * tileWidth + offsetN; - unsigned M = i + offsetM; - Value falseVal = others[{M, N}]; - Value sVal = createIndexAttrConstant( - rewriter, loc, typeConverter->getIndexType(), - nWords++); - v = b.insert_element(vecTy, v, falseVal, sVal); - } - } - } - Value others = b.bitcast(v, load2DGenXType); - ret = b.select(pred, ret, others); - } - - unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad; - unsigned numOperandsN = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad; - - // Split the return matrix by large 2d block io size into multiple - // DPAS operands. - assert(numOperandsN >= vBlocks && - "numOperandsN has to be >= vBlocks"); - unsigned numOperandsPerVBlockN = numOperandsN / vBlocks; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int row = 0; row < numOperandsM; ++row) - for (int col = 0; col < numOperandsPerVBlockN; ++col) { - - unsigned operandStartOffset = (vblk * numOperandsM + row) * - numOperandsPerVBlockN * - packedElemsPerLanePerDPASInst; - - SmallVector indices(packedElemsPerLanePerDPASInst); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - indices[elemIdx] = operandStartOffset + - elemIdx * numOperandsPerVBlockN + col; - } - - LLVM_DEBUG({ - DBGS() << "shuffle idx: ["; - for (int elemIdx = 0; - elemIdx < packedElemsPerLanePerDPASInst; ++elemIdx) { - llvm::dbgs() << indices[elemIdx] << ", "; - } - llvm::dbgs() << "]\n"; - }); - - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(indices); - Value loadVal = rewriter.create( - loc, packedDPASOperandType, ret, ret, attr); - - // Save the decomposed vals to the map; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandC: - case DpasEncodingAttr::OpIdx::OperandA: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + row; - unsigned i = inner * numLoadPerInnerRepCluster * - numOperandsInnerDimPerLoad + - loadInner * numOperandsInnerDimPerLoad + - vblk * numOperandsPerVBlockN + col; - - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + - vblk * numOperandsPerVBlockN + col; - unsigned i = inner * numOperandsInnerDimPerLoad + row; - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - default: { - llvm_unreachable("unknown DPAS operands index type."); - } break; - } - } - } - } - } - } - - // Step 5: Unpack the load values. - // Extract the value returned by the load ops. And put the values in the - // expected order for the layout. - SmallVector unpackedLoadedVals; - for (int outer = 0; outer < numReps[dimOuter]; ++outer) { - for (int inner = 0; inner < numReps[dimInner]; ++inner) { - for (int repOuter = 0; repOuter < repCluster[dimOuter]; ++repOuter) { - for (int repInner = 0; repInner < repCluster[dimInner]; ++repInner) { - unsigned o = outer * repCluster[dimOuter] + repOuter; - unsigned i = inner * repCluster[dimInner] + repInner; - LDBG("extract: [" << o << ", " << i << "]"); - Value loadVal = loadVals.at({o, i}); - VectorType loadTy = cast(loadVal.getType()); - for (int i = 0; i < loadTy.getNumElements(); ++i) { - auto val = b.extract_element(loadVal, b.i32_val(i)); - unpackedLoadedVals.push_back(val); - } - loadVals.erase({o, i}); - } - } - } - } - - assert(loadVals.empty() && "not all loaded values is unpacked."); - - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, - rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); - - return success(); - } - LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { @@ -2524,18 +1900,6 @@ struct LoadOpToBlockIOConversion if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE) vBlocks = 1; - // TODO: use the axis info to general the handling for both regular pointer - // and block pointer. - const bool memoryRowMajor = isMemoryRowMajor(op); - // FIXME: Add support of column major. - if (!memoryRowMajor) - return failure(); - - unsigned contiguousDim = memoryRowMajor ? 1 : 0; - const bool isTransposeRequired = contiguousDim != colDim; - if (isTransposeRequired) - return matchAndRewriteTranspose(op, adaptor, rewriter); - Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); @@ -2664,6 +2028,44 @@ struct LoadOpToBlockIOConversion } } + // TODO: use the axis info to general the handling for both regular pointer + // and block pointer. + const bool memoryRowMajor = isMemoryRowMajor(op); + unsigned contiguousDim = memoryRowMajor ? 1 : 0; + const bool isTransposeRequired = contiguousDim != colDim; + + if (isTransposeRequired) { + if (numPackedVals > 1) + return failure(); + if (elemSizeInBits > 32) + return failure(); + if (tileWidth > 32) + return failure(); // tileWidth is limited to 32 for transpose 2d load. + + vBlocks = 1; + + // use the d32 for transpose 2d load. + packedElemSizeInBits = 32; + numPackedVals = packedElemSizeInBits / elemSizeInBits; + + // Improve this. The current 2D block load only transposes the matrix at + // i32 granularity. We still need to perform an additional in-register + // transpose from i32 -> (N × ElemSizeInBits) tiles, using the tile width. + // At the moment, we can only achieve this using a bitcast operation, + // which implicitly uses the sub-group size as the transpose width. To + // optimize further, we should implement this with inline VISA + // instructions. + if (numPackedVals > 1 && tileWidth != threadsPerWarp) + return failure(); + tileHeight = std::min(tileHeight / numPackedVals, 8); + + if (tileHeight * tileWidth < threadsPerWarp) + return failure(); // The tile size is not large enough for IGC scalar + // backend vectorization. + // transpose the width and height of the tile + std::swap(tileHeight, tileWidth); + } + int64_t numElemsPerLoad = mlir::ceil( tileHeight * tileWidth * numPackedVals * vBlocks, (int)threadsPerWarp); unsigned numValuesPerLoad = mlir::ceil((int)numElemsPerLoad, numPackedVals); @@ -2743,8 +2145,6 @@ struct LoadOpToBlockIOConversion } } break; case DpasEncodingAttr::OpIdx::OperandB: { - assert(numPackedVals == 1 && - "invalid number of packed values for DPAS operand B."); unsigned elemsPerLanePerDPASInst = product(dpasLayout.getDPASInstShapeB()) / threadsPerWarp; // Block 2D contain at least one DotOp B. @@ -2754,6 +2154,9 @@ struct LoadOpToBlockIOConversion if (tileHeight >= (opsPerChannel * sysDepth) && ((opsPerChannel == 4 && elemSizeInBits == 8) || (opsPerChannel == 2 && elemSizeInBits == 16))) { + assert(!isTransposeRequired || + opsPerChannel == numPackedVals && + "invalid opsPerChannel for transposed DotOp B"); // Use the VNNI packing format for DotOp B layout. numValuesPerLoad = numElemsPerLoad / opsPerChannel; packedType = i32_ty; @@ -2815,8 +2218,8 @@ struct LoadOpToBlockIOConversion /*tile_width*/ tileWidth, /*tile_height*/ tileHeight, /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ useVNNIFormat); + /*transpose*/ isTransposeRequired, + /*vnni_transform*/ !isTransposeRequired && useVNNIFormat); // When strides[0] is 0, we only want to load the first row, so we // set the base height to be 1. If tile height is bigger than 1,