@@ -782,30 +782,42 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
782
782
783
783
// -----
784
784
785
- // CHECK-LABEL: fold_extract_shapecast_negative
786
- // CHECK: %[[V :.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32 >
787
- // CHECK: %[[R:.*]] = vector.extract %[[V ]][1 ] : vector<4x2xf32> from vector<2x4x2xf32 >
788
- // CHECK: return %[[R]] : vector<4x2xf32>
789
- func.func @fold_extract_shapecast_negative (%arg0 : vector <16 x f32 >) -> vector < 4 x 2 x f32 > {
790
- %0 = vector.shape_cast %arg0 : vector <16 x f32 > to vector <2 x 4 x 2 x f32 >
791
- %r = vector.extract %0 [1 ] : vector < 4 x 2 x f32 > from vector <2 x 4 x 2 x f32 >
792
- return %r : vector < 4 x 2 x f32 >
785
+ // CHECK-LABEL: fold_extract_shapecast_0d_result
786
+ // CHECK-SAME: %[[IN :.*]]: vector<1x1x1xf32 >
787
+ // CHECK: %[[R:.*]] = vector.extract %[[IN ]][0, 0, 0 ] : f32 from vector<1x1x1xf32 >
788
+ // CHECK: return %[[R]] : f32
789
+ func.func @fold_extract_shapecast_0d_result (%arg0 : vector <1 x 1 x 1 x f32 >) -> f32 {
790
+ %0 = vector.shape_cast %arg0 : vector <1 x 1 x 1 x f32 > to vector <f32 >
791
+ %r = vector.extract %0 [] : f32 from vector <f32 >
792
+ return %r : f32
793
793
}
794
794
795
795
// -----
796
796
797
- // CHECK-LABEL: dont_fold_0d_extract_shapecast
798
- // CHECK: %[[V :.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32 >
799
- // CHECK: %[[R:.*]] = vector.extract %[[V ]][0 ] : f32 from vector<1xf32 >
797
+ // CHECK-LABEL: fold_extract_shapecast_0d_source
798
+ // CHECK-SAME: %[[IN :.*]]: vector<f32>
799
+ // CHECK: %[[R:.*]] = vector.extract %[[IN ]][] : f32 from vector<f32 >
800
800
// CHECK: return %[[R]] : f32
801
- func.func @dont_fold_0d_extract_shapecast (%arg0 : vector <f32 >) -> f32 {
801
+ func.func @fold_extract_shapecast_0d_source (%arg0 : vector <f32 >) -> f32 {
802
802
%0 = vector.shape_cast %arg0 : vector <f32 > to vector <1 xf32 >
803
803
%r = vector.extract %0 [0 ] : f32 from vector <1 xf32 >
804
804
return %r : f32
805
805
}
806
806
807
807
// -----
808
808
809
+ // CHECK-LABEL: fold_extract_shapecast_negative
810
+ // CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
811
+ // CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
812
+ // CHECK: return %[[R]] : vector<4x2xf32>
813
+ func.func @fold_extract_shapecast_negative (%arg0 : vector <16 xf32 >) -> vector <4 x2 xf32 > {
814
+ %0 = vector.shape_cast %arg0 : vector <16 xf32 > to vector <2 x4 x2 xf32 >
815
+ %r = vector.extract %0 [1 ] : vector <4 x2 xf32 > from vector <2 x4 x2 xf32 >
816
+ return %r : vector <4 x2 xf32 >
817
+ }
818
+
819
+ // -----
820
+
809
821
// CHECK-LABEL: fold_extract_shapecast_to_shapecast
810
822
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
811
823
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
0 commit comments