Skip to content

Commit e8c2877

Browse files
committed
[mlir] Reuse the code between getMixed*s() funcs in ViewLikeInterface.cpp.
Differential Revision: https://reviews.llvm.org/D130706
1 parent c09d323 commit e8c2877

File tree

4 files changed

+80
-50
lines changed

4 files changed

+80
-50
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
3939
SmallVectorImpl<int64_t> &staticVec,
4040
int64_t sentinel);
4141

42+
/// Return a vector of OpFoldResults given the special value
43+
/// that indicates whether of the value is dynamic or not.
44+
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
45+
ValueRange dynamicValues,
46+
int64_t dynamicValueIndicator);
47+
48+
/// Decompose a vector of mixed static or dynamic values into the corresponding
49+
/// pair of arrays. This is the inverse function of `getMixedValues`.
50+
std::pair<ArrayAttr, SmallVector<Value>>
51+
decomposeMixedValues(Builder &b,
52+
const SmallVectorImpl<OpFoldResult> &mixedValues,
53+
const int64_t dynamicValueIndicator);
54+
4255
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
4356
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
4457

mlir/include/mlir/Interfaces/ViewLikeInterface.td

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
237237
return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
238238
}]
239239
>,
240-
240+
StaticInterfaceMethod<
241+
/*desc=*/"Return constant that indicates the offset is dynamic",
242+
/*retTy=*/"int64_t",
243+
/*methodName=*/"getDynamicOffsetIndicator",
244+
/*args=*/(ins),
245+
/*methodBody=*/"",
246+
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
247+
>,
248+
StaticInterfaceMethod<
249+
/*desc=*/"Return constant that indicates the size is dynamic",
250+
/*retTy=*/"int64_t",
251+
/*methodName=*/"getDynamicSizeIndicator",
252+
/*args=*/(ins),
253+
/*methodBody=*/"",
254+
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }]
255+
>,
256+
StaticInterfaceMethod<
257+
/*desc=*/"Return constant that indicates the stride is dynamic",
258+
/*retTy=*/"int64_t",
259+
/*methodName=*/"getDynamicStrideIndicator",
260+
/*args=*/(ins),
261+
/*methodBody=*/"",
262+
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
263+
>,
241264
InterfaceMethod<
242265
/*desc=*/[{
243266
Assert the offset `idx` is a static constant and return its value.

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Utils/StaticValueUtils.h"
10+
#include "mlir/IR/Builders.h"
1011
#include "mlir/IR/Matchers.h"
1112
#include "mlir/Support/LLVM.h"
1213
#include "llvm/ADT/APSInt.h"
@@ -109,4 +110,40 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
109110
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
110111
return v1 && v1 == v2;
111112
}
113+
114+
/// Return a vector of OpFoldResults given the special value
115+
/// that indicates whether of the value is dynamic or not.
116+
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
117+
ValueRange dynamicValues,
118+
int64_t dynamicValueIndicator) {
119+
SmallVector<OpFoldResult, 4> res;
120+
res.reserve(staticValues.size());
121+
unsigned numDynamic = 0;
122+
unsigned count = static_cast<unsigned>(staticValues.size());
123+
for (unsigned idx = 0; idx < count; ++idx) {
124+
APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
125+
res.push_back(value.getSExtValue() == dynamicValueIndicator
126+
? OpFoldResult{dynamicValues[numDynamic++]}
127+
: OpFoldResult{staticValues[idx]});
128+
}
129+
return res;
130+
}
131+
132+
std::pair<ArrayAttr, SmallVector<Value>>
133+
decomposeMixedValues(Builder &b,
134+
const SmallVectorImpl<OpFoldResult> &mixedValues,
135+
const int64_t dynamicValueIndicator) {
136+
SmallVector<int64_t> staticValues;
137+
SmallVector<Value> dynamicValues;
138+
for (const auto &it : mixedValues) {
139+
if (it.is<Attribute>()) {
140+
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
141+
} else {
142+
staticValues.push_back(dynamicValueIndicator);
143+
dynamicValues.push_back(it.get<Value>());
144+
}
145+
}
146+
return {b.getI64ArrayAttr(staticValues), dynamicValues};
147+
}
148+
112149
} // namespace mlir

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 6 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -182,72 +182,29 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
182182
SmallVector<OpFoldResult, 4>
183183
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
184184
ArrayAttr staticOffsets, ValueRange offsets) {
185-
SmallVector<OpFoldResult, 4> res;
186-
unsigned numDynamic = 0;
187-
unsigned count = static_cast<unsigned>(staticOffsets.size());
188-
for (unsigned idx = 0; idx < count; ++idx) {
189-
if (op.isDynamicOffset(idx))
190-
res.push_back(offsets[numDynamic++]);
191-
else
192-
res.push_back(staticOffsets[idx]);
193-
}
194-
return res;
185+
return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator());
195186
}
196187

197188
SmallVector<OpFoldResult, 4>
198189
mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
199190
ValueRange sizes) {
200-
SmallVector<OpFoldResult, 4> res;
201-
unsigned numDynamic = 0;
202-
unsigned count = static_cast<unsigned>(staticSizes.size());
203-
for (unsigned idx = 0; idx < count; ++idx) {
204-
if (op.isDynamicSize(idx))
205-
res.push_back(sizes[numDynamic++]);
206-
else
207-
res.push_back(staticSizes[idx]);
208-
}
209-
return res;
191+
return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator());
210192
}
211193

212194
SmallVector<OpFoldResult, 4>
213195
mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
214196
ArrayAttr staticStrides, ValueRange strides) {
215-
SmallVector<OpFoldResult, 4> res;
216-
unsigned numDynamic = 0;
217-
unsigned count = static_cast<unsigned>(staticStrides.size());
218-
for (unsigned idx = 0; idx < count; ++idx) {
219-
if (op.isDynamicStride(idx))
220-
res.push_back(strides[numDynamic++]);
221-
else
222-
res.push_back(staticStrides[idx]);
223-
}
224-
return res;
225-
}
226-
227-
static std::pair<ArrayAttr, SmallVector<Value>>
228-
decomposeMixedImpl(OpBuilder &b,
229-
const SmallVectorImpl<OpFoldResult> &mixedValues,
230-
const int64_t dynamicValuePlaceholder) {
231-
SmallVector<int64_t> staticValues;
232-
SmallVector<Value> dynamicValues;
233-
for (const auto &it : mixedValues) {
234-
if (it.is<Attribute>()) {
235-
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
236-
} else {
237-
staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
238-
dynamicValues.push_back(it.get<Value>());
239-
}
240-
}
241-
return {b.getI64ArrayAttr(staticValues), dynamicValues};
197+
return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator());
242198
}
243199

244200
std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
245201
OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
246-
return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
202+
return decomposeMixedValues(b, mixedValues,
203+
ShapedType::kDynamicStrideOrOffset);
247204
}
248205

249206
std::pair<ArrayAttr, SmallVector<Value>>
250207
mlir::decomposeMixedSizes(OpBuilder &b,
251208
const SmallVectorImpl<OpFoldResult> &mixedValues) {
252-
return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
209+
return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize);
253210
}

0 commit comments

Comments
 (0)