Skip to content

Commit 59ceb7d

Browse files
committed
[flang][openacc] Initial reduction clause lowering
Add initial support to lower reduction clause to its representation in MLIR. This patch adds support for addition of integer and real scalar types. Other operators and types will be added with follow up patches. Reviewed By: razvanlupusoru Differential Revision: https://reviews.llvm.org/D151564
1 parent 2ccb074 commit 59ceb7d

File tree

4 files changed

+198
-12
lines changed

4 files changed

+198
-12
lines changed

flang/include/flang/Lower/OpenACC.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#ifndef FORTRAN_LOWER_OPENACC_H
1414
#define FORTRAN_LOWER_OPENACC_H
1515

16+
#include "mlir/Dialect/OpenACC/OpenACC.h"
17+
1618
namespace llvm {
1719
class StringRef;
1820
}
@@ -21,9 +23,6 @@ namespace mlir {
2123
class Location;
2224
class Type;
2325
class OpBuilder;
24-
namespace acc {
25-
class PrivateRecipeOp;
26-
}
2726
} // namespace mlir
2827

2928
namespace Fortran {
@@ -57,6 +56,12 @@ mlir::acc::PrivateRecipeOp createOrGetPrivateRecipe(mlir::OpBuilder &,
5756
llvm::StringRef,
5857
mlir::Location, mlir::Type);
5958

59+
/// Get a acc.reduction.recipe op for the given type or create it if it does not
60+
/// exist yet.
61+
mlir::acc::ReductionRecipeOp
62+
createOrGetReductionRecipe(mlir::OpBuilder &, llvm::StringRef, mlir::Location,
63+
mlir::Type, mlir::acc::ReductionOperator);
64+
6065
} // namespace lower
6166
} // namespace Fortran
6267

flang/lib/Lower/OpenACC.cpp

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "flang/Parser/parse-tree.h"
2323
#include "flang/Semantics/expression.h"
2424
#include "flang/Semantics/tools.h"
25-
#include "mlir/Dialect/OpenACC/OpenACC.h"
2625
#include "llvm/Frontend/OpenACC/ACC.h.inc"
2726

2827
// Special value for * passed in device_type or gang clauses.
@@ -526,6 +525,132 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
526525
}
527526
}
528527

528+
/// Return the corresponding enum value for the mlir::acc::ReductionOperator
529+
/// from the parser representation.
530+
static mlir::acc::ReductionOperator
531+
getReductionOperator(const Fortran::parser::AccReductionOperator &op) {
532+
switch (op.v) {
533+
case Fortran::parser::AccReductionOperator::Operator::Plus:
534+
return mlir::acc::ReductionOperator::AccAdd;
535+
case Fortran::parser::AccReductionOperator::Operator::Multiply:
536+
return mlir::acc::ReductionOperator::AccMul;
537+
case Fortran::parser::AccReductionOperator::Operator::Max:
538+
return mlir::acc::ReductionOperator::AccMax;
539+
case Fortran::parser::AccReductionOperator::Operator::Min:
540+
return mlir::acc::ReductionOperator::AccMin;
541+
case Fortran::parser::AccReductionOperator::Operator::Iand:
542+
return mlir::acc::ReductionOperator::AccIand;
543+
case Fortran::parser::AccReductionOperator::Operator::Ior:
544+
return mlir::acc::ReductionOperator::AccIor;
545+
case Fortran::parser::AccReductionOperator::Operator::Ieor:
546+
return mlir::acc::ReductionOperator::AccXor;
547+
case Fortran::parser::AccReductionOperator::Operator::And:
548+
return mlir::acc::ReductionOperator::AccLand;
549+
case Fortran::parser::AccReductionOperator::Operator::Or:
550+
return mlir::acc::ReductionOperator::AccLor;
551+
case Fortran::parser::AccReductionOperator::Operator::Eqv:
552+
return mlir::acc::ReductionOperator::AccEqv;
553+
case Fortran::parser::AccReductionOperator::Operator::Neqv:
554+
return mlir::acc::ReductionOperator::AccNeqv;
555+
}
556+
llvm_unreachable("unexpected reduction operator");
557+
}
558+
559+
static mlir::Value genReductionInitValue(mlir::OpBuilder &builder,
560+
mlir::Location loc, mlir::Type ty,
561+
mlir::acc::ReductionOperator op) {
562+
if (op != mlir::acc::ReductionOperator::AccAdd)
563+
TODO(loc, "reduction operator");
564+
565+
unsigned initValue = 0;
566+
567+
if (ty.isIntOrIndex())
568+
return builder.create<mlir::arith::ConstantOp>(
569+
loc, ty, builder.getIntegerAttr(ty, initValue));
570+
if (mlir::isa<mlir::FloatType>(ty))
571+
return builder.create<mlir::arith::ConstantOp>(
572+
loc, ty, builder.getFloatAttr(ty, initValue));
573+
TODO(loc, "reduction type");
574+
}
575+
576+
static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc,
577+
mlir::acc::ReductionOperator op, mlir::Type ty,
578+
mlir::Value value1, mlir::Value value2) {
579+
if (op == mlir::acc::ReductionOperator::AccAdd) {
580+
if (ty.isIntOrIndex())
581+
return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
582+
if (mlir::isa<mlir::FloatType>(ty))
583+
return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
584+
TODO(loc, "reduction add type");
585+
}
586+
TODO(loc, "reduction operator");
587+
}
588+
589+
mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
590+
mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
591+
mlir::Type ty, mlir::acc::ReductionOperator op) {
592+
mlir::ModuleOp mod =
593+
builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
594+
if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
595+
return recipe;
596+
597+
auto crtPos = builder.saveInsertionPoint();
598+
mlir::OpBuilder modBuilder(mod.getBodyRegion());
599+
auto recipe =
600+
modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
601+
builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
602+
{ty}, {loc});
603+
builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
604+
mlir::Value initValue = genReductionInitValue(builder, loc, ty, op);
605+
builder.create<mlir::acc::YieldOp>(loc, initValue);
606+
607+
builder.createBlock(&recipe.getCombinerRegion(),
608+
recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc});
609+
builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
610+
mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
611+
mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
612+
mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2);
613+
builder.create<mlir::acc::YieldOp>(loc, combinedValue);
614+
builder.restoreInsertionPoint(crtPos);
615+
return recipe;
616+
}
617+
618+
static void
619+
genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
620+
Fortran::lower::AbstractConverter &converter,
621+
Fortran::semantics::SemanticsContext &semanticsContext,
622+
Fortran::lower::StatementContext &stmtCtx,
623+
llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
624+
llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
625+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
626+
const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
627+
const auto &op =
628+
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
629+
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
630+
for (const auto &accObject : objects.v) {
631+
llvm::SmallVector<mlir::Value> bounds;
632+
std::stringstream asFortran;
633+
mlir::Location operandLocation = genOperandLocation(converter, accObject);
634+
mlir::Value baseAddr = gatherDataOperandAddrAndBounds(
635+
converter, builder, semanticsContext, stmtCtx, accObject,
636+
operandLocation, asFortran, bounds);
637+
638+
if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType())))
639+
TODO(operandLocation, "reduction with unsupported type");
640+
641+
mlir::Type ty = fir::unwrapRefType(baseAddr.getType());
642+
std::string recipeName = fir::getTypeAsString(
643+
ty, converter.getKindMap(),
644+
("reduction_" + stringifyReductionOperator(mlirOp)).str());
645+
mlir::acc::ReductionRecipeOp recipe =
646+
Fortran::lower::createOrGetReductionRecipe(builder, recipeName,
647+
operandLocation, ty, mlirOp);
648+
reductionRecipes.push_back(mlir::SymbolRefAttr::get(
649+
builder.getContext(), recipe.getSymName().str()));
650+
reductionOperands.push_back(baseAddr);
651+
}
652+
}
653+
529654
static void
530655
addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
531656
llvm::SmallVectorImpl<int32_t> &operandSegments,
@@ -666,7 +791,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
666791
mlir::Value gangStatic;
667792
llvm::SmallVector<mlir::Value, 2> tileOperands, privateOperands,
668793
reductionOperands;
669-
llvm::SmallVector<mlir::Attribute> privatizations;
794+
llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
670795
bool hasGang = false, hasVector = false, hasWorker = false;
671796

672797
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -735,10 +860,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
735860
&clause.u)) {
736861
genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx,
737862
privateOperands, privatizations);
738-
} else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
739-
// Reduction clause is left out for the moment as the clause will probably
740-
// end up having its own operation.
741-
TODO(clauseLocation, "OpenACC compute construct reduction lowering");
863+
} else if (const auto *reductionClause =
864+
std::get_if<Fortran::parser::AccClause::Reduction>(
865+
&clause.u)) {
866+
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
867+
reductionOperands, reductionRecipes);
742868
}
743869
}
744870

@@ -767,6 +893,10 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
767893
loopOp.setPrivatizationsAttr(
768894
mlir::ArrayAttr::get(builder.getContext(), privatizations));
769895

896+
if (!reductionRecipes.empty())
897+
loopOp.setReductionRecipesAttr(
898+
mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
899+
770900
// Lower clauses mapped to attributes
771901
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
772902
if (const auto *collapseClause =
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
! This test checks lowering of OpenACC reduction clause.
2+
3+
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
4+
5+
! CHECK-LABEL: acc.reduction.recipe @reduction_add_f32 : f32 reduction_operator <add> init {
6+
! CHECK: ^bb0(%{{.*}}: f32):
7+
! CHECK: %[[INIT:.*]] = arith.constant 0.000000e+00 : f32
8+
! CHECK: acc.yield %[[INIT]] : f32
9+
! CHECK: } combiner {
10+
! CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32):
11+
! CHECK: %[[COMBINED:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}} : f32
12+
! CHECK: acc.yield %[[COMBINED]] : f32
13+
! CHECK: }
14+
15+
! CHECK-LABEL: acc.reduction.recipe @reduction_add_i32 : i32 reduction_operator <add> init {
16+
! CHECK: ^bb0(%{{.*}}: i32):
17+
! CHECK: %[[INIT:.*]] = arith.constant 0 : i32
18+
! CHECK: acc.yield %[[INIT]] : i32
19+
! CHECK: } combiner {
20+
! CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
21+
! CHECK: %[[COMBINED:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
22+
! CHECK: acc.yield %[[COMBINED]] : i32
23+
! CHECK: }
24+
25+
subroutine acc_reduction_add_int(a, b)
26+
integer :: a(100)
27+
integer :: i, b
28+
29+
!$acc loop reduction(+:b)
30+
do i = 1, 100
31+
b = b + a(i)
32+
end do
33+
end subroutine
34+
35+
! CHECK-LABEL: func.func @_QPacc_reduction_add_int(
36+
! CHECK-SAME: %{{.*}}: !fir.ref<!fir.array<100xi32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<i32> {fir.bindc_name = "b"})
37+
! CHECK: acc.loop reduction(@reduction_add_i32 -> %[[B]] : !fir.ref<i32>) {
38+
39+
subroutine acc_reduction_add_float(a, b)
40+
real :: a(100), b
41+
integer :: i
42+
43+
!$acc loop reduction(+:b)
44+
do i = 1, 100
45+
b = b + a(i)
46+
end do
47+
end subroutine
48+
49+
! CHECK-LABEL: func.func @_QPacc_reduction_add_float(
50+
! CHECK-SAME: %{{.*}}: !fir.ref<!fir.array<100xf32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<f32> {fir.bindc_name = "b"})
51+
! CHECK: acc.loop reduction(@reduction_add_f32 -> %[[B]] : !fir.ref<f32>)

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ template <typename Op>
498498
static LogicalResult
499499
checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
500500
mlir::OperandRange operands, llvm::StringRef operandName,
501-
llvm::StringRef symbolName) {
501+
llvm::StringRef symbolName, bool checkOperandType = true) {
502502
if (!operands.empty()) {
503503
if (!attributes || attributes->size() != operands.size())
504504
return op->emitOpError()
@@ -527,7 +527,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
527527
<< "expected symbol reference " << symbolRef << " to point to a "
528528
<< operandName << " declaration";
529529

530-
if (decl.getType() && decl.getType() != varType)
530+
if (checkOperandType && decl.getType() && decl.getType() != varType)
531531
return op->emitOpError() << "expected " << operandName << " (" << varType
532532
<< ") to be the same type as " << operandName
533533
<< " declaration (" << decl.getType() << ")";
@@ -751,7 +751,7 @@ LogicalResult acc::LoopOp::verify() {
751751

752752
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
753753
*this, getReductionRecipes(), getReductionOperands(), "reduction",
754-
"reductions")))
754+
"reductions", false)))
755755
return failure();
756756

757757
// Check non-empty body().

0 commit comments

Comments
 (0)