Skip to content

Commit 0105f65

Browse files
authored
[mlir][tosa] Fix mul op verifier when input types don't match result (#141617)
This commit fixes a crash when operand types are not integer, but the result is. While this isn't valid, the verifier should not crash.
1 parent 7605198 commit 0105f65

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,10 +1785,10 @@ LogicalResult tosa::MulOp::verify() {
17851785
// specification.
17861786
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
17871787
IntegerType lhsIntType =
1788-
cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1788+
dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
17891789
IntegerType rhsIntType =
1790-
cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1791-
if (lhsIntType != rhsIntType)
1790+
dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1791+
if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
17921792
return emitOpError("requires the same element type for all operands");
17931793

17941794
// Though the spec requires the element type of result to be i32, a more

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,16 @@ func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
677677

678678
// -----
679679

680+
// CHECK-LABEL: test_mul_int_type_mismatch
681+
func.func @test_mul_int_type_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xi32> {
682+
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
683+
// expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
684+
%3 = tosa.mul %arg0, %arg1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xi32>
685+
return %3 : tensor<1xi32>
686+
}
687+
688+
// -----
689+
680690
// CHECK-LABEL: test_mul_invalid_shift
681691
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
682692
%shift = "tosa.const"() {values = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>

0 commit comments

Comments
 (0)