Skip to content

Commit 9d9bb7b

Browse files
authored
Fix complex abs corner cases. (#88373)
The current implementation fails for very small and very large values. For example, (0, -inf) should return inf, but it returns -inf. This ports the logic used in XLA. Tested with XLA's exhaustive_binary_test_f32_f64.
1 parent 717d3f3 commit 9d9bb7b

File tree

3 files changed

+162
-274
lines changed

3 files changed

+162
-274
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace mlir {
2626
using namespace mlir;
2727

2828
namespace {
29-
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
29+
3030
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3131
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
3232

@@ -35,49 +35,27 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3535
ConversionPatternRewriter &rewriter) const override {
3636
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3737

38-
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
38+
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
3939

4040
Type elementType = op.getType();
41-
Value arg = adaptor.getComplex();
42-
43-
Value zero =
44-
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
4541
Value one = b.create<arith::ConstantOp>(elementType,
4642
b.getFloatAttr(elementType, 1.0));
4743

48-
Value real = b.create<complex::ReOp>(elementType, arg);
49-
Value imag = b.create<complex::ImOp>(elementType, arg);
50-
51-
Value realIsZero =
52-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
53-
Value imagIsZero =
54-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
44+
Value real = b.create<complex::ReOp>(adaptor.getComplex());
45+
Value imag = b.create<complex::ImOp>(adaptor.getComplex());
46+
Value absReal = b.create<math::AbsFOp>(real, fmf);
47+
Value absImag = b.create<math::AbsFOp>(imag, fmf);
5548

56-
// Real > Imag
57-
Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
58-
Value imagSq =
59-
b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
60-
Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
61-
Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
62-
Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
63-
Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
64-
65-
// Real <= Imag
66-
Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
67-
Value realSq =
68-
b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
69-
Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
70-
Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
71-
Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
72-
Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
73-
74-
rewriter.replaceOpWithNewOp<arith::SelectOp>(
75-
op, realIsZero, imagAbs,
76-
b.create<arith::SelectOp>(
77-
imagIsZero, realAbs,
78-
b.create<arith::SelectOp>(
79-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
80-
absImag, absReal)));
49+
Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
50+
Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
51+
Value ratio = b.create<arith::DivFOp>(min, max, fmf);
52+
Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
53+
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
54+
Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
55+
Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
56+
Value isNaN =
57+
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
58+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
8159

8260
return success();
8361
}

0 commit comments

Comments
 (0)