Skip to content

Commit 68b0aaa

Browse files
committed
Revert "Revert "[mlir] Reuse the code between getMixed*s() funcs in ViewLikeInterface.cpp.""
This reverts commit e78d763. Differential Revision: https://reviews.llvm.org/D130706
1 parent e78d763 commit 68b0aaa

File tree

3 files changed

+61
-35
lines changed

3 files changed

+61
-35
lines changed

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ struct Range {
3131

3232
class OffsetSizeAndStrideOpInterface;
3333

34+
/// Return a vector of OpFoldResults given the special value
35+
/// that indicates whether of the value is dynamic or not.
36+
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
37+
ValueRange dynamicValues,
38+
int64_t dynamicValueIndicator);
39+
3440
/// Return a vector of all the static or dynamic offsets of the op from provided
3541
/// external static and dynamic offsets.
3642
SmallVector<OpFoldResult, 4> getMixedOffsets(OffsetSizeAndStrideOpInterface op,
@@ -49,6 +55,13 @@ SmallVector<OpFoldResult, 4> getMixedStrides(OffsetSizeAndStrideOpInterface op,
4955
ArrayAttr staticStrides,
5056
ValueRange strides);
5157

58+
/// Decompose a vector of mixed static or dynamic values into the corresponding
59+
/// pair of arrays. This is the inverse function of `getMixedValues`.
60+
std::pair<ArrayAttr, SmallVector<Value>>
61+
decomposeMixedValues(Builder &b,
62+
const SmallVectorImpl<OpFoldResult> &mixedValues,
63+
const int64_t dynamicValueIndicator);
64+
5265
/// Decompose a vector of mixed static or dynamic strides/offsets into the
5366
/// corresponding pair of arrays. This is the inverse function of
5467
/// `getMixedStrides` and `getMixedOffsets`.

mlir/include/mlir/Interfaces/ViewLikeInterface.td

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
237237
return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
238238
}]
239239
>,
240-
240+
StaticInterfaceMethod<
241+
/*desc=*/"Return constant that indicates the offset is dynamic",
242+
/*retTy=*/"int64_t",
243+
/*methodName=*/"getDynamicOffsetIndicator",
244+
/*args=*/(ins),
245+
/*methodBody=*/"",
246+
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
247+
>,
248+
StaticInterfaceMethod<
249+
/*desc=*/"Return constant that indicates the size is dynamic",
250+
/*retTy=*/"int64_t",
251+
/*methodName=*/"getDynamicSizeIndicator",
252+
/*args=*/(ins),
253+
/*methodBody=*/"",
254+
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }]
255+
>,
256+
StaticInterfaceMethod<
257+
/*desc=*/"Return constant that indicates the stride is dynamic",
258+
/*retTy=*/"int64_t",
259+
/*methodName=*/"getDynamicStrideIndicator",
260+
/*args=*/(ins),
261+
/*methodBody=*/"",
262+
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
263+
>,
241264
InterfaceMethod<
242265
/*desc=*/[{
243266
Assert the offset `idx` is a static constant and return its value.

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -180,61 +180,50 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
180180
}
181181

182182
SmallVector<OpFoldResult, 4>
183-
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
184-
ArrayAttr staticOffsets, ValueRange offsets) {
183+
mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues,
184+
int64_t dynamicValueIndicator) {
185185
SmallVector<OpFoldResult, 4> res;
186+
res.reserve(staticValues.size());
186187
unsigned numDynamic = 0;
187-
unsigned count = static_cast<unsigned>(staticOffsets.size());
188+
unsigned count = static_cast<unsigned>(staticValues.size());
188189
for (unsigned idx = 0; idx < count; ++idx) {
189-
if (op.isDynamicOffset(idx))
190-
res.push_back(offsets[numDynamic++]);
191-
else
192-
res.push_back(staticOffsets[idx]);
190+
APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
191+
res.push_back(value.getSExtValue() == dynamicValueIndicator
192+
? OpFoldResult{dynamicValues[numDynamic++]}
193+
: OpFoldResult{staticValues[idx]});
193194
}
194195
return res;
195196
}
196197

198+
SmallVector<OpFoldResult, 4>
199+
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
200+
ArrayAttr staticOffsets, ValueRange offsets) {
201+
return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator());
202+
}
203+
197204
SmallVector<OpFoldResult, 4>
198205
mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
199206
ValueRange sizes) {
200-
SmallVector<OpFoldResult, 4> res;
201-
unsigned numDynamic = 0;
202-
unsigned count = static_cast<unsigned>(staticSizes.size());
203-
for (unsigned idx = 0; idx < count; ++idx) {
204-
if (op.isDynamicSize(idx))
205-
res.push_back(sizes[numDynamic++]);
206-
else
207-
res.push_back(staticSizes[idx]);
208-
}
209-
return res;
207+
return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator());
210208
}
211209

212210
SmallVector<OpFoldResult, 4>
213211
mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
214212
ArrayAttr staticStrides, ValueRange strides) {
215-
SmallVector<OpFoldResult, 4> res;
216-
unsigned numDynamic = 0;
217-
unsigned count = static_cast<unsigned>(staticStrides.size());
218-
for (unsigned idx = 0; idx < count; ++idx) {
219-
if (op.isDynamicStride(idx))
220-
res.push_back(strides[numDynamic++]);
221-
else
222-
res.push_back(staticStrides[idx]);
223-
}
224-
return res;
213+
return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator());
225214
}
226215

227-
static std::pair<ArrayAttr, SmallVector<Value>>
228-
decomposeMixedImpl(OpBuilder &b,
229-
const SmallVectorImpl<OpFoldResult> &mixedValues,
230-
const int64_t dynamicValuePlaceholder) {
216+
std::pair<ArrayAttr, SmallVector<Value>>
217+
mlir::decomposeMixedValues(Builder &b,
218+
const SmallVectorImpl<OpFoldResult> &mixedValues,
219+
const int64_t dynamicValueIndicator) {
231220
SmallVector<int64_t> staticValues;
232221
SmallVector<Value> dynamicValues;
233222
for (const auto &it : mixedValues) {
234223
if (it.is<Attribute>()) {
235224
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
236225
} else {
237-
staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
226+
staticValues.push_back(dynamicValueIndicator);
238227
dynamicValues.push_back(it.get<Value>());
239228
}
240229
}
@@ -243,11 +232,12 @@ decomposeMixedImpl(OpBuilder &b,
243232

244233
std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
245234
OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
246-
return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
235+
return decomposeMixedValues(b, mixedValues,
236+
ShapedType::kDynamicStrideOrOffset);
247237
}
248238

249239
std::pair<ArrayAttr, SmallVector<Value>>
250240
mlir::decomposeMixedSizes(OpBuilder &b,
251241
const SmallVectorImpl<OpFoldResult> &mixedValues) {
252-
return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
242+
return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize);
253243
}

0 commit comments

Comments
 (0)