Skip to content

Commit df48dfa

Browse files
[AArch64] Add custom lowering of nxv32i1 get.active.lane.mask nodes (#141969)
performActiveLaneMaskCombine already tries to combine a single get.active.lane.mask where the low and high halves of the result are extracted into a single whilelo which operates on a predicate pair. If the get.active.lane.mask node requires splitting, multiple nodes are created with saturating adds to increment the starting index. We cannot combine these into a single whilelo_x2 at this point unless we know the add will not overflow. This patch adds custom lowering for the node if the return type is nxv32xi1, as this can be replaced with a whilelo_x2 using legal types. Anything wider than nxv32i1 will still require splitting first.
1 parent b80024e commit df48dfa

File tree

3 files changed

+149
-31
lines changed

3 files changed

+149
-31
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15111511
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal);
15121512
}
15131513

1514+
if (Subtarget->hasSVE2p1() ||
1515+
(Subtarget->hasSME2() && Subtarget->isStreaming()))
1516+
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, MVT::nxv32i1, Custom);
1517+
15141518
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
15151519
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
15161520
}
@@ -17981,7 +17985,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1798117985
/*IsEqual=*/false))
1798217986
return While;
1798317987

17984-
if (!ST->hasSVE2p1())
17988+
if (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))
1798517989
return SDValue();
1798617990

1798717991
if (!N->hasNUsesOfValue(2, 0))
@@ -27138,6 +27142,37 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
2713827142
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
2713927143
}
2714027144

27145+
void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults(
27146+
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
27147+
assert((Subtarget->hasSVE2p1() ||
27148+
(Subtarget->hasSME2() && Subtarget->isStreaming())) &&
27149+
"Custom lower of get.active.lane.mask missing required feature.");
27150+
27151+
assert(N->getValueType(0) == MVT::nxv32i1 &&
27152+
"Unexpected result type for get.active.lane.mask");
27153+
27154+
SDLoc DL(N);
27155+
SDValue Idx = N->getOperand(0);
27156+
SDValue TC = N->getOperand(1);
27157+
27158+
assert(Idx.getValueType().getFixedSizeInBits() <= 64 &&
27159+
"Unexpected operand type for get.active.lane.mask");
27160+
27161+
if (Idx.getValueType() != MVT::i64) {
27162+
Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
27163+
TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
27164+
}
27165+
27166+
SDValue ID =
27167+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
27168+
EVT HalfVT = N->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
27169+
auto WideMask =
27170+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {HalfVT, HalfVT}, {ID, Idx, TC});
27171+
27172+
Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0),
27173+
{WideMask.getValue(0), WideMask.getValue(1)}));
27174+
}
27175+
2714127176
// Create an even/odd pair of X registers holding integer value V.
2714227177
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
2714327178
SDLoc dl(V.getNode());
@@ -27524,6 +27559,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
2752427559
// CONCAT_VECTORS -- but delegate to common code for result type
2752527560
// legalisation
2752627561
return;
27562+
case ISD::GET_ACTIVE_LANE_MASK:
27563+
ReplaceGetActiveLaneMaskResults(N, Results, DAG);
27564+
return;
2752727565
case ISD::INTRINSIC_WO_CHAIN: {
2752827566
EVT VT = N->getValueType(0);
2752927567

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,9 @@ class AArch64TargetLowering : public TargetLowering {
822822
void ReplaceExtractSubVectorResults(SDNode *N,
823823
SmallVectorImpl<SDValue> &Results,
824824
SelectionDAG &DAG) const;
825+
void ReplaceGetActiveLaneMaskResults(SDNode *N,
826+
SmallVectorImpl<SDValue> &Results,
827+
SelectionDAG &DAG) const;
825828

826829
bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;
827830

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 107 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
22
; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE
3-
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
3+
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SVE2p1
4+
; RUN: llc -mattr=+sve -mattr=+sme2 -force-streaming < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SME2
45
target triple = "aarch64-linux"
56

67
; Test combining of getActiveLaneMask with a pair of extract_vector operations.
@@ -13,12 +14,12 @@ define void @test_2x8bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0
1314
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
1415
; CHECK-SVE-NEXT: b use
1516
;
16-
; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
17-
; CHECK-SVE2p1: // %bb.0:
18-
; CHECK-SVE2p1-NEXT: mov w8, w1
19-
; CHECK-SVE2p1-NEXT: mov w9, w0
20-
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
21-
; CHECK-SVE2p1-NEXT: b use
17+
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
18+
; CHECK-SVE2p1-SME2: // %bb.0:
19+
; CHECK-SVE2p1-SME2-NEXT: mov w8, w1
20+
; CHECK-SVE2p1-SME2-NEXT: mov w9, w0
21+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x9, x8
22+
; CHECK-SVE2p1-SME2-NEXT: b use
2223
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
2324
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
2425
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
@@ -34,10 +35,10 @@ define void @test_2x8bit_mask_with_64bit_index_and_trip_count(i64 %i, i64 %n) #0
3435
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
3536
; CHECK-SVE-NEXT: b use
3637
;
37-
; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
38-
; CHECK-SVE2p1: // %bb.0:
39-
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x0, x1
40-
; CHECK-SVE2p1-NEXT: b use
38+
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
39+
; CHECK-SVE2p1-SME2: // %bb.0:
40+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
41+
; CHECK-SVE2p1-SME2-NEXT: b use
4142
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
4243
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
4344
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
@@ -53,12 +54,12 @@ define void @test_edge_case_2x1bit_mask(i64 %i, i64 %n) #0 {
5354
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
5455
; CHECK-SVE-NEXT: b use
5556
;
56-
; CHECK-SVE2p1-LABEL: test_edge_case_2x1bit_mask:
57-
; CHECK-SVE2p1: // %bb.0:
58-
; CHECK-SVE2p1-NEXT: whilelo p1.d, x0, x1
59-
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
60-
; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
61-
; CHECK-SVE2p1-NEXT: b use
57+
; CHECK-SVE2p1-SME2-LABEL: test_edge_case_2x1bit_mask:
58+
; CHECK-SVE2p1-SME2: // %bb.0:
59+
; CHECK-SVE2p1-SME2-NEXT: whilelo p1.d, x0, x1
60+
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
61+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
62+
; CHECK-SVE2p1-SME2-NEXT: b use
6263
%r = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 %i, i64 %n)
6364
%v0 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 0)
6465
%v1 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 1)
@@ -74,10 +75,10 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
7475
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
7576
; CHECK-SVE-NEXT: b use
7677
;
77-
; CHECK-SVE2p1-LABEL: test_boring_case_2x2bit_mask:
78-
; CHECK-SVE2p1: // %bb.0:
79-
; CHECK-SVE2p1-NEXT: whilelo { p0.d, p1.d }, x0, x1
80-
; CHECK-SVE2p1-NEXT: b use
78+
; CHECK-SVE2p1-SME2-LABEL: test_boring_case_2x2bit_mask:
79+
; CHECK-SVE2p1-SME2: // %bb.0:
80+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
81+
; CHECK-SVE2p1-SME2-NEXT: b use
8182
%r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
8283
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
8384
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
@@ -96,22 +97,22 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
9697
; CHECK-SVE-NEXT: punpklo p1.h, p2.b
9798
; CHECK-SVE-NEXT: b use
9899
;
99-
; CHECK-SVE2p1-LABEL: test_partial_extract:
100-
; CHECK-SVE2p1: // %bb.0:
101-
; CHECK-SVE2p1-NEXT: whilelo p0.h, x0, x1
102-
; CHECK-SVE2p1-NEXT: punpklo p1.h, p0.b
103-
; CHECK-SVE2p1-NEXT: punpkhi p2.h, p0.b
104-
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
105-
; CHECK-SVE2p1-NEXT: punpklo p1.h, p2.b
106-
; CHECK-SVE2p1-NEXT: b use
100+
; CHECK-SVE2p1-SME2-LABEL: test_partial_extract:
101+
; CHECK-SVE2p1-SME2: // %bb.0:
102+
; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
103+
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
104+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p0.b
105+
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
106+
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p2.b
107+
; CHECK-SVE2p1-SME2-NEXT: b use
107108
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
108109
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
109110
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
110111
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1)
111112
ret void
112113
}
113114

114-
;; Negative test for when extracting a fixed-length vector.
115+
; Negative test for when extracting a fixed-length vector.
115116
define void @test_fixed_extract(i64 %i, i64 %n) #0 {
116117
; CHECK-SVE-LABEL: test_fixed_extract:
117118
; CHECK-SVE: // %bb.0:
@@ -144,13 +145,89 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
144145
; CHECK-SVE2p1-NEXT: mov v1.s[1], w11
145146
; CHECK-SVE2p1-NEXT: // kill: def $d1 killed $d1 killed $q1
146147
; CHECK-SVE2p1-NEXT: b use
148+
;
149+
; CHECK-SME2-LABEL: test_fixed_extract:
150+
; CHECK-SME2: // %bb.0:
151+
; CHECK-SME2-NEXT: whilelo p0.h, x0, x1
152+
; CHECK-SME2-NEXT: cset w8, mi
153+
; CHECK-SME2-NEXT: mov z0.h, p0/z, #1 // =0x1
154+
; CHECK-SME2-NEXT: mov z1.h, z0.h[1]
155+
; CHECK-SME2-NEXT: mov z2.h, z0.h[5]
156+
; CHECK-SME2-NEXT: mov z3.h, z0.h[4]
157+
; CHECK-SME2-NEXT: fmov s0, w8
158+
; CHECK-SME2-NEXT: zip1 z0.s, z0.s, z1.s
159+
; CHECK-SME2-NEXT: zip1 z1.s, z3.s, z2.s
160+
; CHECK-SME2-NEXT: // kill: def $d0 killed $d0 killed $z0
161+
; CHECK-SME2-NEXT: // kill: def $d1 killed $d1 killed $z1
162+
; CHECK-SME2-NEXT: b use
147163
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
148164
%v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
149165
%v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
150166
tail call void @use(<2 x i1> %v0, <2 x i1> %v1)
151167
ret void
152168
}
153169

170+
; Illegal Types
171+
172+
define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
173+
; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
174+
; CHECK-SVE: // %bb.0:
175+
; CHECK-SVE-NEXT: rdvl x8, #1
176+
; CHECK-SVE-NEXT: adds w8, w0, w8
177+
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
178+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
179+
; CHECK-SVE-NEXT: whilelo p1.b, w8, w1
180+
; CHECK-SVE-NEXT: b use
181+
;
182+
; CHECK-SVE2p1-SME2-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
183+
; CHECK-SVE2p1-SME2: // %bb.0:
184+
; CHECK-SVE2p1-SME2-NEXT: mov w8, w1
185+
; CHECK-SVE2p1-SME2-NEXT: mov w9, w0
186+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x9, x8
187+
; CHECK-SVE2p1-SME2-NEXT: b use
188+
%r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
189+
%v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
190+
%v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
191+
tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1)
192+
ret void
193+
}
194+
195+
define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
196+
; CHECK-SVE-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
197+
; CHECK-SVE: // %bb.0:
198+
; CHECK-SVE-NEXT: rdvl x8, #2
199+
; CHECK-SVE-NEXT: rdvl x9, #1
200+
; CHECK-SVE-NEXT: adds w8, w0, w8
201+
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
202+
; CHECK-SVE-NEXT: adds w10, w8, w9
203+
; CHECK-SVE-NEXT: csinv w10, w10, wzr, lo
204+
; CHECK-SVE-NEXT: whilelo p3.b, w10, w1
205+
; CHECK-SVE-NEXT: adds w9, w0, w9
206+
; CHECK-SVE-NEXT: csinv w9, w9, wzr, lo
207+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
208+
; CHECK-SVE-NEXT: whilelo p1.b, w9, w1
209+
; CHECK-SVE-NEXT: whilelo p2.b, w8, w1
210+
; CHECK-SVE-NEXT: b use
211+
;
212+
; CHECK-SVE2p1-SME2-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
213+
; CHECK-SVE2p1-SME2: // %bb.0:
214+
; CHECK-SVE2p1-SME2-NEXT: rdvl x8, #2
215+
; CHECK-SVE2p1-SME2-NEXT: mov w9, w1
216+
; CHECK-SVE2p1-SME2-NEXT: mov w10, w0
217+
; CHECK-SVE2p1-SME2-NEXT: adds w8, w0, w8
218+
; CHECK-SVE2p1-SME2-NEXT: csinv w8, w8, wzr, lo
219+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x10, x9
220+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.b, p3.b }, x8, x9
221+
; CHECK-SVE2p1-SME2-NEXT: b use
222+
%r = call <vscale x 64 x i1> @llvm.get.active.lane.mask.nxv64i1.i32(i32 %i, i32 %n)
223+
%v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 0)
224+
%v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 16)
225+
%v2 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 32)
226+
%v3 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 48)
227+
tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1, <vscale x 16 x i1> %v2, <vscale x 16 x i1> %v3)
228+
ret void
229+
}
230+
154231
declare void @use(...)
155232

156233
attributes #0 = { nounwind }

0 commit comments

Comments
 (0)