Skip to content

Commit 10a854e

Browse files
committed
simplify the lowering a lot
1 parent 786f5e8 commit 10a854e

File tree

3 files changed

+44
-46
lines changed

3 files changed

+44
-46
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,6 @@ class LowerMatrixIntrinsics {
560560
return M;
561561

562562
MatrixVal = M.embedInVector(Builder);
563-
Inst2ColumnMatrix[MatrixVal] = M;
564563
}
565564

566565
// Otherwise split MatrixVal.
@@ -1082,10 +1081,7 @@ class LowerMatrixIntrinsics {
10821081
for (Instruction &Inst : *BB) {
10831082
IRBuilder<> Builder(&Inst);
10841083

1085-
Value *Op1;
1086-
uint64_t Index;
1087-
if (match(&Inst, m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
1088-
Changed |= VisitExtractElt(cast<ExtractElementInst>(&Inst), Index);
1084+
10891085
}
10901086

10911087
if (ORE) {
@@ -1354,6 +1350,28 @@ class LowerMatrixIntrinsics {
13541350
return Builder.CreateAdd(Sum, Mul);
13551351
}
13561352

1353+
bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
1354+
Value *Op0 = Inst->getOperand(0);
1355+
auto *VTy = cast<VectorType>(Op0->getType());
1356+
1357+
if (VTy->getElementCount().getKnownMinValue() < Index) {
1358+
Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
1359+
Inst->eraseFromParent();
1360+
return true;
1361+
}
1362+
1363+
auto *I = Inst2ColumnMatrix.find(Op0);
1364+
if (I == Inst2ColumnMatrix.end())
1365+
return false;
1366+
1367+
const MatrixTy &M = I->second;
1368+
1369+
IRBuilder<> Builder(Inst);
1370+
Inst->setOperand(0, M.getVector(Index / M.getStride()));
1371+
Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
1372+
return true;
1373+
}
1374+
13571375
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
13581376
/// users with shape information, there's nothing to do: they will use the
13591377
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1368,13 +1386,18 @@ class LowerMatrixIntrinsics {
13681386
ToRemove.push_back(Inst);
13691387
Value *Flattened = nullptr;
13701388
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1371-
if (!ShapeMap.contains(U.getUser())) {
1372-
if (!Flattened) {
1373-
Flattened = Matrix.embedInVector(Builder);
1374-
Inst2ColumnMatrix[Flattened] = Matrix;
1375-
}
1376-
U.set(Flattened);
1377-
}
1389+
if (ShapeMap.contains(U.getUser()))
1390+
continue;
1391+
1392+
Value *Op1;
1393+
uint64_t Index;
1394+
if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
1395+
if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
1396+
continue;
1397+
1398+
if (!Flattened)
1399+
Flattened = Matrix.embedInVector(Builder);
1400+
U.set(Flattened);
13781401
}
13791402
}
13801403

@@ -2149,30 +2172,6 @@ class LowerMatrixIntrinsics {
21492172
return true;
21502173
}
21512174

2152-
bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
2153-
Value *Op0 = Inst->getOperand(0);
2154-
auto *VTy = cast<VectorType>(Op0->getType());
2155-
2156-
if (VTy->getElementCount().getKnownMinValue() < Index) {
2157-
Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
2158-
ToRemove.push_back(Inst);
2159-
return true;
2160-
}
2161-
2162-
auto *I = Inst2ColumnMatrix.find(Op0);
2163-
if (I == Inst2ColumnMatrix.end())
2164-
return false;
2165-
2166-
const MatrixTy &M = I->second;
2167-
2168-
IRBuilder<> Builder(Inst);
2169-
Inst->setOperand(0, M.getVector(Index / M.getStride()));
2170-
Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
2171-
if (Op0->use_empty() && isa<Instruction>(Op0))
2172-
ToRemove.push_back(cast<Instruction>(Op0));
2173-
return true;
2174-
}
2175-
21762175
/// Lower binary operators, if shape information is available.
21772176
bool VisitBinaryOperator(BinaryOperator *Inst) {
21782177
auto I = ShapeMap.find(Inst);

llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ define float @extract_static_outofbounds(ptr %in, ptr %out) {
2121
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
2222
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
2323
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
24-
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
2524
; CHECK-NEXT: ret float poison
2625
;
2726
%inv = load <4 x float>, ptr %in

llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ define void @multiply_ntt(ptr %A, ptr %B, ptr %C, ptr %R) {
5555
; REMARK-NEXT: Function: multiply_ntt
5656
; REMARK-NEXT: Args:
5757
; REMARK-NEXT: - String: 'Lowered with '
58-
; REMARK-NEXT: - NumStores: '0'
58+
; REMARK-NEXT: - NumStores: '4'
5959
; REMARK-NEXT: - String: ' stores, '
60-
; REMARK-NEXT: - NumLoads: '3'
60+
; REMARK-NEXT: - NumLoads: '10'
6161
; REMARK-NEXT: - String: ' loads, '
62-
; REMARK-NEXT: - NumComputeOps: '0'
62+
; REMARK-NEXT: - NumComputeOps: '38'
6363
; REMARK-NEXT: - String: ' compute ops, '
6464
; REMARK-NEXT: - NumExposedTransposes: '0'
6565
; REMARK-NEXT: - String: ' exposed transposes'
@@ -443,11 +443,11 @@ define void @multiply_nt_t(ptr %A, ptr %B, ptr %C) {
443443
; REMARK-NEXT: Function: multiply_nt_t
444444
; REMARK-NEXT: Args:
445445
; REMARK-NEXT: - String: 'Lowered with '
446-
; REMARK-NEXT: - NumStores: '0'
446+
; REMARK-NEXT: - NumStores: '4'
447447
; REMARK-NEXT: - String: ' stores, '
448-
; REMARK-NEXT: - NumLoads: '3'
448+
; REMARK-NEXT: - NumLoads: '9'
449449
; REMARK-NEXT: - String: ' loads, '
450-
; REMARK-NEXT: - NumComputeOps: '0'
450+
; REMARK-NEXT: - NumComputeOps: '20'
451451
; REMARK-NEXT: - String: ' compute ops, '
452452
; REMARK-NEXT: - NumExposedTransposes: '0'
453453
; REMARK-NEXT: - String: ' exposed transposes'
@@ -578,11 +578,11 @@ define void @multiply_ntt_t(ptr %A, ptr %B, ptr %C, ptr %R) {
578578
; REMARK-NEXT: Function: multiply_ntt_t
579579
; REMARK-NEXT: Args:
580580
; REMARK-NEXT: - String: 'Lowered with '
581-
; REMARK-NEXT: - NumStores: '0'
581+
; REMARK-NEXT: - NumStores: '6'
582582
; REMARK-NEXT: - String: ' stores, '
583-
; REMARK-NEXT: - NumLoads: '6'
583+
; REMARK-NEXT: - NumLoads: '18'
584584
; REMARK-NEXT: - String: ' loads, '
585-
; REMARK-NEXT: - NumComputeOps: '0'
585+
; REMARK-NEXT: - NumComputeOps: '60'
586586
; REMARK-NEXT: - String: ' compute ops, '
587587
; REMARK-NEXT: - NumExposedTransposes: '0'
588588
; REMARK-NEXT: - String: ' exposed transposes'

0 commit comments

Comments
 (0)