From 323c9a8d2de459a0f81f32f7537e8b2e087ffc00 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Wed, 28 May 2025 09:30:50 -0700 Subject: [PATCH 1/5] [Matrix] Optimize static extracts with ShapeInfo For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns, avoiding some of the shuffles that embedInVector creates. --- .../Scalar/LowerMatrixIntrinsics.cpp | 43 ++++++++++++++--- .../LowerMatrixIntrinsics/extract.ll | 47 +++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 756a72e6d97bc..8b322afd9b6e4 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -34,6 +34,7 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h" @@ -623,7 +624,8 @@ class LowerMatrixIntrinsics { default: return false; } - return isUniformShape(V) || isa(V) || isa(V); + return isUniformShape(V) || isa(V) || isa(V) || + isa(V); } /// Propagate the shape information of instructions to their users. @@ -1337,6 +1339,28 @@ class LowerMatrixIntrinsics { return Builder.CreateAdd(Sum, Mul); } + bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) { + Value *Op0 = Inst->getOperand(0); + auto *VTy = cast(Op0->getType()); + + if (VTy->getElementCount().getKnownMinValue() < Index) { + Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType())); + Inst->eraseFromParent(); + return true; + } + + auto *I = Inst2ColumnMatrix.find(Op0); + if (I == Inst2ColumnMatrix.end()) + return false; + + const MatrixTy &M = I->second; + + IRBuilder<> Builder(Inst); + Inst->setOperand(0, M.getVector(Index / M.getStride())); + Inst->setOperand(1, Builder.getInt32(Index % M.getStride())); + return true; + } + /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For /// users with shape information, there's nothing to do: they will use the /// cached value when they are lowered. For other users, \p Matrix is @@ -1351,11 +1375,18 @@ class LowerMatrixIntrinsics { ToRemove.push_back(Inst); Value *Flattened = nullptr; for (Use &U : llvm::make_early_inc_range(Inst->uses())) { - if (!ShapeMap.contains(U.getUser())) { - if (!Flattened) - Flattened = Matrix.embedInVector(Builder); - U.set(Flattened); - } + if (ShapeMap.contains(U.getUser())) + continue; + + Value *Op1; + uint64_t Index; + if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index)))) + if (VisitExtractElt(cast(U.getUser()), Index)) + continue; + + if (!Flattened) + Flattened = Matrix.embedInVector(Builder); + U.set(Flattened); } } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll new file mode 100644 index 0000000000000..0bac9492d654a --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll @@ -0,0 +1,47 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define float @extract_static(ptr %in, ptr %out) { +; CHECK-LABEL: @extract_static( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1 +; CHECK-NEXT: ret float [[EXTRACT]] +; + %inv = load <4 x float>, ptr %in + %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2) + %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2) + %extract = extractelement <4 x float> %invtt, i32 3 + ret float %extract +} + +define float @extract_static_outofbounds(ptr %in, ptr %out) { +; CHECK-LABEL: @extract_static_outofbounds( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: ret float poison +; + %inv = load <4 x float>, ptr %in + %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2) + %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2) + %extract = extractelement <4 x float> %invtt, i32 5 + ret float %extract +} + +define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) { +; CHECK-LABEL: @extract_dynamic( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> +; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]] +; CHECK-NEXT: ret float [[EXTRACT]] +; + %inv = load <4 x float>, ptr %in + %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2) + %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2) + %extract = extractelement <4 x float> %invtt, i32 %idx + ret float %extract +} From c962ce9a79028d9514951d8ec07a818114aeea92 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 29 May 2025 11:07:13 -0700 Subject: [PATCH 2/5] use colum major load intrinsic for shape info --- .../LowerMatrixIntrinsics/extract.ll | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll index 0bac9492d654a..db5444ca036ae 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll @@ -3,45 +3,39 @@ define float @extract_static(ptr %in, ptr %out) { ; CHECK-LABEL: @extract_static( -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 -; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4 ; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1 ; CHECK-NEXT: ret float [[EXTRACT]] ; - %inv = load <4 x float>, ptr %in - %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2) - %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2) - %extract = extractelement <4 x float> %invtt, i32 3 + %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2) + %extract = extractelement <4 x float> %inv, i32 3 ret float %extract } define float @extract_static_outofbounds(ptr %in, ptr %out) { ; CHECK-LABEL: @extract_static_outofbounds( -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 -; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4 ; CHECK-NEXT: ret float poison ; - %inv = load <4 x float>, ptr %in - %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2) - %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2) - %extract = extractelement <4 x float> %invtt, i32 5 + %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2) + %extract = extractelement <4 x float> %inv, i32 5 ret float %extract } define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) { ; CHECK-LABEL: @extract_dynamic( -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 -; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4 ; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> ; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]] ; CHECK-NEXT: ret float [[EXTRACT]] ; - %inv = load <4 x float>, ptr %in - %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2) - %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2) - %extract = extractelement <4 x float> %invtt, i32 %idx + %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2) + %extract = extractelement <4 x float> %inv, i32 %idx ret float %extract } From 169ed5650b3583d0d4046aafb35958593f267803 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 29 May 2025 15:48:58 -0700 Subject: [PATCH 3/5] extractelemtn shouldn't supportShapeInfo --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 8b322afd9b6e4..341b4cc5d75c4 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -624,8 +624,7 @@ class LowerMatrixIntrinsics { default: return false; } - return isUniformShape(V) || isa(V) || isa(V) || - isa(V); + return isUniformShape(V) || isa(V) || isa(V); } /// Propagate the shape information of instructions to their users. From 04f9707e3ad3b4f560b1c5fd333a7ad02c5455cc Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Fri, 30 May 2025 13:48:50 -0700 Subject: [PATCH 4/5] refactor as tryLower --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 341b4cc5d75c4..3d684ef82ca7c 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1338,7 +1338,11 @@ class LowerMatrixIntrinsics { return Builder.CreateAdd(Sum, Mul); } - bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) { + bool tryLowerExtractElement(ExtractElementInst *Inst) { + uint64_t Index; + if (!match(Inst->getOperand(1), m_ConstantInt(Index))) + return false; + Value *Op0 = Inst->getOperand(0); auto *VTy = cast(Op0->getType()); @@ -1377,10 +1381,8 @@ class LowerMatrixIntrinsics { if (ShapeMap.contains(U.getUser())) continue; - Value *Op1; - uint64_t Index; - if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index)))) - if (VisitExtractElt(cast(U.getUser()), Index)) + if (auto *Extract = dyn_cast(U.getUser())) + if (tryLowerExtractElement(Extract)) continue; if (!Flattened) From cb11115bb32e8d396446b408bf98917c6535dd98 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 10 Jun 2025 13:09:56 -0700 Subject: [PATCH 5/5] clang-format --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index ecd1b6c916353..5b721205258d3 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -36,8 +36,8 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h"