diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 1cfb866db0b51..e41be8cbc1aa1 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -19,12 +19,14 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" +#include using namespace mlir; using namespace mlir::scf; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1e1056e..c8960039a6ce1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -14,7 +14,12 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { #define GEN_PASS_DEF_SCFFORALLTOFORLOOP @@ -35,16 +40,108 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp, SmallVector lbs = forallOp.getLowerBound(rewriter); SmallVector ubs = forallOp.getUpperBound(rewriter); SmallVector steps = forallOp.getStep(rewriter); - LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps); + SmallVector iterArgs; + for (auto result : forallOp->getResults()) { + iterArgs.push_back(forallOp.getTiedOpOperand(result)->get()); + } + + InParallelOp threadReduction = + cast(forallOp.getBody()->getTerminator()); + SmallVector regionArgToSlice; + for (auto &op : threadReduction.getBody()->getOperations()) { + auto parallelInsert = dyn_cast(op); + if (!parallelInsert) { + return op.emitOpError() << "expected parallel insert slice op"; + } + regionArgToSlice.push_back(parallelInsert); + } + + function_ref + build = [&](OpBuilder &rewriter, Location loc, ValueRange ivs, + ValueRange regionArgs) -> ValueVector { + SmallVector res; + for (auto [i, val] : llvm::enumerate(regionArgs)) { + tensor::ParallelInsertSliceOp sliceOp = regionArgToSlice[i]; + + // Map new induction variables where applicable. + + SmallVector sliceOpOffsets = sliceOp.getMixedOffsets(); + for (OpFoldResult offset : sliceOpOffsets) { + if (offset.is()) { + Value dynamicOffset = offset.get(); + SmallVector originalInductionVars = + forallOp.getInductionVars(); + auto *it = llvm::find(originalInductionVars, dynamicOffset); + if (it != originalInductionVars.end()) { + size_t index = std::distance(originalInductionVars.begin(), it); + offset = ivs[index]; + } + } + } + + SmallVector sliceOpSizes = sliceOp.getMixedSizes(); + for (OpFoldResult size : sliceOpSizes) { + if (size.is()) { + Value dynamicSize = size.get(); + SmallVector originalInductionVars = + forallOp.getInductionVars(); + auto *it = llvm::find(originalInductionVars, dynamicSize); + if (it != originalInductionVars.end()) { + size_t index = std::distance(originalInductionVars.begin(), it); + size = ivs[index]; + } + } + } + + SmallVector sliceOpStrides = sliceOp.getMixedStrides(); + for (OpFoldResult stride : sliceOpStrides) { + if (stride.is()) { + Value dynamicStride = stride.get(); + SmallVector originalInductionVars = + forallOp.getInductionVars(); + auto *it = llvm::find(originalInductionVars, dynamicStride); + if (it != originalInductionVars.end()) { + size_t index = std::distance(originalInductionVars.begin(), it); + stride = ivs[index]; + } + } + } + + res.push_back(rewriter.create( + sliceOp->getLoc(), sliceOp.getSource(), val, sliceOpOffsets, + sliceOpSizes, sliceOpStrides)); + } + return res; + }; + // Now we want to create our new loops with the innermost getting the tensor + // insert slices appropriately. + LoopNest loopNest = + scf::buildLoopNest(rewriter, loc, lbs, ubs, steps, iterArgs, build); SmallVector ivs = llvm::map_to_vector( loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); }); + rewriter.replaceAllOpUsesWith(forallOp, + {loopNest.loops.front()->getResults()}); + // Erase the parallel inserts and associated shared outputs. + for (tensor::ParallelInsertSliceOp insertSlice : + llvm::make_early_inc_range(regionArgToSlice)) { + auto loopBlockArg = dyn_cast(insertSlice.getDest()); + if (!loopBlockArg || loopBlockArg.getOwner()->getParentOp() != forallOp) { + insertSlice->emitOpError() + << "expected destination to be block argument in loop"; + } + rewriter.eraseOp(insertSlice); + rewriter.modifyOpInPlace(forallOp, [&]() { + forallOp.getBody()->eraseArgument(loopBlockArg.getArgNumber()); + }); + } + rewriter.eraseOp(forallOp.getTerminator()); + Block *innermostBlock = loopNest.loops.back().getBody(); - rewriter.eraseOp(forallOp.getBody()->getTerminator()); + rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock, - innermostBlock->getTerminator()->getIterator(), - ivs); + innermostBlock->front().getIterator(), ivs); rewriter.eraseOp(forallOp); if (results) { diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir index e7d183fb9d2b5..4d8390f0b62c4 100644 --- a/mlir/test/Dialect/SCF/forall-to-for.mlir +++ b/mlir/test/Dialect/SCF/forall-to-for.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for,canonicalize))' -split-input-file | FileCheck %s func.func private @callee(%i: index, %j: index) @@ -55,3 +55,40 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) { } return } + +// ----- + +func.func @nested_with_result() -> tensor<4x2xf32> { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4x2xf32> + %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) { + %1 = tensor.empty() : tensor<1x1xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] : + tensor<1x1xf32> into tensor<4x2xf32> + } + } + return %res: tensor<4x2xf32> +} + +// CHECK-LABEL: func.func @nested_with_result() -> tensor<4x2xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[FILL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[REDUCED_RES:.*]] = tensor.empty() : tensor<4x2xf32> +// CHECK: %[[OUTER:.*]] = scf.for %[[IV_OUTER:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[OUTER_RES:.*]] = %[[REDUCED_RES]]) -> (tensor<4x2xf32>) { +// CHECK: %[[INNER:.*]] = scf.for %[[IV_INNER:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[INNER_RES:.*]] = %[[OUTER_RES]]) -> (tensor<4x2xf32>) { +// CHECK: %[[ITERATION_TENS:.*]] = tensor.empty() : tensor<1x1xf32> +// CHECK: %[[ITERATION_RES:.*]] = linalg.fill ins(%[[FILL]] : f32) outs(%[[ITERATION_TENS]] : tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[UPDATED_RES:.*]] = tensor.insert_slice %[[ITERATION_RES]] into %[[INNER_RES]]{{\[}}%[[IV_OUTER]], %[[IV_INNER]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<4x2xf32> +// CHECK: scf.yield %[[UPDATED_RES]] : tensor<4x2xf32> +// CHECK: } +// CHECK: scf.yield %[[INNER]] : tensor<4x2xf32> +// CHECK: } +// CHECK: return %[[OUTER]] : tensor<4x2xf32> +// CHECK: } \ No newline at end of file