@@ -269,6 +269,15 @@ computeShapeInfoForInst(Instruction *I,
269
269
return OpShape->second ;
270
270
}
271
271
272
+ if (isa<SelectInst>(I)) {
273
+ auto OpShape = ShapeMap.find (I->getOperand (1 ));
274
+ if (OpShape != ShapeMap.end ())
275
+ return OpShape->second ;
276
+ OpShape = ShapeMap.find (I->getOperand (2 ));
277
+ if (OpShape != ShapeMap.end ())
278
+ return OpShape->second ;
279
+ }
280
+
272
281
if (isUniformShape (I)) {
273
282
// Find the first operand that has a known shape and use that.
274
283
for (auto &Op : I->operands ()) {
@@ -623,7 +632,7 @@ class LowerMatrixIntrinsics {
623
632
default :
624
633
return false ;
625
634
}
626
- return isUniformShape (V) || isa<StoreInst>(V) || isa<LoadInst>(V);
635
+ return isUniformShape (V) || isa<StoreInst>(V) || isa<LoadInst>(V) || isa<SelectInst>(V) ;
627
636
}
628
637
629
638
// / Propagate the shape information of instructions to their users.
@@ -710,6 +719,12 @@ class LowerMatrixIntrinsics {
710
719
} else if (isa<StoreInst>(V)) {
711
720
// Nothing to do. We forward-propagated to this so we would just
712
721
// backward propagate to an instruction with an already known shape.
722
+ } else if (auto *Select = dyn_cast<SelectInst>(V)) {
723
+ ShapeInfo Shape = ShapeMap[V];
724
+ if (setShapeInfo (Select->getOperand (1 ), Shape))
725
+ pushInstruction (Select, WorkList);
726
+ if (setShapeInfo (Select->getOperand (2 ), Shape))
727
+ pushInstruction (Select, WorkList);
713
728
} else if (isUniformShape (V)) {
714
729
// Propagate to all operands.
715
730
ShapeInfo Shape = ShapeMap[V];
@@ -1068,6 +1083,8 @@ class LowerMatrixIntrinsics {
1068
1083
Changed |= VisitBinaryOperator (BinOp);
1069
1084
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1070
1085
Changed |= VisitUnaryOperator (UnOp);
1086
+ if (auto *Select = dyn_cast<SelectInst>(Inst))
1087
+ Changed |= VisitSelectInst (Select);
1071
1088
if (match (Inst, m_Load (m_Value (Op1))))
1072
1089
Changed |= VisitLoad (cast<LoadInst>(Inst), Op1, Builder);
1073
1090
else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
@@ -2198,6 +2215,35 @@ class LowerMatrixIntrinsics {
2198
2215
return true ;
2199
2216
}
2200
2217
2218
+ // / Lower selects, if shape information is available.
2219
+ bool VisitSelectInst (SelectInst *Inst) {
2220
+ auto I = ShapeMap.find (Inst);
2221
+ if (I == ShapeMap.end ())
2222
+ return false ;
2223
+
2224
+ Value *Cond = Inst->getOperand (0 );
2225
+ Value *OpA = Inst->getOperand (1 );
2226
+ Value *OpB = Inst->getOperand (2 );
2227
+
2228
+ IRBuilder<> Builder (Inst);
2229
+ ShapeInfo &Shape = I->second ;
2230
+
2231
+ MatrixTy Result;
2232
+ MatrixTy A = getMatrix (OpA, Shape, Builder);
2233
+ MatrixTy B = getMatrix (OpB, Shape, Builder);
2234
+
2235
+ for (unsigned I = 0 ; I < Shape.getNumVectors (); ++I) {
2236
+ auto *Sel = Builder.CreateSelect (Cond, A.getVector (I), B.getVector (I));
2237
+ Result.addVector (Sel);
2238
+ }
2239
+
2240
+ finalizeLowering (Inst,
2241
+ Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2242
+ Result.getNumVectors ()),
2243
+ Builder);
2244
+ return true ;
2245
+ }
2246
+
2201
2247
// / Helper to linearize a matrix expression tree into a string. Currently
2202
2248
// / matrix expressions are linarized by starting at an expression leaf and
2203
2249
// / linearizing bottom up.
0 commit comments