Skip to content

Commit fdc3e42

Browse files
committed
[Matrix] Propagate shape information through Select insts
1 parent 79ae407 commit fdc3e42

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,15 @@ computeShapeInfoForInst(Instruction *I,
269269
return OpShape->second;
270270
}
271271

272+
if (isa<SelectInst>(I)) {
273+
auto OpShape = ShapeMap.find(I->getOperand(1));
274+
if (OpShape != ShapeMap.end())
275+
return OpShape->second;
276+
OpShape = ShapeMap.find(I->getOperand(2));
277+
if (OpShape != ShapeMap.end())
278+
return OpShape->second;
279+
}
280+
272281
if (isUniformShape(I)) {
273282
// Find the first operand that has a known shape and use that.
274283
for (auto &Op : I->operands()) {
@@ -623,7 +632,7 @@ class LowerMatrixIntrinsics {
623632
default:
624633
return false;
625634
}
626-
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
635+
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || isa<SelectInst>(V);
627636
}
628637

629638
/// Propagate the shape information of instructions to their users.
@@ -710,6 +719,12 @@ class LowerMatrixIntrinsics {
710719
} else if (isa<StoreInst>(V)) {
711720
// Nothing to do. We forward-propagated to this so we would just
712721
// backward propagate to an instruction with an already known shape.
722+
} else if (auto *Select = dyn_cast<SelectInst>(V)) {
723+
ShapeInfo Shape = ShapeMap[V];
724+
if (setShapeInfo(Select->getOperand(1), Shape))
725+
pushInstruction(Select, WorkList);
726+
if (setShapeInfo(Select->getOperand(2), Shape))
727+
pushInstruction(Select, WorkList);
713728
} else if (isUniformShape(V)) {
714729
// Propagate to all operands.
715730
ShapeInfo Shape = ShapeMap[V];
@@ -1068,6 +1083,8 @@ class LowerMatrixIntrinsics {
10681083
Changed |= VisitBinaryOperator(BinOp);
10691084
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10701085
Changed |= VisitUnaryOperator(UnOp);
1086+
if (auto *Select = dyn_cast<SelectInst>(Inst))
1087+
Changed |= VisitSelectInst(Select);
10711088
if (match(Inst, m_Load(m_Value(Op1))))
10721089
Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
10731090
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2198,6 +2215,35 @@ class LowerMatrixIntrinsics {
21982215
return true;
21992216
}
22002217

2218+
/// Lower selects, if shape information is available.
2219+
bool VisitSelectInst(SelectInst *Inst) {
2220+
auto I = ShapeMap.find(Inst);
2221+
if (I == ShapeMap.end())
2222+
return false;
2223+
2224+
Value *Cond = Inst->getOperand(0);
2225+
Value *OpA = Inst->getOperand(1);
2226+
Value *OpB = Inst->getOperand(2);
2227+
2228+
IRBuilder<> Builder(Inst);
2229+
ShapeInfo &Shape = I->second;
2230+
2231+
MatrixTy Result;
2232+
MatrixTy A = getMatrix(OpA, Shape, Builder);
2233+
MatrixTy B = getMatrix(OpB, Shape, Builder);
2234+
2235+
for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
2236+
auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I));
2237+
Result.addVector(Sel);
2238+
}
2239+
2240+
finalizeLowering(Inst,
2241+
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2242+
Result.getNumVectors()),
2243+
Builder);
2244+
return true;
2245+
}
2246+
22012247
/// Helper to linearize a matrix expression tree into a string. Currently
22022248
/// matrix expressions are linarized by starting at an expression leaf and
22032249
/// linearizing bottom up.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define void @select_2x2(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
5+
; CHECK-LABEL: @select_2x2(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
10+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
11+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
12+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
13+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
14+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
15+
; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
16+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
17+
; CHECK-NEXT: ret void
18+
;
19+
%lhsv = load <4 x float>, ptr %lhs
20+
%rhsv = load <4 x float>, ptr %rhs
21+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
22+
%opt = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
23+
%optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
24+
store <4 x float> %optt, ptr %out
25+
ret void
26+
}

0 commit comments

Comments
 (0)