Skip to content

[mlir] Use arith max or min ops instead of cmp + select #82178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ using namespace mlir::affine;
using namespace mlir::vector;

/// Given a range of values, emit the code that reduces them with "min" or "max"
/// depending on the provided comparison predicate. The predicate defines which
/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
/// `cmpi` operation followed by the `select` operation:
///
/// %cond = arith.cmpi "predicate" %v0, %v1
/// %result = select %cond, %v0, %v1
/// depending on the provided comparison predicate, sgt for max and slt for min.
///
/// Multiple values are scanned in a linear sequence. This creates a data
/// dependences that wouldn't exist in a tree reduction, but is easier to
Expand All @@ -48,13 +43,16 @@ static Value buildMinMaxReductionSeq(Location loc,
arith::CmpIPredicate predicate,
ValueRange values, OpBuilder &builder) {
assert(!values.empty() && "empty min/max chain");
assert(predicate == arith::CmpIPredicate::sgt ||
predicate == arith::CmpIPredicate::slt);

auto valueIt = values.begin();
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
*valueIt);
if (predicate == arith::CmpIPredicate::sgt)
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
else
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
}

return value;
Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
Value rankIsGreater =
lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
}

// Calculate the difference of ranks and the maximum rank for later offsets.
Expand Down Expand Up @@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
Value rankIsGreater =
lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
}

// Calculate the difference of ranks and the maximum rank for later offsets.
Expand Down
20 changes: 5 additions & 15 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
args[0], zero);
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
}

// tosa::AddOp
Expand Down Expand Up @@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}

if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}

// tosa::MinimumOp
Expand All @@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}

if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}

// tosa::CeilOp
Expand Down Expand Up @@ -1000,19 +994,15 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}

if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}

if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
}

if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}

if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
Expand Down
9 changes: 2 additions & 7 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);

Value cmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, dpos, zero);
Value offset =
rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
return rewriter.create<arith::AddIOp>(loc, valid, offset)
->getResult(0);
};
Expand All @@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Determine how much padding was included.
val = padFn(val, left, pad[i * 2]);
val = padFn(val, right, pad[i * 2 + 1]);
Value cmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, val, one);
return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
return rewriter.create<arith::MaxSIOp>(loc, one, val);
};

// Compute the indices from either end.
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
// Insert newForOp before the terminator of `t`.
auto b = OpBuilder::atBlockTerminator((t.getBody()));
Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
forOp.getUpperBound(), stepped);
Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
forOp.getUpperBound(), stepped);
Value ub =
b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);

// Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
Expand Down
9 changes: 2 additions & 7 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,

Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
OpBuilder &rewriter) {
auto smallerThanMin =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
auto minOrArg =
rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
}

bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
Expand Down
42 changes: 16 additions & 26 deletions mlir/test/Conversion/AffineToStandard/lower-affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,14 @@ func.func @if_for() {
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
// CHECK-NEXT: %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
// CHECK-NEXT: %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
// CHECK-NEXT: %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
// CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
// CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index
// CHECK-NEXT: %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
// CHECK-NEXT: %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
// CHECK-NEXT: %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
// CHECK-NEXT: %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
// CHECK-NEXT: %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
// CHECK-NEXT: %[[c1_0:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
// CHECK-NEXT: for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
// CHECK-NEXT: call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
Expand All @@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {

#map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>

// Check that the "min" (cmpi slt + select) reduction sequence is emitted
// Check that the "min" reduction sequence is emitted
// correctly for an affine map with 7 results.

// CHECK-LABEL: func @min_reduction_tree
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
// CHECK-NEXT: %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
// CHECK-NEXT: %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
// CHECK-NEXT: %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
// CHECK-NEXT: %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
// CHECK-NEXT: %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
// CHECK-NEXT: %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
// CHECK-NEXT: %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
// CHECK-NEXT: %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
// CHECK-NEXT: %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
// CHECK-NEXT: %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
// CHECK-NEXT: %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
// CHECK-NEXT: %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
// CHECK-NEXT: %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
// CHECK-NEXT: %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
// CHECK-NEXT: call @body(%{{.*}}) : (index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
Expand Down Expand Up @@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
// CHECK: %[[Cm2:.*]] = arith.constant -1
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
// CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
// CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
// CHECK: arith.minsi %[[first]], %[[second]]
%0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
return %0 : index
}
Expand All @@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
// CHECK: %[[Cm2:.*]] = arith.constant -1
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
// CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
// CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
// CHECK: arith.maxsi %[[first]], %[[second]]
%0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
return %0 : index
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
Expand Down
18 changes: 6 additions & 12 deletions mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
Expand Down Expand Up @@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
Expand Down Expand Up @@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
Expand Down
Loading