Skip to content

Commit 6e5f9bb

Browse files
authored
[CIR] Allow use different Int types together in Vec Shift Op (#141111)
Update the verification of ShiftOp for Vector to allow performing shift op between signed and unsigned integers, similar to LLVM IR Issue #136487
1 parent 4a44e00 commit 6e5f9bb

File tree

4 files changed

+153
-21
lines changed

4 files changed

+153
-21
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/FunctionImplementation.h"
20-
#include "mlir/Support/LogicalResult.h"
2120

2221
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
2322
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
@@ -1427,15 +1426,32 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
14271426
//===----------------------------------------------------------------------===//
14281427
LogicalResult cir::ShiftOp::verify() {
14291428
mlir::Operation *op = getOperation();
1430-
mlir::Type resType = getResult().getType();
1431-
const bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
1432-
const bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
1433-
if (isOp0Vec != isOp1Vec)
1429+
auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
1430+
auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
1431+
if (!op0VecTy ^ !op1VecTy)
14341432
return emitOpError() << "input types cannot be one vector and one scalar";
1435-
if (isOp1Vec && op->getOperand(1).getType() != resType) {
1436-
return emitOpError() << "shift amount must have the type of the result "
1437-
<< "if it is vector shift";
1433+
1434+
if (op0VecTy) {
1435+
if (op0VecTy.getSize() != op1VecTy.getSize())
1436+
return emitOpError() << "input vector types must have the same size";
1437+
1438+
auto opResultTy = mlir::dyn_cast<cir::VectorType>(getResult().getType());
1439+
if (!opResultTy)
1440+
return emitOpError() << "the type of the result must be a vector "
1441+
<< "if it is vector shift";
1442+
1443+
auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
1444+
auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
1445+
if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
1446+
return emitOpError()
1447+
<< "vector operands do not have the same elements sizes";
1448+
1449+
auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
1450+
if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
1451+
return emitOpError() << "vector operands and result type do not have the "
1452+
"same elements sizes";
14381453
}
1454+
14391455
return mlir::success();
14401456
}
14411457

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,61 @@ void foo9() {
461461
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
462462
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
463463

464+
void foo10() {
465+
vi4 a = {1, 2, 3, 4};
466+
uvi4 b = {5u, 6u, 7u, 8u};
467+
468+
vi4 shl = a << b;
469+
uvi4 shr = b >> a;
470+
}
471+
472+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
473+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
474+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
475+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
476+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
477+
// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
478+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
479+
// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
480+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
481+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
482+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !s32i>
483+
// CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
484+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
485+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
486+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[TMP_A]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !u32i>
487+
// CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
488+
489+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
490+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
491+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
492+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
493+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
494+
// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
495+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
496+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
497+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
498+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
499+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
500+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
501+
// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[TMP_A]]
502+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
503+
504+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
505+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
506+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
507+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
508+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
509+
// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
510+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
511+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
512+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
513+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
514+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
515+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
516+
// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[TMP_A]]
517+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
518+
464519
void foo11() {
465520
vi4 a = {1, 2, 3, 4};
466521
vi4 b = {5, 6, 7, 8};
@@ -933,4 +988,3 @@ void foo14() {
933988
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
934989
// OGCG: %[[GE:.*]] = fcmp oge <4 x float> %[[TMP_A]], %[[TMP_B]]
935990
// OGCG: %[[RES:.*]] = sext <4 x i1> %[[GE]] to <4 x i32>
936-

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -396,19 +396,9 @@ void foo9() {
396396
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
397397
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
398398
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
399-
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
400-
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
401-
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
402-
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
403-
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
404-
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
399+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
405400
// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
406-
// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
407-
// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
408-
// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
409-
// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
410-
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
411-
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
401+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
412402
// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
413403
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
414404
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
@@ -449,6 +439,61 @@ void foo9() {
449439
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
450440
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
451441

442+
void foo10() {
443+
vi4 a = {1, 2, 3, 4};
444+
uvi4 b = {5u, 6u, 7u, 8u};
445+
446+
vi4 shl = a << b;
447+
uvi4 shr = b >> a;
448+
}
449+
450+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
451+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
452+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
453+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
454+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
455+
// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
456+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
457+
// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
458+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
459+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
460+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !s32i>
461+
// CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
462+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
463+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
464+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[TMP_A]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !u32i>
465+
// CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
466+
467+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
468+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
469+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
470+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
471+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
472+
// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
473+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
474+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
475+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
476+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
477+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
478+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
479+
// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[TMP_A]]
480+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
481+
482+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
483+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
484+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
485+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
486+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
487+
// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
488+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
489+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
490+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
491+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
492+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
493+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
494+
// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[TMP_A]]
495+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
496+
452497
void foo11() {
453498
vi4 a = {1, 2, 3, 4};
454499
vi4 b = {5, 6, 7, 8};
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: cir-opt %s -verify-diagnostics -split-input-file
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @foo() {
7+
%1 = cir.const #cir.int<1> : !s32i
8+
%2 = cir.const #cir.int<2> : !s32i
9+
%3 = cir.const #cir.int<3> : !s32i
10+
%4 = cir.const #cir.int<4> : !s32i
11+
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
12+
%6 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
13+
// expected-error @below {{the type of the result must be a vector if it is vector shift}}
14+
%7 = cir.shift(left, %5 : !cir.vector<4 x !s32i>, %6 : !cir.vector<4 x !s32i>) -> !s32i
15+
cir.return
16+
}
17+
}

0 commit comments

Comments
 (0)