Skip to content

Commit 1fc49ff

Browse files
mirza-halilcevicpcf000krzysz00
authored
[MLIR][AMDGPU] Add OCP FP8 support for new hardware (llvm#127728)
(Continuing from llvm#106160) This PR addresses remaining review comments from the original PR. Original PR Description --- Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware. --------- Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> Co-authored-by: Paul Fuqua <pf@acm.org> Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
1 parent 08dc81b commit 1fc49ff

File tree

11 files changed

+424
-35
lines changed

11 files changed

+424
-35
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
8383

8484
def AMDGPU_ExtPackedFp8Op :
8585
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
86-
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
87-
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
86+
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
87+
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
8888
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
8989
Results<(outs F32:$res)> {
9090
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -110,8 +110,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
110110
Arguments<(ins F32:$sourceA,
111111
Optional<F32>:$sourceB,
112112
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
113-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
114-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
113+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
114+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
115115
let summary = "Round two floats into a packed vector of 8-bit floats";
116116
let description = [{
117117
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -137,8 +137,8 @@ def AMDGPU_PackedStochRoundFp8Op :
137137
Arguments<(ins F32:$source,
138138
I32:$stochiasticParam,
139139
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
140-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
141-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
140+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
141+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
142142
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
143143
let description = [{
144144
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -651,7 +651,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
651651
VectorOfLengthAndType<[4], [F16]>,
652652
VectorOfLengthAndType<[2, 4], [BF16]>,
653653
VectorOfLengthAndType<[4, 8], [I8]>,
654-
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
654+
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
655655
def MFMAOutTypes : AnyTypeOf<[F64,
656656
VectorOfLengthAndType<[4, 16, 32], [F32]>,
657657
VectorOfLengthAndType<[4, 16, 32], [I32]>,

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ struct Chipset {
4949
#undef DEFINE_COMP_OPERATOR
5050
};
5151

52+
inline bool hasOcpFp8(const Chipset &chipset) {
53+
return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
54+
chipset.majorVersion >= 12;
55+
}
56+
5257
} // namespace mlir::amdgpu
5358

5459
#endif

mlir/include/mlir/IR/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ class Type {
132132
bool isF64() const;
133133
bool isF80() const;
134134
bool isF128() const;
135+
/// Return true if this is an float type (with the specified width).
136+
bool isFloat() const;
137+
bool isFloat(unsigned width) const;
135138

136139
/// Return true if this is an integer type (with the specified width).
137140
bool isInteger() const;

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
601601
}
602602
}
603603

604+
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
605+
/// supported by the `_bf8` instructions on the given `chipset`.
606+
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
607+
return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
608+
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
609+
}
610+
611+
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
612+
/// supported by the `_fp8` instructions on the given `chipset`.
613+
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
614+
return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
615+
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
616+
}
617+
604618
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
605619
/// if one exists. This includes checking to ensure the intrinsic is supported
606620
/// on the architecture you are compiling for.
@@ -697,40 +711,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
697711
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
698712
}
699713

700-
if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
701-
chipset >= kGfx942) {
714+
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
702715
// Known to be correct because there are no scalar f8 instructions and
703716
// because a length mismatch will have been caught by the verifier.
704717
Type sourceBElem =
705718
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
706719
if (m == 16 && n == 16 && k == 32 && b == 1) {
707-
if (isa<Float8E5M2FNUZType>(sourceBElem))
720+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
708721
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
709-
if (isa<Float8E4M3FNUZType>(sourceBElem))
722+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
710723
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
711724
}
712725
if (m == 32 && n == 32 && k == 16 && b == 1) {
713-
if (isa<Float8E5M2FNUZType>(sourceBElem))
726+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
714727
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
715-
if (isa<Float8E4M3FNUZType>(sourceBElem))
728+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
716729
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
717730
}
718731
}
719732

720-
if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
721-
chipset >= kGfx942) {
733+
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
722734
Type sourceBElem =
723735
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
724736
if (m == 16 && n == 16 && k == 32 && b == 1) {
725-
if (isa<Float8E5M2FNUZType>(sourceBElem))
737+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
726738
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
727-
if (isa<Float8E4M3FNUZType>(sourceBElem))
739+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
728740
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
729741
}
730742
if (m == 32 && n == 32 && k == 16 && b == 1) {
731-
if (isa<Float8E5M2FNUZType>(sourceBElem))
743+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
732744
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
733-
if (isa<Float8E4M3FNUZType>(sourceBElem))
745+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
734746
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
735747
}
736748
}
@@ -936,7 +948,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
936948
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
937949
ConversionPatternRewriter &rewriter) const {
938950
Location loc = op.getLoc();
939-
if (chipset.majorVersion != 9 || chipset < kGfx942)
951+
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
940952
return rewriter.notifyMatchFailure(
941953
loc, "Fp8 conversion instructions are not available on target "
942954
"architecture and their emulation is not implemented");
@@ -966,10 +978,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
966978
}
967979
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
968980
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
969-
if (isa<Float8E5M2FNUZType>(sourceElemType)) {
981+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
970982
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
971983
wordSel);
972-
} else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
984+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
973985
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
974986
wordSel);
975987
}
@@ -980,7 +992,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
980992
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
981993
ConversionPatternRewriter &rewriter) const {
982994
Location loc = op.getLoc();
983-
if (chipset.majorVersion != 9 || chipset < kGfx942)
995+
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
984996
return rewriter.notifyMatchFailure(
985997
loc, "Fp8 conversion instructions are not available on target "
986998
"architecture and their emulation is not implemented");
@@ -1001,10 +1013,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
10011013
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
10021014

10031015
Value result;
1004-
if (isa<Float8E5M2FNUZType>(resultElemType))
1016+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
10051017
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
10061018
existing, wordSel);
1007-
else if (isa<Float8E4M3FNUZType>(resultElemType))
1019+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
10081020
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
10091021
existing, wordSel);
10101022

@@ -1017,7 +1029,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
10171029
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
10181030
ConversionPatternRewriter &rewriter) const {
10191031
Location loc = op.getLoc();
1020-
if (chipset.majorVersion != 9 || chipset < kGfx942)
1032+
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
10211033
return rewriter.notifyMatchFailure(
10221034
loc, "Fp8 conversion instructions are not available on target "
10231035
"architecture and their emulation is not implemented");
@@ -1036,10 +1048,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
10361048
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
10371049

10381050
Value result;
1039-
if (isa<Float8E5M2FNUZType>(resultElemType))
1051+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
10401052
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
10411053
existing, byteSel);
1042-
else if (isa<Float8E4M3FNUZType>(resultElemType))
1054+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
10431055
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
10441056
existing, byteSel);
10451057

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ using namespace mlir;
3030
using namespace mlir::amdgpu;
3131

3232
namespace {
33+
// Define commonly used chipsets versions for convenience.
34+
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
35+
3336
struct ArithToAMDGPUConversionPass final
3437
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
3538
using impl::ArithToAMDGPUConversionPassBase<
@@ -41,6 +44,10 @@ struct ArithToAMDGPUConversionPass final
4144
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4245
using OpRewritePattern::OpRewritePattern;
4346

47+
Chipset chipset;
48+
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
49+
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
50+
4451
LogicalResult match(arith::ExtFOp op) const override;
4552
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
4653
};
@@ -68,6 +75,14 @@ struct TruncfToFloat16RewritePattern final
6875

6976
} // end namespace
7077

78+
static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
79+
if (chipset == kGfx942)
80+
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
81+
if (hasOcpFp8(chipset))
82+
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
83+
return failure();
84+
}
85+
7186
static Value castF32To(Type elementType, Value f32, Location loc,
7287
PatternRewriter &rewriter) {
7388
if (elementType.isF32())
@@ -86,7 +101,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
86101
return failure();
87102
inType = inVecType.getElementType();
88103
}
89-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
104+
return isSupportedF8(inType, chipset);
90105
}
91106

92107
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -219,7 +234,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
219234
if (inType && inType.getWidth() <= 8 && saturateFP8)
220235
// Conversion between 8-bit floats is not supported with truncation enabled.
221236
return failure();
222-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
237+
238+
return isSupportedF8(outType, chipset);
223239
}
224240

225241
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -365,7 +381,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
365381
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
366382

367383
if (convertFP8Arithmetic) {
368-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
384+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
369385
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
370386
saturateFP8Truncf, chipset);
371387
}
@@ -384,7 +400,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
384400
}
385401

386402
bool convertFP8Arithmetic =
387-
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 2);
403+
*maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
388404
arith::populateArithToAMDGPUConversionPatterns(
389405
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
390406
*maybeChipset);

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,14 @@ LogicalResult MFMAOp::verify() {
341341
}
342342

343343
Type sourceBType = getSourceB().getType();
344-
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
344+
if (sourceElem.isFloat(8)) {
345345
int64_t sourceBLen = 1;
346346
Type sourceBElem = sourceBType;
347347
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
348348
sourceBLen = sourceBVector.getNumElements();
349349
sourceBElem = sourceBVector.getElementType();
350350
}
351-
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
351+
if (!sourceBElem.isFloat(8))
352352
return emitOpError("expected both source operands to have f8 elements");
353353
if (sourceLen != sourceBLen)
354354
return emitOpError(

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
686686

687687
bool TosaValidation::isValidElementType(Type type) {
688688
if (isa<FloatType>(type)) {
689-
return type.isF32() || type.isF16() || type.isBF16();
689+
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
690+
Float8E5M2Type>(type);
690691
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
691692
if (intTy.isSignless()) {
692693
switch (intTy.getWidth()) {

mlir/lib/IR/Types.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
4242
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
4343
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
4444

45+
bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
46+
47+
/// Return true if this is a float type with the specified width.
48+
bool Type::isFloat(unsigned width) const {
49+
if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
50+
return fltTy.getWidth() == width;
51+
return false;
52+
}
53+
4554
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
4655

4756
bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }

0 commit comments

Comments
 (0)