diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 05fd87ed5807e..5eb94f9a57dde 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7270,6 +7270,13 @@ static void fixReductionScalarResumeWhenVectorizingEpilog( // created a bc.merge.rdx Phi after the main vector body. Ensure that we carry // over the incoming values correctly. using namespace VPlanPatternMatch; + if (EpiRedResult->getNumUsers() == 1 && + isa(*EpiRedResult->user_begin())) { + EpiRedResult = cast(*EpiRedResult->user_begin()); + assert((EpiRedResult->getOpcode() == Instruction::SExt || + EpiRedResult->getOpcode() == Instruction::ZExt) && + "can only have SExt/ZExt users"); + } assert(count_if(EpiRedResult->users(), IsaPred) == 1 && "ResumePhi must have a single user"); auto *EpiResumePhiVPI = @@ -9204,28 +9211,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( PhiR->setOperand(1, NewExitingVPV); } - // If the vector reduction can be performed in a smaller type, we truncate - // then extend the loop exit value to enable InstCombine to evaluate the - // entire expression in the smaller type. - if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() && - !RecurrenceDescriptor::isAnyOfRecurrenceKind( - RdxDesc.getRecurrenceKind())) { - assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!"); - Type *RdxTy = RdxDesc.getRecurrenceType(); - auto *Trunc = - new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy); - auto *Extnd = - RdxDesc.isSigned() - ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy) - : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy); - - Trunc->insertAfter(NewExitingVPV->getDefiningRecipe()); - Extnd->insertAfter(Trunc); - if (PhiR->getOperand(1) == NewExitingVPV) - PhiR->setOperand(1, Extnd->getVPSingleValue()); - NewExitingVPV = Extnd; - } - // We want code in the middle block to appear to execute on the location of // the scalar loop's latch terminator because: (a) it is all compiler // generated, (b) these instructions are always executed after evaluating @@ -9263,6 +9248,31 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( Builder.createNaryOp(VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, Flags, ExitDL); } + // If the vector reduction can be performed in a smaller type, we truncate + // then extend the loop exit value to enable InstCombine to evaluate the + // entire expression in the smaller type. + if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() && + !RecurrenceDescriptor::isAnyOfRecurrenceKind( + RdxDesc.getRecurrenceKind())) { + assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!"); + Type *RdxTy = RdxDesc.getRecurrenceType(); + auto *Trunc = + new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy); + Instruction::CastOps ExtendOpc = + RdxDesc.isSigned() ? Instruction::SExt : Instruction::ZExt; + auto *Extnd = new VPWidenCastRecipe(ExtendOpc, Trunc, PhiTy); + Trunc->insertAfter(NewExitingVPV->getDefiningRecipe()); + Extnd->insertAfter(Trunc); + if (PhiR->getOperand(1) == NewExitingVPV) + PhiR->setOperand(1, Extnd->getVPSingleValue()); + + // Update ComputeReductionResult with the truncated exiting value and + // extend its result. + FinalReductionResult->setOperand(1, Trunc); + FinalReductionResult = + Builder.createScalarCast(ExtendOpc, FinalReductionResult, PhiTy, {}); + } + // Update all users outside the vector region. OrigExitingVPV->replaceUsesWithIf( FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 2aa5dd1b48c00..2604a5f4971f2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -647,7 +647,6 @@ Value *VPInstruction::generate(VPTransformState &State) { // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary // and will be removed by breaking up the recipe further. auto *PhiR = cast(getOperand(0)); - auto *OrigPhi = cast(PhiR->getUnderlyingValue()); // Get its reduction variable descriptor. const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); @@ -655,7 +654,6 @@ Value *VPInstruction::generate(VPTransformState &State) { assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) && "should be handled by ComputeFindLastIVResult"); - Type *PhiTy = OrigPhi->getType(); // The recipe's operands are the reduction phi, followed by one operand for // each part of the reduction. unsigned UF = getNumOperands() - 1; @@ -667,15 +665,6 @@ Value *VPInstruction::generate(VPTransformState &State) { if (hasFastMathFlags()) Builder.setFastMathFlags(getFastMathFlags()); - // If the vector reduction can be performed in a smaller type, we truncate - // then extend the loop exit value to enable InstCombine to evaluate the - // entire expression in the smaller type. - // TODO: Handle this in truncateToMinBW. - if (State.VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) { - Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), State.VF); - for (unsigned Part = 0; Part < UF; ++Part) - RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy); - } // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = RdxParts[0]; if (PhiR->isOrdered()) { @@ -699,14 +688,13 @@ Value *VPInstruction::generate(VPTransformState &State) { // TODO: Support in-order reductions based on the recurrence descriptor. // All ops in the reduction inherit fast-math-flags from the recurrence // descriptor. - ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK); - - // If the reduction can be performed in a smaller type, we need to extend - // the reduction to the wider type before we branch to the original loop. - if (PhiTy != RdxDesc.getRecurrenceType()) - ReducedPartRdx = RdxDesc.isSigned() - ? Builder.CreateSExt(ReducedPartRdx, PhiTy) - : Builder.CreateZExt(ReducedPartRdx, PhiTy); + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) { + auto *OrigPhi = cast(PhiR->getUnderlyingValue()); + ReducedPartRdx = + createAnyOfReduction(Builder, ReducedPartRdx, + RdxDesc.getRecurrenceStartValue(), OrigPhi); + } else + ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK); } return ReducedPartRdx; @@ -1052,6 +1040,7 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, void VPInstructionWithType::execute(VPTransformState &State) { State.setDebugLocFrom(getDebugLoc()); switch (getOpcode()) { + case Instruction::SExt: case Instruction::ZExt: case Instruction::Trunc: { Value *Op = State.get(getOperand(0), VPLane(0)); diff --git a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll index 7c42c3d9cd52e..a6ac9c2886a92 100644 --- a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll +++ b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll @@ -1167,8 +1167,7 @@ define i32 @narrowed_reduction(ptr %a, i1 %cmp) #0 { ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 16 ; CHECK-NEXT: br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP28:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[TMP10:%.*]] = trunc <16 x i32> [[TMP7]] to <16 x i1> -; CHECK-NEXT: [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP10]]) +; CHECK-NEXT: [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP5]]) ; CHECK-NEXT: [[TMP21:%.*]] = zext i1 [[TMP20]] to i32 ; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[VEC_EPILOG_PH]] ; CHECK: scalar.ph: diff --git a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll index 0a2bb8d5682f2..c101d6a19aa2e 100644 --- a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll +++ b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll @@ -208,8 +208,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) { ; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256 ; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[TMP9:%.*]] = trunc <4 x i32> [[TMP7]] to <4 x i16> -; CHECK-NEXT: [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP9]]) +; CHECK-NEXT: [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]]) ; CHECK-NEXT: [[TMP11:%.*]] = zext i16 [[TMP10]] to i32 ; CHECK-NEXT: br i1 true, label [[FOR_END:%.*]], label [[VEC_EPILOG_ITER_CHECK:%.*]] ; CHECK: vec.epilog.iter.check: @@ -234,8 +233,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) { ; CHECK-NEXT: [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT4]], 256 ; CHECK-NEXT: br i1 [[TMP21]], label [[VEC_EPILOG_MIDDLE_BLOCK:%.*]], label [[VEC_EPILOG_VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]] ; CHECK: vec.epilog.middle.block: -; CHECK-NEXT: [[TMP22:%.*]] = trunc <4 x i32> [[TMP20]] to <4 x i16> -; CHECK-NEXT: [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP22]]) +; CHECK-NEXT: [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP19]]) ; CHECK-NEXT: [[TMP24:%.*]] = zext i16 [[TMP23]] to i32 ; CHECK-NEXT: br i1 true, label [[FOR_END]], label [[VEC_EPILOG_SCALAR_PH]] ; CHECK: vec.epilog.scalar.ph: diff --git a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll index 796c1d116aa19..13cc1b657d231 100644 --- a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll +++ b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll @@ -25,8 +25,7 @@ define i8 @PR34687(i1 %c, i32 %x, i32 %n) { ; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[TMP6:%.*]] = trunc <4 x i32> [[TMP4]] to <4 x i8> -; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP6]]) +; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP3]]) ; CHECK-NEXT: [[TMP8:%.*]] = zext i8 [[TMP7]] to i32 ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]] @@ -104,8 +103,7 @@ define i8 @PR34687_no_undef(i1 %c, i32 %x, i32 %n) { ; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i8> -; CHECK-NEXT: [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP8]]) +; CHECK-NEXT: [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP5]]) ; CHECK-NEXT: [[TMP10:%.*]] = zext i8 [[TMP9]] to i32 ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]] @@ -183,8 +181,7 @@ define i32 @PR35734(i32 %x, i32 %y) { ; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i1> -; CHECK-NEXT: [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP8]]) +; CHECK-NEXT: [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP5]]) ; CHECK-NEXT: [[TMP10:%.*]] = sext i1 [[TMP9]] to i32 ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]] diff --git a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll index 223acfa2e3a25..6e251cb7b0ad3 100644 --- a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll +++ b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll @@ -28,9 +28,7 @@ define i8 @reduction_add_trunc(ptr noalias nocapture %A) { ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP31]] ; CHECK-NEXT: [[TMP32:%.*]] = icmp eq i32 [[INDEX_NEXT]], {{%.*}} ; CHECK: middle.block: -; CHECK-NEXT: [[TMP37:%.*]] = trunc [[TMP34]] to -; CHECK-NEXT: [[TMP38:%.*]] = trunc [[TMP36]] to -; CHECK-NEXT: [[BIN_RDX:%.*]] = add [[TMP38]], [[TMP37]] +; CHECK-NEXT: [[BIN_RDX:%.*]] = add [[TMP35]], [[TMP33]] ; CHECK-NEXT: [[TMP39:%.*]] = call i8 @llvm.vector.reduce.add.nxv8i8( [[BIN_RDX]]) ; CHECK-NEXT: [[TMP40:%.*]] = zext i8 [[TMP39]] to i32 ; CHECK-NEXT: %cmp.n = icmp eq i32 256, %n.vec