Skip to content

Commit 79ae407

Browse files
authored
[ConstantFolding] Fold intrinsics of scalable vectors with splatted operands (llvm#141845)
As noted in llvm#141821 (comment), whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors. This handles the scalable vector case when the operands are splats. One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.
1 parent 8199f18 commit 79ae407

File tree

5 files changed

+49
-4
lines changed

5 files changed

+49
-4
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3780,7 +3780,35 @@ static Constant *ConstantFoldScalableVectorCall(
37803780
default:
37813781
break;
37823782
}
3783-
return nullptr;
3783+
3784+
// If trivially vectorizable, try folding it via the scalar call if all
3785+
// operands are splats.
3786+
3787+
// TODO: ConstantFoldFixedVectorCall should probably check this too?
3788+
if (!isTriviallyVectorizable(IntrinsicID))
3789+
return nullptr;
3790+
3791+
SmallVector<Constant *, 4> SplatOps;
3792+
for (auto [I, Op] : enumerate(Operands)) {
3793+
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, I, /*TTI=*/nullptr)) {
3794+
SplatOps.push_back(Op);
3795+
continue;
3796+
}
3797+
// TODO: Should getSplatValue return a poison scalar for a poison vector?
3798+
if (isa<PoisonValue>(Op)) {
3799+
SplatOps.push_back(PoisonValue::get(Op->getType()->getScalarType()));
3800+
continue;
3801+
}
3802+
Constant *Splat = Op->getSplatValue();
3803+
if (!Splat)
3804+
return nullptr;
3805+
SplatOps.push_back(Splat);
3806+
}
3807+
Constant *Folded = ConstantFoldScalarCall(
3808+
Name, IntrinsicID, SVTy->getElementType(), SplatOps, TLI, Call);
3809+
if (!Folded)
3810+
return nullptr;
3811+
return ConstantVector::getSplat(SVTy->getElementCount(), Folded);
37843812
}
37853813

37863814
static std::pair<Constant *, Constant *>

llvm/lib/IR/Constants.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,9 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
15071507

15081508
if (V->isNullValue())
15091509
return ConstantAggregateZero::get(VTy);
1510-
else if (isa<UndefValue>(V))
1510+
if (isa<PoisonValue>(V))
1511+
return PoisonValue::get(VTy);
1512+
if (isa<UndefValue>(V))
15111513
return UndefValue::get(VTy);
15121514

15131515
Type *IdxTy = Type::getInt64Ty(VTy->getContext());

llvm/test/Transforms/InstSimplify/ConstProp/abs.ll

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,11 @@ define <8 x i8> @vec_const() {
4343
%r = call <8 x i8> @llvm.abs.v8i8(<8 x i8> <i8 -127, i8 -126, i8 -42, i8 -1, i8 0, i8 1, i8 42, i8 127>, i1 1)
4444
ret <8 x i8> %r
4545
}
46+
47+
define <vscale x 1 x i8> @scalable_vec_const() {
48+
; CHECK-LABEL: @scalable_vec_const(
49+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 42)
50+
;
51+
%r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
52+
ret <vscale x 1 x i8> %r
53+
}

llvm/test/Transforms/InstSimplify/ConstProp/fma.ll

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ define double @PR20832() {
1616
ret double %1
1717
}
1818

19+
define <vscale x 1 x double> @scalable_vector() {
20+
; CHECK-LABEL: @scalable_vector(
21+
; CHECK-NEXT: ret <vscale x 1 x double> splat (double 5.600000e+01)
22+
;
23+
%1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
24+
ret <vscale x 1 x double> %1
25+
}
26+
1927
; Test builtin fma with all finite non-zero constants.
2028
define double @test_all_finite() {
2129
; CHECK-LABEL: @test_all_finite(

llvm/test/Transforms/InstSimplify/exp10.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ define <2 x float> @exp10_zero_vector() {
109109

110110
define <vscale x 2 x float> @exp10_zero_scalable_vector() {
111111
; CHECK-LABEL: define <vscale x 2 x float> @exp10_zero_scalable_vector() {
112-
; CHECK-NEXT: [[RET:%.*]] = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
113-
; CHECK-NEXT: ret <vscale x 2 x float> [[RET]]
112+
; CHECK-NEXT: ret <vscale x 2 x float> splat (float 1.000000e+00)
114113
;
115114
%ret = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
116115
ret <vscale x 2 x float> %ret

0 commit comments

Comments
 (0)