Skip to content

Commit 00b3a18

Browse files
authored
[mlir][rocdl] add gfx950 smfmac instructions to rocdl dialect (#171737)
Signed-off-by: Eric Feng <Eric.Feng@amd.com>
1 parent e94bf71 commit 00b3a18

File tree

3 files changed

+165
-3
lines changed

3 files changed

+165
-3
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
592592
def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
593593
def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
594594
def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
595+
// New in gfx950.
596+
def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
597+
def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
598+
def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
599+
def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
600+
def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
601+
def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
602+
def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
603+
def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
604+
def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
605+
def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
606+
def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
607+
def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
608+
def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
609+
def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
595610

596611

597612
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,11 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
341341
%arg6 : vector<8 x i16>,
342342
%arg7 : vector<2xi32>,
343343
%arg8 : vector<4xi32>,
344-
%arg9 : vector<16xi32>) -> vector<4 x f32> {
344+
%arg9 : vector<16xi32>,
345+
%arg10 : vector<16 x f16>,
346+
%arg11 : vector<8 x bf16>,
347+
%arg12 : vector<16 x bf16>,
348+
%arg13 : vector<8 x i32>) -> vector<4 x f32> {
345349
%csti32 = llvm.mlir.constant(42 : i32) : i32
346350

347351
// CHECK-LABEL: rocdl.smfmac
@@ -415,6 +419,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
415419
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
416420
i32, i32, i32) -> vector<16xf32>
417421

422+
// CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
423+
%r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
424+
(vector<8xf16>, vector<16xf16>, vector<4xf32>,
425+
i32, i32, i32) -> vector<4xf32>
426+
427+
// CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
428+
%r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
429+
(vector<8xf16>, vector<16xf16>, vector<16xf32>,
430+
i32, i32, i32) -> vector<16xf32>
431+
432+
// CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
433+
%r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
434+
(vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
435+
i32, i32, i32) -> vector<4xf32>
436+
437+
// CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
438+
%r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
439+
(vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
440+
i32, i32, i32) -> vector<16xf32>
441+
442+
// CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
443+
%r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
444+
(vector<4xi32>, vector<8xi32>, vector<4xi32>,
445+
i32, i32, i32) -> vector<4xi32>
446+
447+
// CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
448+
%r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
449+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
450+
i32, i32, i32) -> vector<4xf32>
451+
452+
// CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
453+
%r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
454+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
455+
i32, i32, i32) -> vector<4xf32>
456+
457+
// CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
458+
%r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
459+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
460+
i32, i32, i32) -> vector<4xf32>
461+
462+
// CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
463+
%r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
464+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
465+
i32, i32, i32) -> vector<4xf32>
466+
467+
// CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
468+
%r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
469+
(vector<4xi32>, vector<8xi32>, vector<16xi32>,
470+
i32, i32, i32) -> vector<16xi32>
471+
472+
// CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
473+
%r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
474+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
475+
i32, i32, i32) -> vector<16xf32>
476+
477+
// CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
478+
%r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
479+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
480+
i32, i32, i32) -> vector<16xf32>
481+
482+
// CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
483+
%r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
484+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
485+
i32, i32, i32) -> vector<16xf32>
486+
487+
// CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
488+
%r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
489+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
490+
i32, i32, i32) -> vector<16xf32>
491+
418492
llvm.return %r0 : vector<4 x f32>
419493
}
420494

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,11 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
581581
%arg6 : vector<8 x i16>,
582582
%arg7 : vector<2xi32>,
583583
%arg8 : vector<4xi32>,
584-
%arg9 : vector<16xi32>) -> vector<4 x f32> {
584+
%arg9 : vector<16xi32>,
585+
%arg10 : vector<16 x f16>,
586+
%arg11 : vector<8 x bf16>,
587+
%arg12 : vector<16 x bf16>,
588+
%arg13 : vector<8 x i32>) -> vector<4 x f32> {
585589
%csti32 = llvm.mlir.constant(42 : i32) : i32
586590

587591
// CHECK-LABEL: rocdl.smfmac
@@ -651,12 +655,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
651655
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
652656
i32, i32, i32) -> vector<16xf32>
653657

654-
655658
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
656659
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
657660
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
658661
i32, i32, i32) -> vector<16xf32>
659662

663+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
664+
%r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
665+
(vector<8xf16>, vector<16xf16>, vector<4xf32>,
666+
i32, i32, i32) -> vector<4xf32>
667+
668+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
669+
%r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
670+
(vector<8xf16>, vector<16xf16>, vector<16xf32>,
671+
i32, i32, i32) -> vector<16xf32>
672+
673+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
674+
%r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
675+
(vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
676+
i32, i32, i32) -> vector<4xf32>
677+
678+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
679+
%r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
680+
(vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
681+
i32, i32, i32) -> vector<16xf32>
682+
683+
// CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
684+
%r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
685+
(vector<4xi32>, vector<8xi32>, vector<4xi32>,
686+
i32, i32, i32) -> vector<4xi32>
687+
688+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
689+
%r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
690+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
691+
i32, i32, i32) -> vector<4xf32>
692+
693+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
694+
%r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
695+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
696+
i32, i32, i32) -> vector<4xf32>
697+
698+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
699+
%r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
700+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
701+
i32, i32, i32) -> vector<4xf32>
702+
703+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
704+
%r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
705+
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
706+
i32, i32, i32) -> vector<4xf32>
707+
708+
// CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
709+
%r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
710+
(vector<4xi32>, vector<8xi32>, vector<16xi32>,
711+
i32, i32, i32) -> vector<16xi32>
712+
713+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
714+
%r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
715+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
716+
i32, i32, i32) -> vector<16xf32>
717+
718+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
719+
%r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
720+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
721+
i32, i32, i32) -> vector<16xf32>
722+
723+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
724+
%r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
725+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
726+
i32, i32, i32) -> vector<16xf32>
727+
728+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
729+
%r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
730+
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
731+
i32, i32, i32) -> vector<16xf32>
732+
660733
llvm.return %r0 : vector<4 x f32>
661734
}
662735

0 commit comments

Comments
 (0)