@@ -26,7 +26,7 @@ namespace mlir {
26
26
using namespace mlir ;
27
27
28
28
namespace {
29
- // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
29
+
30
30
struct AbsOpConversion : public OpConversionPattern <complex::AbsOp> {
31
31
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
32
32
@@ -35,49 +35,27 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
35
35
ConversionPatternRewriter &rewriter) const override {
36
36
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
37
37
38
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
38
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr (). getValue ();
39
39
40
40
Type elementType = op.getType ();
41
- Value arg = adaptor.getComplex ();
42
-
43
- Value zero =
44
- b.create <arith::ConstantOp>(elementType, b.getZeroAttr (elementType));
45
41
Value one = b.create <arith::ConstantOp>(elementType,
46
42
b.getFloatAttr (elementType, 1.0 ));
47
43
48
- Value real = b.create <complex::ReOp>(elementType, arg);
49
- Value imag = b.create <complex::ImOp>(elementType, arg);
50
-
51
- Value realIsZero =
52
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
53
- Value imagIsZero =
54
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
44
+ Value real = b.create <complex::ReOp>(adaptor.getComplex ());
45
+ Value imag = b.create <complex::ImOp>(adaptor.getComplex ());
46
+ Value absReal = b.create <math::AbsFOp>(real, fmf);
47
+ Value absImag = b.create <math::AbsFOp>(imag, fmf);
55
48
56
- // Real > Imag
57
- Value imagDivReal = b.create <arith::DivFOp>(imag, real, fmf.getValue ());
58
- Value imagSq =
59
- b.create <arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue ());
60
- Value imagSqPlusOne = b.create <arith::AddFOp>(imagSq, one, fmf.getValue ());
61
- Value imagSqrt = b.create <math::SqrtOp>(imagSqPlusOne, fmf.getValue ());
62
- Value realAbs = b.create <math::AbsFOp>(real, fmf.getValue ());
63
- Value absImag = b.create <arith::MulFOp>(imagSqrt, realAbs, fmf.getValue ());
64
-
65
- // Real <= Imag
66
- Value realDivImag = b.create <arith::DivFOp>(real, imag, fmf.getValue ());
67
- Value realSq =
68
- b.create <arith::MulFOp>(realDivImag, realDivImag, fmf.getValue ());
69
- Value realSqPlusOne = b.create <arith::AddFOp>(realSq, one, fmf.getValue ());
70
- Value realSqrt = b.create <math::SqrtOp>(realSqPlusOne, fmf.getValue ());
71
- Value imagAbs = b.create <math::AbsFOp>(imag, fmf.getValue ());
72
- Value absReal = b.create <arith::MulFOp>(realSqrt, imagAbs, fmf.getValue ());
73
-
74
- rewriter.replaceOpWithNewOp <arith::SelectOp>(
75
- op, realIsZero, imagAbs,
76
- b.create <arith::SelectOp>(
77
- imagIsZero, realAbs,
78
- b.create <arith::SelectOp>(
79
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
80
- absImag, absReal)));
49
+ Value max = b.create <arith::MaximumFOp>(absReal, absImag, fmf);
50
+ Value min = b.create <arith::MinimumFOp>(absReal, absImag, fmf);
51
+ Value ratio = b.create <arith::DivFOp>(min, max, fmf);
52
+ Value ratioSq = b.create <arith::MulFOp>(ratio, ratio, fmf);
53
+ Value ratioSqPlusOne = b.create <arith::AddFOp>(ratioSq, one, fmf);
54
+ Value sqrt = b.create <math::SqrtOp>(ratioSqPlusOne, fmf);
55
+ Value result = b.create <arith::MulFOp>(max, sqrt, fmf);
56
+ Value isNaN =
57
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
58
+ rewriter.replaceOpWithNewOp <arith::SelectOp>(op, isNaN, min, result);
81
59
82
60
return success ();
83
61
}
0 commit comments