@@ -560,7 +560,6 @@ class LowerMatrixIntrinsics {
560
560
return M;
561
561
562
562
MatrixVal = M.embedInVector (Builder);
563
- Inst2ColumnMatrix[MatrixVal] = M;
564
563
}
565
564
566
565
// Otherwise split MatrixVal.
@@ -1082,10 +1081,7 @@ class LowerMatrixIntrinsics {
1082
1081
for (Instruction &Inst : *BB) {
1083
1082
IRBuilder<> Builder (&Inst);
1084
1083
1085
- Value *Op1;
1086
- uint64_t Index;
1087
- if (match (&Inst, m_ExtractElt (m_Value (Op1), m_ConstantInt (Index))))
1088
- Changed |= VisitExtractElt (cast<ExtractElementInst>(&Inst), Index);
1084
+
1089
1085
}
1090
1086
1091
1087
if (ORE) {
@@ -1354,6 +1350,28 @@ class LowerMatrixIntrinsics {
1354
1350
return Builder.CreateAdd (Sum, Mul);
1355
1351
}
1356
1352
1353
+ bool VisitExtractElt (ExtractElementInst *Inst, uint64_t Index) {
1354
+ Value *Op0 = Inst->getOperand (0 );
1355
+ auto *VTy = cast<VectorType>(Op0->getType ());
1356
+
1357
+ if (VTy->getElementCount ().getKnownMinValue () < Index) {
1358
+ Inst->replaceAllUsesWith (PoisonValue::get (VTy->getElementType ()));
1359
+ Inst->eraseFromParent ();
1360
+ return true ;
1361
+ }
1362
+
1363
+ auto *I = Inst2ColumnMatrix.find (Op0);
1364
+ if (I == Inst2ColumnMatrix.end ())
1365
+ return false ;
1366
+
1367
+ const MatrixTy &M = I->second ;
1368
+
1369
+ IRBuilder<> Builder (Inst);
1370
+ Inst->setOperand (0 , M.getVector (Index / M.getStride ()));
1371
+ Inst->setOperand (1 , Builder.getInt32 (Index % M.getStride ()));
1372
+ return true ;
1373
+ }
1374
+
1357
1375
// / Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1358
1376
// / users with shape information, there's nothing to do: they will use the
1359
1377
// / cached value when they are lowered. For other users, \p Matrix is
@@ -1368,13 +1386,18 @@ class LowerMatrixIntrinsics {
1368
1386
ToRemove.push_back (Inst);
1369
1387
Value *Flattened = nullptr ;
1370
1388
for (Use &U : llvm::make_early_inc_range (Inst->uses ())) {
1371
- if (!ShapeMap.contains (U.getUser ())) {
1372
- if (!Flattened) {
1373
- Flattened = Matrix.embedInVector (Builder);
1374
- Inst2ColumnMatrix[Flattened] = Matrix;
1375
- }
1376
- U.set (Flattened);
1377
- }
1389
+ if (ShapeMap.contains (U.getUser ()))
1390
+ continue ;
1391
+
1392
+ Value *Op1;
1393
+ uint64_t Index;
1394
+ if (match (U.getUser (), m_ExtractElt (m_Value (Op1), m_ConstantInt (Index))))
1395
+ if (VisitExtractElt (cast<ExtractElementInst>(U.getUser ()), Index))
1396
+ continue ;
1397
+
1398
+ if (!Flattened)
1399
+ Flattened = Matrix.embedInVector (Builder);
1400
+ U.set (Flattened);
1378
1401
}
1379
1402
}
1380
1403
@@ -2149,30 +2172,6 @@ class LowerMatrixIntrinsics {
2149
2172
return true ;
2150
2173
}
2151
2174
2152
- bool VisitExtractElt (ExtractElementInst *Inst, uint64_t Index) {
2153
- Value *Op0 = Inst->getOperand (0 );
2154
- auto *VTy = cast<VectorType>(Op0->getType ());
2155
-
2156
- if (VTy->getElementCount ().getKnownMinValue () < Index) {
2157
- Inst->replaceAllUsesWith (PoisonValue::get (VTy->getElementType ()));
2158
- ToRemove.push_back (Inst);
2159
- return true ;
2160
- }
2161
-
2162
- auto *I = Inst2ColumnMatrix.find (Op0);
2163
- if (I == Inst2ColumnMatrix.end ())
2164
- return false ;
2165
-
2166
- const MatrixTy &M = I->second ;
2167
-
2168
- IRBuilder<> Builder (Inst);
2169
- Inst->setOperand (0 , M.getVector (Index / M.getStride ()));
2170
- Inst->setOperand (1 , Builder.getInt32 (Index % M.getStride ()));
2171
- if (Op0->use_empty () && isa<Instruction>(Op0))
2172
- ToRemove.push_back (cast<Instruction>(Op0));
2173
- return true ;
2174
- }
2175
-
2176
2175
// / Lower binary operators, if shape information is available.
2177
2176
bool VisitBinaryOperator (BinaryOperator *Inst) {
2178
2177
auto I = ShapeMap.find (Inst);
0 commit comments