Skip to content

Commit 4ce84b0

Browse files
abialas1ThomasRaoux
authored andcommitted
[mlir][spirv] Add GroupNonUniformBroadcastOp
Added GroupNonUniformBroadcastOp to spirv dialect. Differential Revision: https://reviews.llvm.org/D87688
1 parent 027d47d commit 4ce84b0

File tree

5 files changed

+158
-12
lines changed

5 files changed

+158
-12
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3256,6 +3256,7 @@ def SPV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263
32563256
def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
32573257
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
32583258
def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
3259+
def SPV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
32593260
def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
32603261
def SPV_OC_OpGroupNonUniformIAdd : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>;
32613262
def SPV_OC_OpGroupNonUniformFAdd : I32EnumAttrCase<"OpGroupNonUniformFAdd", 350>;
@@ -3323,16 +3324,16 @@ def SPV_OpcodeAttr :
33233324
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
33243325
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
33253326
SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
3326-
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
3327-
SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
3328-
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
3329-
SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
3330-
SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
3331-
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
3332-
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
3333-
SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
3334-
SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL,
3335-
SPV_OC_OpSubgroupBlockWriteINTEL
3327+
SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot,
3328+
SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
3329+
SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
3330+
SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
3331+
SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
3332+
SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
3333+
SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
3334+
SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
3335+
SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV,
3336+
SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL
33363337
]>;
33373338

33383339
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,77 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
105105

106106
// -----
107107

108+
def SPV_GroupNonUniformBroadcastOp : SPV_Op<"GroupNonUniformBroadcast",
109+
[NoSideEffect, AllTypesMatch<["value", "result"]>]> {
110+
let summary = [{
111+
Return the Value of the invocation identified by the id Id to all active
112+
invocations in the group.
113+
}];
114+
115+
let description = [{
116+
Result Type must be a scalar or vector of floating-point type, integer
117+
type, or Boolean type.
118+
119+
Execution must be Workgroup or Subgroup Scope.
120+
121+
The type of Value must be the same as Result Type.
122+
123+
Id must be a scalar of integer type, whose Signedness operand is 0.
124+
125+
Before version 1.5, Id must come from a constant instruction. Starting
126+
with version 1.5, Id must be dynamically uniform.
127+
128+
The resulting value is undefined if Id is an inactive invocation, or is
129+
greater than or equal to the size of the group.
130+
131+
<!-- End of AutoGen section -->
132+
133+
```
134+
scope ::= `"Workgroup"` | `"Subgroup"`
135+
integer-float-scalar-vector-type ::= integer-type | float-type |
136+
`vector<` integer-literal `x` integer-type `>` |
137+
`vector<` integer-literal `x` float-type `>`
138+
group-non-uniform-broadcast-op ::= ssa-id `=`
139+
`spv.GroupNonUniformBroadcast` scope ssa_use,
140+
ssa_use `:` integer-float-scalar-vector-type `,` integer-type
141+
```mlir
142+
143+
#### Example:
144+
145+
```
146+
%scalar_value = ... : f32
147+
%vector_value = ... : vector<4xf32>
148+
%id = ... : i32
149+
%0 = spv.GroupNonUniformBroadcast "Subgroup" %scalar_value, %id : f32, i32
150+
%1 = spv.GroupNonUniformBroadcast "Workgroup" %vector_value, %id :
151+
vector<4xf32>, i32
152+
```
153+
}];
154+
155+
let availability = [
156+
MinVersion<SPV_V_1_3>,
157+
MaxVersion<SPV_V_1_5>,
158+
Extension<[]>,
159+
Capability<[SPV_C_GroupNonUniformBallot]>
160+
];
161+
162+
let arguments = (ins
163+
SPV_ScopeAttr:$execution_scope,
164+
SPV_Type:$value,
165+
SPV_Integer:$id
166+
);
167+
168+
let results = (outs
169+
SPV_Type:$result
170+
);
171+
172+
let assemblyFormat = [{
173+
$execution_scope operands attr-dict `:` type($value) `,` type($id)
174+
}];
175+
}
176+
177+
// -----
178+
108179
def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {
109180
let summary = [{
110181
Result is true only in the active invocation with the lowest id in the
@@ -368,8 +439,8 @@ def SPV_GroupNonUniformFMulOp :
368439
def SPV_GroupNonUniformIAddOp :
369440
SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIAdd", SPV_Integer, []> {
370441
let summary = [{
371-
An integer add group operation of all Value operands contributed active
372-
by invocations in the group.
442+
An integer add group operation of all Value operands contributed by
443+
active invocations in the group.
373444
}];
374445

375446
let description = [{

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
1717
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
1818
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
19+
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/Function.h"
2122
#include "mlir/IR/FunctionImplementation.h"
@@ -2043,6 +2044,32 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
20432044
return success();
20442045
}
20452046

2047+
//===----------------------------------------------------------------------===//
2048+
// spv.GroupNonUniformBroadcast
2049+
//===----------------------------------------------------------------------===//
2050+
2051+
static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
2052+
spirv::Scope scope = broadcastOp.execution_scope();
2053+
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2054+
return broadcastOp.emitOpError(
2055+
"execution scope must be 'Workgroup' or 'Subgroup'");
2056+
2057+
// SPIR-V spec: "Before version 1.5, Id must come from a
2058+
// constant instruction.
2059+
auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
2060+
if (auto spirvModule = broadcastOp.getParentOfType<spirv::ModuleOp>())
2061+
targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2062+
2063+
if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2064+
auto *idOp = broadcastOp.id().getDefiningOp();
2065+
if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2066+
spirv::ReferenceOfOp>(idOp)) // for spec constant
2067+
return broadcastOp.emitOpError("id must be the result of a constant op");
2068+
}
2069+
2070+
return success();
2071+
}
2072+
20462073
//===----------------------------------------------------------------------===//
20472074
// spv.SubgroupBlockReadINTEL
20482075
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
88
spv.ReturnValue %0: vector<4xi32>
99
}
1010

11+
// CHECK-LABEL: @group_non_uniform_broadcast
12+
spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" {
13+
%one = spv.constant 1 : i32
14+
// CHECK: spv.GroupNonUniformBroadcast "Subgroup" %{{.*}}, %{{.*}} : f32, i32
15+
%0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %one : f32, i32
16+
spv.ReturnValue %0: f32
17+
}
18+
1119
// CHECK-LABEL: @group_non_uniform_elect
1220
spv.func @group_non_uniform_elect() -> i1 "None" {
1321
// CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1

mlir/test/Dialect/SPIRV/non-uniform-ops.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,45 @@ func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
2828

2929
// -----
3030

31+
//===----------------------------------------------------------------------===//
32+
// spv.NonUniformGroupBroadcast
33+
//===----------------------------------------------------------------------===//
34+
35+
func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
36+
%one = spv.constant 1 : i32
37+
// CHECK: spv.GroupNonUniformBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
38+
%0 = spv.GroupNonUniformBroadcast "Workgroup" %value, %one : f32, i32
39+
return %0: f32
40+
}
41+
42+
// -----
43+
44+
func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> {
45+
%one = spv.constant 1 : i32
46+
// CHECK: spv.GroupNonUniformBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, i32
47+
%0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %one : vector<4xf32>, i32
48+
return %0: vector<4xf32>
49+
}
50+
51+
// -----
52+
53+
func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
54+
%one = spv.constant 1 : i32
55+
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
56+
%0 = spv.GroupNonUniformBroadcast "Device" %value, %one : f32, i32
57+
return %0: f32
58+
}
59+
60+
// -----
61+
62+
func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 {
63+
// expected-error @+1 {{id must be the result of a constant op}}
64+
%0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %localid : f32, i32
65+
return %0: f32
66+
}
67+
68+
// -----
69+
3170
//===----------------------------------------------------------------------===//
3271
// spv.GroupNonUniformElect
3372
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)