Skip to content

Commit 12175bc

Browse files
authored
[mlir][spirv] Support coop matrix in spirv.CompositeConstruct (#66399)
Also improve the documentation (code and website).
1 parent 571e4f2 commit 12175bc

File tree

3 files changed

+57
-23
lines changed

3 files changed

+57
-23
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
5353
#### Example:
5454

5555
```mlir
56-
%0 = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
56+
%a = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
57+
%b = spirv.CompositeConstruct %a, %1 : (vector<3xf32>, f32) -> vector<4xf32>
58+
59+
%c = spirv.CompositeConstruct %1 :
60+
(f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
61+
62+
%d = spirv.CompositeConstruct %a, %4, %5 :
63+
(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) ->
64+
!spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
5765
```
5866
}];
5967

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/IR/Operation.h"
3030
#include "mlir/IR/TypeUtilities.h"
3131
#include "mlir/Interfaces/FunctionImplementation.h"
32+
#include "mlir/Support/LogicalResult.h"
3233
#include "llvm/ADT/APFloat.h"
3334
#include "llvm/ADT/APInt.h"
3435
#include "llvm/ADT/ArrayRef.h"
@@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() {
363364
//===----------------------------------------------------------------------===//
364365

365366
LogicalResult spirv::CompositeConstructOp::verify() {
366-
auto cType = llvm::cast<spirv::CompositeType>(getType());
367367
operand_range constituents = this->getConstituents();
368368

369-
if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) {
370-
if (constituents.size() != 1)
371-
return emitOpError("has incorrect number of operands: expected ")
372-
<< "1, but provided " << constituents.size();
373-
if (coopType.getElementType() != constituents.front().getType())
374-
return emitOpError("operand type mismatch: expected operand type ")
375-
<< coopType.getElementType() << ", but provided "
376-
<< constituents.front().getType();
377-
return success();
378-
}
369+
// There are 4 cases with varying verification rules:
370+
// 1. Cooperative Matrices (1 constituent)
371+
// 2. Structs (1 constituent for each member)
372+
// 3. Arrays (1 constituent for each array element)
373+
// 4. Vectors (1 constituent (sub-)element for each vector element)
374+
375+
auto coopElementType =
376+
llvm::TypeSwitch<Type, Type>(getType())
377+
.Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
378+
spirv::JointMatrixINTELType>(
379+
[](auto coopType) { return coopType.getElementType(); })
380+
.Default([](Type) { return nullptr; });
379381

380-
if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) {
382+
// Case 1. -- matrices.
383+
if (coopElementType) {
381384
if (constituents.size() != 1)
382385
return emitOpError("has incorrect number of operands: expected ")
383386
<< "1, but provided " << constituents.size();
384-
if (jointType.getElementType() != constituents.front().getType())
387+
if (coopElementType != constituents.front().getType())
385388
return emitOpError("operand type mismatch: expected operand type ")
386-
<< jointType.getElementType() << ", but provided "
389+
<< coopElementType << ", but provided "
387390
<< constituents.front().getType();
388391
return success();
389392
}
390393

394+
// Case 2./3./4. -- number of constituents matches the number of elements.
395+
auto cType = llvm::cast<spirv::CompositeType>(getType());
391396
if (constituents.size() == cType.getNumElements()) {
392397
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
393398
if (constituents[index].getType() != cType.getElementType(index)) {
@@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
399404
return success();
400405
}
401406

402-
// If not constructing a cooperative matrix type, then we must be constructing
403-
// a vector type.
407+
// Case 4. -- check that all constituents add up tp the expected vector type.
404408
auto resultType = llvm::dyn_cast<VectorType>(cType);
405409
if (!resultType)
406410
return emitOpError(

mlir/test/Dialect/SPIRV/IR/composite-ops.mlir

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,36 @@
44
// spirv.CompositeConstruct
55
//===----------------------------------------------------------------------===//
66

7+
// CHECK-LABEL: func @composite_construct_vector
78
func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
89
// CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
910
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
1011
return %0: vector<3xf32>
1112
}
1213

13-
// -----
14-
14+
// CHECK-LABEL: func @composite_construct_struct
1515
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
1616
// CHECK: spirv.CompositeConstruct
1717
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
1818
return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
1919
}
2020

21-
// -----
22-
2321
// CHECK-LABEL: func @composite_construct_mixed_scalar_vector
2422
func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
2523
// CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
2624
%0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32>
2725
return %0: vector<4xf32>
2826
}
2927

30-
// -----
28+
// CHECK-LABEL: func @composite_construct_coopmatrix_khr
29+
func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
30+
// CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
31+
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
32+
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
33+
}
3134

32-
func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
35+
// CHECK-LABEL: func @composite_construct_coopmatrix_nv
36+
func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
3337
// CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
3438
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
3539
return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
@@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
5357

5458
// -----
5559

60+
func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
61+
!spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
62+
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
63+
%0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
64+
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
65+
}
66+
67+
// -----
68+
69+
func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
70+
!spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
71+
// expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
72+
%0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
73+
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
74+
}
75+
76+
// -----
77+
5678
func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
5779
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
5880
%0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>

0 commit comments

Comments
 (0)