Skip to content

Commit a6c4ca8

Browse files
authored
[CIR] Upstream insert op for VectorType (#139146)
This change adds an insert op for VectorType Issue #136487
1 parent 377a047 commit a6c4ca8

File tree

8 files changed

+398
-10
lines changed

8 files changed

+398
-10
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,6 +1969,42 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
19691969
let hasVerifier = 1;
19701970
}
19711971

1972+
//===----------------------------------------------------------------------===//
1973+
// VecInsertOp
1974+
//===----------------------------------------------------------------------===//
1975+
1976+
def VecInsertOp : CIR_Op<"vec.insert", [Pure,
1977+
TypesMatchWith<"argument type matches vector element type", "vec", "value",
1978+
"cast<VectorType>($_self).getElementType()">,
1979+
AllTypesMatch<["result", "vec"]>]> {
1980+
1981+
let summary = "Insert one element into a vector object";
1982+
let description = [{
1983+
The `cir.vec.insert` operation produces a new vector by replacing
1984+
the element of the input vector at `index` with `value`.
1985+
1986+
```mlir
1987+
%value = cir.const #cir.int<5> : !s32i
1988+
%index = cir.const #cir.int<2> : !s32i
1989+
%vec_tmp = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1990+
%new_vec = cir.vec.insert %value, %vec_tmp[%index : !s32i] : !cir.vector<4 x !s32i>
1991+
```
1992+
}];
1993+
1994+
let arguments = (ins
1995+
CIR_VectorType:$vec,
1996+
AnyType:$value,
1997+
CIR_AnyFundamentalIntType:$index
1998+
);
1999+
2000+
let results = (outs CIR_VectorType:$result);
2001+
2002+
let assemblyFormat = [{
2003+
$value `,` $vec `[` $index `:` type($index) `]` attr-dict `:`
2004+
qualified(type($vec))
2005+
}];
2006+
}
2007+
19722008
//===----------------------------------------------------------------------===//
19732009
// VecExtractOp
19742010
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,17 @@ Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr,
205205
void CIRGenFunction::emitStoreThroughLValue(RValue src, LValue dst,
206206
bool isInit) {
207207
if (!dst.isSimple()) {
208+
if (dst.isVectorElt()) {
209+
// Read/modify/write the vector, inserting the new element
210+
const mlir::Location loc = dst.getVectorPointer().getLoc();
211+
const mlir::Value vector =
212+
builder.createLoad(loc, dst.getVectorAddress().getPointer());
213+
const mlir::Value newVector = builder.create<cir::VecInsertOp>(
214+
loc, vector, src.getScalarVal(), dst.getVectorIdx());
215+
builder.createStore(loc, newVector, dst.getVectorAddress().getPointer());
216+
return;
217+
}
218+
208219
cgm.errorNYI(dst.getPointer().getLoc(),
209220
"emitStoreThroughLValue: non-simple lvalue");
210221
return;
@@ -418,6 +429,13 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
418429
if (lv.isSimple())
419430
return RValue::get(emitLoadOfScalar(lv, loc));
420431

432+
if (lv.isVectorElt()) {
433+
const mlir::Value load =
434+
builder.createLoad(getLoc(loc), lv.getVectorAddress().getPointer());
435+
return RValue::get(builder.create<cir::VecExtractOp>(getLoc(loc), load,
436+
lv.getVectorIdx()));
437+
}
438+
421439
cgm.errorNYI(loc, "emitLoadOfLValue");
422440
return RValue::get(nullptr);
423441
}
@@ -638,12 +656,6 @@ static Address emitArraySubscriptPtr(CIRGenFunction &cgf,
638656

639657
LValue
640658
CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
641-
if (e->getBase()->getType()->isVectorType() &&
642-
!isa<ExtVectorElementExpr>(e->getBase())) {
643-
cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: VectorType");
644-
return LValue::makeAddr(Address::invalid(), e->getType(), LValueBaseInfo());
645-
}
646-
647659
if (isa<ExtVectorElementExpr>(e->getBase())) {
648660
cgm.errorNYI(e->getSourceRange(),
649661
"emitArraySubscriptExpr: ExtVectorElementExpr");
@@ -666,18 +678,28 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
666678
assert((e->getIdx() == e->getLHS() || e->getIdx() == e->getRHS()) &&
667679
"index was neither LHS nor RHS");
668680

669-
auto emitIdxAfterBase = [&]() -> mlir::Value {
681+
auto emitIdxAfterBase = [&](bool promote) -> mlir::Value {
670682
const mlir::Value idx = emitScalarExpr(e->getIdx());
671683

672684
// Extend or truncate the index type to 32 or 64-bits.
673685
auto ptrTy = mlir::dyn_cast<cir::PointerType>(idx.getType());
674-
if (ptrTy && mlir::isa<cir::IntType>(ptrTy.getPointee()))
686+
if (promote && ptrTy && ptrTy.isPtrTo<cir::IntType>())
675687
cgm.errorNYI(e->getSourceRange(),
676688
"emitArraySubscriptExpr: index type cast");
677689
return idx;
678690
};
679691

680-
const mlir::Value idx = emitIdxAfterBase();
692+
// If the base is a vector type, then we are forming a vector element
693+
// with this subscript.
694+
if (e->getBase()->getType()->isVectorType() &&
695+
!isa<ExtVectorElementExpr>(e->getBase())) {
696+
const mlir::Value idx = emitIdxAfterBase(/*promote=*/false);
697+
const LValue lhs = emitLValue(e->getBase());
698+
return LValue::makeVectorElt(lhs.getAddress(), idx, e->getBase()->getType(),
699+
lhs.getBaseInfo());
700+
}
701+
702+
const mlir::Value idx = emitIdxAfterBase(/*promote=*/true);
681703
if (const Expr *array = getSimpleArrayDecayOperand(e->getBase())) {
682704
LValue arrayLV;
683705
if (const auto *ase = dyn_cast<ArraySubscriptExpr>(array))

clang/lib/CIR/CodeGen/CIRGenValue.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class LValue {
116116
// this is the alignment of the whole vector)
117117
unsigned alignment;
118118
mlir::Value v;
119+
mlir::Value vectorIdx; // Index for vector subscript
119120
mlir::Type elementType;
120121
LValueBaseInfo baseInfo;
121122

@@ -136,6 +137,7 @@ class LValue {
136137

137138
public:
138139
bool isSimple() const { return lvType == Simple; }
140+
bool isVectorElt() const { return lvType == VectorElt; }
139141
bool isBitField() const { return lvType == BitField; }
140142

141143
// TODO: Add support for volatile
@@ -176,6 +178,31 @@ class LValue {
176178
r.initialize(t, t.getQualifiers(), address.getAlignment(), baseInfo);
177179
return r;
178180
}
181+
182+
Address getVectorAddress() const {
183+
return Address(getVectorPointer(), elementType, getAlignment());
184+
}
185+
186+
mlir::Value getVectorPointer() const {
187+
assert(isVectorElt());
188+
return v;
189+
}
190+
191+
mlir::Value getVectorIdx() const {
192+
assert(isVectorElt());
193+
return vectorIdx;
194+
}
195+
196+
static LValue makeVectorElt(Address vecAddress, mlir::Value index,
197+
clang::QualType t, LValueBaseInfo baseInfo) {
198+
LValue r;
199+
r.lvType = VectorElt;
200+
r.v = vecAddress.getPointer();
201+
r.elementType = vecAddress.getElementType();
202+
r.vectorIdx = index;
203+
r.initialize(t, t.getQualifiers(), vecAddress.getAlignment(), baseInfo);
204+
return r;
205+
}
179206
};
180207

181208
/// An aggregate value slot.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
16461646
CIRToLLVMTrapOpLowering,
16471647
CIRToLLVMUnaryOpLowering,
16481648
CIRToLLVMVecCreateOpLowering,
1649-
CIRToLLVMVecExtractOpLowering
1649+
CIRToLLVMVecExtractOpLowering,
1650+
CIRToLLVMVecInsertOpLowering
16501651
// clang-format on
16511652
>(converter, patterns.getContext());
16521653

@@ -1763,6 +1764,14 @@ mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
17631764
return mlir::success();
17641765
}
17651766

1767+
mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
1768+
cir::VecInsertOp op, OpAdaptor adaptor,
1769+
mlir::ConversionPatternRewriter &rewriter) const {
1770+
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
1771+
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
1772+
return mlir::success();
1773+
}
1774+
17661775
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
17671776
return std::make_unique<ConvertCIRToLLVMPass>();
17681777
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,16 @@ class CIRToLLVMVecExtractOpLowering
322322
mlir::ConversionPatternRewriter &) const override;
323323
};
324324

325+
class CIRToLLVMVecInsertOpLowering
326+
: public mlir::OpConversionPattern<cir::VecInsertOp> {
327+
public:
328+
using mlir::OpConversionPattern<cir::VecInsertOp>::OpConversionPattern;
329+
330+
mlir::LogicalResult
331+
matchAndRewrite(cir::VecInsertOp op, OpAdaptor,
332+
mlir::ConversionPatternRewriter &) const override;
333+
};
334+
325335
} // namespace direct
326336
} // namespace cir
327337

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,126 @@ void foo4() {
213213
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
214214
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
215215
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
216+
217+
void foo5() {
218+
vi4 a = { 1, 2, 3, 4 };
219+
220+
a[2] = 5;
221+
}
222+
223+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
224+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
225+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
226+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
227+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
228+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
229+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
230+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
231+
// CIR: %[[CONST_VAL:.*]] = cir.const #cir.int<5> : !s32i
232+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
233+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
234+
// CIR: %[[NEW_VEC:.*]] = cir.vec.insert %[[CONST_VAL]], %[[TMP]][%[[CONST_IDX]] : !s32i] : !cir.vector<4 x !s32i>
235+
// CIR: cir.store %[[NEW_VEC]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
236+
237+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
238+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
239+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
240+
// LLVM: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP]], i32 5, i32 2
241+
// LLVM: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
242+
243+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
244+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
245+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
246+
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP]], i32 5, i32 2
247+
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
248+
249+
void foo6() {
250+
vi4 a = { 1, 2, 3, 4 };
251+
int idx = 2;
252+
int value = 5;
253+
a[idx] = value;
254+
}
255+
256+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
257+
// CIR: %[[IDX:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["idx", init]
258+
// CIR: %[[VAL:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["value", init]
259+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
260+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
261+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
262+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
263+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
264+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
265+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
266+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
267+
// CIR: cir.store %[[CONST_IDX]], %[[IDX]] : !s32i, !cir.ptr<!s32i>
268+
// CIR: %[[CONST_VAL:.*]] = cir.const #cir.int<5> : !s32i
269+
// CIR: cir.store %[[CONST_VAL]], %[[VAL]] : !s32i, !cir.ptr<!s32i>
270+
// CIR: %[[TMP1:.*]] = cir.load %[[VAL]] : !cir.ptr<!s32i>, !s32i
271+
// CIR: %[[TMP2:.*]] = cir.load %[[IDX]] : !cir.ptr<!s32i>, !s32i
272+
// CIR: %[[TMP3:.*]] = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
273+
// CIR: %[[NEW_VEC:.*]] = cir.vec.insert %[[TMP1]], %[[TMP3]][%[[TMP2]] : !s32i] : !cir.vector<4 x !s32i>
274+
// CIR: cir.store %[[NEW_VEC]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
275+
276+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
277+
// LLVM: %[[IDX:.*]] = alloca i32, i64 1, align 4
278+
// LLVM: %[[VAL:.*]] = alloca i32, i64 1, align 4
279+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %1, align 16
280+
// LLVM: store i32 2, ptr %[[IDX]], align 4
281+
// LLVM: store i32 5, ptr %[[VAL]], align 4
282+
// LLVM: %[[TMP1:.*]] = load i32, ptr %[[VAL]], align 4
283+
// LLVM: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
284+
// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
285+
// LLVM: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP3]], i32 %[[TMP1]], i32 %[[TMP2]]
286+
// LLVM: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
287+
288+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
289+
// OGCG: %[[IDX:.*]] = alloca i32, align 4
290+
// OGCG: %[[VAL:.*]] = alloca i32, align 4
291+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
292+
// OGCG: store i32 2, ptr %[[IDX]], align 4
293+
// OGCG: store i32 5, ptr %[[VAL]], align 4
294+
// OGCG: %[[TMP1:.*]] = load i32, ptr %[[VAL]], align 4
295+
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
296+
// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
297+
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP3]], i32 %[[TMP1]], i32 %[[TMP2]]
298+
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
299+
300+
void foo7() {
301+
vi4 a = {1, 2, 3, 4};
302+
a[2] += 5;
303+
}
304+
305+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
306+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
307+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
308+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
309+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
310+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
311+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
312+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
313+
// CIR: %[[CONST_VAL:.*]] = cir.const #cir.int<5> : !s32i
314+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
315+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
316+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[CONST_IDX]] : !s32i] : !cir.vector<4 x !s32i>
317+
// CIR: %[[RES:.*]] = cir.binop(add, %[[ELE]], %[[CONST_VAL]]) nsw : !s32i
318+
// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
319+
// CIR: %[[NEW_VEC:.*]] = cir.vec.insert %[[RES]], %[[TMP2]][%[[CONST_IDX]] : !s32i] : !cir.vector<4 x !s32i>
320+
// CIR: cir.store %[[NEW_VEC]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
321+
322+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
323+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
324+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
325+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 2
326+
// LLVM: %[[RES:.*]] = add nsw i32 %[[ELE]], 5
327+
// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
328+
// LLVM: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
329+
// LLVM: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
330+
331+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
332+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
333+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
334+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 2
335+
// OGCG: %[[RES:.*]] = add nsw i32 %[[ELE]], 5
336+
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
337+
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
338+
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16

0 commit comments

Comments
 (0)