@@ -180,61 +180,50 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
180
180
}
181
181
182
182
SmallVector<OpFoldResult, 4 >
183
- mlir::getMixedOffsets (OffsetSizeAndStrideOpInterface op ,
184
- ArrayAttr staticOffsets, ValueRange offsets ) {
183
+ mlir::getMixedValues (ArrayAttr staticValues, ValueRange dynamicValues ,
184
+ int64_t dynamicValueIndicator ) {
185
185
SmallVector<OpFoldResult, 4 > res;
186
+ res.reserve (staticValues.size ());
186
187
unsigned numDynamic = 0 ;
187
- unsigned count = static_cast <unsigned >(staticOffsets .size ());
188
+ unsigned count = static_cast <unsigned >(staticValues .size ());
188
189
for (unsigned idx = 0 ; idx < count; ++idx) {
189
- if (op. isDynamicOffset (idx))
190
- res.push_back (offsets[numDynamic++]);
191
- else
192
- res. push_back (staticOffsets [idx]);
190
+ APInt value = staticValues[idx]. cast <IntegerAttr>(). getValue ();
191
+ res.push_back (value. getSExtValue () == dynamicValueIndicator
192
+ ? OpFoldResult{dynamicValues[numDynamic++]}
193
+ : OpFoldResult{staticValues [idx]} );
193
194
}
194
195
return res;
195
196
}
196
197
198
+ SmallVector<OpFoldResult, 4 >
199
+ mlir::getMixedOffsets (OffsetSizeAndStrideOpInterface op,
200
+ ArrayAttr staticOffsets, ValueRange offsets) {
201
+ return getMixedValues (staticOffsets, offsets, op.getDynamicOffsetIndicator ());
202
+ }
203
+
197
204
SmallVector<OpFoldResult, 4 >
198
205
mlir::getMixedSizes (OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
199
206
ValueRange sizes) {
200
- SmallVector<OpFoldResult, 4 > res;
201
- unsigned numDynamic = 0 ;
202
- unsigned count = static_cast <unsigned >(staticSizes.size ());
203
- for (unsigned idx = 0 ; idx < count; ++idx) {
204
- if (op.isDynamicSize (idx))
205
- res.push_back (sizes[numDynamic++]);
206
- else
207
- res.push_back (staticSizes[idx]);
208
- }
209
- return res;
207
+ return getMixedValues (staticSizes, sizes, op.getDynamicSizeIndicator ());
210
208
}
211
209
212
210
SmallVector<OpFoldResult, 4 >
213
211
mlir::getMixedStrides (OffsetSizeAndStrideOpInterface op,
214
212
ArrayAttr staticStrides, ValueRange strides) {
215
- SmallVector<OpFoldResult, 4 > res;
216
- unsigned numDynamic = 0 ;
217
- unsigned count = static_cast <unsigned >(staticStrides.size ());
218
- for (unsigned idx = 0 ; idx < count; ++idx) {
219
- if (op.isDynamicStride (idx))
220
- res.push_back (strides[numDynamic++]);
221
- else
222
- res.push_back (staticStrides[idx]);
223
- }
224
- return res;
213
+ return getMixedValues (staticStrides, strides, op.getDynamicStrideIndicator ());
225
214
}
226
215
227
- static std::pair<ArrayAttr, SmallVector<Value>>
228
- decomposeMixedImpl (OpBuilder &b,
229
- const SmallVectorImpl<OpFoldResult> &mixedValues,
230
- const int64_t dynamicValuePlaceholder ) {
216
+ std::pair<ArrayAttr, SmallVector<Value>>
217
+ mlir::decomposeMixedValues (Builder &b,
218
+ const SmallVectorImpl<OpFoldResult> &mixedValues,
219
+ const int64_t dynamicValueIndicator ) {
231
220
SmallVector<int64_t > staticValues;
232
221
SmallVector<Value> dynamicValues;
233
222
for (const auto &it : mixedValues) {
234
223
if (it.is <Attribute>()) {
235
224
staticValues.push_back (it.get <Attribute>().cast <IntegerAttr>().getInt ());
236
225
} else {
237
- staticValues.push_back (ShapedType:: kDynamicStrideOrOffset );
226
+ staticValues.push_back (dynamicValueIndicator );
238
227
dynamicValues.push_back (it.get <Value>());
239
228
}
240
229
}
@@ -243,11 +232,12 @@ decomposeMixedImpl(OpBuilder &b,
243
232
244
233
std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets (
245
234
OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
246
- return decomposeMixedImpl (b, mixedValues, ShapedType::kDynamicStrideOrOffset );
235
+ return decomposeMixedValues (b, mixedValues,
236
+ ShapedType::kDynamicStrideOrOffset );
247
237
}
248
238
249
239
std::pair<ArrayAttr, SmallVector<Value>>
250
240
mlir::decomposeMixedSizes (OpBuilder &b,
251
241
const SmallVectorImpl<OpFoldResult> &mixedValues) {
252
- return decomposeMixedImpl (b, mixedValues, ShapedType::kDynamicSize );
242
+ return decomposeMixedValues (b, mixedValues, ShapedType::kDynamicSize );
253
243
}
0 commit comments