Skip to content

Commit ae5751f

Browse files
committed
[Matrix] Optimize static extracts with ShapeInfo
For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns, avoiding some of the shuffles that embedInVector creates.
1 parent 1f1c725 commit ae5751f

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/IR/DebugInfoMetadata.h"
3535
#include "llvm/IR/Function.h"
3636
#include "llvm/IR/IRBuilder.h"
37+
#include "llvm/IR/Instruction.h"
3738
#include "llvm/IR/Instructions.h"
3839
#include "llvm/IR/IntrinsicInst.h"
3940
#include "llvm/IR/MatrixBuilder.h"
@@ -623,7 +624,7 @@ class LowerMatrixIntrinsics {
623624
default:
624625
return false;
625626
}
626-
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
627+
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || isa<ExtractElementInst>(V);
627628
}
628629

629630
/// Propagate the shape information of instructions to their users.
@@ -1337,6 +1338,28 @@ class LowerMatrixIntrinsics {
13371338
return Builder.CreateAdd(Sum, Mul);
13381339
}
13391340

1341+
bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
1342+
Value *Op0 = Inst->getOperand(0);
1343+
auto *VTy = cast<VectorType>(Op0->getType());
1344+
1345+
if (VTy->getElementCount().getKnownMinValue() < Index) {
1346+
Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
1347+
Inst->eraseFromParent();
1348+
return true;
1349+
}
1350+
1351+
auto *I = Inst2ColumnMatrix.find(Op0);
1352+
if (I == Inst2ColumnMatrix.end())
1353+
return false;
1354+
1355+
const MatrixTy &M = I->second;
1356+
1357+
IRBuilder<> Builder(Inst);
1358+
Inst->setOperand(0, M.getVector(Index / M.getStride()));
1359+
Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
1360+
return true;
1361+
}
1362+
13401363
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
13411364
/// users with shape information, there's nothing to do: they will use the
13421365
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1374,18 @@ class LowerMatrixIntrinsics {
13511374
ToRemove.push_back(Inst);
13521375
Value *Flattened = nullptr;
13531376
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1354-
if (!ShapeMap.contains(U.getUser())) {
1355-
if (!Flattened)
1356-
Flattened = Matrix.embedInVector(Builder);
1357-
U.set(Flattened);
1358-
}
1377+
if (ShapeMap.contains(U.getUser()))
1378+
continue;
1379+
1380+
Value *Op1;
1381+
uint64_t Index;
1382+
if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
1383+
if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
1384+
continue;
1385+
1386+
if (!Flattened)
1387+
Flattened = Matrix.embedInVector(Builder);
1388+
U.set(Flattened);
13591389
}
13601390
}
13611391

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define float @extract_static(ptr %in, ptr %out) {
5+
; CHECK-LABEL: @extract_static(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
10+
; CHECK-NEXT: ret float [[EXTRACT]]
11+
;
12+
%inv = load <4 x float>, ptr %in
13+
%invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
14+
%invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
15+
%extract = extractelement <4 x float> %invtt, i32 3
16+
ret float %extract
17+
}
18+
19+
define float @extract_static_outofbounds(ptr %in, ptr %out) {
20+
; CHECK-LABEL: @extract_static_outofbounds(
21+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
22+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
23+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
24+
; CHECK-NEXT: ret float poison
25+
;
26+
%inv = load <4 x float>, ptr %in
27+
%invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
28+
%invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
29+
%extract = extractelement <4 x float> %invtt, i32 5
30+
ret float %extract
31+
}
32+
33+
define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
34+
; CHECK-LABEL: @extract_dynamic(
35+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
36+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
37+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
38+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
39+
; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
40+
; CHECK-NEXT: ret float [[EXTRACT]]
41+
;
42+
%inv = load <4 x float>, ptr %in
43+
%invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
44+
%invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
45+
%extract = extractelement <4 x float> %invtt, i32 %idx
46+
ret float %extract
47+
}

0 commit comments

Comments
 (0)