Skip to content

Commit 288f05f

Browse files
authored
[NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related C++ code. (#116377)
This commit refactors part of the code in preparation for the migration of other *matmul* variants from OpDSL to ODS. Moves getDefaultIndexingmaps() helper into the MatmulOp class.
1 parent b7ddb97 commit 288f05f

File tree

2 files changed

+38
-63
lines changed

2 files changed

+38
-63
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -622,24 +622,17 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
622622
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
623623
[{
624624
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
625-
attributes, MatmulOp::getRegionBuilder());
625+
attributes, MatmulOp::getRegionBuilder(),
626+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
626627
}]>,
627628
OpBuilder<
628629
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
629630
"ValueRange":$outputs,
630631
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
631632
[{
632633
buildMatmulOp($_builder, $_state, resultTensorTypes,
633-
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
634-
}]>,
635-
OpBuilder<
636-
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
637-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
638-
[{
639-
$_state.addOperands(operands);
640-
$_state.addAttributes(attributes);
641-
$_state.addTypes(resultTensorTypes);
642-
(void)$_state.addRegion();
634+
inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
635+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
643636
}]>,
644637
OpBuilder<
645638
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -648,7 +641,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
648641
[{
649642
$_state.addAttribute("cast", cast);
650643
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
651-
attributes, MatmulOp::getRegionBuilder());
644+
attributes, MatmulOp::getRegionBuilder(),
645+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
652646
}]>
653647

654648
];
@@ -664,7 +658,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
664658
Block &block, ArrayRef<NamedAttribute> attrs);
665659

666660
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
667-
SmallVector<AffineMap> getDefaultIndexingMaps();
661+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
668662

669663
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
670664
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 31 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -155,27 +155,6 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
155155
// iterator_types is an auto-generated method.
156156
}
157157

158-
/// Helper to create a typical indexing map for MatmulOp. Returns a list of
159-
/// AffineMap.
160-
static SmallVector<AffineMap, 3>
161-
getDefaultIndexingMapsForMatmul(MLIRContext *context) {
162-
AffineExpr d0, d1, d2;
163-
SmallVector<AffineMap, 3> indexingMaps;
164-
bindDims(context, d0, d1, d2);
165-
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
166-
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
167-
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
168-
return indexingMaps;
169-
}
170-
171-
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
172-
static SmallVector<Attribute>
173-
getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
174-
return llvm::map_to_vector(
175-
getDefaultIndexingMapsForMatmul(context),
176-
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
177-
}
178-
179158
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
180159
/// The result types are derived automatically if `resultTensorTypes` is none.
181160
/// The body of the operation is filled using `regionBuilder`. All ods-gen
@@ -208,24 +187,18 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
208187
state.attributes.getAttrs(), regionBuilder);
209188
}
210189

211-
static void
212-
buildMatmulOp(OpBuilder &b, OperationState &state,
213-
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
214-
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
215-
RegionBuilderFn regionBuilder,
216-
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
217-
// Initialize indexingMaps, for MatmulOp.
190+
static void buildMatmulOp(OpBuilder &b, OperationState &state,
191+
std::optional<TypeRange> resultTensorTypes,
192+
ValueRange inputs, ValueRange outputs,
193+
ArrayRef<NamedAttribute> attributes,
194+
RegionBuilderFn regionBuilder,
195+
ArrayRef<AffineMap> indexingMaps) {
196+
// Initialize indexingMaps attribute, for MatmulOp.
218197
SmallVector<Attribute, 3> indexingMapsAttrVal;
219-
if (indexingMaps.has_value()) {
220-
for (mlir::AffineMap map : *indexingMaps) {
221-
// Convert each AffineMap to an AffineMapAttr
222-
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
223-
}
224-
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
225-
} else {
226-
indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
227-
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
228-
}
198+
indexingMapsAttrVal = llvm::map_to_vector(
199+
MatmulOp::getDefaultIndexingMaps(b.getContext()),
200+
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201+
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
229202
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
230203
attributes, regionBuilder);
231204
}
@@ -3457,7 +3430,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34573430
unsigned opIndex) {
34583431
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
34593432
SmallVector<AffineMap, 3> defaultIndexingMaps =
3460-
matmulOp.getDefaultIndexingMaps();
3433+
matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
34613434

34623435
auto opIndexingMap = opIndexingMaps[opIndex];
34633436
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3484,6 +3457,17 @@ namespace linalg {
34843457
// MatMulOp
34853458
//===----------------------------------------------------------------------===//
34863459

3460+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3461+
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3462+
AffineExpr d0, d1, d2;
3463+
SmallVector<AffineMap, 3> indexingMaps;
3464+
bindDims(context, d0, d1, d2);
3465+
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3466+
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3467+
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3468+
return indexingMaps;
3469+
}
3470+
34873471
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
34883472
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
34893473
utils::IteratorType::parallel,
@@ -3501,7 +3485,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
35013485
/// Check if the op has broadcast and/or transpose semantic. Returns true if
35023486
/// the user defined indexing maps are not equal to default map.
35033487
bool MatmulOp::hasUserDefinedMaps() {
3504-
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
3488+
SmallVector<AffineMap, 3> defaultMaps =
3489+
getDefaultIndexingMaps(this->getContext());
35053490
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
35063491
return defaultMaps != explicitMaps;
35073492
}
@@ -3535,13 +3520,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
35353520
helper.yieldOutputs(yields);
35363521
}
35373522

3538-
/// Returns a list of AffineMap with the typical matmul indexing
3539-
/// charactristic.
3540-
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
3541-
MLIRContext *context = this->getContext();
3542-
return getDefaultIndexingMapsForMatmul(context);
3543-
}
3544-
35453523
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
35463524
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
35473525
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
@@ -3578,7 +3556,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
35783556
}
35793557
// Initialize indexingMaps, if not supplied explicitly.
35803558
if (indexingMapsAttr.empty()) {
3581-
indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
3559+
indexingMapsAttr = llvm::map_to_vector(
3560+
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3561+
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
35823562
}
35833563
result.addAttribute("indexing_maps",
35843564
parser.getBuilder().getArrayAttr(indexingMapsAttr));
@@ -3592,8 +3572,9 @@ void MatmulOp::print(OpAsmPrinter &p) {
35923572
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
35933573
elidedAttrs);
35943574

3595-
SmallVector<Attribute, 3> indexingMaps =
3596-
getDefaultMatmulIndexingMapAttr(getContext());
3575+
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3576+
MatmulOp::getDefaultIndexingMaps(getContext()),
3577+
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
35973578
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
35983579
p << " indexing_maps = [";
35993580
llvm::interleaveComma(getIndexingMaps(), p,

0 commit comments

Comments
 (0)