Skip to content

Commit dd73a1e

Browse files
committed
review comments:
- change intrinsic names to `wave_readlane` to align with AMD implementations - update semantic check to ensure that only a scalar/vector is allowed - add test case to illustrate this
1 parent 49edfec commit dd73a1e

File tree

12 files changed

+68
-35
lines changed

12 files changed

+68
-35
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9232,6 +9232,8 @@ def err_typecheck_cond_incompatible_operands : Error<
92329232
def err_typecheck_expect_scalar_or_vector : Error<
92339233
"invalid operand of type %0 where %1 or "
92349234
"a vector of such type is required">;
9235+
def err_typecheck_expect_any_scalar_or_vector : Error<
9236+
"invalid operand of type %0 where a scalar or vector is required">;
92359237
def err_typecheck_expect_flt_or_vector : Error<
92369238
"invalid operand of type %0 where floating, complex or "
92379239
"a vector of such types is required">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18903,7 +18903,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1890318903
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
1890418904
/*Local=*/false,
1890518905
/*AssumeConvergent=*/true),
18906-
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlaneAt");
18906+
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
1890718907
}
1890818908
case Builtin::BI__builtin_hlsl_elementwise_sign: {
1890918909
auto *Arg0 = E->getArg(0);

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class CGHLSLRuntime {
8989
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
9090
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
92-
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlaneat)
92+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
9393

9494
//===----------------------------------------------------------------------===//
9595
// End of reserved area for HLSL intrinsic getters.

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,22 @@ static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
17511751
return false;
17521752
}
17531753

1754+
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
1755+
unsigned ArgIndex) {
1756+
assert(TheCall->getNumArgs() >= ArgIndex);
1757+
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
1758+
auto *VTy = ArgType->getAs<VectorType>();
1759+
// not the scalar or vector<scalar>
1760+
if (!(ArgType->isScalarType() ||
1761+
(VTy && VTy->getElementType()->isScalarType()))) {
1762+
S->Diag(TheCall->getArg(0)->getBeginLoc(),
1763+
diag::err_typecheck_expect_any_scalar_or_vector)
1764+
<< ArgType;
1765+
return true;
1766+
}
1767+
return false;
1768+
}
1769+
17541770
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
17551771
assert(TheCall->getNumArgs() == 3);
17561772
Expr *Arg1 = TheCall->getArg(1);
@@ -2006,7 +2022,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
20062022
return true;
20072023
}
20082024

2009-
// Ensure return type is the same as the input expr type
2025+
// Ensure input expr type is a scalar/vector and the same as the return type
2026+
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2027+
return true;
2028+
20102029
ExprResult Expr = TheCall->getArg(0);
20112030
QualType ArgTyExpr = Expr.get()->getType();
20122031
TheCall->setType(ArgTyExpr);

clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,27 @@
1010
// CHECK-LABEL: test_int
1111
int test_int(int expr, uint idx) {
1212
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13-
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlaneat.i32([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok]]) ]
14-
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlaneat.i32([[TY]] %[[#]], i32 %[[#]])
13+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.i32([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.i32([[TY]] %[[#]], i32 %[[#]])
1515
// CHECK: ret [[TY]] %[[RET]]
1616
return WaveReadLaneAt(expr, idx);
1717
}
1818

19-
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlaneat.i32([[TY]], i32) #[[#attr:]]
20-
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlaneat.i32([[TY]], i32) #[[#attr:]]
19+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i32([[TY]], i32) #[[#attr:]]
20+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]
2121

2222
// Test basic lowering to runtime function call with array and float value.
2323

2424
// CHECK-LABEL: test_floatv4
2525
float4 test_floatv4(float4 expr, uint idx) {
2626
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
27-
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.readlaneat.v4f32([[TY1]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok1]]) ]
28-
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.readlaneat.v4f32([[TY1]] %[[#]], i32 %[[#]])
27+
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.readlane.v4f32([[TY1]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok1]]) ]
28+
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.readlane.v4f32([[TY1]] %[[#]], i32 %[[#]])
2929
// CHECK: ret [[TY1]] %[[RET1]]
3030
return WaveReadLaneAt(expr, idx);
3131
}
3232

33-
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.readlaneat.v4f32([[TY1]], i32) #[[#attr]]
34-
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.readlaneat.v4f32([[TY1]], i32) #[[#attr]]
33+
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
34+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
3535

3636
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,24 @@ float2 test_too_many_arg(float2 p0) {
1515
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
1616
}
1717

18-
float3 test_index_type_check(float3 p0, double idx) {
18+
float3 test_index_double_type_check(float3 p0, double idx) {
1919
return __builtin_hlsl_wave_read_lane_at(p0, idx);
2020
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
2121
}
2222

23-
float3 test_index_type_check(float3 p0, int3 idxs) {
23+
float3 test_index_int3_type_check(float3 p0, int3 idxs) {
2424
return __builtin_hlsl_wave_read_lane_at(p0, idxs);
2525
// expected-error@-1 {{passing 'int3' (aka 'vector<int, 3>') to parameter of incompatible type 'unsigned int'}}
2626
}
27+
28+
struct S { float f; };
29+
30+
float3 test_index_S_type_check(float3 p0, S idx) {
31+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
32+
// expected-error@-1 {{passing 'S' to parameter of incompatible type 'unsigned int'}}
33+
}
34+
35+
S test_expr_struct_type_check(S p0, int idx) {
36+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
37+
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
38+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
8585
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
8686
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
8787
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
88-
def int_dx_wave_readlaneat : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
88+
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
8989
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
9090
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
9191
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ let TargetPrefix = "spv" in {
8383
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8484
[IntrNoMem, Commutative] >;
8585
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
86-
def int_spv_wave_readlaneat : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
86+
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
8787
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8888
def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8989

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
804804

805805
def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> {
806806
let Doc = "returns the value from the specified lane";
807-
let LLVMIntrinsic = int_dx_wave_readlaneat;
807+
let LLVMIntrinsic = int_dx_wave_readlane;
808808
let arguments = [OverloadTy, Int32Ty];
809809
let result = OverloadTy;
810810
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty]>];

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2565,7 +2565,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
25652565
.addUse(GR.getSPIRVTypeID(ResType))
25662566
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
25672567
}
2568-
case Intrinsic::spv_wave_readlaneat:
2568+
case Intrinsic::spv_wave_readlane:
25692569
return selectWaveReadLaneAt(ResVReg, ResType, I);
25702570
case Intrinsic::spv_step:
25712571
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);

llvm/test/CodeGen/DirectX/WaveReadLaneAt.ll

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,49 @@
55
define noundef half @wave_rla_half(half noundef %expr, i32 noundef %idx) {
66
entry:
77
; CHECK: call half @dx.op.waveReadLaneAt.f16(i32 117, half %expr, i32 %idx)
8-
%ret = call half @llvm.dx.wave.readlaneat.f16(half %expr, i32 %idx)
8+
%ret = call half @llvm.dx.wave.readlane.f16(half %expr, i32 %idx)
99
ret half %ret
1010
}
1111

1212
define noundef float @wave_rla_float(float noundef %expr, i32 noundef %idx) {
1313
entry:
1414
; CHECK: call float @dx.op.waveReadLaneAt.f32(i32 117, float %expr, i32 %idx)
15-
%ret = call float @llvm.dx.wave.readlaneat(float %expr, i32 %idx)
15+
%ret = call float @llvm.dx.wave.readlane(float %expr, i32 %idx)
1616
ret float %ret
1717
}
1818

1919
define noundef double @wave_rla_double(double noundef %expr, i32 noundef %idx) {
2020
entry:
2121
; CHECK: call double @dx.op.waveReadLaneAt.f64(i32 117, double %expr, i32 %idx)
22-
%ret = call double @llvm.dx.wave.readlaneat(double %expr, i32 %idx)
22+
%ret = call double @llvm.dx.wave.readlane(double %expr, i32 %idx)
2323
ret double %ret
2424
}
2525

2626
define noundef i1 @wave_rla_i1(i1 noundef %expr, i32 noundef %idx) {
2727
entry:
2828
; CHECK: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1 %expr, i32 %idx)
29-
%ret = call i1 @llvm.dx.wave.readlaneat.i1(i1 %expr, i32 %idx)
29+
%ret = call i1 @llvm.dx.wave.readlane.i1(i1 %expr, i32 %idx)
3030
ret i1 %ret
3131
}
3232

3333
define noundef i16 @wave_rla_i16(i16 noundef %expr, i32 noundef %idx) {
3434
entry:
3535
; CHECK: call i16 @dx.op.waveReadLaneAt.i16(i32 117, i16 %expr, i32 %idx)
36-
%ret = call i16 @llvm.dx.wave.readlaneat.i16(i16 %expr, i32 %idx)
36+
%ret = call i16 @llvm.dx.wave.readlane.i16(i16 %expr, i32 %idx)
3737
ret i16 %ret
3838
}
3939

4040
define noundef i32 @wave_rla_i32(i32 noundef %expr, i32 noundef %idx) {
4141
entry:
4242
; CHECK: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32 %expr, i32 %idx)
43-
%ret = call i32 @llvm.dx.wave.readlaneat.i32(i32 %expr, i32 %idx)
43+
%ret = call i32 @llvm.dx.wave.readlane.i32(i32 %expr, i32 %idx)
4444
ret i32 %ret
4545
}
4646

47-
declare half @llvm.dx.wave.readlaneat.f16(half, i32)
48-
declare float @llvm.dx.wave.readlaneat.f32(float, i32)
49-
declare double @llvm.dx.wave.readlaneat.f64(double, i32)
47+
declare half @llvm.dx.wave.readlane.f16(half, i32)
48+
declare float @llvm.dx.wave.readlane.f32(float, i32)
49+
declare double @llvm.dx.wave.readlane.f64(double, i32)
5050

51-
declare i1 @llvm.dx.wave.readlaneat.i1(i1, i32)
52-
declare i16 @llvm.dx.wave.readlaneat.i16(i16, i32)
53-
declare i32 @llvm.dx.wave.readlaneat.i32(i32, i32)
51+
declare i1 @llvm.dx.wave.readlane.i1(i1, i32)
52+
declare i16 @llvm.dx.wave.readlane.i16(i16, i32)
53+
declare i32 @llvm.dx.wave.readlane.i32(i32, i32)

llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
define float @test_1(float %fexpr, i32 %idx) {
1515
entry:
1616
; CHECK: %[[#fret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#fexpr]] %[[#idx1]] %[[#scope]]
17-
%0 = call float @llvm.spv.wave.readlaneat.f32(float %fexpr, i32 %idx)
17+
%0 = call float @llvm.spv.wave.readlane.f32(float %fexpr, i32 %idx)
1818
ret float %0
1919
}
2020

@@ -23,7 +23,7 @@ entry:
2323
define i32 @test_2(i32 %iexpr, i32 %idx) {
2424
entry:
2525
; CHECK: %[[#iret:]] = OpGroupNonUniformShuffle %[[#uint]] %[[#iexpr]] %[[#idx2]] %[[#scope]]
26-
%0 = call i32 @llvm.spv.wave.readlaneat.i32(i32 %iexpr, i32 %idx)
26+
%0 = call i32 @llvm.spv.wave.readlane.i32(i32 %iexpr, i32 %idx)
2727
ret i32 %0
2828
}
2929

@@ -32,10 +32,10 @@ entry:
3232
define <4 x i1> @test_3(<4 x i1> %vbexpr, i32 %idx) {
3333
entry:
3434
; CHECK: %[[#vbret:]] = OpGroupNonUniformShuffle %[[#v4_bool]] %[[#vbexpr]] %[[#idx3]] %[[#scope]]
35-
%0 = call <4 x i1> @llvm.spv.wave.readlaneat.v4i1(<4 x i1> %vbexpr, i32 %idx)
35+
%0 = call <4 x i1> @llvm.spv.wave.readlane.v4i1(<4 x i1> %vbexpr, i32 %idx)
3636
ret <4 x i1> %0
3737
}
3838

39-
declare float @llvm.spv.wave.readlaneat.f32(float, i32)
40-
declare i32 @llvm.spv.wave.readlaneat.i32(i32, i32)
41-
declare <4 x i1> @llvm.spv.wave.readlaneat.v4i1(<4 x i1>, i32)
39+
declare float @llvm.spv.wave.readlane.f32(float, i32)
40+
declare i32 @llvm.spv.wave.readlane.i32(i32, i32)
41+
declare <4 x i1> @llvm.spv.wave.readlane.v4i1(<4 x i1>, i32)

0 commit comments

Comments
 (0)