Skip to content

Commit 53098ce

Browse files
committed
[Matrix] Optimize static extracts with ShapeInfo
For ExtractElementInsts with static indices that extract from a Matrix, use the known layout of the Rows/Columns to look through the shuffles that embedInVector creates, which in some cases allows us to delete them.
1 parent 7984bb9 commit 53098ce

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/IR/DebugInfoMetadata.h"
3535
#include "llvm/IR/Function.h"
3636
#include "llvm/IR/IRBuilder.h"
37+
#include "llvm/IR/Instruction.h"
3738
#include "llvm/IR/Instructions.h"
3839
#include "llvm/IR/IntrinsicInst.h"
3940
#include "llvm/IR/MatrixBuilder.h"
@@ -568,6 +569,7 @@ class LowerMatrixIntrinsics {
568569
return M;
569570

570571
MatrixVal = M.embedInVector(Builder);
572+
Inst2ColumnMatrix[MatrixVal] = M;
571573
}
572574

573575
// Otherwise split MatrixVal.
@@ -632,7 +634,7 @@ class LowerMatrixIntrinsics {
632634
default:
633635
return isUniformShape(II);
634636
}
635-
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
637+
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || isa<ExtractElementInst>(V);
636638
}
637639

638640
/// Propagate the shape information of instructions to their users.
@@ -1083,6 +1085,18 @@ class LowerMatrixIntrinsics {
10831085
Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
10841086
}
10851087

1088+
// Fifth, lower instructions which can make use of shape information, but do
1089+
// not have shapes themselves.
1090+
for (auto *BB : RPOT)
1091+
for (Instruction &Inst : *BB) {
1092+
IRBuilder<> Builder(&Inst);
1093+
1094+
Value *Op1;
1095+
uint64_t Index;
1096+
if (match(&Inst, m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
1097+
Changed |= VisitExtractElt(cast<ExtractElementInst>(&Inst), Index);
1098+
}
1099+
10861100
if (ORE) {
10871101
RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
10881102
RemarkGen.emitRemarks();
@@ -1364,8 +1378,10 @@ class LowerMatrixIntrinsics {
13641378
Value *Flattened = nullptr;
13651379
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
13661380
if (!ShapeMap.contains(U.getUser())) {
1367-
if (!Flattened)
1381+
if (!Flattened) {
13681382
Flattened = Matrix.embedInVector(Builder);
1383+
Inst2ColumnMatrix[Flattened] = Matrix;
1384+
}
13691385
U.set(Flattened);
13701386
}
13711387
}
@@ -2142,6 +2158,30 @@ class LowerMatrixIntrinsics {
21422158
return true;
21432159
}
21442160

2161+
bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
2162+
Value *Op0 = Inst->getOperand(0);
2163+
auto *VTy = cast<VectorType>(Op0->getType());
2164+
2165+
if (VTy->getElementCount().getKnownMinValue() < Index) {
2166+
Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
2167+
ToRemove.push_back(Inst);
2168+
return true;
2169+
}
2170+
2171+
auto *I = Inst2ColumnMatrix.find(Op0);
2172+
if (I == Inst2ColumnMatrix.end())
2173+
return false;
2174+
2175+
const MatrixTy &M = I->second;
2176+
2177+
IRBuilder<> Builder(Inst);
2178+
Inst->setOperand(0, M.getVector(Index / M.getStride()));
2179+
Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
2180+
if (Op0->use_empty() && isa<Instruction>(Op0))
2181+
ToRemove.push_back(cast<Instruction>(Op0));
2182+
return true;
2183+
}
2184+
21452185
/// Lower binary operators, if shape information is available.
21462186
bool VisitBinaryOperator(BinaryOperator *Inst) {
21472187
auto I = ShapeMap.find(Inst);

llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,28 @@ define float @extract_static(ptr %in, ptr %out) {
66
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
77
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
88
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9-
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
10-
; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 0
9+
; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
1110
; CHECK-NEXT: ret float [[EXTRACT]]
1211
;
1312
%inv = load <4 x float>, ptr %in
1413
%invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
1514
%invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
16-
%extract = extractelement <4 x float> %invtt, i32 0
15+
%extract = extractelement <4 x float> %invtt, i32 3
16+
ret float %extract
17+
}
18+
19+
define float @extract_static_outofbounds(ptr %in, ptr %out) {
20+
; CHECK-LABEL: @extract_static_outofbounds(
21+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
22+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
23+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
24+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
25+
; CHECK-NEXT: ret float poison
26+
;
27+
%inv = load <4 x float>, ptr %in
28+
%invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
29+
%invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
30+
%extract = extractelement <4 x float> %invtt, i32 5
1731
ret float %extract
1832
}
1933

llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ define void @multiply_ntt(ptr %A, ptr %B, ptr %C, ptr %R) {
5555
; REMARK-NEXT: Function: multiply_ntt
5656
; REMARK-NEXT: Args:
5757
; REMARK-NEXT: - String: 'Lowered with '
58-
; REMARK-NEXT: - NumStores: '4'
58+
; REMARK-NEXT: - NumStores: '0'
5959
; REMARK-NEXT: - String: ' stores, '
60-
; REMARK-NEXT: - NumLoads: '10'
60+
; REMARK-NEXT: - NumLoads: '3'
6161
; REMARK-NEXT: - String: ' loads, '
62-
; REMARK-NEXT: - NumComputeOps: '38'
62+
; REMARK-NEXT: - NumComputeOps: '0'
6363
; REMARK-NEXT: - String: ' compute ops, '
6464
; REMARK-NEXT: - NumExposedTransposes: '0'
6565
; REMARK-NEXT: - String: ' exposed transposes'
@@ -443,11 +443,11 @@ define void @multiply_nt_t(ptr %A, ptr %B, ptr %C) {
443443
; REMARK-NEXT: Function: multiply_nt_t
444444
; REMARK-NEXT: Args:
445445
; REMARK-NEXT: - String: 'Lowered with '
446-
; REMARK-NEXT: - NumStores: '4'
446+
; REMARK-NEXT: - NumStores: '0'
447447
; REMARK-NEXT: - String: ' stores, '
448-
; REMARK-NEXT: - NumLoads: '9'
448+
; REMARK-NEXT: - NumLoads: '3'
449449
; REMARK-NEXT: - String: ' loads, '
450-
; REMARK-NEXT: - NumComputeOps: '20'
450+
; REMARK-NEXT: - NumComputeOps: '0'
451451
; REMARK-NEXT: - String: ' compute ops, '
452452
; REMARK-NEXT: - NumExposedTransposes: '0'
453453
; REMARK-NEXT: - String: ' exposed transposes'
@@ -578,11 +578,11 @@ define void @multiply_ntt_t(ptr %A, ptr %B, ptr %C, ptr %R) {
578578
; REMARK-NEXT: Function: multiply_ntt_t
579579
; REMARK-NEXT: Args:
580580
; REMARK-NEXT: - String: 'Lowered with '
581-
; REMARK-NEXT: - NumStores: '6'
581+
; REMARK-NEXT: - NumStores: '0'
582582
; REMARK-NEXT: - String: ' stores, '
583-
; REMARK-NEXT: - NumLoads: '18'
583+
; REMARK-NEXT: - NumLoads: '6'
584584
; REMARK-NEXT: - String: ' loads, '
585-
; REMARK-NEXT: - NumComputeOps: '60'
585+
; REMARK-NEXT: - NumComputeOps: '0'
586586
; REMARK-NEXT: - String: ' compute ops, '
587587
; REMARK-NEXT: - NumExposedTransposes: '0'
588588
; REMARK-NEXT: - String: ' exposed transposes'

0 commit comments

Comments
 (0)