@@ -229,6 +229,9 @@ static bool isUniformShape(Value *V) {
229
229
if (!I)
230
230
return true ;
231
231
232
+ if (I->isBinaryOp ())
233
+ return true ;
234
+
232
235
if (auto *II = dyn_cast<IntrinsicInst>(V))
233
236
switch (II->getIntrinsicID ()) {
234
237
case Intrinsic::abs:
@@ -239,14 +242,7 @@ static bool isUniformShape(Value *V) {
239
242
}
240
243
241
244
switch (I->getOpcode ()) {
242
- case Instruction::FAdd:
243
- case Instruction::FSub:
244
- case Instruction::FMul: // Scalar multiply.
245
- case Instruction::FDiv:
246
245
case Instruction::FNeg:
247
- case Instruction::Add:
248
- case Instruction::Mul:
249
- case Instruction::Sub:
250
246
return true ;
251
247
default :
252
248
return false ;
@@ -2167,30 +2163,9 @@ class LowerMatrixIntrinsics {
2167
2163
2168
2164
Builder.setFastMathFlags (getFastMathFlags (Inst));
2169
2165
2170
- // Helper to perform binary op on vectors.
2171
- auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2172
- switch (Inst->getOpcode ()) {
2173
- case Instruction::Add:
2174
- return Builder.CreateAdd (LHS, RHS);
2175
- case Instruction::Mul:
2176
- return Builder.CreateMul (LHS, RHS);
2177
- case Instruction::Sub:
2178
- return Builder.CreateSub (LHS, RHS);
2179
- case Instruction::FAdd:
2180
- return Builder.CreateFAdd (LHS, RHS);
2181
- case Instruction::FMul:
2182
- return Builder.CreateFMul (LHS, RHS);
2183
- case Instruction::FDiv:
2184
- return Builder.CreateFDiv (LHS, RHS);
2185
- case Instruction::FSub:
2186
- return Builder.CreateFSub (LHS, RHS);
2187
- default :
2188
- llvm_unreachable (" Unsupported binary operator for matrix" );
2189
- }
2190
- };
2191
-
2192
2166
for (unsigned I = 0 ; I < Shape.getNumVectors (); ++I)
2193
- Result.addVector (BuildVectorOp (A.getVector (I), B.getVector (I)));
2167
+ Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
2168
+ B.getVector (I)));
2194
2169
2195
2170
finalizeLowering (Inst,
2196
2171
Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
0 commit comments