@@ -54,3 +54,105 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
54
54
}
55
55
return %1 : tensor <?x?x?x?xf32 >
56
56
}
57
+
58
+ // -----
59
+
60
+ // CHECK-LABEL: func @insert_of_padding_expand_shape(
61
+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
62
+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
63
+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
64
+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
65
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
66
+ // CHECK: return %[[insert]]
67
+ func.func @insert_of_padding_expand_shape (
68
+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index )
69
+ -> tensor <?x?x?x?xf32 > {
70
+ %c0 = arith.constant 0 : index
71
+ %c1 = arith.constant 1 : index
72
+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
73
+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
74
+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 , 3 ]] output_shape [1 , %sz0 , 1 , %sz1 ]
75
+ : tensor <?x?xf32 > into tensor <1 x?x1 x?xf32 >
76
+ %1 = tensor.insert_slice %0 into %d [%x , %y , 0 , 0 ][1 , %sz0 , 1 , %sz1 ][1 , 1 , 1 , 1 ]
77
+ : tensor <1 x?x1 x?xf32 > into tensor <?x?x?x?xf32 >
78
+ return %1 : tensor <?x?x?x?xf32 >
79
+ }
80
+
81
+ // -----
82
+
83
+ // CHECK-LABEL: func @insert_of_non_padding_expand_shape(
84
+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
85
+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
86
+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
87
+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
88
+ // CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
89
+ // CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
90
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
91
+ // CHECK: return %[[insert]]
92
+ func.func @insert_of_non_padding_expand_shape (
93
+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index , %sz: index )
94
+ -> tensor <?x?x?x?xf32 > {
95
+ %c0 = arith.constant 0 : index
96
+ %c1 = arith.constant 1 : index
97
+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
98
+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
99
+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 ]] output_shape [%sz , %sz0 , %sz1 ]
100
+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
101
+ %1 = tensor.insert_slice %0 into %d [%x , %y , 0 , 0 ][%sz , 1 , %sz0 , %sz1 ][1 , 1 , 1 , 1 ]
102
+ : tensor <?x?x?xf32 > into tensor <?x?x?x?xf32 >
103
+ return %1 : tensor <?x?x?x?xf32 >
104
+ }
105
+
106
+ // -----
107
+
108
+ // CHECK-LABEL: func @parallel_insert_of_padding_expand_shape(
109
+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
110
+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
111
+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
112
+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
113
+ // CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
114
+ func.func @parallel_insert_of_padding_expand_shape (
115
+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index )
116
+ -> tensor <?x?x?x?xf32 > {
117
+ %c0 = arith.constant 0 : index
118
+ %c1 = arith.constant 1 : index
119
+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
120
+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
121
+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 , 3 ]] output_shape [1 , %sz0 , 1 , %sz1 ]
122
+ : tensor <?x?xf32 > into tensor <1 x?x1 x?xf32 >
123
+ %1 = scf.forall (%i , %j ) in (%x , %y ) shared_outs (%o = %d ) -> (tensor <?x?x?x?xf32 >) {
124
+ scf.forall.in_parallel {
125
+ tensor.parallel_insert_slice %0 into %o [%i , %j , 0 , 0 ][1 , %sz0 , 1 , %sz1 ][1 , 1 , 1 , 1 ]
126
+ : tensor <1 x?x1 x?xf32 > into tensor <?x?x?x?xf32 >
127
+ }
128
+ }
129
+ return %1 : tensor <?x?x?x?xf32 >
130
+ }
131
+
132
+ // -----
133
+
134
+ // CHECK-LABEL: func @parallel_insert_of_non_padding_expand_shape(
135
+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
136
+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
137
+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
138
+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
139
+ // CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
140
+ // CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
141
+ // CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
142
+ func.func @parallel_insert_of_non_padding_expand_shape (
143
+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index , %sz: index )
144
+ -> tensor <?x?x?x?xf32 > {
145
+ %c0 = arith.constant 0 : index
146
+ %c1 = arith.constant 1 : index
147
+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
148
+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
149
+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 ]] output_shape [%sz , %sz0 , %sz1 ]
150
+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
151
+ %1 = scf.forall (%i , %j ) in (%x , %y ) shared_outs (%o = %d ) -> (tensor <?x?x?x?xf32 >) {
152
+ scf.forall.in_parallel {
153
+ tensor.parallel_insert_slice %0 into %o [%i , %j , 0 , 0 ][%sz , 1 , %sz0 , %sz1 ][1 , 1 , 1 , 1 ]
154
+ : tensor <?x?x?xf32 > into tensor <?x?x?x?xf32 >
155
+ }
156
+ }
157
+ return %1 : tensor <?x?x?x?xf32 >
158
+ }
0 commit comments