-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[ConstantFolding] Fold intrinsics of scalable vectors with splatted operands #141845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesAs noted in #141821 (comment), whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors. This handles the scalable vector case when the operands are splats. One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue. Full diff: https://github.com/llvm/llvm-project/pull/141845.diff 5 Files Affected:
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..40302fbc8ee52 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3780,7 +3780,35 @@ static Constant *ConstantFoldScalableVectorCall(
default:
break;
}
- return nullptr;
+
+ // If trivially vectorizable, try folding it via the scalar call if all
+ // operands are splats.
+
+ // TODO: ConstantFoldFixedVectorCall should probably check this too?
+ if (!isTriviallyVectorizable(IntrinsicID))
+ return nullptr;
+
+ SmallVector<Constant *, 4> SplatOps;
+ for (auto [I, Op] : enumerate(Operands)) {
+ if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, I, /*TTI=*/nullptr)) {
+ SplatOps.push_back(Op);
+ continue;
+ }
+ // TODO: Should getSplatValue return a poison scalar for a poison vector?
+ if (isa<PoisonValue>(Op)) {
+ SplatOps.push_back(PoisonValue::get(Op->getType()->getScalarType()));
+ continue;
+ }
+ Constant *Splat = Op->getSplatValue();
+ if (!Splat)
+ return nullptr;
+ SplatOps.push_back(Splat);
+ }
+ Constant *Folded = ConstantFoldScalarCall(
+ Name, IntrinsicID, SVTy->getElementType(), SplatOps, TLI, Call);
+ if (!Folded)
+ return nullptr;
+ return ConstantVector::getSplat(SVTy->getElementCount(), Folded);
}
static std::pair<Constant *, Constant *>
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index b2087d3651143..fa453309b34ee 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1507,7 +1507,9 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (V->isNullValue())
return ConstantAggregateZero::get(VTy);
- else if (isa<UndefValue>(V))
+ if (isa<PoisonValue>(V))
+ return PoisonValue::get(VTy);
+ if (isa<UndefValue>(V))
return UndefValue::get(VTy);
Type *IdxTy = Type::getInt64Ty(VTy->getContext());
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
index 37233b0f29342..615ab10248b2a 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
@@ -43,3 +43,11 @@ define <8 x i8> @vec_const() {
%r = call <8 x i8> @llvm.abs.v8i8(<8 x i8> <i8 -127, i8 -126, i8 -42, i8 -1, i8 0, i8 1, i8 42, i8 127>, i1 1)
ret <8 x i8> %r
}
+
+define <vscale x 1 x i8> @scalable_vec_const() {
+; CHECK-LABEL: @scalable_vec_const(
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 42)
+;
+ %r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
+ ret <vscale x 1 x i8> %r
+}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
index d3ade92a6db05..2f56c2df0ca8f 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
@@ -16,6 +16,14 @@ define double @PR20832() {
ret double %1
}
+define <vscale x 1 x double> @scalable_vector() {
+; CHECK-LABEL: @scalable_vector(
+; CHECK-NEXT: ret <vscale x 1 x double> splat (double 5.600000e+01)
+;
+ %1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
+ ret <vscale x 1 x double> %1
+}
+
; Test builtin fma with all finite non-zero constants.
define double @test_all_finite() {
; CHECK-LABEL: @test_all_finite(
diff --git a/llvm/test/Transforms/InstSimplify/exp10.ll b/llvm/test/Transforms/InstSimplify/exp10.ll
index a546bb1255d85..c415c419aad84 100644
--- a/llvm/test/Transforms/InstSimplify/exp10.ll
+++ b/llvm/test/Transforms/InstSimplify/exp10.ll
@@ -109,8 +109,7 @@ define <2 x float> @exp10_zero_vector() {
define <vscale x 2 x float> @exp10_zero_scalable_vector() {
; CHECK-LABEL: define <vscale x 2 x float> @exp10_zero_scalable_vector() {
-; CHECK-NEXT: [[RET:%.*]] = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
-; CHECK-NEXT: ret <vscale x 2 x float> [[RET]]
+; CHECK-NEXT: ret <vscale x 2 x float> splat (float 1.000000e+00)
;
%ret = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
ret <vscale x 2 x float> %ret
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
SplatOps.push_back(Op); | ||
continue; | ||
} | ||
// TODO: Should getSplatValue return a poison scalar for a poison vector? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do this in a follow up, seems to be NFC?
This is a follow up from llvm#141845. I'm not sure if this actually NFC but it doesn't seem to affect any of the in-tree tests. I went through the users of getSplatValue to see if anything could be cleaned up but nothing immediately stuck out.
) This is a follow up from #141845. TargetTransformInfo::getOperandInfo needs to be updated to check for undef values as otherwise a splat is considered a constant, and some RISC-V cost model tests will start adding a cost to materialize the constant.
) This is a follow up from #141845. TargetTransformInfo::getOperandInfo needs to be updated to check for undef values as otherwise a splat is considered a constant, and some RISC-V cost model tests will start adding a cost to materialize the constant.
…perands (llvm#141845) As noted in llvm#141821 (comment), whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors. This handles the scalable vector case when the operands are splats. One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.
…#141870) This is a follow up from llvm#141845. TargetTransformInfo::getOperandInfo needs to be updated to check for undef values as otherwise a splat is considered a constant, and some RISC-V cost model tests will start adding a cost to materialize the constant.
…perands (llvm#141845) As noted in llvm#141821 (comment), whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors. This handles the scalable vector case when the operands are splats. One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.
…#141870) This is a follow up from llvm#141845. TargetTransformInfo::getOperandInfo needs to be updated to check for undef values as otherwise a splat is considered a constant, and some RISC-V cost model tests will start adding a cost to materialize the constant.
As noted in #141821 (comment), whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors.
This handles the scalable vector case when the operands are splats.
One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.