@@ -155,27 +155,6 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
155
155
// iterator_types is an auto-generated method.
156
156
}
157
157
158
- // / Helper to create a typical indexing map for MatmulOp. Returns a list of
159
- // / AffineMap.
160
- static SmallVector<AffineMap, 3 >
161
- getDefaultIndexingMapsForMatmul (MLIRContext *context) {
162
- AffineExpr d0, d1, d2;
163
- SmallVector<AffineMap, 3 > indexingMaps;
164
- bindDims (context, d0, d1, d2);
165
- indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d2}, context));
166
- indexingMaps.push_back (AffineMap::get (3 , 0 , {d2, d1}, context));
167
- indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d1}, context));
168
- return indexingMaps;
169
- }
170
-
171
- // / Wrapper to return the typical indexing map array attribute for MatmulOp.
172
- static SmallVector<Attribute>
173
- getDefaultMatmulIndexingMapAttr (MLIRContext *context) {
174
- return llvm::map_to_vector (
175
- getDefaultIndexingMapsForMatmul (context),
176
- [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
177
- }
178
-
179
158
// / Creates a structured operation given `inputs`, `outputs`, and `attributes`.
180
159
// / The result types are derived automatically if `resultTensorTypes` is none.
181
160
// / The body of the operation is filled using `regionBuilder`. All ods-gen
@@ -208,24 +187,18 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
208
187
state.attributes .getAttrs (), regionBuilder);
209
188
}
210
189
211
- static void
212
- buildMatmulOp (OpBuilder &b, OperationState &state ,
213
- std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
214
- ValueRange outputs, ArrayRef<NamedAttribute> attributes,
215
- RegionBuilderFn regionBuilder,
216
- std::optional< ArrayRef<AffineMap>> indexingMaps = std::nullopt ) {
217
- // Initialize indexingMaps, for MatmulOp.
190
+ static void buildMatmulOp (OpBuilder &b, OperationState &state,
191
+ std::optional<TypeRange> resultTensorTypes ,
192
+ ValueRange inputs, ValueRange outputs ,
193
+ ArrayRef<NamedAttribute> attributes,
194
+ RegionBuilderFn regionBuilder,
195
+ ArrayRef<AffineMap> indexingMaps) {
196
+ // Initialize indexingMaps attribute , for MatmulOp.
218
197
SmallVector<Attribute, 3 > indexingMapsAttrVal;
219
- if (indexingMaps.has_value ()) {
220
- for (mlir::AffineMap map : *indexingMaps) {
221
- // Convert each AffineMap to an AffineMapAttr
222
- indexingMapsAttrVal.push_back (AffineMapAttr::get (map));
223
- }
224
- state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
225
- } else {
226
- indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr (b.getContext ());
227
- state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
228
- }
198
+ indexingMapsAttrVal = llvm::map_to_vector (
199
+ MatmulOp::getDefaultIndexingMaps (b.getContext ()),
200
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
201
+ state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
229
202
return buildStructuredOp (b, state, resultTensorTypes, inputs, outputs,
230
203
attributes, regionBuilder);
231
204
}
@@ -3457,7 +3430,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3457
3430
unsigned opIndex) {
3458
3431
SmallVector<AffineMap, 3 > opIndexingMaps = matmulOp.getIndexingMapsArray ();
3459
3432
SmallVector<AffineMap, 3 > defaultIndexingMaps =
3460
- matmulOp.getDefaultIndexingMaps ();
3433
+ matmulOp.getDefaultIndexingMaps (matmulOp-> getContext () );
3461
3434
3462
3435
auto opIndexingMap = opIndexingMaps[opIndex];
3463
3436
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3484,6 +3457,17 @@ namespace linalg {
3484
3457
// MatMulOp
3485
3458
// ===----------------------------------------------------------------------===//
3486
3459
3460
+ // / Returns a list of AffineMap with the typical matmul indexing charactristic.
3461
+ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps (MLIRContext *context) {
3462
+ AffineExpr d0, d1, d2;
3463
+ SmallVector<AffineMap, 3 > indexingMaps;
3464
+ bindDims (context, d0, d1, d2);
3465
+ indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d2}, context));
3466
+ indexingMaps.push_back (AffineMap::get (3 , 0 , {d2, d1}, context));
3467
+ indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d1}, context));
3468
+ return indexingMaps;
3469
+ }
3470
+
3487
3471
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray () {
3488
3472
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3489
3473
utils::IteratorType::parallel,
@@ -3501,7 +3485,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3501
3485
// / Check if the op has broadcast and/or transpose semantic. Returns true if
3502
3486
// / the user defined indexing maps are not equal to default map.
3503
3487
bool MatmulOp::hasUserDefinedMaps () {
3504
- SmallVector<AffineMap, 3 > defaultMaps = getDefaultIndexingMaps ();
3488
+ SmallVector<AffineMap, 3 > defaultMaps =
3489
+ getDefaultIndexingMaps (this ->getContext ());
3505
3490
SmallVector<AffineMap, 3 > explicitMaps = getIndexingMapsArray ();
3506
3491
return defaultMaps != explicitMaps;
3507
3492
}
@@ -3535,13 +3520,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3535
3520
helper.yieldOutputs (yields);
3536
3521
}
3537
3522
3538
- // / Returns a list of AffineMap with the typical matmul indexing
3539
- // / charactristic.
3540
- SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps () {
3541
- MLIRContext *context = this ->getContext ();
3542
- return getDefaultIndexingMapsForMatmul (context);
3543
- }
3544
-
3545
3523
// / Returns true if the given broadcast map \p bcastMap is valid for this op.
3546
3524
bool MatmulOp::isValidLhsRhsBroadcastMap (AffineMap bcastMap) {
3547
3525
assert (bcastMap.getNumResults () == 1 && " Expected single result dim expr." );
@@ -3578,7 +3556,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3578
3556
}
3579
3557
// Initialize indexingMaps, if not supplied explicitly.
3580
3558
if (indexingMapsAttr.empty ()) {
3581
- indexingMapsAttr = getDefaultMatmulIndexingMapAttr (result.getContext ());
3559
+ indexingMapsAttr = llvm::map_to_vector (
3560
+ MatmulOp::getDefaultIndexingMaps (parser.getContext ()),
3561
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
3582
3562
}
3583
3563
result.addAttribute (" indexing_maps" ,
3584
3564
parser.getBuilder ().getArrayAttr (indexingMapsAttr));
@@ -3592,8 +3572,9 @@ void MatmulOp::print(OpAsmPrinter &p) {
3592
3572
printNamedStructuredOp (p, getOperation (), getInputs (), getOutputs (),
3593
3573
elidedAttrs);
3594
3574
3595
- SmallVector<Attribute, 3 > indexingMaps =
3596
- getDefaultMatmulIndexingMapAttr (getContext ());
3575
+ SmallVector<Attribute, 3 > indexingMaps = llvm::map_to_vector (
3576
+ MatmulOp::getDefaultIndexingMaps (getContext ()),
3577
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
3597
3578
if (!llvm::equal (getIndexingMaps (), indexingMaps)) {
3598
3579
p << " indexing_maps = [" ;
3599
3580
llvm::interleaveComma (getIndexingMaps (), p,
0 commit comments