Skip to content

Commit 0516f49

Browse files
committed
Add linalg.mmt4d named op
This op performs matrix-matrix-transpose multiplication of 4-d inputs as the following: ``` C[m1, n1, m0, n0] = sum_{k1, k0}(A[m1, k1, m0, k0] * B[n1, k1, n0, k0]) ``` Reviewed By: Benoit Differential Revision: https://reviews.llvm.org/D105244
1 parent e86fe36 commit 0516f49

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,79 @@ structured_op: !LinalgStructuredOpConfig
6262
- !ScalarExpression
6363
scalar_arg: B
6464
--- !LinalgOpConfig
65+
metadata: !LinalgOpMetadata
66+
name: mmt4d
67+
cpp_class_name: Mmt4DOp
68+
doc: |-
69+
Performs a matrix-matrix-transpose multiplication of two 4D inputs.
70+
71+
Differences from linalg.matmul:
72+
* The right hand side is transposed, whence the 't' in 'mmt'.
73+
* The input and output tensors have a 4D shape instead of a 2D shape. They
74+
are interpreted as 2D matrices with one level of 2D tile subdivision,
75+
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
76+
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
77+
as: MxK tiles, each of shape M0xK0.
78+
implements:
79+
- LinalgContractionOpInterface
80+
structured_op: !LinalgStructuredOpConfig
81+
args:
82+
- !LinalgOperandDefConfig
83+
name: lhs
84+
usage: InputOperand
85+
type_var: LhsType
86+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
87+
- !LinalgOperandDefConfig
88+
name: rhs
89+
usage: InputOperand
90+
type_var: RhsType
91+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
92+
- !LinalgOperandDefConfig
93+
name: accum
94+
usage: OutputOperand
95+
type_var: AccumType
96+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
97+
indexing_maps: !LinalgIndexingMapsConfig
98+
static_indexing_maps:
99+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
100+
d5)>
101+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
102+
d5)>
103+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
104+
d3)>
105+
iterator_types:
106+
- parallel
107+
- parallel
108+
- parallel
109+
- parallel
110+
- reduction
111+
- reduction
112+
assignments:
113+
- !ScalarAssign
114+
arg: accum
115+
value: !ScalarExpression
116+
scalar_apply:
117+
fn_name: add
118+
operands:
119+
- !ScalarExpression
120+
scalar_arg: accum
121+
- !ScalarExpression
122+
scalar_apply:
123+
fn_name: mul
124+
operands:
125+
- !ScalarExpression
126+
symbolic_cast:
127+
type_var: AccumType
128+
operands:
129+
- !ScalarExpression
130+
scalar_arg: lhs
131+
- !ScalarExpression
132+
symbolic_cast:
133+
type_var: AccumType
134+
operands:
135+
- !ScalarExpression
136+
scalar_arg: rhs
137+
--- !LinalgOpConfig
65138
metadata: !LinalgOpMetadata
66139
name: batch_matmul
67140
cpp_class_name: BatchMatmulOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@ def matmul(
2121
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
2222

2323

24+
@linalg_structured_op
25+
def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
26+
rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
27+
accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0,
28+
output=True)):
29+
"""Performs a matrix-matrix-transpose multiplication of two 4D inputs.
30+
31+
Differences from linalg.matmul:
32+
* The right hand side is transposed, whence the 't' in 'mmt'.
33+
* The input and output tensors have a 4D shape instead of a 2D shape. They
34+
are interpreted as 2D matrices with one level of 2D tile subdivision,
35+
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
36+
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
37+
as: MxK tiles, each of shape M0xK0.
38+
"""
39+
domain(D.m, D.m0, D.n, D.n0, D.k, D.k0)
40+
implements(ContractionOpInterface)
41+
accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
42+
43+
2444
@linalg_structured_op
2545
def batch_matmul(
2646
A=TensorDef(T1, Batch, S.M, S.K),

0 commit comments

Comments
 (0)