@@ -581,7 +581,11 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
581581 %arg6 : vector <8 x i16 >,
582582 %arg7 : vector <2 xi32 >,
583583 %arg8 : vector <4 xi32 >,
584- %arg9 : vector <16 xi32 >) -> vector <4 x f32 > {
584+ %arg9 : vector <16 xi32 >,
585+ %arg10 : vector <16 x f16 >,
586+ %arg11 : vector <8 x bf16 >,
587+ %arg12 : vector <16 x bf16 >,
588+ %arg13 : vector <8 x i32 >) -> vector <4 x f32 > {
585589 %csti32 = llvm.mlir.constant (42 : i32 ) : i32
586590
587591 // CHECK-LABEL: rocdl.smfmac
@@ -651,12 +655,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
651655 (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
652656 i32 , i32 , i32 ) -> vector <16 xf32 >
653657
654-
655658 // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
656659 %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
657660 (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
658661 i32 , i32 , i32 ) -> vector <16 xf32 >
659662
663+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
664+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2 , %arg10 , %arg3 , %csti32 , %csti32 , %csti32 :
665+ (vector <8 xf16 >, vector <16 xf16 >, vector <4 xf32 >,
666+ i32 , i32 , i32 ) -> vector <4 xf32 >
667+
668+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
669+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2 , %arg10 , %arg4 , %csti32 , %csti32 , %csti32 :
670+ (vector <8 xf16 >, vector <16 xf16 >, vector <16 xf32 >,
671+ i32 , i32 , i32 ) -> vector <16 xf32 >
672+
673+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
674+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11 , %arg12 , %arg3 , %csti32 , %csti32 , %csti32 :
675+ (vector <8 xbf16 >, vector <16 xbf16 >, vector <4 xf32 >,
676+ i32 , i32 , i32 ) -> vector <4 xf32 >
677+
678+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
679+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11 , %arg12 , %arg4 , %csti32 , %csti32 , %csti32 :
680+ (vector <8 xbf16 >, vector <16 xbf16 >, vector <16 xf32 >,
681+ i32 , i32 , i32 ) -> vector <16 xf32 >
682+
683+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
684+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8 , %arg13 , %arg8 , %csti32 , %csti32 , %csti32 :
685+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xi32 >,
686+ i32 , i32 , i32 ) -> vector <4 xi32 >
687+
688+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
689+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8 , %arg13 , %arg3 , %csti32 , %csti32 , %csti32 :
690+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xf32 >,
691+ i32 , i32 , i32 ) -> vector <4 xf32 >
692+
693+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
694+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8 , %arg13 , %arg3 , %csti32 , %csti32 , %csti32 :
695+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xf32 >,
696+ i32 , i32 , i32 ) -> vector <4 xf32 >
697+
698+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
699+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8 , %arg13 , %arg3 , %csti32 , %csti32 , %csti32 :
700+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xf32 >,
701+ i32 , i32 , i32 ) -> vector <4 xf32 >
702+
703+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
704+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8 , %arg13 , %arg3 , %csti32 , %csti32 , %csti32 :
705+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xf32 >,
706+ i32 , i32 , i32 ) -> vector <4 xf32 >
707+
708+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
709+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8 , %arg13 , %arg9 , %csti32 , %csti32 , %csti32 :
710+ (vector <4 xi32 >, vector <8 xi32 >, vector <16 xi32 >,
711+ i32 , i32 , i32 ) -> vector <16 xi32 >
712+
713+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
714+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8 , %arg13 , %arg4 , %csti32 , %csti32 , %csti32 :
715+ (vector <4 xi32 >, vector <8 xi32 >, vector <16 xf32 >,
716+ i32 , i32 , i32 ) -> vector <16 xf32 >
717+
718+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
719+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8 , %arg13 , %arg4 , %csti32 , %csti32 , %csti32 :
720+ (vector <4 xi32 >, vector <8 xi32 >, vector <16 xf32 >,
721+ i32 , i32 , i32 ) -> vector <16 xf32 >
722+
723+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
724+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8 , %arg13 , %arg4 , %csti32 , %csti32 , %csti32 :
725+ (vector <4 xi32 >, vector <8 xi32 >, vector <16 xf32 >,
726+ i32 , i32 , i32 ) -> vector <16 xf32 >
727+
728+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
729+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8 , %arg13 , %arg4 , %csti32 , %csti32 , %csti32 :
730+ (vector <4 xi32 >, vector <8 xi32 >, vector <16 xf32 >,
731+ i32 , i32 , i32 ) -> vector <16 xf32 >
732+
660733 llvm.return %r0 : vector <4 x f32 >
661734}
662735
0 commit comments