Skip to content

Commit a201ba1

Browse files
authored
[mlir][Vector] Add support for 0-d shapes in extract-shape_cast folder (#116650)
The extract <-> shape cast folder was conservatively asserting and failing on 0-d vectors. This pr fixes this. This pr also adds more tests for 0d cases and updates related tests to better reflect what they test.
1 parent 1afb81d commit a201ba1

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,11 +1756,6 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
17561756
if (!shapeCastOp)
17571757
return Value();
17581758

1759-
// 0-D vectors not supported.
1760-
assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1761-
if (hasZeroDimVectors(shapeCastOp))
1762-
return Value();
1763-
17641759
// Get the nth dimension size starting from lowest dimension.
17651760
auto getDimReverse = [](VectorType type, int64_t n) {
17661761
return type.getShape().take_back(n + 1).front();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -782,30 +782,42 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
782782

783783
// -----
784784

785-
// CHECK-LABEL: fold_extract_shapecast_negative
786-
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
787-
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
788-
// CHECK: return %[[R]] : vector<4x2xf32>
789-
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
790-
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
791-
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
792-
return %r : vector<4x2xf32>
785+
// CHECK-LABEL: fold_extract_shapecast_0d_result
786+
// CHECK-SAME: %[[IN:.*]]: vector<1x1x1xf32>
787+
// CHECK: %[[R:.*]] = vector.extract %[[IN]][0, 0, 0] : f32 from vector<1x1x1xf32>
788+
// CHECK: return %[[R]] : f32
789+
func.func @fold_extract_shapecast_0d_result(%arg0 : vector<1x1x1xf32>) -> f32 {
790+
%0 = vector.shape_cast %arg0 : vector<1x1x1xf32> to vector<f32>
791+
%r = vector.extract %0[] : f32 from vector<f32>
792+
return %r : f32
793793
}
794794

795795
// -----
796796

797-
// CHECK-LABEL: dont_fold_0d_extract_shapecast
798-
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32>
799-
// CHECK: %[[R:.*]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
797+
// CHECK-LABEL: fold_extract_shapecast_0d_source
798+
// CHECK-SAME: %[[IN:.*]]: vector<f32>
799+
// CHECK: %[[R:.*]] = vector.extract %[[IN]][] : f32 from vector<f32>
800800
// CHECK: return %[[R]] : f32
801-
func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
801+
func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
802802
%0 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
803803
%r = vector.extract %0[0] : f32 from vector<1xf32>
804804
return %r : f32
805805
}
806806

807807
// -----
808808

809+
// CHECK-LABEL: fold_extract_shapecast_negative
810+
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
811+
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
812+
// CHECK: return %[[R]] : vector<4x2xf32>
813+
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
814+
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
815+
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
816+
return %r : vector<4x2xf32>
817+
}
818+
819+
// -----
820+
809821
// CHECK-LABEL: fold_extract_shapecast_to_shapecast
810822
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
811823
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>

0 commit comments

Comments
 (0)