1
1
#include " llvm/Passes/PassPlugin.h"
2
2
#include " llvm/Passes/PassBuilder.h"
3
- #include " llvm/IR/IRBuilder.h"
4
3
#include " llvm/IR/Constants.h"
5
- #include " llvm/IR/GlobalVariable.h"
6
-
7
- using namespace llvm ;
8
-
9
- struct LLVMPass : public PassInfoMixin <LLVMPass> {
10
- PreservedAnalyses run (Module &M, ModuleAnalysisManager &MAM);
11
- };
12
- PreservedAnalyses LLVMPass::run (Module &M, ModuleAnalysisManager &MAM) {
13
- LLVMContext &Ctx = M.getContext ();
14
- IntegerType *Int32Ty = IntegerType::getInt32Ty (Ctx);
15
- PointerType *Int8PtrTy = Type::getInt8PtrTy (Ctx);
16
-
17
- FunctionType *DebugTy = FunctionType::get (Type::getVoidTy (Ctx), {Int32Ty}, false );
18
- FunctionCallee debug_func = M.getOrInsertFunction (" debug" , DebugTy);
19
- ConstantInt *debug_arg = ConstantInt::get (Int32Ty, 48763 );
20
-
21
- Constant *StrConstant = ConstantDataArray::getString (Ctx, " hayaku... motohayaku!" , true );
22
- GlobalVariable *StrVar = new GlobalVariable (M, StrConstant->getType (), true ,
23
- GlobalValue::PrivateLinkage, StrConstant, " .str.hayaku" );
24
- Constant *Zero = ConstantInt::get (Int32Ty, 0 );
25
- Constant *Indices[] = {Zero, Zero};
26
- Constant *StrPtr = ConstantExpr::getGetElementPtr (StrConstant->getType (), StrVar, Indices);
27
-
28
- for (auto &F : M) {
29
- if (F.getName () != " main" )
30
- continue ;
31
-
32
- errs () << " Instrumenting function: " << F.getName () << " \n " ;
33
- IRBuilder<> Builder (&*F.getEntryBlock ().getFirstInsertionPt ());
34
-
35
- Builder.CreateCall (debug_func, {debug_arg});
36
-
37
- auto ArgIter = F.arg_begin ();
38
- Argument *argcArg = ArgIter++;
39
- Argument *argvArg = ArgIter;
4
+ #include " llvm/IR/IRBuilder.h"
5
+ #include " llvm/IR/GlobalValue.h"
6
+ #include " llvm/Support/raw_ostream.h"
40
7
41
- Value *Argv1Ptr = Builder.CreateGEP (Int8PtrTy, argvArg, ConstantInt::get (Int32Ty, 1 ));
42
- Builder.CreateStore (StrPtr, Argv1Ptr);
8
+ namespace {
43
9
44
- argcArg->replaceAllUsesWith (debug_arg);
10
+ // Module transformation to modify main function
11
+ class ModuleTransform : PassInfoMixin<ModuleTransform> {
12
+ public:
13
+ PreservedAnalyses run (llvm::Module& Module, llvm::ModuleAnalysisManager& AM) {
14
+ // Debug constant value
15
+ const uint32_t DEBUG_MAGIC = 48763 ;
16
+
17
+ // Get necessary types
18
+ auto & Context = Module.getContext ();
19
+ auto * Int32Type = llvm::Type::getInt32Ty (Context);
20
+ auto * CharPtrType = llvm::Type::getInt8PtrTy (Context);
21
+
22
+ // Create debug function declaration
23
+ auto * DebugFuncType = llvm::FunctionType::get (
24
+ llvm::Type::getVoidTy (Context),
25
+ {Int32Type},
26
+ false
27
+ );
28
+ auto DebugFunction = Module.getOrInsertFunction (" debug" , DebugFuncType);
29
+
30
+ // Create magic number constant
31
+ auto * MagicNumber = llvm::ConstantInt::get (Int32Type, DEBUG_MAGIC);
32
+
33
+ // Create string constant
34
+ auto * HiddenMessage = llvm::ConstantDataArray::getString (
35
+ Context,
36
+ " hayaku... motohayaku!" ,
37
+ true
38
+ );
39
+
40
+ // Create global variable for the string
41
+ auto * StringGlobal = new llvm::GlobalVariable (
42
+ Module,
43
+ HiddenMessage->getType (),
44
+ true ,
45
+ llvm::GlobalValue::PrivateLinkage,
46
+ HiddenMessage,
47
+ " .str.hayaku"
48
+ );
49
+
50
+ // Create GEP indices for string pointer
51
+ auto * ZeroIndex = llvm::ConstantInt::get (Int32Type, 0 );
52
+ llvm::Constant* Indices[2 ] = {ZeroIndex, ZeroIndex};
53
+
54
+ // Get pointer to string
55
+ auto * StringPointer = llvm::ConstantExpr::getGetElementPtr (
56
+ HiddenMessage->getType (),
57
+ StringGlobal,
58
+ Indices
59
+ );
60
+
61
+ // Find and instrument main function
62
+ for (auto & Func : Module) {
63
+ if (Func.getName () != " main" )
64
+ continue ;
65
+
66
+ llvm::errs () << " Found and instrumenting: " << Func.getName () << " \n " ;
67
+
68
+ // Create builder at entry point
69
+ llvm::IRBuilder<> Builder (&Func.getEntryBlock ().front ());
70
+
71
+ // Insert debug call
72
+ Builder.CreateCall (DebugFunction, {MagicNumber});
73
+
74
+ // Get function args
75
+ auto Args = Func.arg_begin ();
76
+ auto * ArgCount = Args++;
77
+ auto * ArgVector = Args;
78
+
79
+ // Modify argv[1] to point to our string
80
+ auto * SecondArgPtr = Builder.CreateGEP (CharPtrType, ArgVector,
81
+ llvm::ConstantInt::get (Int32Type, 1 ));
82
+ Builder.CreateStore (StringPointer, SecondArgPtr);
83
+
84
+ // Replace all uses of argc with our magic value
85
+ ArgCount->replaceAllUsesWith (MagicNumber);
86
+ }
87
+
88
+ // Mark all analyses as invalidated
89
+ return PreservedAnalyses::none ();
45
90
}
91
+ };
46
92
47
- return PreservedAnalyses::none ();
48
- }
93
+ } // anonymous namespace
49
94
50
- extern " C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
51
- llvmGetPassPluginInfo () {
52
- return {LLVM_PLUGIN_API_VERSION, " LLVMPass" , " 1.0" ,
53
- [](PassBuilder &PB) {
95
+ // Plugin registration
96
+ extern " C" LLVM_ATTRIBUTE_WEAK
97
+ ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo () {
98
+ return {
99
+ LLVM_PLUGIN_API_VERSION,
100
+ " ModuleInstrumenter" , // plugin name
101
+ " v1.0" , // plugin version
102
+ [](llvm::PassBuilder &PB) {
54
103
PB.registerOptimizerLastEPCallback (
55
- [](ModulePassManager &MPM, OptimizationLevel OL) {
56
- MPM.addPass (LLVMPass ());
57
- });
58
- }};
104
+ [](llvm::ModulePassManager &MPM, llvm::OptimizationLevel Level) {
105
+ MPM.addPass (ModuleTransform ());
106
+ }
107
+ );
108
+ }
109
+ };
59
110
}
0 commit comments