Skip to content

Commit 9b7cf5b

Browse files
authored
[Flang] Allow Intrinsic simpification with min/maxloc dim and scalar result (#76194)
This makes an adjustment to the existing fir minloc/maxloc generation code to handle functions with a dim=1 that produce a scalar result. This should allow us to get the same benefits as the existing generated minmax reductions. This is a recommit of #75820 with the typename added to the generated function.
1 parent 21a0335 commit 9b7cf5b

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,11 +1162,14 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
11621162

11631163
mlir::Operation::operand_range args = call.getArgs();
11641164

1165-
mlir::Value back = args[6];
1165+
mlir::SymbolRefAttr callee = call.getCalleeAttr();
1166+
mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1167+
bool isDim = funcNameBase.ends_with("Dim");
1168+
mlir::Value back = args[isDim ? 7 : 6];
11661169
if (isTrueOrNotConstant(back))
11671170
return;
11681171

1169-
mlir::Value mask = args[5];
1172+
mlir::Value mask = args[isDim ? 6 : 5];
11701173
mlir::Value maskDef = findMaskDef(mask);
11711174

11721175
// maskDef is set to NULL when the defining op is not one we accept.
@@ -1175,10 +1178,8 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
11751178
if (maskDef == NULL)
11761179
return;
11771180

1178-
mlir::SymbolRefAttr callee = call.getCalleeAttr();
1179-
mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
11801181
unsigned rank = getDimCount(args[1]);
1181-
if (funcNameBase.ends_with("Dim") || !(rank > 0))
1182+
if ((isDim && rank != 1) || !(rank > 0))
11821183
return;
11831184

11841185
fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
@@ -1219,6 +1220,8 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
12191220

12201221
llvm::raw_string_ostream nameOS(funcName);
12211222
outType.print(nameOS);
1223+
if (isDim)
1224+
nameOS << '_' << inputType;
12221225
nameOS << '_' << fmfString;
12231226

12241227
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
@@ -1234,7 +1237,7 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
12341237
mlir::func::FuncOp newFunc =
12351238
getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
12361239
builder.create<fir::CallOp>(loc, newFunc,
1237-
mlir::ValueRange{args[0], args[1], args[5]});
1240+
mlir::ValueRange{args[0], args[1], mask});
12381241
call->dropAllReferences();
12391242
call->erase();
12401243
}

flang/test/Transforms/simplifyintrinsics.fir

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,13 +2115,13 @@ func.func @_QPtestminloc_doesntwork1d_back(%arg0: !fir.ref<!fir.array<10xi32>> {
21152115
// CHECK-NOT: fir.call @_FortranAMinlocInteger4x1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
21162116

21172117
// -----
2118-
// Check Minloc is not simplified when DIM arg is set
2118+
// Check Minloc is simplified when DIM arg is set so long as the result is scalar
21192119

2120-
func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
2120+
func.func @_QPtestminloc_1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
21212121
%0 = fir.alloca !fir.box<!fir.heap<i32>>
21222122
%c10 = arith.constant 10 : index
21232123
%c1 = arith.constant 1 : index
2124-
%1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_doesntwork1d_dim", uniq_name = "_QFtestminloc_doesntwork1d_dimEtestminloc_doesntwork1d_dim"}
2124+
%1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_1d_dim", uniq_name = "_QFtestminloc_1d_dimEtestminloc_1d_dim"}
21252125
%2 = fir.shape %c1 : (index) -> !fir.shape<1>
21262126
%3 = fir.array_load %1(%2) : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>) -> !fir.array<1xi32>
21272127
%4 = fir.shape %c10 : (index) -> !fir.shape<1>
@@ -2156,11 +2156,65 @@ func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {f
21562156
%21 = fir.load %1 : !fir.ref<!fir.array<1xi32>>
21572157
return %21 : !fir.array<1xi32>
21582158
}
2159-
// CHECK-LABEL: func.func @_QPtestminloc_doesntwork1d_dim(
2159+
// CHECK-LABEL: func.func @_QPtestminloc_1d_dim(
21602160
// CHECK-SAME: %[[ARR:.*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
2161-
// CHECK-NOT: fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
2162-
// CHECK: fir.call @_FortranAMinlocDim({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32, !fir.box<none>, i1) -> none
2163-
// CHECK-NOT: fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
2161+
// CHECK: fir.call @_FortranAMinlocDimx1_i32_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
2162+
2163+
// CHECK-LABEL: func.func private @_FortranAMinlocDimx1_i32_i32_contract_simplified(%arg0: !fir.ref<!fir.box<none>>, %arg1: !fir.box<none>, %arg2: !fir.box<none>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
2164+
// CHECK-NEXT: %[[V0:.*]] = fir.alloca i32
2165+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
2166+
// CHECK-NEXT: %c1 = arith.constant 1 : index
2167+
// CHECK-NEXT: %[[V1:.*]] = fir.allocmem !fir.array<1xi32>
2168+
// CHECK-NEXT: %[[V2:.*]] = fir.shape %c1 : (index) -> !fir.shape<1>
2169+
// CHECK-NEXT: %[[V3:.*]] = fir.embox %[[V1]](%[[V2]]) : (!fir.heap<!fir.array<1xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<1xi32>>>
2170+
// CHECK-NEXT: %c0 = arith.constant 0 : index
2171+
// CHECK-NEXT: %[[V4:.*]] = fir.coordinate_of %[[V3]], %c0 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
2172+
// CHECK-NEXT: fir.store %c0_i32 to %[[V4]] : !fir.ref<i32>
2173+
// CHECK-NEXT: %c0_0 = arith.constant 0 : index
2174+
// CHECK-NEXT: %[[V5:.*]] = fir.convert %arg1 : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
2175+
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
2176+
// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32
2177+
// CHECK-NEXT: fir.store %c0_i32_1 to %[[V0]] : !fir.ref<i32>
2178+
// CHECK-NEXT: %c2147483647_i32 = arith.constant 2147483647 : i32
2179+
// CHECK-NEXT: %c1_2 = arith.constant 1 : index
2180+
// CHECK-NEXT: %c0_3 = arith.constant 0 : index
2181+
// CHECK-NEXT: %[[V6:.*]]:3 = fir.box_dims %[[V5]], %c0_3 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
2182+
// CHECK-NEXT: %[[V7:.*]] = arith.subi %[[V6]]#1, %c1_2 : index
2183+
// CHECK-NEXT: %[[V8:.*]] = fir.do_loop %arg3 = %c0_0 to %[[V7]] step %c1_2 iter_args(%arg4 = %c2147483647_i32) -> (i32) {
2184+
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
2185+
// CHECK-NEXT: %[[V12:.*]] = fir.coordinate_of %[[V5]], %arg3 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
2186+
// CHECK-NEXT: %[[V13:.*]] = fir.load %[[V12]] : !fir.ref<i32>
2187+
// CHECK-NEXT: %[[V14:.*]] = arith.cmpi slt, %[[V13]], %arg4 : i32
2188+
// CHECK-NEXT: %[[V15:.*]] = fir.if %[[V14]] -> (i32) {
2189+
// CHECK-NEXT: %c1_i32_4 = arith.constant 1 : i32
2190+
// CHECK-NEXT: %c0_5 = arith.constant 0 : index
2191+
// CHECK-NEXT: %[[V16:.*]] = fir.coordinate_of %[[V3]], %c0_5 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
2192+
// CHECK-NEXT: %[[V17:.*]] = fir.convert %arg3 : (index) -> i32
2193+
// CHECK-NEXT: %[[V18:.*]] = arith.addi %[[V17]], %c1_i32_4 : i32
2194+
// CHECK-NEXT: fir.store %[[V18]] to %[[V16]] : !fir.ref<i32>
2195+
// CHECK-NEXT: fir.result %[[V13]] : i32
2196+
// CHECK-NEXT: } else {
2197+
// CHECK-NEXT: fir.result %arg4 : i32
2198+
// CHECK-NEXT: }
2199+
// CHECK-NEXT: fir.result %[[V15]] : i32
2200+
// CHECK-NEXT: }
2201+
// CHECK-NEXT: %[[V9:.*]] = fir.load %[[V0]] : !fir.ref<i32>
2202+
// CHECK-NEXT: %[[V10:.*]] = arith.cmpi eq, %[[V9]], %c1_i32 : i32
2203+
// CHECK-NEXT: fir.if %[[V10]] {
2204+
// CHECK-NEXT: %c2147483647_i32_4 = arith.constant 2147483647 : i32
2205+
// CHECK-NEXT: %[[V12]] = arith.cmpi eq, %c2147483647_i32_4, %[[V8]] : i32
2206+
// CHECK-NEXT: fir.if %[[V12]] {
2207+
// CHECK-NEXT: %c0_5 = arith.constant 0 : index
2208+
// CHECK-NEXT: %[[V13]] = fir.coordinate_of %[[V3]], %c0_5 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
2209+
// CHECK-NEXT: fir.store %c1_i32 to %[[V13]] : !fir.ref<i32>
2210+
// CHECK-NEXT: }
2211+
// CHECK-NEXT: }
2212+
// CHECK-NEXT: %[[V11:.*]] = fir.convert %arg0 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<1xi32>>>>
2213+
// CHECK-NEXT: fir.store %[[V3]] to %[[V11]] : !fir.ref<!fir.box<!fir.heap<!fir.array<1xi32>>>>
2214+
// CHECK-NEXT: return
2215+
// CHECK-NEXT: }
2216+
2217+
21642218

21652219
// -----
21662220
// Check Minloc is not simplified when dimension of inputArr is unknown

0 commit comments

Comments
 (0)