Skip to content

[HLSL] Implement WaveReadLaneAt intrinsic #111010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 16, 2024
Merged
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4733,6 +4733,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}

def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -9232,6 +9232,8 @@ def err_typecheck_cond_incompatible_operands : Error<
def err_typecheck_expect_scalar_or_vector : Error<
"invalid operand of type %0 where %1 or "
"a vector of such type is required">;
def err_typecheck_expect_any_scalar_or_vector : Error<
"invalid operand of type %0 where a scalar or vector is required">;
def err_typecheck_expect_flt_or_vector : Error<
"invalid operand of type %0 where floating, complex or "
"a vector of such types is required">;
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18887,6 +18887,24 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
}
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
// Due to the use of variadic arguments we must explicitly retreive them and
// create our function type.
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Value *OpIndex = EmitScalarExpr(E->getArg(1));
llvm::FunctionType *FT = llvm::FunctionType::get(
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
false);

// Get overloaded name
std::string Name =
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
ArrayRef{OpExpr->getType()}, &CGM.getModule());
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
/*Local=*/false,
/*AssumeConvergent=*/true),
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
}
case Builtin::BI__builtin_hlsl_elementwise_sign: {
auto *Arg0 = E->getArg(0);
Value *Op0 = EmitScalarExpr(Arg0);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)

//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
Expand Down
80 changes: 80 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2067,6 +2067,86 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
__attribute__((convergent)) bool WaveIsFirstLane();

//===----------------------------------------------------------------------===//
// WaveReadLaneAt builtins
//===----------------------------------------------------------------------===//

// \brief Returns the value of the expression for the given lane index within
// the specified wave.

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) bool WaveReadLaneAt(bool, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) bool2 WaveReadLaneAt(bool2, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) bool3 WaveReadLaneAt(bool3, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) bool4 WaveReadLaneAt(bool4, int32_t);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int16_t WaveReadLaneAt(int16_t, int32_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int16_t2 WaveReadLaneAt(int16_t2, int32_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int16_t3 WaveReadLaneAt(int16_t3, int32_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int16_t4 WaveReadLaneAt(int16_t4, int32_t);
#endif

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) half WaveReadLaneAt(half, int32_t);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) half2 WaveReadLaneAt(half2, int32_t);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) half3 WaveReadLaneAt(half3, int32_t);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) half4 WaveReadLaneAt(half4, int32_t);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int WaveReadLaneAt(int, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int2 WaveReadLaneAt(int2, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int3 WaveReadLaneAt(int3, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int4 WaveReadLaneAt(int4, int32_t);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) float WaveReadLaneAt(float, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) float2 WaveReadLaneAt(float2, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) float3 WaveReadLaneAt(float3, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) float4 WaveReadLaneAt(float4, int32_t);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int64_t WaveReadLaneAt(int64_t, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int64_t2 WaveReadLaneAt(int64_t2, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int64_t3 WaveReadLaneAt(int64_t3, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) int64_t4 WaveReadLaneAt(int64_t4, int32_t);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double WaveReadLaneAt(double, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double2 WaveReadLaneAt(double2, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double3 WaveReadLaneAt(double3, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double4 WaveReadLaneAt(double4, int32_t);

//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,22 @@ static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
return false;
}

static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
assert(TheCall->getNumArgs() >= ArgIndex);
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
auto *VTy = ArgType->getAs<VectorType>();
// not the scalar or vector<scalar>
if (!(ArgType->isScalarType() ||
(VTy && VTy->getElementType()->isScalarType()))) {
S->Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_expect_any_scalar_or_vector)
<< ArgType;
return true;
}
return false;
}

static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
Expand Down Expand Up @@ -1992,6 +2008,29 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

// Ensure index parameter type can be interpreted as a uint
ExprResult Index = TheCall->getArg(1);
QualType ArgTyIndex = Index.get()->getType();
if (!ArgTyIndex->isIntegerType()) {
SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
return true;
}

// Ensure input expr type is a scalar/vector and the same as the return type
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
return true;

ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
TheCall->setType(ArgTyExpr);
break;
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
if (SemaRef.checkArgCount(TheCall, 0))
return true;
Expand Down
74 changes: 74 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -fnative-half-type -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -fnative-half-type -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call for int values.

// CHECK-LABEL: test_int
int test_int(int expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.i32([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok0]]) ]
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.i32([[TY]] %[[#]], i32 %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i32([[TY]], i32) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]

#ifdef __HLSL_ENABLE_16_BIT
// CHECK-LABEL: test_int16
int16_t test_int16(int16_t expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.i16([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok1]]) ]
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.i16([[TY]] %[[#]], i32 %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i16([[TY]], i32) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i16([[TY]], i32) #[[#attr:]]
#endif

// Test basic lowering to runtime function call with array and float values.

// CHECK-LABEL: test_half
half test_half(half expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok2:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.f16([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok2]]) ]
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.f16([[TY]] %[[#]], i32 %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f16([[TY]], i32) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f16([[TY]], i32) #[[#attr:]]

// CHECK-LABEL: test_double
double test_double(double expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok3:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.f64([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok3]]) ]
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.f64([[TY]] %[[#]], i32 %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f64([[TY]], i32) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f64([[TY]], i32) #[[#attr:]]

// CHECK-LABEL: test_floatv4
float4 test_floatv4(float4 expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok4:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.readlane.v4f32([[TY1]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok4]]) ]
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.readlane.v4f32([[TY1]] %[[#]], i32 %[[#]])
// CHECK: ret [[TY1]] %[[RET1]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
38 changes: 38 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

bool test_too_few_arg() {
return __builtin_hlsl_wave_read_lane_at();
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
}

float2 test_too_few_arg_1(float2 p0) {
return __builtin_hlsl_wave_read_lane_at(p0);
// expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
}

float3 test_index_double_type_check(float3 p0, double idx) {
return __builtin_hlsl_wave_read_lane_at(p0, idx);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
}

float3 test_index_int3_type_check(float3 p0, int3 idxs) {
return __builtin_hlsl_wave_read_lane_at(p0, idxs);
// expected-error@-1 {{passing 'int3' (aka 'vector<int, 3>') to parameter of incompatible type 'unsigned int'}}
}

struct S { float f; };

float3 test_index_S_type_check(float3 p0, S idx) {
return __builtin_hlsl_wave_read_lane_at(p0, idx);
// expected-error@-1 {{passing 'S' to parameter of incompatible type 'unsigned int'}}
}

S test_expr_struct_type_check(S p0, int idx) {
return __builtin_hlsl_wave_read_lane_at(p0, idx);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;

Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,16 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> {
let Doc = "returns the value from the specified lane";
let LLVMIntrinsic = int_dx_wave_readlane;
let arguments = [OverloadTy, Int32Ty];
let result = OverloadTy;
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let Doc = "returns the index of the current lane in the wave";
let LLVMIntrinsic = int_dx_wave_getlaneindex;
Expand Down
26 changes: 26 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectUnmergeValues(MachineInstr &I) const;

void selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
Expand Down Expand Up @@ -417,6 +420,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,

case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_INTRINSIC_CONVERGENT:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything that distinguishes a G_INTRINSIC_CONVERGENT vs a G_INTRINSIC? Seems like we are handling them the same way so wondering what happens with that extra information about convergence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict, not within this section of the codebase, it is used elsewhere though eg GlobalISel. But I think @Keenuts would be able to better elaborate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra convergence information is being used when we generate the merge instructions. They should not change how the instrinsic itself is generated. The convergence token help us distinguish between:

int i;
for(i = 0; i < N; ++i) {
  if (cond(i)) {
    // The "active" lanes are those where cond was true in the same iteration.
    WaveReadLaneAt(i);
    break;
  }
}


for(i = 0; i < N; ++i) {
  if (cond(i))
    break;
}
// All lanes are active because they all reconverged at the end of the loop.
WaveReadLaneAt(i);

case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
return selectIntrinsic(ResVReg, ResType, I);
case TargetOpcode::G_BITREVERSE:
Expand Down Expand Up @@ -1758,6 +1762,26 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}

bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
MachineBasicBlock &BB = *I.getParent();

// IntTy is used to define the execution scope, set to 3 to denote a
// cross-lane interaction equivalent to a SPIR-V subgroup.
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformShuffle))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg());
}

bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2541,6 +2565,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
case Intrinsic::spv_wave_readlane:
return selectWaveReadLaneAt(ResVReg, ResType, I);
case Intrinsic::spv_step:
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
case Intrinsic::spv_radians:
Expand Down
Loading