Skip to content

[Matrix] Optimize static extracts with ShapeInfo #141815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jroelofs
Copy link
Contributor

For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns to look through the shuffles that embedInVector creates, which in some cases allows us to delete them.

@jroelofs jroelofs requested review from fhahn and anemet May 28, 2025 18:16
Copy link

github-actions bot commented May 28, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@jroelofs jroelofs force-pushed the jroelofs/lower-matrix-extract branch from 10a854e to ae5751f Compare May 28, 2025 19:02
For ExtractElementInsts with static indices that extract from a Matrix, use the
known layout of the Rows/Columns, avoiding some of the shuffles that
embedInVector creates.
@jroelofs jroelofs force-pushed the jroelofs/lower-matrix-extract branch from ae5751f to 323c9a8 Compare May 28, 2025 19:03
@jroelofs jroelofs marked this pull request as ready for review May 28, 2025 19:04
@llvmbot
Copy link
Member

llvmbot commented May 28, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns to look through the shuffles that embedInVector creates, which in some cases allows us to delete them.


Full diff: https://github.com/llvm/llvm-project/pull/141815.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+37-6)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll (+47)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..8b322afd9b6e4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -34,6 +34,7 @@
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MatrixBuilder.h"
@@ -623,7 +624,8 @@ class LowerMatrixIntrinsics {
       default:
         return false;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+           isa<ExtractElementInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -1337,6 +1339,28 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
+  bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+    Value *Op0 = Inst->getOperand(0);
+    auto *VTy = cast<VectorType>(Op0->getType());
+
+    if (VTy->getElementCount().getKnownMinValue() < Index) {
+      Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+      Inst->eraseFromParent();
+      return true;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op0);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    const MatrixTy &M = I->second;
+
+    IRBuilder<> Builder(Inst);
+    Inst->setOperand(0, M.getVector(Index / M.getStride()));
+    Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1375,18 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
-      }
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      Value *Op1;
+      uint64_t Index;
+      if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+        if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+          continue;
+
+      if (!Flattened)
+        Flattened = Matrix.embedInVector(Builder);
+      U.set(Flattened);
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..0bac9492d654a
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 3
+  ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    ret float poison
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 5
+  ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 %idx
+  ret float %extract
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants