@@ -229,14 +229,11 @@ static bool isUniformShape(Value *V) {
229
229
if (!I)
230
230
return true ;
231
231
232
+ if (I->isBinaryOp ())
233
+ return true ;
234
+
232
235
switch (I->getOpcode ()) {
233
- case Instruction::FAdd:
234
- case Instruction::FSub:
235
- case Instruction::FMul: // Scalar multiply.
236
236
case Instruction::FNeg:
237
- case Instruction::Add:
238
- case Instruction::Mul:
239
- case Instruction::Sub:
240
237
return true ;
241
238
default :
242
239
return false ;
@@ -2154,28 +2151,9 @@ class LowerMatrixIntrinsics {
2154
2151
2155
2152
Builder.setFastMathFlags (getFastMathFlags (Inst));
2156
2153
2157
- // Helper to perform binary op on vectors.
2158
- auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2159
- switch (Inst->getOpcode ()) {
2160
- case Instruction::Add:
2161
- return Builder.CreateAdd (LHS, RHS);
2162
- case Instruction::Mul:
2163
- return Builder.CreateMul (LHS, RHS);
2164
- case Instruction::Sub:
2165
- return Builder.CreateSub (LHS, RHS);
2166
- case Instruction::FAdd:
2167
- return Builder.CreateFAdd (LHS, RHS);
2168
- case Instruction::FMul:
2169
- return Builder.CreateFMul (LHS, RHS);
2170
- case Instruction::FSub:
2171
- return Builder.CreateFSub (LHS, RHS);
2172
- default :
2173
- llvm_unreachable (" Unsupported binary operator for matrix" );
2174
- }
2175
- };
2176
-
2177
2154
for (unsigned I = 0 ; I < Shape.getNumVectors (); ++I)
2178
- Result.addVector (BuildVectorOp (A.getVector (I), B.getVector (I)));
2155
+ Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
2156
+ B.getVector (I)));
2179
2157
2180
2158
finalizeLowering (Inst,
2181
2159
Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
0 commit comments