@@ -62,6 +62,79 @@ structured_op: !LinalgStructuredOpConfig
62
62
- !ScalarExpression
63
63
scalar_arg : B
64
64
--- !LinalgOpConfig
65
+ metadata : !LinalgOpMetadata
66
+ name : mmt4d
67
+ cpp_class_name : Mmt4DOp
68
+ doc : |-
69
+ Performs a matrix-matrix-transpose multiplication of two 4D inputs.
70
+
71
+ Differences from linalg.matmul:
72
+ * The right hand side is transposed, whence the 't' in 'mmt'.
73
+ * The input and output tensors have a 4D shape instead of a 2D shape. They
74
+ are interpreted as 2D matrices with one level of 2D tile subdivision,
75
+ whence the 2+2=4 dimensions. The inner tile dimensions are identified with
76
+ '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
77
+ as: MxK tiles, each of shape M0xK0.
78
+ implements :
79
+ - LinalgContractionOpInterface
80
+ structured_op : !LinalgStructuredOpConfig
81
+ args :
82
+ - !LinalgOperandDefConfig
83
+ name : lhs
84
+ usage : InputOperand
85
+ type_var : LhsType
86
+ shape_map : affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
87
+ - !LinalgOperandDefConfig
88
+ name : rhs
89
+ usage : InputOperand
90
+ type_var : RhsType
91
+ shape_map : affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
92
+ - !LinalgOperandDefConfig
93
+ name : accum
94
+ usage : OutputOperand
95
+ type_var : AccumType
96
+ shape_map : affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
97
+ indexing_maps : !LinalgIndexingMapsConfig
98
+ static_indexing_maps :
99
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
100
+ d5)>
101
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
102
+ d5)>
103
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
104
+ d3)>
105
+ iterator_types :
106
+ - parallel
107
+ - parallel
108
+ - parallel
109
+ - parallel
110
+ - reduction
111
+ - reduction
112
+ assignments :
113
+ - !ScalarAssign
114
+ arg : accum
115
+ value : !ScalarExpression
116
+ scalar_apply :
117
+ fn_name : add
118
+ operands :
119
+ - !ScalarExpression
120
+ scalar_arg : accum
121
+ - !ScalarExpression
122
+ scalar_apply :
123
+ fn_name : mul
124
+ operands :
125
+ - !ScalarExpression
126
+ symbolic_cast :
127
+ type_var : AccumType
128
+ operands :
129
+ - !ScalarExpression
130
+ scalar_arg : lhs
131
+ - !ScalarExpression
132
+ symbolic_cast :
133
+ type_var : AccumType
134
+ operands :
135
+ - !ScalarExpression
136
+ scalar_arg : rhs
137
+ --- !LinalgOpConfig
65
138
metadata : !LinalgOpMetadata
66
139
name : batch_matmul
67
140
cpp_class_name : BatchMatmulOp
0 commit comments