Skip to content

Commit 637f71d

Browse files
committed
Merge branch 'jroelofs/lower-matrix-fdiv' into jroelofs/lower-matrix-fabs
2 parents 41eeb1e + 39c585c commit 637f71d

File tree

2 files changed

+416
-33
lines changed

2 files changed

+416
-33
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ static bool isUniformShape(Value *V) {
229229
if (!I)
230230
return true;
231231

232+
if (I->isBinaryOp())
233+
return true;
234+
232235
if (auto *II = dyn_cast<IntrinsicInst>(V))
233236
switch (II->getIntrinsicID()) {
234237
case Intrinsic::abs:
@@ -239,14 +242,7 @@ static bool isUniformShape(Value *V) {
239242
}
240243

241244
switch (I->getOpcode()) {
242-
case Instruction::FAdd:
243-
case Instruction::FSub:
244-
case Instruction::FMul: // Scalar multiply.
245-
case Instruction::FDiv:
246245
case Instruction::FNeg:
247-
case Instruction::Add:
248-
case Instruction::Mul:
249-
case Instruction::Sub:
250246
return true;
251247
default:
252248
return false;
@@ -2167,30 +2163,9 @@ class LowerMatrixIntrinsics {
21672163

21682164
Builder.setFastMathFlags(getFastMathFlags(Inst));
21692165

2170-
// Helper to perform binary op on vectors.
2171-
auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2172-
switch (Inst->getOpcode()) {
2173-
case Instruction::Add:
2174-
return Builder.CreateAdd(LHS, RHS);
2175-
case Instruction::Mul:
2176-
return Builder.CreateMul(LHS, RHS);
2177-
case Instruction::Sub:
2178-
return Builder.CreateSub(LHS, RHS);
2179-
case Instruction::FAdd:
2180-
return Builder.CreateFAdd(LHS, RHS);
2181-
case Instruction::FMul:
2182-
return Builder.CreateFMul(LHS, RHS);
2183-
case Instruction::FDiv:
2184-
return Builder.CreateFDiv(LHS, RHS);
2185-
case Instruction::FSub:
2186-
return Builder.CreateFSub(LHS, RHS);
2187-
default:
2188-
llvm_unreachable("Unsupported binary operator for matrix");
2189-
}
2190-
};
2191-
21922166
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2193-
Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2167+
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
2168+
B.getVector(I)));
21942169

21952170
finalizeLowering(Inst,
21962171
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

0 commit comments

Comments
 (0)