Skip to content

Commit b0a5c4a

Browse files
authored
SPV_NV_shader_atomic_fp16_vector (KhronosGroup#5581)
1 parent 55cb398 commit b0a5c4a

File tree

6 files changed

+206
-22
lines changed

6 files changed

+206
-22
lines changed

DEPS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ vars = {
1313
'protobuf_revision': 'v21.12',
1414

1515
're2_revision': 'b4c6fe091b74b65f706ff9c9ff369b396c2a3177',
16-
'spirv_headers_revision': 'd3c2a6fa95ad463ca8044d7fc45557db381a6a64',
16+
'spirv_headers_revision': '05cc486580771e4fa7ddc89f5c9ee1e97382689a',
1717
}
1818

1919
deps = {

source/val/validate_atomics.cpp

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,13 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
144144
case spv::Op::OpAtomicFlagClear: {
145145
const uint32_t result_type = inst->type_id();
146146

147-
// All current atomics only are scalar result
148147
// Validate return type first so can just check if pointer type is same
149148
// (if applicable)
150149
if (HasReturnType(opcode)) {
151150
if (HasOnlyFloatReturnType(opcode) &&
152-
!_.IsFloatScalarType(result_type)) {
151+
(!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
152+
_.IsFloat16Vector2Or4Type(result_type)) &&
153+
!_.IsFloatScalarType(result_type))) {
153154
return _.diag(SPV_ERROR_INVALID_DATA, inst)
154155
<< spvOpcodeString(opcode)
155156
<< ": expected Result Type to be float scalar type";
@@ -160,6 +161,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
160161
<< ": expected Result Type to be integer scalar type";
161162
} else if (HasIntOrFloatReturnType(opcode) &&
162163
!_.IsFloatScalarType(result_type) &&
164+
!(opcode == spv::Op::OpAtomicExchange &&
165+
_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
166+
_.IsFloat16Vector2Or4Type(result_type)) &&
163167
!_.IsIntScalarType(result_type)) {
164168
return _.diag(SPV_ERROR_INVALID_DATA, inst)
165169
<< spvOpcodeString(opcode)
@@ -222,12 +226,21 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
222226

223227
if (opcode == spv::Op::OpAtomicFAddEXT) {
224228
// result type being float checked already
225-
if ((_.GetBitWidth(result_type) == 16) &&
226-
(!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
227-
return _.diag(SPV_ERROR_INVALID_DATA, inst)
228-
<< spvOpcodeString(opcode)
229-
<< ": float add atomics require the AtomicFloat32AddEXT "
230-
"capability";
229+
if (_.GetBitWidth(result_type) == 16) {
230+
if (_.IsFloat16Vector2Or4Type(result_type)) {
231+
if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
232+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
233+
<< spvOpcodeString(opcode)
234+
<< ": float vector atomics require the "
235+
"AtomicFloat16VectorNV capability";
236+
} else {
237+
if (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT)) {
238+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
239+
<< spvOpcodeString(opcode)
240+
<< ": float add atomics require the AtomicFloat32AddEXT "
241+
"capability";
242+
}
243+
}
231244
}
232245
if ((_.GetBitWidth(result_type) == 32) &&
233246
(!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
@@ -245,12 +258,21 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
245258
}
246259
} else if (opcode == spv::Op::OpAtomicFMinEXT ||
247260
opcode == spv::Op::OpAtomicFMaxEXT) {
248-
if ((_.GetBitWidth(result_type) == 16) &&
249-
(!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
250-
return _.diag(SPV_ERROR_INVALID_DATA, inst)
251-
<< spvOpcodeString(opcode)
252-
<< ": float min/max atomics require the "
253-
"AtomicFloat16MinMaxEXT capability";
261+
if (_.GetBitWidth(result_type) == 16) {
262+
if (_.IsFloat16Vector2Or4Type(result_type)) {
263+
if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
264+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
265+
<< spvOpcodeString(opcode)
266+
<< ": float vector atomics require the "
267+
"AtomicFloat16VectorNV capability";
268+
} else {
269+
if (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT)) {
270+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
271+
<< spvOpcodeString(opcode)
272+
<< ": float min/max atomics require the "
273+
"AtomicFloat16MinMaxEXT capability";
274+
}
275+
}
254276
}
255277
if ((_.GetBitWidth(result_type) == 32) &&
256278
(!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {

source/val/validate_image.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,10 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
11181118
const auto ptr_type = result_type->GetOperandAs<uint32_t>(2);
11191119
const auto ptr_opcode = _.GetIdOpcode(ptr_type);
11201120
if (ptr_opcode != spv::Op::OpTypeInt && ptr_opcode != spv::Op::OpTypeFloat &&
1121-
ptr_opcode != spv::Op::OpTypeVoid) {
1121+
ptr_opcode != spv::Op::OpTypeVoid &&
1122+
!(ptr_opcode == spv::Op::OpTypeVector &&
1123+
_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
1124+
_.IsFloat16Vector2Or4Type(ptr_type))) {
11221125
return _.diag(SPV_ERROR_INVALID_DATA, inst)
11231126
<< "Expected Result Type to be OpTypePointer whose Type operand "
11241127
"must be a scalar numerical type or OpTypeVoid";
@@ -1142,7 +1145,14 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
11421145
<< "Corrupt image type definition";
11431146
}
11441147

1145-
if (info.sampled_type != ptr_type) {
1148+
if (info.sampled_type != ptr_type &&
1149+
!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
1150+
_.IsFloat16Vector2Or4Type(ptr_type) &&
1151+
_.GetIdOpcode(info.sampled_type) == spv::Op::OpTypeFloat &&
1152+
((_.GetDimension(ptr_type) == 2 &&
1153+
info.format == spv::ImageFormat::Rg16f) ||
1154+
(_.GetDimension(ptr_type) == 4 &&
1155+
info.format == spv::ImageFormat::Rgba16f)))) {
11461156
return _.diag(SPV_ERROR_INVALID_DATA, inst)
11471157
<< "Expected Image 'Sampled Type' to be the same as the Type "
11481158
"pointed to by Result Type";
@@ -1213,7 +1223,10 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
12131223
(info.format != spv::ImageFormat::R64ui) &&
12141224
(info.format != spv::ImageFormat::R32f) &&
12151225
(info.format != spv::ImageFormat::R32i) &&
1216-
(info.format != spv::ImageFormat::R32ui)) {
1226+
(info.format != spv::ImageFormat::R32ui) &&
1227+
!((info.format == spv::ImageFormat::Rg16f ||
1228+
info.format == spv::ImageFormat::Rgba16f) &&
1229+
_.HasCapability(spv::Capability::AtomicFloat16VectorNV))) {
12171230
return _.diag(SPV_ERROR_INVALID_DATA, inst)
12181231
<< _.VkErrorID(4658)
12191232
<< "Expected the Image Format in Image to be R64i, R64ui, R32f, "

source/val/validation_state.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,20 @@ bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
954954
return false;
955955
}
956956

957+
bool ValidationState_t::IsFloat16Vector2Or4Type(uint32_t id) const {
958+
const Instruction* inst = FindDef(id);
959+
assert(inst);
960+
961+
if (inst->opcode() == spv::Op::OpTypeVector) {
962+
uint32_t vectorDim = GetDimension(id);
963+
return IsFloatScalarType(GetComponentType(id)) &&
964+
(vectorDim == 2 || vectorDim == 4) &&
965+
(GetBitWidth(GetComponentType(id)) == 16);
966+
}
967+
968+
return false;
969+
}
970+
957971
bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
958972
const Instruction* inst = FindDef(id);
959973
if (!inst) {

source/val/validation_state.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ class ValidationState_t {
602602
bool IsVoidType(uint32_t id) const;
603603
bool IsFloatScalarType(uint32_t id) const;
604604
bool IsFloatVectorType(uint32_t id) const;
605+
bool IsFloat16Vector2Or4Type(uint32_t id) const;
605606
bool IsFloatScalarOrVectorType(uint32_t id) const;
606607
bool IsFloatMatrixType(uint32_t id) const;
607608
bool IsIntScalarType(uint32_t id) const;

test/val/val_atomics_test.cpp

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ TEST_F(ValidateAtomics, AtomicAddFloatVulkan) {
318318
EXPECT_THAT(
319319
getDiagnosticString(),
320320
HasSubstr("Opcode AtomicFAddEXT requires one of these capabilities: "
321-
"AtomicFloat32AddEXT AtomicFloat64AddEXT AtomicFloat16AddEXT"));
321+
"AtomicFloat16VectorNV AtomicFloat32AddEXT AtomicFloat64AddEXT "
322+
"AtomicFloat16AddEXT"));
322323
}
323324

324325
TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
@@ -331,7 +332,8 @@ TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
331332
EXPECT_THAT(
332333
getDiagnosticString(),
333334
HasSubstr("Opcode AtomicFMinEXT requires one of these capabilities: "
334-
"AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
335+
"AtomicFloat16VectorNV AtomicFloat32MinMaxEXT "
336+
"AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
335337
}
336338

337339
TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
@@ -343,8 +345,10 @@ TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
343345
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
344346
EXPECT_THAT(
345347
getDiagnosticString(),
346-
HasSubstr("Opcode AtomicFMaxEXT requires one of these capabilities: "
347-
"AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
348+
HasSubstr(
349+
"Opcode AtomicFMaxEXT requires one of these capabilities: "
350+
"AtomicFloat16VectorNV AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT "
351+
"AtomicFloat16MinMaxEXT"));
348352
}
349353

350354
TEST_F(ValidateAtomics, AtomicAddFloatVulkanWrongType1) {
@@ -2713,6 +2717,136 @@ TEST_F(ValidateAtomics, IIncrementBadPointerDataType) {
27132717
"value of type Result Type"));
27142718
}
27152719

2720+
TEST_F(ValidateAtomics, AtomicFloat16VectorSuccess) {
2721+
const std::string definitions = R"(
2722+
%f16 = OpTypeFloat 16
2723+
%f16vec2 = OpTypeVector %f16 2
2724+
%f16vec4 = OpTypeVector %f16 4
2725+
2726+
%f16_1 = OpConstant %f16 1
2727+
%f16vec2_1 = OpConstantComposite %f16vec2 %f16_1 %f16_1
2728+
%f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1
2729+
2730+
%f16vec2_ptr = OpTypePointer Workgroup %f16vec2
2731+
%f16vec4_ptr = OpTypePointer Workgroup %f16vec4
2732+
%f16vec2_var = OpVariable %f16vec2_ptr Workgroup
2733+
%f16vec4_var = OpVariable %f16vec4_ptr Workgroup
2734+
)";
2735+
2736+
const std::string body = R"(
2737+
%val3 = OpAtomicFMinEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
2738+
%val4 = OpAtomicFMaxEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
2739+
%val8 = OpAtomicFAddEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
2740+
%val9 = OpAtomicExchange %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
2741+
2742+
%val11 = OpAtomicFMinEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
2743+
%val12 = OpAtomicFMaxEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
2744+
%val18 = OpAtomicFAddEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
2745+
%val19 = OpAtomicExchange %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
2746+
2747+
)";
2748+
2749+
CompileSuccessfully(GenerateShaderComputeCode(
2750+
body,
2751+
"OpCapability Float16\n"
2752+
"OpCapability AtomicFloat16VectorNV\n"
2753+
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
2754+
definitions),
2755+
SPV_ENV_VULKAN_1_0);
2756+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
2757+
}
2758+
2759+
static constexpr char Float16Vector3Defs[] = R"(
2760+
%f16 = OpTypeFloat 16
2761+
%f16vec3 = OpTypeVector %f16 3
2762+
2763+
%f16_1 = OpConstant %f16 1
2764+
%f16vec3_1 = OpConstantComposite %f16vec3 %f16_1 %f16_1 %f16_1
2765+
2766+
%f16vec3_ptr = OpTypePointer Workgroup %f16vec3
2767+
%f16vec3_var = OpVariable %f16vec3_ptr Workgroup
2768+
)";
2769+
2770+
TEST_F(ValidateAtomics, AtomicFloat16Vector3MinFail) {
2771+
const std::string definitions = Float16Vector3Defs;
2772+
2773+
const std::string body = R"(
2774+
%val11 = OpAtomicFMinEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
2775+
)";
2776+
2777+
CompileSuccessfully(GenerateShaderComputeCode(
2778+
body,
2779+
"OpCapability Float16\n"
2780+
"OpCapability AtomicFloat16VectorNV\n"
2781+
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
2782+
definitions),
2783+
SPV_ENV_VULKAN_1_0);
2784+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
2785+
EXPECT_THAT(
2786+
getDiagnosticString(),
2787+
HasSubstr("AtomicFMinEXT: expected Result Type to be float scalar type"));
2788+
}
2789+
2790+
TEST_F(ValidateAtomics, AtomicFloat16Vector3MaxFail) {
2791+
const std::string definitions = Float16Vector3Defs;
2792+
2793+
const std::string body = R"(
2794+
%val12 = OpAtomicFMaxEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
2795+
)";
2796+
2797+
CompileSuccessfully(GenerateShaderComputeCode(
2798+
body,
2799+
"OpCapability Float16\n"
2800+
"OpCapability AtomicFloat16VectorNV\n"
2801+
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
2802+
definitions),
2803+
SPV_ENV_VULKAN_1_0);
2804+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
2805+
EXPECT_THAT(
2806+
getDiagnosticString(),
2807+
HasSubstr("AtomicFMaxEXT: expected Result Type to be float scalar type"));
2808+
}
2809+
2810+
TEST_F(ValidateAtomics, AtomicFloat16Vector3AddFail) {
2811+
const std::string definitions = Float16Vector3Defs;
2812+
2813+
const std::string body = R"(
2814+
%val18 = OpAtomicFAddEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
2815+
)";
2816+
2817+
CompileSuccessfully(GenerateShaderComputeCode(
2818+
body,
2819+
"OpCapability Float16\n"
2820+
"OpCapability AtomicFloat16VectorNV\n"
2821+
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
2822+
definitions),
2823+
SPV_ENV_VULKAN_1_0);
2824+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
2825+
EXPECT_THAT(
2826+
getDiagnosticString(),
2827+
HasSubstr("AtomicFAddEXT: expected Result Type to be float scalar type"));
2828+
}
2829+
2830+
TEST_F(ValidateAtomics, AtomicFloat16Vector3ExchangeFail) {
2831+
const std::string definitions = Float16Vector3Defs;
2832+
2833+
const std::string body = R"(
2834+
%val19 = OpAtomicExchange %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
2835+
)";
2836+
2837+
CompileSuccessfully(GenerateShaderComputeCode(
2838+
body,
2839+
"OpCapability Float16\n"
2840+
"OpCapability AtomicFloat16VectorNV\n"
2841+
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
2842+
definitions),
2843+
SPV_ENV_VULKAN_1_0);
2844+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
2845+
EXPECT_THAT(getDiagnosticString(),
2846+
HasSubstr("AtomicExchange: expected Result Type to be integer or "
2847+
"float scalar type"));
2848+
}
2849+
27162850
} // namespace
27172851
} // namespace val
27182852
} // namespace spvtools

0 commit comments

Comments
 (0)