Skip to content

Commit 6ec169d

Browse files
authored
[SandboxIR] Implement BinaryOperator (#104121)
This patch implements sandboxir::BinaryOperator mirroring llvm::BinaryOperator.
1 parent 2adc012 commit 6ec169d

File tree

4 files changed

+435
-24
lines changed

4 files changed

+435
-24
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class CastInst;
130130
class PtrToIntInst;
131131
class BitCastInst;
132132
class AllocaInst;
133+
class BinaryOperator;
133134
class AtomicCmpXchgInst;
134135

135136
/// Iterator for the `Use` edges of a User's operands.
@@ -249,6 +250,7 @@ class Value {
249250
friend class InvokeInst; // For getting `Val`.
250251
friend class CallBrInst; // For getting `Val`.
251252
friend class GetElementPtrInst; // For getting `Val`.
253+
friend class BinaryOperator; // For getting `Val`.
252254
friend class AtomicCmpXchgInst; // For getting `Val`.
253255
friend class AllocaInst; // For getting `Val`.
254256
friend class CastInst; // For getting `Val`.
@@ -630,6 +632,7 @@ class Instruction : public sandboxir::User {
630632
friend class InvokeInst; // For getTopmostLLVMInstruction().
631633
friend class CallBrInst; // For getTopmostLLVMInstruction().
632634
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
635+
friend class BinaryOperator; // For getTopmostLLVMInstruction().
633636
friend class AtomicCmpXchgInst; // For getTopmostLLVMInstruction().
634637
friend class AllocaInst; // For getTopmostLLVMInstruction().
635638
friend class CastInst; // For getTopmostLLVMInstruction().
@@ -1432,6 +1435,86 @@ class GetElementPtrInst final
14321435
// TODO: Add missing member functions.
14331436
};
14341437

1438+
class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
1439+
static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
1440+
switch (BinOp) {
1441+
case llvm::Instruction::Add:
1442+
return Opcode::Add;
1443+
case llvm::Instruction::FAdd:
1444+
return Opcode::FAdd;
1445+
case llvm::Instruction::Sub:
1446+
return Opcode::Sub;
1447+
case llvm::Instruction::FSub:
1448+
return Opcode::FSub;
1449+
case llvm::Instruction::Mul:
1450+
return Opcode::Mul;
1451+
case llvm::Instruction::FMul:
1452+
return Opcode::FMul;
1453+
case llvm::Instruction::UDiv:
1454+
return Opcode::UDiv;
1455+
case llvm::Instruction::SDiv:
1456+
return Opcode::SDiv;
1457+
case llvm::Instruction::FDiv:
1458+
return Opcode::FDiv;
1459+
case llvm::Instruction::URem:
1460+
return Opcode::URem;
1461+
case llvm::Instruction::SRem:
1462+
return Opcode::SRem;
1463+
case llvm::Instruction::FRem:
1464+
return Opcode::FRem;
1465+
case llvm::Instruction::Shl:
1466+
return Opcode::Shl;
1467+
case llvm::Instruction::LShr:
1468+
return Opcode::LShr;
1469+
case llvm::Instruction::AShr:
1470+
return Opcode::AShr;
1471+
case llvm::Instruction::And:
1472+
return Opcode::And;
1473+
case llvm::Instruction::Or:
1474+
return Opcode::Or;
1475+
case llvm::Instruction::Xor:
1476+
return Opcode::Xor;
1477+
case llvm::Instruction::BinaryOpsEnd:
1478+
llvm_unreachable("Bad BinOp!");
1479+
}
1480+
llvm_unreachable("Unhandled BinOp!");
1481+
}
1482+
BinaryOperator(llvm::BinaryOperator *BinOp, Context &Ctx)
1483+
: SingleLLVMInstructionImpl(ClassID::BinaryOperator,
1484+
getBinOpOpcode(BinOp->getOpcode()), BinOp,
1485+
Ctx) {}
1486+
friend class Context; // For constructor.
1487+
1488+
public:
1489+
static Value *create(Instruction::Opcode Op, Value *LHS, Value *RHS,
1490+
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
1491+
const Twine &Name = "");
1492+
static Value *create(Instruction::Opcode Op, Value *LHS, Value *RHS,
1493+
Instruction *InsertBefore, Context &Ctx,
1494+
const Twine &Name = "");
1495+
static Value *create(Instruction::Opcode Op, Value *LHS, Value *RHS,
1496+
BasicBlock *InsertAtEnd, Context &Ctx,
1497+
const Twine &Name = "");
1498+
1499+
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
1500+
Value *RHS, Value *CopyFrom,
1501+
BBIterator WhereIt, BasicBlock *WhereBB,
1502+
Context &Ctx, const Twine &Name = "");
1503+
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
1504+
Value *RHS, Value *CopyFrom,
1505+
Instruction *InsertBefore, Context &Ctx,
1506+
const Twine &Name = "");
1507+
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
1508+
Value *RHS, Value *CopyFrom,
1509+
BasicBlock *InsertAtEnd, Context &Ctx,
1510+
const Twine &Name = "");
1511+
/// For isa/dyn_cast.
1512+
static bool classof(const Value *From) {
1513+
return From->getSubclassID() == ClassID::BinaryOperator;
1514+
}
1515+
void swapOperands() { swapOperandsInternal(0, 1); }
1516+
};
1517+
14351518
class AtomicCmpXchgInst
14361519
: public SingleLLVMInstructionImpl<llvm::AtomicCmpXchgInst> {
14371520
AtomicCmpXchgInst(llvm::AtomicCmpXchgInst *Atomic, Context &Ctx)
@@ -1876,6 +1959,8 @@ class Context {
18761959
friend CallBrInst; // For createCallBrInst()
18771960
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
18781961
friend GetElementPtrInst; // For createGetElementPtrInst()
1962+
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
1963+
friend BinaryOperator; // For createBinaryOperator()
18791964
AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
18801965
friend AtomicCmpXchgInst; // For createAtomicCmpXchgInst()
18811966
AllocaInst *createAllocaInst(llvm::AllocaInst *I);

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,39 @@ DEF_USER(Constant, Constant)
3232
#define OPCODES(...)
3333
#endif
3434
// clang-format off
35-
// ClassID, Opcode(s), Class
36-
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
35+
// ClassID, Opcode(s), Class
36+
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
3737
DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
38-
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
39-
DEF_INSTR(Select, OP(Select), SelectInst)
40-
DEF_INSTR(Br, OP(Br), BranchInst)
41-
DEF_INSTR(Load, OP(Load), LoadInst)
42-
DEF_INSTR(Store, OP(Store), StoreInst)
43-
DEF_INSTR(Ret, OP(Ret), ReturnInst)
44-
DEF_INSTR(Call, OP(Call), CallInst)
45-
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
46-
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
47-
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
38+
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
39+
DEF_INSTR(Select, OP(Select), SelectInst)
40+
DEF_INSTR(Br, OP(Br), BranchInst)
41+
DEF_INSTR(Load, OP(Load), LoadInst)
42+
DEF_INSTR(Store, OP(Store), StoreInst)
43+
DEF_INSTR(Ret, OP(Ret), ReturnInst)
44+
DEF_INSTR(Call, OP(Call), CallInst)
45+
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
46+
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
47+
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
48+
DEF_INSTR(BinaryOperator, OPCODES( \
49+
OP(Add) \
50+
OP(FAdd) \
51+
OP(Sub) \
52+
OP(FSub) \
53+
OP(Mul) \
54+
OP(FMul) \
55+
OP(UDiv) \
56+
OP(SDiv) \
57+
OP(FDiv) \
58+
OP(URem) \
59+
OP(SRem) \
60+
OP(FRem) \
61+
OP(Shl) \
62+
OP(LShr) \
63+
OP(AShr) \
64+
OP(And) \
65+
OP(Or) \
66+
OP(Xor) \
67+
), BinaryOperator)
4868
DEF_INSTR(AtomicCmpXchg, OP(AtomicCmpXchg), AtomicCmpXchgInst)
4969
DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
5070
DEF_INSTR(Cast, OPCODES(\

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,107 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
12191219
}
12201220
}
12211221

1222+
/// \Returns the LLVM opcode that corresponds to \p Opc.
1223+
static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
1224+
switch (Opc) {
1225+
case Instruction::Opcode::Add:
1226+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Add);
1227+
case Instruction::Opcode::FAdd:
1228+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FAdd);
1229+
case Instruction::Opcode::Sub:
1230+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Sub);
1231+
case Instruction::Opcode::FSub:
1232+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FSub);
1233+
case Instruction::Opcode::Mul:
1234+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Mul);
1235+
case Instruction::Opcode::FMul:
1236+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FMul);
1237+
case Instruction::Opcode::UDiv:
1238+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::UDiv);
1239+
case Instruction::Opcode::SDiv:
1240+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::SDiv);
1241+
case Instruction::Opcode::FDiv:
1242+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FDiv);
1243+
case Instruction::Opcode::URem:
1244+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::URem);
1245+
case Instruction::Opcode::SRem:
1246+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::SRem);
1247+
case Instruction::Opcode::FRem:
1248+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FRem);
1249+
case Instruction::Opcode::Shl:
1250+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Shl);
1251+
case Instruction::Opcode::LShr:
1252+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::LShr);
1253+
case Instruction::Opcode::AShr:
1254+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::AShr);
1255+
case Instruction::Opcode::And:
1256+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::And);
1257+
case Instruction::Opcode::Or:
1258+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Or);
1259+
case Instruction::Opcode::Xor:
1260+
return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Xor);
1261+
default:
1262+
llvm_unreachable("Not a binary op!");
1263+
}
1264+
}
1265+
Value *BinaryOperator::create(Instruction::Opcode Op, Value *LHS, Value *RHS,
1266+
BBIterator WhereIt, BasicBlock *WhereBB,
1267+
Context &Ctx, const Twine &Name) {
1268+
auto &Builder = Ctx.getLLVMIRBuilder();
1269+
if (WhereIt == WhereBB->end())
1270+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
1271+
else
1272+
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
1273+
llvm::Value *NewV =
1274+
Builder.CreateBinOp(getLLVMBinaryOp(Op), LHS->Val, RHS->Val, Name);
1275+
if (auto *NewBinOp = dyn_cast<llvm::BinaryOperator>(NewV))
1276+
return Ctx.createBinaryOperator(NewBinOp);
1277+
assert(isa<llvm::Constant>(NewV) && "Expected constant");
1278+
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
1279+
}
1280+
1281+
Value *BinaryOperator::create(Instruction::Opcode Op, Value *LHS, Value *RHS,
1282+
Instruction *InsertBefore, Context &Ctx,
1283+
const Twine &Name) {
1284+
return create(Op, LHS, RHS, InsertBefore->getIterator(),
1285+
InsertBefore->getParent(), Ctx, Name);
1286+
}
1287+
1288+
Value *BinaryOperator::create(Instruction::Opcode Op, Value *LHS, Value *RHS,
1289+
BasicBlock *InsertAtEnd, Context &Ctx,
1290+
const Twine &Name) {
1291+
return create(Op, LHS, RHS, InsertAtEnd->end(), InsertAtEnd, Ctx, Name);
1292+
}
1293+
1294+
Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
1295+
Value *RHS, Value *CopyFrom,
1296+
BBIterator WhereIt,
1297+
BasicBlock *WhereBB, Context &Ctx,
1298+
const Twine &Name) {
1299+
1300+
Value *NewV = create(Op, LHS, RHS, WhereIt, WhereBB, Ctx, Name);
1301+
if (auto *NewBO = dyn_cast<BinaryOperator>(NewV))
1302+
cast<llvm::BinaryOperator>(NewBO->Val)->copyIRFlags(CopyFrom->Val);
1303+
return NewV;
1304+
}
1305+
1306+
Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
1307+
Value *RHS, Value *CopyFrom,
1308+
Instruction *InsertBefore,
1309+
Context &Ctx, const Twine &Name) {
1310+
return createWithCopiedFlags(Op, LHS, RHS, CopyFrom,
1311+
InsertBefore->getIterator(),
1312+
InsertBefore->getParent(), Ctx, Name);
1313+
}
1314+
1315+
Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
1316+
Value *RHS, Value *CopyFrom,
1317+
BasicBlock *InsertAtEnd,
1318+
Context &Ctx, const Twine &Name) {
1319+
return createWithCopiedFlags(Op, LHS, RHS, CopyFrom, InsertAtEnd->end(),
1320+
InsertAtEnd, Ctx, Name);
1321+
}
1322+
12221323
void AtomicCmpXchgInst::setSyncScopeID(SyncScope::ID SSID) {
12231324
Ctx.getTracker()
12241325
.emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getSyncScopeID,
@@ -1628,6 +1729,29 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
16281729
new GetElementPtrInst(LLVMGEP, *this));
16291730
return It->second.get();
16301731
}
1732+
case llvm::Instruction::Add:
1733+
case llvm::Instruction::FAdd:
1734+
case llvm::Instruction::Sub:
1735+
case llvm::Instruction::FSub:
1736+
case llvm::Instruction::Mul:
1737+
case llvm::Instruction::FMul:
1738+
case llvm::Instruction::UDiv:
1739+
case llvm::Instruction::SDiv:
1740+
case llvm::Instruction::FDiv:
1741+
case llvm::Instruction::URem:
1742+
case llvm::Instruction::SRem:
1743+
case llvm::Instruction::FRem:
1744+
case llvm::Instruction::Shl:
1745+
case llvm::Instruction::LShr:
1746+
case llvm::Instruction::AShr:
1747+
case llvm::Instruction::And:
1748+
case llvm::Instruction::Or:
1749+
case llvm::Instruction::Xor: {
1750+
auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV);
1751+
It->second = std::unique_ptr<BinaryOperator>(
1752+
new BinaryOperator(LLVMBinaryOperator, *this));
1753+
return It->second.get();
1754+
}
16311755
case llvm::Instruction::AtomicCmpXchg: {
16321756
auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
16331757
It->second = std::unique_ptr<AtomicCmpXchgInst>(
@@ -1751,6 +1875,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
17511875
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
17521876
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
17531877
}
1878+
BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
1879+
auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
1880+
return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
1881+
}
17541882
AtomicCmpXchgInst *
17551883
Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
17561884
auto NewPtr =

0 commit comments

Comments
 (0)