Skip to content

Commit e43d64e

Browse files
NexMingyanming
and
yanming
authored
[RISCV] Sink vp.splat operands of VP intrinsic. (#133245)
This patch introduces a `vp.splat` matching method for VP support by sinking the `vp.splat` operand of VP operations back into the same basic block as the VP operation, facilitating the generation of .vx instructions to reduce vector register pressure. --------- Co-authored-by: yanming <ming.yan@terapines.com>
1 parent 031101c commit e43d64e

File tree

2 files changed

+144
-7
lines changed

2 files changed

+144
-7
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2868,8 +2868,11 @@ bool RISCVTTIImpl::isProfitableToSinkOperands(
28682868
if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
28692869
continue;
28702870

2871-
// We are looking for a splat that can be sunk.
2872-
if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
2871+
// We are looking for a splat/vp.splat that can be sunk.
2872+
bool IsVPSplat = match(Op, m_Intrinsic<Intrinsic::experimental_vp_splat>(
2873+
m_Value(), m_Value(), m_Value()));
2874+
if (!IsVPSplat &&
2875+
!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
28732876
m_Undef(), m_ZeroMask())))
28742877
continue;
28752878

@@ -2885,12 +2888,17 @@ bool RISCVTTIImpl::isProfitableToSinkOperands(
28852888
return false;
28862889
}
28872890

2888-
Use *InsertEltUse = &Op->getOperandUse(0);
28892891
// Sink any fpexts since they might be used in a widening fp pattern.
2890-
auto *InsertElt = cast<InsertElementInst>(InsertEltUse);
2891-
if (isa<FPExtInst>(InsertElt->getOperand(1)))
2892-
Ops.push_back(&InsertElt->getOperandUse(1));
2893-
Ops.push_back(InsertEltUse);
2892+
if (IsVPSplat) {
2893+
if (isa<FPExtInst>(Op->getOperand(0)))
2894+
Ops.push_back(&Op->getOperandUse(0));
2895+
} else {
2896+
Use *InsertEltUse = &Op->getOperandUse(0);
2897+
auto *InsertElt = cast<InsertElementInst>(InsertEltUse);
2898+
if (isa<FPExtInst>(InsertElt->getOperand(1)))
2899+
Ops.push_back(&InsertElt->getOperandUse(1));
2900+
Ops.push_back(InsertEltUse);
2901+
}
28942902
Ops.push_back(&OpIdx.value());
28952903
}
28962904
return true;

llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5890,3 +5890,132 @@ vector.body: ; preds = %vector.body, %entry
58905890
for.cond.cleanup: ; preds = %vector.body
58915891
ret void
58925892
}
5893+
5894+
define void @sink_vp_splat(ptr nocapture %out, ptr nocapture %in) {
5895+
; CHECK-LABEL: sink_vp_splat:
5896+
; CHECK: # %bb.0: # %entry
5897+
; CHECK-NEXT: li a2, 0
5898+
; CHECK-NEXT: li a3, 1024
5899+
; CHECK-NEXT: li a4, 3
5900+
; CHECK-NEXT: lui a5, 1
5901+
; CHECK-NEXT: .LBB129_1: # %vector.body
5902+
; CHECK-NEXT: # =>This Loop Header: Depth=1
5903+
; CHECK-NEXT: # Child Loop BB129_2 Depth 2
5904+
; CHECK-NEXT: vsetvli a6, a3, e32, m4, ta, ma
5905+
; CHECK-NEXT: slli a7, a2, 2
5906+
; CHECK-NEXT: vmv.v.i v8, 0
5907+
; CHECK-NEXT: add t0, a1, a7
5908+
; CHECK-NEXT: li t1, 1024
5909+
; CHECK-NEXT: .LBB129_2: # %for.body424
5910+
; CHECK-NEXT: # Parent Loop BB129_1 Depth=1
5911+
; CHECK-NEXT: # => This Inner Loop Header: Depth=2
5912+
; CHECK-NEXT: vle32.v v12, (t0)
5913+
; CHECK-NEXT: addi t1, t1, -1
5914+
; CHECK-NEXT: vmacc.vx v8, a4, v12
5915+
; CHECK-NEXT: add t0, t0, a5
5916+
; CHECK-NEXT: bnez t1, .LBB129_2
5917+
; CHECK-NEXT: # %bb.3: # %vector.latch
5918+
; CHECK-NEXT: # in Loop: Header=BB129_1 Depth=1
5919+
; CHECK-NEXT: add a7, a0, a7
5920+
; CHECK-NEXT: sub a3, a3, a6
5921+
; CHECK-NEXT: vse32.v v8, (a7)
5922+
; CHECK-NEXT: add a2, a2, a6
5923+
; CHECK-NEXT: bnez a3, .LBB129_1
5924+
; CHECK-NEXT: # %bb.4: # %for.cond.cleanup
5925+
; CHECK-NEXT: ret
5926+
entry:
5927+
br label %vector.body
5928+
5929+
vector.body: ; preds = %vector.latch, %entry
5930+
%scalar.ind = phi i64 [ 0, %entry ], [ %next.ind, %vector.latch ]
5931+
%trip.count = phi i64 [ 1024, %entry ], [ %remaining.trip.count, %vector.latch ]
5932+
%evl = tail call i32 @llvm.experimental.get.vector.length.i64(i64 %trip.count, i32 8, i1 true)
5933+
%vp.splat1 = tail call <vscale x 8 x i32> @llvm.experimental.vp.splat.nxv8i32(i32 0, <vscale x 8 x i1> splat(i1 true), i32 %evl)
5934+
%vp.splat2 = tail call <vscale x 8 x i32> @llvm.experimental.vp.splat.nxv8i32(i32 3, <vscale x 8 x i1> splat(i1 true), i32 %evl)
5935+
%evl.cast = zext i32 %evl to i64
5936+
br label %for.body424
5937+
5938+
for.body424: ; preds = %for.body424, %vector.body
5939+
%scalar.phi = phi i64 [ 0, %vector.body ], [ %indvars.iv.next27, %for.body424 ]
5940+
%vector.phi = phi <vscale x 8 x i32> [ %vp.splat1, %vector.body ], [ %vp.binary26, %for.body424 ]
5941+
%arrayidx625 = getelementptr inbounds [1024 x i32], ptr %in, i64 %scalar.phi, i64 %scalar.ind
5942+
%widen.load = tail call <vscale x 8 x i32> @llvm.vp.load.nxv8i32.p0(ptr %arrayidx625, <vscale x 8 x i1> splat (i1 true), i32 %evl)
5943+
%vp.binary = tail call <vscale x 8 x i32> @llvm.vp.mul.nxv8i32(<vscale x 8 x i32> %widen.load, <vscale x 8 x i32> %vp.splat2, <vscale x 8 x i1> splat (i1 true), i32 %evl)
5944+
%vp.binary26 = tail call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> %vector.phi, <vscale x 8 x i32> %vp.binary, <vscale x 8 x i1> splat (i1 true), i32 %evl)
5945+
%indvars.iv.next27 = add nuw nsw i64 %scalar.phi, 1
5946+
%exitcond.not28 = icmp eq i64 %indvars.iv.next27, 1024
5947+
br i1 %exitcond.not28, label %vector.latch, label %for.body424
5948+
5949+
vector.latch: ; preds = %for.body424
5950+
%arrayidx830 = getelementptr inbounds i32, ptr %out, i64 %scalar.ind
5951+
tail call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %vp.binary26, ptr %arrayidx830, <vscale x 8 x i1> splat (i1 true), i32 %evl)
5952+
%remaining.trip.count = sub nuw i64 %trip.count, %evl.cast
5953+
%next.ind = add i64 %scalar.ind, %evl.cast
5954+
%6 = icmp eq i64 %remaining.trip.count, 0
5955+
br i1 %6, label %for.cond.cleanup, label %vector.body
5956+
5957+
for.cond.cleanup: ; preds = %vector.latch
5958+
ret void
5959+
}
5960+
5961+
define void @sink_vp_splat_vfwadd_wf(ptr nocapture %in, float %f) {
5962+
; CHECK-LABEL: sink_vp_splat_vfwadd_wf:
5963+
; CHECK: # %bb.0: # %entry
5964+
; CHECK-NEXT: li a1, 0
5965+
; CHECK-NEXT: li a2, 1024
5966+
; CHECK-NEXT: lui a3, 2
5967+
; CHECK-NEXT: .LBB130_1: # %vector.body
5968+
; CHECK-NEXT: # =>This Loop Header: Depth=1
5969+
; CHECK-NEXT: # Child Loop BB130_2 Depth 2
5970+
; CHECK-NEXT: vsetvli a4, a2, e8, m1, ta, ma
5971+
; CHECK-NEXT: slli a5, a1, 3
5972+
; CHECK-NEXT: add a5, a0, a5
5973+
; CHECK-NEXT: li a6, 1024
5974+
; CHECK-NEXT: .LBB130_2: # %for.body419
5975+
; CHECK-NEXT: # Parent Loop BB130_1 Depth=1
5976+
; CHECK-NEXT: # => This Inner Loop Header: Depth=2
5977+
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
5978+
; CHECK-NEXT: vle64.v v8, (a5)
5979+
; CHECK-NEXT: addi a6, a6, -1
5980+
; CHECK-NEXT: vfwadd.wf v8, v8, fa0
5981+
; CHECK-NEXT: vse64.v v8, (a5)
5982+
; CHECK-NEXT: add a5, a5, a3
5983+
; CHECK-NEXT: bnez a6, .LBB130_2
5984+
; CHECK-NEXT: # %bb.3: # %vector.latch
5985+
; CHECK-NEXT: # in Loop: Header=BB130_1 Depth=1
5986+
; CHECK-NEXT: sub a2, a2, a4
5987+
; CHECK-NEXT: add a1, a1, a4
5988+
; CHECK-NEXT: bnez a2, .LBB130_1
5989+
; CHECK-NEXT: # %bb.4: # %for.cond.cleanup
5990+
; CHECK-NEXT: ret
5991+
entry:
5992+
%conv = fpext float %f to double
5993+
br label %vector.body
5994+
5995+
vector.body: ; preds = %vector.latch, %entry
5996+
%scalar.ind = phi i64 [ 0, %entry ], [ %next.ind, %vector.latch ]
5997+
%trip.count = phi i64 [ 1024, %entry ], [ %remaining.trip.count, %vector.latch ]
5998+
%evl = call i32 @llvm.experimental.get.vector.length.i64(i64 %trip.count, i32 8, i1 true)
5999+
%vp.splat = call <vscale x 8 x double> @llvm.experimental.vp.splat.nxv8f64(double %conv, <vscale x 8 x i1> splat (i1 true), i32 %evl)
6000+
%evl.cast = zext i32 %evl to i64
6001+
br label %for.body419
6002+
6003+
for.body419: ; preds = %for.body419, %vector.body
6004+
%scalar.phi = phi i64 [ 0, %vector.body ], [ %indvars.iv.next21, %for.body419 ]
6005+
%arrayidx620 = getelementptr inbounds [1024 x double], ptr %in, i64 %scalar.phi, i64 %scalar.ind
6006+
%widen.load = call <vscale x 8 x double> @llvm.vp.load.nxv8f64.p0(ptr %arrayidx620, <vscale x 8 x i1> splat (i1 true), i32 %evl)
6007+
%vp.binary = call <vscale x 8 x double> @llvm.vp.fadd.nxv8f64(<vscale x 8 x double> %widen.load, <vscale x 8 x double> %vp.splat, <vscale x 8 x i1> splat (i1 true), i32 %evl)
6008+
call void @llvm.vp.store.nxv8f64.p0(<vscale x 8 x double> %vp.binary, ptr %arrayidx620, <vscale x 8 x i1> splat (i1 true), i32 %evl)
6009+
%indvars.iv.next21 = add nuw nsw i64 %scalar.phi, 1
6010+
%exitcond.not22 = icmp eq i64 %indvars.iv.next21, 1024
6011+
br i1 %exitcond.not22, label %vector.latch, label %for.body419
6012+
6013+
vector.latch: ; preds = %for.body419
6014+
%remaining.trip.count = sub nuw i64 %trip.count, %evl.cast
6015+
%next.ind = add i64 %scalar.ind, %evl.cast
6016+
%cond = icmp eq i64 %remaining.trip.count, 0
6017+
br i1 %cond, label %for.cond.cleanup, label %vector.body
6018+
6019+
for.cond.cleanup: ; preds = %vector.latch
6020+
ret void
6021+
}

0 commit comments

Comments
 (0)