-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[ConstantFolding] Add folding for [de]interleave2, insert and extract #141301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ConstantFolding] Add folding for [de]interleave2, insert and extract #141301
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Nikolay Panchenko (npanchen) ChangesThe change adds folding for 4 vector intrinsics: Full diff: https://github.com/llvm/llvm-project/pull/141301.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..d30f2fef69a54 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1619,6 +1619,10 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::vector_reduce_smax:
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_umax:
+ case Intrinsic::vector_extract:
+ case Intrinsic::vector_insert:
+ case Intrinsic::vector_interleave2:
+ case Intrinsic::vector_deinterleave2:
// Target intrinsics
case Intrinsic::amdgcn_perm:
case Intrinsic::amdgcn_wave_reduce_umin:
@@ -3734,6 +3738,65 @@ static Constant *ConstantFoldFixedVectorCall(
}
return nullptr;
}
+ case Intrinsic::vector_extract: {
+ auto *Vec = dyn_cast<Constant>(Operands[0]);
+ auto *Idx = dyn_cast<ConstantInt>(Operands[1]);
+ if (!Vec || !Idx)
+ return nullptr;
+
+ unsigned NumElements = FVTy->getNumElements();
+ unsigned VecNumElements =
+ cast<FixedVectorType>(Vec->getType())->getNumElements();
+ // Extracting entire vector is nop
+ if (NumElements == VecNumElements)
+ return Vec;
+
+ unsigned StartingIndex = Idx->getZExtValue();
+ assert(StartingIndex + NumElements <= VecNumElements &&
+ "Cannot extract more elements than exist in the vector");
+ for (unsigned I = 0; I != NumElements; ++I)
+ Result[I] = Vec->getAggregateElement(StartingIndex + I);
+ return ConstantVector::get(Result);
+ }
+ case Intrinsic::vector_insert: {
+ auto *Vec = dyn_cast<Constant>(Operands[0]);
+ auto *SubVec = dyn_cast<Constant>(Operands[1]);
+ auto *Idx = dyn_cast<ConstantInt>(Operands[2]);
+ if (!Vec || !SubVec || !Idx)
+ return nullptr;
+
+ unsigned SubVecNumElements =
+ cast<FixedVectorType>(SubVec->getType())->getNumElements();
+ unsigned VecNumElements =
+ cast<FixedVectorType>(Vec->getType())->getNumElements();
+ unsigned IdxN = Idx->getZExtValue();
+ // Replacing entire vector with a subvec is nop
+ if (SubVecNumElements == VecNumElements)
+ return SubVec;
+
+ unsigned I = 0;
+ for (; I < IdxN; ++I)
+ Result[I] = Vec->getAggregateElement(I);
+ for (; I < IdxN + SubVecNumElements; ++I)
+ Result[I] = SubVec->getAggregateElement(I - IdxN);
+ for (; I < VecNumElements; ++I)
+ Result[I] = Vec->getAggregateElement(I);
+ return ConstantVector::get(Result);
+ }
+ case Intrinsic::vector_interleave2: {
+ auto *Vec0 = dyn_cast<Constant>(Operands[0]);
+ auto *Vec1 = dyn_cast<Constant>(Operands[1]);
+ if (!Vec0 || !Vec1)
+ return nullptr;
+
+ unsigned NumElements =
+ cast<FixedVectorType>(Vec0->getType())->getNumElements();
+ for (unsigned I = 0; I < NumElements; ++I) {
+ Result[2 * I] = Vec0->getAggregateElement(I);
+ Result[2 * I + 1] = Vec1->getAggregateElement(I);
+ }
+ return ConstantVector::get(Result);
+ }
default:
break;
}
@@ -3872,6 +3935,21 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID,
return nullptr;
return ConstantStruct::get(StTy, SinResult, CosResult);
}
+ case Intrinsic::vector_deinterleave2: {
+ auto *Vec = dyn_cast<Constant>(Operands[0]);
+ if (!Vec)
+ return nullptr;
+
+ unsigned NumElements =
+ cast<FixedVectorType>(Vec->getType())->getNumElements() / 2;
+ SmallVector<Constant *, 4> Res0(NumElements), Res1(NumElements);
+ for (unsigned I = 0; I < NumElements; ++I) {
+ Res0[I] = Vec->getAggregateElement(2 * I);
+ Res1[I] = Vec->getAggregateElement(2 * I + 1);
+ }
+ return ConstantStruct::get(StTy, ConstantVector::get(Res0),
+ ConstantVector::get(Res1));
+ }
default:
// TODO: Constant folding of vector intrinsics that fall through here does
// not work (e.g. overflow intrinsics)
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll
new file mode 100644
index 0000000000000..f0bf610fa52aa
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll
@@ -0,0 +1,50 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify,verify -S | FileCheck %s
+
+define <3 x i32> @fold_vector_extract() {
+; CHECK-LABEL: define <3 x i32> @fold_vector_extract() {
+; CHECK-NEXT: ret <3 x i32> <i32 3, i32 4, i32 5>
+;
+ %1 = call <3 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 3)
+ ret <3 x i32> %1
+}
+
+define <8 x i32> @fold_vector_extract_nop() {
+; CHECK-LABEL: define <8 x i32> @fold_vector_extract_nop() {
+; CHECK-NEXT: ret <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+;
+ %1 = call <8 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 0)
+ ret <8 x i32> %1
+}
+
+define <8 x i32> @fold_vector_insert() {
+; CHECK-LABEL: define <8 x i32> @fold_vector_insert() {
+; CHECK-NEXT: ret <8 x i32> <i32 9, i32 10, i32 11, i32 12, i32 5, i32 6, i32 7, i32 8>
+;
+ %1 = call <8 x i32> @llvm.vector.insert.v8i32(<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>, i64 0)
+ ret <8 x i32> %1
+}
+
+define <8 x i32> @fold_vector_insert_nop() {
+; CHECK-LABEL: define <8 x i32> @fold_vector_insert_nop() {
+; CHECK-NEXT: ret <8 x i32> <i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18>
+;
+ %1 = call <8 x i32> @llvm.vector.insert.v8i32(<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>, <8 x i32> <i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18>, i64 0)
+ ret <8 x i32> %1
+}
+
+define <8 x i32> @fold_vector_interleave2() {
+; CHECK-LABEL: define <8 x i32> @fold_vector_interleave2() {
+; CHECK-NEXT: ret <8 x i32> <i32 1, i32 5, i32 2, i32 6, i32 3, i32 7, i32 4, i32 8>
+;
+ %1 = call<8 x i32> @llvm.vector.interleave2.v8i32(<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>)
+ ret <8 x i32> %1
+}
+
+define {<4 x i32>, <4 x i32>} @fold_vector_deinterleav2() {
+; CHECK-LABEL: define { <4 x i32>, <4 x i32> } @fold_vector_deinterleav2() {
+; CHECK-NEXT: ret { <4 x i32>, <4 x i32> } { <4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8> }
+;
+ %1 = call {<4 x i32>, <4 x i32>} @llvm.vector.deinterleave2.v4i32.v8i32(<8 x i32> <i32 1, i32 5, i32 2, i32 6, i32 3, i32 7, i32 4, i32 8>)
+ ret {<4 x i32>, <4 x i32>} %1
+}
|
f0bc11d
to
69707e3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some constexpr tests? getAggregateElement
may return null.
@@ -0,0 +1,90 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 | |||
; RUN: opt < %s -passes=instsimplify,verify -disable-verify -S | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
; RUN: opt < %s -passes=instsimplify,verify -disable-verify -S | FileCheck %s | |
; RUN: opt < %s -passes=instsimplify -S | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable-verify
needs to test poison value generation of vector.insert
, vector.extract
when writing/reading OOB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These make sense to have, but I'm wondering if you were seeing any cases where we were using fixed length versions of these intrinsics. From my understanding these are only really used for scalable vectors in the loop vectorizer etc.
Maybe from the ComplexDeinterleaving pass?
The fixed versions of these intrinsics are generated by vectorizers (at least they exist in fuzzer-generated cases). |
Intrinsics I added here were explicitly generated from Mojo code. For example: https://github.com/modular/modular/blob/e7419034262164d41798b11876e40ca173bae28c/mojo/stdlib/stdlib/builtin/simd.mojo#L2328-L2332 I doubt LV can emit it, at least not with innermost loop vectorization. Perhaps SLP can do this. |
The change adds folding for 4 vector intrinsics: `interleave2`, `deinterleave2`, `vector_extract` and `vector_insert`. For the last 2 intrinsics the change does not use `ShuffleVector` fold mechanism as it's much simpler to construct result vector explicitly.
69707e3
to
55c7971
Compare
not sure how to trigger it, but just in case added bailout if it's nullptr, similar to the default code in the function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please wait for additional approval from other reviewers :)
const unsigned NonPoisonNumElements = | ||
std::min(StartingIndex + NumElements, VecNumElements); | ||
for (unsigned I = StartingIndex; I < NonPoisonNumElements; ++I) { | ||
Constant *Elt = Vec->getAggregateElement(I); | ||
if (!Elt) | ||
return nullptr; | ||
Result[I - StartingIndex] = Elt; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistic nit, you could avoid NonPoisonNumElements
and the second loop if you handle it in the main loop
const unsigned NonPoisonNumElements = | |
std::min(StartingIndex + NumElements, VecNumElements); | |
for (unsigned I = StartingIndex; I < NonPoisonNumElements; ++I) { | |
Constant *Elt = Vec->getAggregateElement(I); | |
if (!Elt) | |
return nullptr; | |
Result[I - StartingIndex] = Elt; | |
} | |
for (unsigned I = 0; I < NumElements; ++I) { | |
// Out of bounds elements are poison | |
if (StartingIndex + I >= VecNumElements) { | |
Result[I] = PoisonValue::get(FVTy->getElementType()); | |
continue; | |
} | |
Constant *Elt = Vec->getAggregateElement(StartingIndex + I); | |
if (!Elt) | |
return nullptr; | |
Result[I] = Elt; | |
} |
|
||
// Make sure indices are in the range [0, VecNumElements), otherwise the | ||
// result is a poison value. | ||
if (IdxN >= VecNumElements || IdxN + SubVecNumElements > VecNumElements || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SubVecNumElements is always >= 1 so if IdxN >= VecNumElements
then IdxN + SubVecNumElements > VecNumElements
. So I think you can remove the first check
if (IdxN >= VecNumElements || IdxN + SubVecNumElements > VecNumElements || | |
if (IdxN + SubVecNumElements > VecNumElements || |
// Make sure indices are in the range [0, VecNumElements), otherwise the | ||
// result is a poison value. | ||
if (IdxN >= VecNumElements || IdxN + SubVecNumElements > VecNumElements || | ||
(IdxN && (SubVecNumElements % IdxN) != 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be if the index isn't a multiple of the vector length? I.e.
(IdxN && (SubVecNumElements % IdxN) != 0)) | |
IdxN % SubVecNumElements) |
Is it possible to add a test for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's not correct when IdxN = 3
and SubVecNumElements = 6
?
idx must be a constant multiple of subvec’s known minimum vector length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IdxN must be 0,6,12,18... right? So 3 should be poison and 3%6=3 would be true for this check
auto *Vec0 = dyn_cast<Constant>(Operands[0]); | ||
auto *Vec1 = dyn_cast<Constant>(Operands[1]); | ||
if (!Vec0 || !Vec1) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Operands is an ArrayRef<Constant *>
already, are the dyn_casts redundant?
auto *Vec = dyn_cast<Constant>(Operands[0]); | ||
auto *SubVec = dyn_cast<Constant>(Operands[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I think you can drop the dyn_casts
auto *Vec = dyn_cast<Constant>(Operands[0]); | |
auto *SubVec = dyn_cast<Constant>(Operands[1]); | |
Constant *Vec = Operands[0]; | |
Constant *SubVec = Operands[1]; |
auto *Vec = dyn_cast<Constant>(Operands[0]); | ||
if (!Vec) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto *Vec = dyn_cast<Constant>(Operands[0]); | |
if (!Vec) | |
return nullptr; | |
Constant *Vec = Operands[0]; |
return nullptr; | ||
|
||
unsigned NumElements = | ||
cast<VectorType>(Vec->getType())->getElementCount().getKnownMinValue() / |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If Vec can be scalable here, should we check if it's scalable and bail? Otherwise we're relying on getAggregateElement to return nullptr
The change adds folding for 4 vector intrinsics:
interleave2
,deinterleave2
,vector_extract
andvector_insert
. For the last 2 intrinsics the change does not useShuffleVector
fold mechanism as it's much simpler to construct result vector explicitly.