Skip to content

Commit 1f1c725

Browse files
authored
[Matrix] Propagate shape information through all binops (#141705)
They all have vector variants, so the obvious "find and replace" does the trick.
1 parent 6e5f9bb commit 1f1c725

File tree

2 files changed

+439
-27
lines changed

2 files changed

+439
-27
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,11 @@ static bool isUniformShape(Value *V) {
229229
if (!I)
230230
return true;
231231

232+
if (I->isBinaryOp())
233+
return true;
234+
232235
switch (I->getOpcode()) {
233-
case Instruction::FAdd:
234-
case Instruction::FSub:
235-
case Instruction::FMul: // Scalar multiply.
236236
case Instruction::FNeg:
237-
case Instruction::Add:
238-
case Instruction::Mul:
239-
case Instruction::Sub:
240237
return true;
241238
default:
242239
return false;
@@ -2154,28 +2151,9 @@ class LowerMatrixIntrinsics {
21542151

21552152
Builder.setFastMathFlags(getFastMathFlags(Inst));
21562153

2157-
// Helper to perform binary op on vectors.
2158-
auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2159-
switch (Inst->getOpcode()) {
2160-
case Instruction::Add:
2161-
return Builder.CreateAdd(LHS, RHS);
2162-
case Instruction::Mul:
2163-
return Builder.CreateMul(LHS, RHS);
2164-
case Instruction::Sub:
2165-
return Builder.CreateSub(LHS, RHS);
2166-
case Instruction::FAdd:
2167-
return Builder.CreateFAdd(LHS, RHS);
2168-
case Instruction::FMul:
2169-
return Builder.CreateFMul(LHS, RHS);
2170-
case Instruction::FSub:
2171-
return Builder.CreateFSub(LHS, RHS);
2172-
default:
2173-
llvm_unreachable("Unsupported binary operator for matrix");
2174-
}
2175-
};
2176-
21772154
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2178-
Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2155+
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
2156+
B.getVector(I)));
21792157

21802158
finalizeLowering(Inst,
21812159
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

0 commit comments

Comments
 (0)