Skip to content

Commit 4040d3f

Browse files
committed
[Matrix] Propagate shape information through Select insts
1 parent 79ae407 commit 4040d3f

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,15 @@ computeShapeInfoForInst(Instruction *I,
269269
return OpShape->second;
270270
}
271271

272+
if (isa<SelectInst>(I)) {
273+
auto OpShape = ShapeMap.find(I->getOperand(1));
274+
if (OpShape != ShapeMap.end())
275+
return OpShape->second;
276+
OpShape = ShapeMap.find(I->getOperand(2));
277+
if (OpShape != ShapeMap.end())
278+
return OpShape->second;
279+
}
280+
272281
if (isUniformShape(I)) {
273282
// Find the first operand that has a known shape and use that.
274283
for (auto &Op : I->operands()) {
@@ -623,7 +632,8 @@ class LowerMatrixIntrinsics {
623632
default:
624633
return false;
625634
}
626-
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
635+
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
636+
isa<SelectInst>(V);
627637
}
628638

629639
/// Propagate the shape information of instructions to their users.
@@ -710,6 +720,12 @@ class LowerMatrixIntrinsics {
710720
} else if (isa<StoreInst>(V)) {
711721
// Nothing to do. We forward-propagated to this so we would just
712722
// backward propagate to an instruction with an already known shape.
723+
} else if (auto *Select = dyn_cast<SelectInst>(V)) {
724+
ShapeInfo Shape = ShapeMap[V];
725+
if (setShapeInfo(Select->getOperand(1), Shape))
726+
pushInstruction(Select, WorkList);
727+
if (setShapeInfo(Select->getOperand(2), Shape))
728+
pushInstruction(Select, WorkList);
713729
} else if (isUniformShape(V)) {
714730
// Propagate to all operands.
715731
ShapeInfo Shape = ShapeMap[V];
@@ -1068,6 +1084,8 @@ class LowerMatrixIntrinsics {
10681084
Changed |= VisitBinaryOperator(BinOp);
10691085
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10701086
Changed |= VisitUnaryOperator(UnOp);
1087+
if (auto *Select = dyn_cast<SelectInst>(Inst))
1088+
Changed |= VisitSelectInst(Select);
10711089
if (match(Inst, m_Load(m_Value(Op1))))
10721090
Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
10731091
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2198,6 +2216,35 @@ class LowerMatrixIntrinsics {
21982216
return true;
21992217
}
22002218

2219+
/// Lower selects, if shape information is available.
2220+
bool VisitSelectInst(SelectInst *Inst) {
2221+
auto I = ShapeMap.find(Inst);
2222+
if (I == ShapeMap.end())
2223+
return false;
2224+
2225+
Value *Cond = Inst->getOperand(0);
2226+
Value *OpA = Inst->getOperand(1);
2227+
Value *OpB = Inst->getOperand(2);
2228+
2229+
IRBuilder<> Builder(Inst);
2230+
ShapeInfo &Shape = I->second;
2231+
2232+
MatrixTy Result;
2233+
MatrixTy A = getMatrix(OpA, Shape, Builder);
2234+
MatrixTy B = getMatrix(OpB, Shape, Builder);
2235+
2236+
for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
2237+
auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I));
2238+
Result.addVector(Sel);
2239+
}
2240+
2241+
finalizeLowering(Inst,
2242+
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2243+
Result.getNumVectors()),
2244+
Builder);
2245+
return true;
2246+
}
2247+
22012248
/// Helper to linearize a matrix expression tree into a string. Currently
22022249
/// matrix expressions are linarized by starting at an expression leaf and
22032250
/// linearizing bottom up.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define void @select_2x2_bot(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
5+
; CHECK-LABEL: @select_2x2_bot(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
10+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
11+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
12+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
13+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
14+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
15+
; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
16+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 4
17+
; CHECK-NEXT: ret void
18+
;
19+
%lhsv = load <4 x float>, ptr %lhs
20+
%rhsv = load <4 x float>, ptr %rhs
21+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
22+
call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
23+
ret void
24+
}
25+
26+
define void @select_2x2_lhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
27+
; CHECK-LABEL: @select_2x2_lhs(
28+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 4
29+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
30+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 4
31+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
32+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
33+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
34+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
35+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
36+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
37+
; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
38+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
39+
; CHECK-NEXT: ret void
40+
;
41+
%lhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %lhs, i64 2, i1 false, i32 2, i32 2)
42+
%rhsv = load <4 x float>, ptr %rhs
43+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
44+
store <4 x float> %op, ptr %out
45+
ret void
46+
}
47+
48+
define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
49+
; CHECK-LABEL: @select_2x2_rhs(
50+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
51+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS]], i64 2
52+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
53+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
54+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
55+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
56+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
57+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
58+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
59+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
60+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
61+
; CHECK-NEXT: ret void
62+
;
63+
%lhsv = load <4 x float>, ptr %lhs
64+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
65+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
66+
store <4 x float> %op, ptr %out
67+
ret void
68+
}

0 commit comments

Comments
 (0)