Skip to content

Commit a07b422

Browse files
authored
[mlir][linalg] Fix SemiFunctionType custom parsing crash on missing () (#110365)
The `SemiFunctionType` allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory. This patch fixes a bug where omitting the parens around the argument types for a `SemiFunctionType` with non-optional result Types would crash the parser. It introduces a `bool` argument `resultOptional` to the parser and printer which, when `false`, correctly enforces the parens around argument types, otherwise printing an error. Fix #109128
1 parent 30213e9 commit a07b422

File tree

6 files changed

+49
-25
lines changed

6 files changed

+49
-25
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
541541

542542
let arguments = (ins TransformHandleTypeInterface:$operand_handle);
543543
let results = (outs TransformParamTypeInterface:$rank);
544-
let assemblyFormat =
545-
"$operand_handle attr-dict `:`"
546-
"custom<SemiFunctionType>(type($operand_handle), type($rank))";
544+
let assemblyFormat = [{
545+
$operand_handle attr-dict `:`
546+
custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
547+
}];
547548

548549
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
549550
}

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
418418

419419
let arguments = (ins TransformHandleTypeInterface:$target);
420420
let results = (outs TransformHandleTypeInterface:$transformed);
421-
let assemblyFormat =
422-
"$target attr-dict `:` "
423-
"custom<SemiFunctionType>(type($target), type($transformed))";
421+
let assemblyFormat = [{
422+
$target attr-dict `:`
423+
custom<SemiFunctionType>(type($target), type($transformed), "false")
424+
}];
424425

425426
let extraClassDeclaration = [{
426427
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -455,9 +456,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
455456

456457
let arguments = (ins TransformHandleTypeInterface:$target);
457458
let results = (outs TransformHandleTypeInterface:$transformed);
458-
let assemblyFormat =
459-
"$target attr-dict `:` "
460-
"custom<SemiFunctionType>(type($target), type($transformed))";
459+
let assemblyFormat = [{
460+
$target attr-dict `:`
461+
custom<SemiFunctionType>(type($target), type($transformed), "false")
462+
}];
461463

462464
let extraClassDeclaration = [{
463465
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -500,7 +502,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
500502
let assemblyFormat = [{
501503
$target
502504
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
503-
`:` custom<SemiFunctionType>(type($target), type($transformed))
505+
`:` custom<SemiFunctionType>(type($target), type($transformed), "false")
504506
}];
505507
let hasVerifier = 1;
506508

@@ -1233,9 +1235,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
12331235
OptionalAttr<I64Attr>:$alignment);
12341236
let results = (outs TransformHandleTypeInterface:$transformed);
12351237

1236-
let assemblyFormat =
1237-
"$target attr-dict `:`"
1238-
"custom<SemiFunctionType>(type($target), type($transformed))";
1238+
let assemblyFormat = [{
1239+
$target attr-dict `:`
1240+
custom<SemiFunctionType>(type($target), type($transformed), "false")
1241+
}];
12391242

12401243
let extraClassDeclaration = [{
12411244
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1269,9 +1272,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
12691272
let arguments = (ins TransformHandleTypeInterface:$target);
12701273
let results = (outs TransformHandleTypeInterface:$replacement);
12711274
let regions = (region SizedRegion<1>:$bodyRegion);
1272-
let assemblyFormat =
1273-
"$target attr-dict-with-keyword regions `:` "
1274-
"custom<SemiFunctionType>(type($target), type($replacement))";
1275+
let assemblyFormat = [{
1276+
$target attr-dict-with-keyword regions `:`
1277+
custom<SemiFunctionType>(type($target), type($replacement), "false")
1278+
}];
12751279
let hasVerifier = 1;
12761280
}
12771281

@@ -1310,9 +1314,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
13101314
let arguments = (ins TransformHandleTypeInterface:$target);
13111315
let results = (outs TransformHandleTypeInterface:$result);
13121316

1313-
let assemblyFormat =
1314-
"$target attr-dict `:`"
1315-
"custom<SemiFunctionType>(type($target), type($result))";
1317+
let assemblyFormat = [{
1318+
$target attr-dict `:`
1319+
custom<SemiFunctionType>(type($target), type($result), "false")
1320+
}];
13161321

13171322
let extraClassDeclaration = [{
13181323
::mlir::DiagnosedSilenceableFailure applyToOne(

mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Operation;
3030
/// the argument type in absence of result types, and does not accept the
3131
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
3232
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
33-
Type &resultType);
33+
Type &resultType, bool resultOptional = true);
3434
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
3535
SmallVectorImpl<Type> &resultTypes);
3636

@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
4040
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
4141
Type argumentType, TypeRange resultType);
4242
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
43-
Type argumentType, Type resultType);
43+
Type argumentType, Type resultType,
44+
bool resultOptional = true);
4445
} // namespace mlir
4546

4647
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H

mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
3232
let arguments = (ins TransformHandleTypeInterface:$target);
3333
let results = (outs TransformHandleTypeInterface:$result);
3434

35-
let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
35+
let assemblyFormat = [{
36+
$target attr-dict `:`
37+
custom<SemiFunctionType>(type($target), type($result), "false")
38+
}];
3639
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
3740
::mlir::Value getOperandHandle() { return getTarget(); }
3841
}];

mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
using namespace mlir;
1313

1414
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
15-
Type &resultType) {
15+
Type &resultType, bool resultOptional) {
1616
argumentType = resultType = nullptr;
17-
bool hasLParen = parser.parseOptionalLParen().succeeded();
17+
18+
bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
19+
: parser.parseLParen().succeeded();
20+
if (!resultOptional && !hasLParen)
21+
return failure();
1822
if (parser.parseType(argumentType).failed())
1923
return failure();
2024
if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
6973
}
7074

7175
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
72-
Type argumentType, Type resultType) {
76+
Type argumentType, Type resultType,
77+
bool resultOptional) {
78+
assert(resultOptional || resultType != nullptr);
7379
return printSemiFunctionType(printer, op, argumentType,
7480
resultType ? TypeRange(resultType)
7581
: TypeRange());

mlir/test/Dialect/Linalg/transform-ops-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
9292
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
9393

9494
}
95+
96+
// -----
97+
98+
transform.sequence failures(propagate) {
99+
^bb0(%arg0: !transform.any_op):
100+
// expected-error@below {{expected '('}}
101+
%res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
102+
}

0 commit comments

Comments
 (0)