-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[Matrix] Optimize static extracts with ShapeInfo #141815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
10a854e
to
ae5751f
Compare
For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns, avoiding some of the shuffles that embedInVector creates.
ae5751f
to
323c9a8
Compare
@llvm/pr-subscribers-llvm-transforms Author: Jon Roelofs (jroelofs) ChangesFor ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns to look through the shuffles that embedInVector creates, which in some cases allows us to delete them. Full diff: https://github.com/llvm/llvm-project/pull/141815.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..8b322afd9b6e4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -34,6 +34,7 @@
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
@@ -623,7 +624,8 @@ class LowerMatrixIntrinsics {
default:
return false;
}
- return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+ return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+ isa<ExtractElementInst>(V);
}
/// Propagate the shape information of instructions to their users.
@@ -1337,6 +1339,28 @@ class LowerMatrixIntrinsics {
return Builder.CreateAdd(Sum, Mul);
}
+ bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+ Value *Op0 = Inst->getOperand(0);
+ auto *VTy = cast<VectorType>(Op0->getType());
+
+ if (VTy->getElementCount().getKnownMinValue() < Index) {
+ Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+ Inst->eraseFromParent();
+ return true;
+ }
+
+ auto *I = Inst2ColumnMatrix.find(Op0);
+ if (I == Inst2ColumnMatrix.end())
+ return false;
+
+ const MatrixTy &M = I->second;
+
+ IRBuilder<> Builder(Inst);
+ Inst->setOperand(0, M.getVector(Index / M.getStride()));
+ Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+ return true;
+ }
+
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
/// users with shape information, there's nothing to do: they will use the
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1375,18 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(Inst);
Value *Flattened = nullptr;
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (!ShapeMap.contains(U.getUser())) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
- }
+ if (ShapeMap.contains(U.getUser()))
+ continue;
+
+ Value *Op1;
+ uint64_t Index;
+ if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+ if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+ continue;
+
+ if (!Flattened)
+ Flattened = Matrix.embedInVector(Builder);
+ U.set(Flattened);
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..0bac9492d654a
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
+; CHECK-NEXT: ret float [[EXTRACT]]
+;
+ %inv = load <4 x float>, ptr %in
+ %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+ %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+ %extract = extractelement <4 x float> %invtt, i32 3
+ ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: ret float poison
+;
+ %inv = load <4 x float>, ptr %in
+ %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+ %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+ %extract = extractelement <4 x float> %invtt, i32 5
+ ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT: ret float [[EXTRACT]]
+;
+ %inv = load <4 x float>, ptr %in
+ %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+ %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+ %extract = extractelement <4 x float> %invtt, i32 %idx
+ ret float %extract
+}
|
For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns to look through the shuffles that embedInVector creates, which in some cases allows us to delete them.