Skip to content

Commit 961e954

Browse files
[AArch64][SVE] Add more folds to make use of gather/scatter with 32-bit indices
In AArch64ISelLowering.cpp this patch implements this fold: 1) GEP (%ptr, SHL ((stepvector(A) + splat(%offset))) << splat(B))) into GEP (%ptr + (%offset << B), step_vector (A << B)) The above transform simplifies the index operand so that it can be expressed as i32 elements. This allows using only one gather/scatter assembly instruction instead of two. Patch by Paul Walker (@paulwalker-arm). Depends on D117900 Differential Revision: https://reviews.llvm.org/D118345
1 parent 8ada962 commit 961e954

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16387,6 +16387,29 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
1638716387
}
1638816388
}
1638916389

16390+
// Index = shl((step(const) + splat(offset))), splat(shift))
16391+
if (Index.getOpcode() == ISD::SHL &&
16392+
Index.getOperand(0).getOpcode() == ISD::ADD &&
16393+
Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
16394+
SDValue Add = Index.getOperand(0);
16395+
SDValue ShiftOp = Index.getOperand(1);
16396+
SDValue StepOp = Add.getOperand(0);
16397+
SDValue OffsetOp = Add.getOperand(1);
16398+
if (auto *Shift =
16399+
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(ShiftOp)))
16400+
if (auto Offset = DAG.getSplatValue(OffsetOp)) {
16401+
int64_t Step =
16402+
cast<ConstantSDNode>(StepOp.getOperand(0))->getSExtValue();
16403+
// Stride does not scale explicitly by 'Scale', because it happens in
16404+
// the gather/scatter addressing mode.
16405+
Stride = Step << Shift->getSExtValue();
16406+
// BasePtr = BasePtr + ((Offset * Scale) << Shift)
16407+
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
16408+
Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0));
16409+
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
16410+
}
16411+
}
16412+
1639016413
// Return early because no supported pattern is found.
1639116414
if (Stride == 0)
1639216415
return false;

llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,92 @@ define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4
201201
ret void
202202
}
203203

204+
; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the
205+
; impression the gather must be split due to it's <vscale x 4 x i64> offset.
206+
; gather_f32(base, index(offset, 8 * sizeof(float))
207+
define <vscale x 4 x i8> @gather_8i8_index_offset_8([8 x i8]* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
208+
; CHECK-LABEL: gather_8i8_index_offset_8:
209+
; CHECK: // %bb.0:
210+
; CHECK-NEXT: add x8, x0, x1, lsl #3
211+
; CHECK-NEXT: index z0.s, #0, #8
212+
; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw]
213+
; CHECK-NEXT: ret
214+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
215+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
216+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
217+
%t2 = add <vscale x 4 x i64> %t1, %step
218+
%t3 = getelementptr [8 x i8], [8 x i8]* %base, <vscale x 4 x i64> %t2
219+
%t4 = bitcast <vscale x 4 x [8 x i8]*> %t3 to <vscale x 4 x i8*>
220+
%load = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*> %t4, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x i8> undef)
221+
ret <vscale x 4 x i8> %load
222+
}
223+
224+
; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the
225+
; impression the gather must be split due to it's <vscale x 4 x i64> offset.
226+
; gather_f32(base, index(offset, 8 * sizeof(float))
227+
define <vscale x 4 x float> @gather_f32_index_offset_8([8 x float]* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
228+
; CHECK-LABEL: gather_f32_index_offset_8:
229+
; CHECK: // %bb.0:
230+
; CHECK-NEXT: mov w8, #32
231+
; CHECK-NEXT: add x9, x0, x1, lsl #5
232+
; CHECK-NEXT: index z0.s, #0, w8
233+
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x9, z0.s, sxtw]
234+
; CHECK-NEXT: ret
235+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
236+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
237+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
238+
%t2 = add <vscale x 4 x i64> %t1, %step
239+
%t3 = getelementptr [8 x float], [8 x float]* %base, <vscale x 4 x i64> %t2
240+
%t4 = bitcast <vscale x 4 x [8 x float]*> %t3 to <vscale x 4 x float*>
241+
%load = call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*> %t4, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x float> undef)
242+
ret <vscale x 4 x float> %load
243+
}
244+
245+
; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the
246+
; impression the scatter must be split due to it's <vscale x 4 x i64> offset.
247+
; scatter_f16(base, index(offset, 8 * sizeof(i8))
248+
define void @scatter_i8_index_offset_8([8 x i8]* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
249+
; CHECK-LABEL: scatter_i8_index_offset_8:
250+
; CHECK: // %bb.0:
251+
; CHECK-NEXT: add x8, x0, x1, lsl #3
252+
; CHECK-NEXT: index z1.s, #0, #8
253+
; CHECK-NEXT: st1b { z0.s }, p0, [x8, z1.s, sxtw]
254+
; CHECK-NEXT: ret
255+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
256+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
257+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
258+
%t2 = add <vscale x 4 x i64> %t1, %step
259+
%t3 = getelementptr [8 x i8], [8 x i8]* %base, <vscale x 4 x i64> %t2
260+
%t4 = bitcast <vscale x 4 x [8 x i8]*> %t3 to <vscale x 4 x i8*>
261+
call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t4, i32 2, <vscale x 4 x i1> %pg)
262+
ret void
263+
}
264+
265+
; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the
266+
; impression the scatter must be split due to it's <vscale x 4 x i64> offset.
267+
; scatter_f16(base, index(offset, 8 * sizeof(half))
268+
define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
269+
; CHECK-LABEL: scatter_f16_index_offset_8:
270+
; CHECK: // %bb.0:
271+
; CHECK-NEXT: mov w8, #16
272+
; CHECK-NEXT: add x9, x0, x1, lsl #4
273+
; CHECK-NEXT: index z1.s, #0, w8
274+
; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw]
275+
; CHECK-NEXT: ret
276+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
277+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
278+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
279+
%t2 = add <vscale x 4 x i64> %t1, %step
280+
%t3 = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %t2
281+
%t4 = bitcast <vscale x 4 x [8 x half]*> %t3 to <vscale x 4 x half*>
282+
call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %t4, i32 2, <vscale x 4 x i1> %pg)
283+
ret void
284+
}
285+
204286

205287
attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
206288

289+
declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
207290

208291
declare <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*>, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
209292
declare void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8>, <vscale x 4 x i8*>, i32, <vscale x 4 x i1>)

0 commit comments

Comments
 (0)