|
22 | 22 | #include "flang/Parser/parse-tree.h"
|
23 | 23 | #include "flang/Semantics/expression.h"
|
24 | 24 | #include "flang/Semantics/tools.h"
|
25 |
| -#include "mlir/Dialect/OpenACC/OpenACC.h" |
26 | 25 | #include "llvm/Frontend/OpenACC/ACC.h.inc"
|
27 | 26 |
|
28 | 27 | // Special value for * passed in device_type or gang clauses.
|
@@ -526,6 +525,132 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
|
526 | 525 | }
|
527 | 526 | }
|
528 | 527 |
|
| 528 | +/// Return the corresponding enum value for the mlir::acc::ReductionOperator |
| 529 | +/// from the parser representation. |
| 530 | +static mlir::acc::ReductionOperator |
| 531 | +getReductionOperator(const Fortran::parser::AccReductionOperator &op) { |
| 532 | + switch (op.v) { |
| 533 | + case Fortran::parser::AccReductionOperator::Operator::Plus: |
| 534 | + return mlir::acc::ReductionOperator::AccAdd; |
| 535 | + case Fortran::parser::AccReductionOperator::Operator::Multiply: |
| 536 | + return mlir::acc::ReductionOperator::AccMul; |
| 537 | + case Fortran::parser::AccReductionOperator::Operator::Max: |
| 538 | + return mlir::acc::ReductionOperator::AccMax; |
| 539 | + case Fortran::parser::AccReductionOperator::Operator::Min: |
| 540 | + return mlir::acc::ReductionOperator::AccMin; |
| 541 | + case Fortran::parser::AccReductionOperator::Operator::Iand: |
| 542 | + return mlir::acc::ReductionOperator::AccIand; |
| 543 | + case Fortran::parser::AccReductionOperator::Operator::Ior: |
| 544 | + return mlir::acc::ReductionOperator::AccIor; |
| 545 | + case Fortran::parser::AccReductionOperator::Operator::Ieor: |
| 546 | + return mlir::acc::ReductionOperator::AccXor; |
| 547 | + case Fortran::parser::AccReductionOperator::Operator::And: |
| 548 | + return mlir::acc::ReductionOperator::AccLand; |
| 549 | + case Fortran::parser::AccReductionOperator::Operator::Or: |
| 550 | + return mlir::acc::ReductionOperator::AccLor; |
| 551 | + case Fortran::parser::AccReductionOperator::Operator::Eqv: |
| 552 | + return mlir::acc::ReductionOperator::AccEqv; |
| 553 | + case Fortran::parser::AccReductionOperator::Operator::Neqv: |
| 554 | + return mlir::acc::ReductionOperator::AccNeqv; |
| 555 | + } |
| 556 | + llvm_unreachable("unexpected reduction operator"); |
| 557 | +} |
| 558 | + |
| 559 | +static mlir::Value genReductionInitValue(mlir::OpBuilder &builder, |
| 560 | + mlir::Location loc, mlir::Type ty, |
| 561 | + mlir::acc::ReductionOperator op) { |
| 562 | + if (op != mlir::acc::ReductionOperator::AccAdd) |
| 563 | + TODO(loc, "reduction operator"); |
| 564 | + |
| 565 | + unsigned initValue = 0; |
| 566 | + |
| 567 | + if (ty.isIntOrIndex()) |
| 568 | + return builder.create<mlir::arith::ConstantOp>( |
| 569 | + loc, ty, builder.getIntegerAttr(ty, initValue)); |
| 570 | + if (mlir::isa<mlir::FloatType>(ty)) |
| 571 | + return builder.create<mlir::arith::ConstantOp>( |
| 572 | + loc, ty, builder.getFloatAttr(ty, initValue)); |
| 573 | + TODO(loc, "reduction type"); |
| 574 | +} |
| 575 | + |
| 576 | +static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc, |
| 577 | + mlir::acc::ReductionOperator op, mlir::Type ty, |
| 578 | + mlir::Value value1, mlir::Value value2) { |
| 579 | + if (op == mlir::acc::ReductionOperator::AccAdd) { |
| 580 | + if (ty.isIntOrIndex()) |
| 581 | + return builder.create<mlir::arith::AddIOp>(loc, value1, value2); |
| 582 | + if (mlir::isa<mlir::FloatType>(ty)) |
| 583 | + return builder.create<mlir::arith::AddFOp>(loc, value1, value2); |
| 584 | + TODO(loc, "reduction add type"); |
| 585 | + } |
| 586 | + TODO(loc, "reduction operator"); |
| 587 | +} |
| 588 | + |
| 589 | +mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe( |
| 590 | + mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, |
| 591 | + mlir::Type ty, mlir::acc::ReductionOperator op) { |
| 592 | + mlir::ModuleOp mod = |
| 593 | + builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); |
| 594 | + if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName)) |
| 595 | + return recipe; |
| 596 | + |
| 597 | + auto crtPos = builder.saveInsertionPoint(); |
| 598 | + mlir::OpBuilder modBuilder(mod.getBodyRegion()); |
| 599 | + auto recipe = |
| 600 | + modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op); |
| 601 | + builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(), |
| 602 | + {ty}, {loc}); |
| 603 | + builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); |
| 604 | + mlir::Value initValue = genReductionInitValue(builder, loc, ty, op); |
| 605 | + builder.create<mlir::acc::YieldOp>(loc, initValue); |
| 606 | + |
| 607 | + builder.createBlock(&recipe.getCombinerRegion(), |
| 608 | + recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc}); |
| 609 | + builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); |
| 610 | + mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0); |
| 611 | + mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1); |
| 612 | + mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2); |
| 613 | + builder.create<mlir::acc::YieldOp>(loc, combinedValue); |
| 614 | + builder.restoreInsertionPoint(crtPos); |
| 615 | + return recipe; |
| 616 | +} |
| 617 | + |
| 618 | +static void |
| 619 | +genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, |
| 620 | + Fortran::lower::AbstractConverter &converter, |
| 621 | + Fortran::semantics::SemanticsContext &semanticsContext, |
| 622 | + Fortran::lower::StatementContext &stmtCtx, |
| 623 | + llvm::SmallVectorImpl<mlir::Value> &reductionOperands, |
| 624 | + llvm::SmallVector<mlir::Attribute> &reductionRecipes) { |
| 625 | + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| 626 | + const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t); |
| 627 | + const auto &op = |
| 628 | + std::get<Fortran::parser::AccReductionOperator>(objectList.t); |
| 629 | + mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); |
| 630 | + for (const auto &accObject : objects.v) { |
| 631 | + llvm::SmallVector<mlir::Value> bounds; |
| 632 | + std::stringstream asFortran; |
| 633 | + mlir::Location operandLocation = genOperandLocation(converter, accObject); |
| 634 | + mlir::Value baseAddr = gatherDataOperandAddrAndBounds( |
| 635 | + converter, builder, semanticsContext, stmtCtx, accObject, |
| 636 | + operandLocation, asFortran, bounds); |
| 637 | + |
| 638 | + if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType()))) |
| 639 | + TODO(operandLocation, "reduction with unsupported type"); |
| 640 | + |
| 641 | + mlir::Type ty = fir::unwrapRefType(baseAddr.getType()); |
| 642 | + std::string recipeName = fir::getTypeAsString( |
| 643 | + ty, converter.getKindMap(), |
| 644 | + ("reduction_" + stringifyReductionOperator(mlirOp)).str()); |
| 645 | + mlir::acc::ReductionRecipeOp recipe = |
| 646 | + Fortran::lower::createOrGetReductionRecipe(builder, recipeName, |
| 647 | + operandLocation, ty, mlirOp); |
| 648 | + reductionRecipes.push_back(mlir::SymbolRefAttr::get( |
| 649 | + builder.getContext(), recipe.getSymName().str())); |
| 650 | + reductionOperands.push_back(baseAddr); |
| 651 | + } |
| 652 | +} |
| 653 | + |
529 | 654 | static void
|
530 | 655 | addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
|
531 | 656 | llvm::SmallVectorImpl<int32_t> &operandSegments,
|
@@ -666,7 +791,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
|
666 | 791 | mlir::Value gangStatic;
|
667 | 792 | llvm::SmallVector<mlir::Value, 2> tileOperands, privateOperands,
|
668 | 793 | reductionOperands;
|
669 |
| - llvm::SmallVector<mlir::Attribute> privatizations; |
| 794 | + llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes; |
670 | 795 | bool hasGang = false, hasVector = false, hasWorker = false;
|
671 | 796 |
|
672 | 797 | for (const Fortran::parser::AccClause &clause : accClauseList.v) {
|
@@ -735,10 +860,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
|
735 | 860 | &clause.u)) {
|
736 | 861 | genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx,
|
737 | 862 | privateOperands, privatizations);
|
738 |
| - } else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) { |
739 |
| - // Reduction clause is left out for the moment as the clause will probably |
740 |
| - // end up having its own operation. |
741 |
| - TODO(clauseLocation, "OpenACC compute construct reduction lowering"); |
| 863 | + } else if (const auto *reductionClause = |
| 864 | + std::get_if<Fortran::parser::AccClause::Reduction>( |
| 865 | + &clause.u)) { |
| 866 | + genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, |
| 867 | + reductionOperands, reductionRecipes); |
742 | 868 | }
|
743 | 869 | }
|
744 | 870 |
|
@@ -767,6 +893,10 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
|
767 | 893 | loopOp.setPrivatizationsAttr(
|
768 | 894 | mlir::ArrayAttr::get(builder.getContext(), privatizations));
|
769 | 895 |
|
| 896 | + if (!reductionRecipes.empty()) |
| 897 | + loopOp.setReductionRecipesAttr( |
| 898 | + mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); |
| 899 | + |
770 | 900 | // Lower clauses mapped to attributes
|
771 | 901 | for (const Fortran::parser::AccClause &clause : accClauseList.v) {
|
772 | 902 | if (const auto *collapseClause =
|
|
0 commit comments