diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 412a0e8979193..40302fbc8ee52 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3780,7 +3780,35 @@ static Constant *ConstantFoldScalableVectorCall( default: break; } - return nullptr; + + // If trivially vectorizable, try folding it via the scalar call if all + // operands are splats. + + // TODO: ConstantFoldFixedVectorCall should probably check this too? + if (!isTriviallyVectorizable(IntrinsicID)) + return nullptr; + + SmallVector SplatOps; + for (auto [I, Op] : enumerate(Operands)) { + if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, I, /*TTI=*/nullptr)) { + SplatOps.push_back(Op); + continue; + } + // TODO: Should getSplatValue return a poison scalar for a poison vector? + if (isa(Op)) { + SplatOps.push_back(PoisonValue::get(Op->getType()->getScalarType())); + continue; + } + Constant *Splat = Op->getSplatValue(); + if (!Splat) + return nullptr; + SplatOps.push_back(Splat); + } + Constant *Folded = ConstantFoldScalarCall( + Name, IntrinsicID, SVTy->getElementType(), SplatOps, TLI, Call); + if (!Folded) + return nullptr; + return ConstantVector::getSplat(SVTy->getElementCount(), Folded); } static std::pair diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index b2087d3651143..fa453309b34ee 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -1507,7 +1507,9 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { if (V->isNullValue()) return ConstantAggregateZero::get(VTy); - else if (isa(V)) + if (isa(V)) + return PoisonValue::get(VTy); + if (isa(V)) return UndefValue::get(VTy); Type *IdxTy = Type::getInt64Ty(VTy->getContext()); diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll index 37233b0f29342..615ab10248b2a 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll @@ -43,3 +43,11 @@ define <8 x i8> @vec_const() { %r = call <8 x i8> @llvm.abs.v8i8(<8 x i8> , i1 1) ret <8 x i8> %r } + +define @scalable_vec_const() { +; CHECK-LABEL: @scalable_vec_const( +; CHECK-NEXT: ret splat (i8 42) +; + %r = call @llvm.abs( splat (i8 -42), i1 1) + ret %r +} diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll index d3ade92a6db05..2f56c2df0ca8f 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll @@ -16,6 +16,14 @@ define double @PR20832() { ret double %1 } +define @scalable_vector() { +; CHECK-LABEL: @scalable_vector( +; CHECK-NEXT: ret splat (double 5.600000e+01) +; + %1 = call @llvm.fma( splat (double 7.0), splat (double 8.0), splat (double 0.0)) + ret %1 +} + ; Test builtin fma with all finite non-zero constants. define double @test_all_finite() { ; CHECK-LABEL: @test_all_finite( diff --git a/llvm/test/Transforms/InstSimplify/exp10.ll b/llvm/test/Transforms/InstSimplify/exp10.ll index a546bb1255d85..c415c419aad84 100644 --- a/llvm/test/Transforms/InstSimplify/exp10.ll +++ b/llvm/test/Transforms/InstSimplify/exp10.ll @@ -109,8 +109,7 @@ define <2 x float> @exp10_zero_vector() { define @exp10_zero_scalable_vector() { ; CHECK-LABEL: define @exp10_zero_scalable_vector() { -; CHECK-NEXT: [[RET:%.*]] = call @llvm.exp10.nxv2f32( zeroinitializer) -; CHECK-NEXT: ret [[RET]] +; CHECK-NEXT: ret splat (float 1.000000e+00) ; %ret = call @llvm.exp10.nxv2f32( zeroinitializer) ret %ret