@@ -601,6 +601,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
601
601
}
602
602
}
603
603
604
+ // / Return true if `type` is the E5M2 variant of an 8-bit float that is
605
+ // / supported by the `_bf8` instructions on the given `chipset`.
606
+ static bool typeIsExpectedBf8ForChipset (Chipset chipset, Type type) {
607
+ return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
608
+ (hasOcpFp8 (chipset) && isa<Float8E5M2Type>(type));
609
+ }
610
+
611
+ // / Return true if `type` is the E4M3FN variant of an 8-bit float that is
612
+ // / supported by the `_fp8` instructions on the given `chipset`.
613
+ static bool typeIsExpectedFp8ForChipset (Chipset chipset, Type type) {
614
+ return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
615
+ (hasOcpFp8 (chipset) && isa<Float8E4M3FNType>(type));
616
+ }
617
+
604
618
// / Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
605
619
// / if one exists. This includes checking to ensure the intrinsic is supported
606
620
// / on the architecture you are compiling for.
@@ -697,40 +711,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
697
711
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
698
712
}
699
713
700
- if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32 () &&
701
- chipset >= kGfx942 ) {
714
+ if (destElem.isF32 () && typeIsExpectedBf8ForChipset (chipset, sourceElem)) {
702
715
// Known to be correct because there are no scalar f8 instructions and
703
716
// because a length mismatch will have been caught by the verifier.
704
717
Type sourceBElem =
705
718
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
706
719
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
707
- if (isa<Float8E5M2FNUZType>( sourceBElem))
720
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
708
721
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
709
- if (isa<Float8E4M3FNUZType>( sourceBElem))
722
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
710
723
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
711
724
}
712
725
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
713
- if (isa<Float8E5M2FNUZType>( sourceBElem))
726
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
714
727
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
715
- if (isa<Float8E4M3FNUZType>( sourceBElem))
728
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
716
729
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
717
730
}
718
731
}
719
732
720
- if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32 () &&
721
- chipset >= kGfx942 ) {
733
+ if (destElem.isF32 () && typeIsExpectedFp8ForChipset (chipset, sourceElem)) {
722
734
Type sourceBElem =
723
735
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
724
736
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
725
- if (isa<Float8E5M2FNUZType>( sourceBElem))
737
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
726
738
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
727
- if (isa<Float8E4M3FNUZType>( sourceBElem))
739
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
728
740
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
729
741
}
730
742
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
731
- if (isa<Float8E5M2FNUZType>( sourceBElem))
743
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
732
744
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
733
- if (isa<Float8E4M3FNUZType>( sourceBElem))
745
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
734
746
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
735
747
}
736
748
}
@@ -936,7 +948,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
936
948
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
937
949
ConversionPatternRewriter &rewriter) const {
938
950
Location loc = op.getLoc ();
939
- if (chipset. majorVersion != 9 || chipset < kGfx942 )
951
+ if (!( chipset == kGfx942 || hasOcpFp8 ( chipset)) )
940
952
return rewriter.notifyMatchFailure (
941
953
loc, " Fp8 conversion instructions are not available on target "
942
954
" architecture and their emulation is not implemented" );
@@ -966,10 +978,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
966
978
}
967
979
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
968
980
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
969
- if (isa<Float8E5M2FNUZType>( sourceElemType)) {
981
+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
970
982
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
971
983
wordSel);
972
- } else if (isa<Float8E4M3FNUZType>( sourceElemType)) {
984
+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
973
985
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
974
986
wordSel);
975
987
}
@@ -980,7 +992,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
980
992
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
981
993
ConversionPatternRewriter &rewriter) const {
982
994
Location loc = op.getLoc ();
983
- if (chipset. majorVersion != 9 || chipset < kGfx942 )
995
+ if (!( chipset == kGfx942 || hasOcpFp8 ( chipset)) )
984
996
return rewriter.notifyMatchFailure (
985
997
loc, " Fp8 conversion instructions are not available on target "
986
998
" architecture and their emulation is not implemented" );
@@ -1001,10 +1013,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1001
1013
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
1002
1014
1003
1015
Value result;
1004
- if (isa<Float8E5M2FNUZType>( resultElemType))
1016
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
1005
1017
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
1006
1018
existing, wordSel);
1007
- else if (isa<Float8E4M3FNUZType>( resultElemType))
1019
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
1008
1020
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
1009
1021
existing, wordSel);
1010
1022
@@ -1017,7 +1029,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1017
1029
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1018
1030
ConversionPatternRewriter &rewriter) const {
1019
1031
Location loc = op.getLoc ();
1020
- if (chipset. majorVersion != 9 || chipset < kGfx942 )
1032
+ if (!( chipset == kGfx942 || hasOcpFp8 ( chipset)) )
1021
1033
return rewriter.notifyMatchFailure (
1022
1034
loc, " Fp8 conversion instructions are not available on target "
1023
1035
" architecture and their emulation is not implemented" );
@@ -1036,10 +1048,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1036
1048
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
1037
1049
1038
1050
Value result;
1039
- if (isa<Float8E5M2FNUZType>( resultElemType))
1051
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
1040
1052
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
1041
1053
existing, byteSel);
1042
- else if (isa<Float8E4M3FNUZType>( resultElemType))
1054
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
1043
1055
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
1044
1056
existing, byteSel);
1045
1057
0 commit comments