Skip to content

Commit af91f1b

Browse files
authored
Merge pull request #311 from czgdp1807/inline_pass_fix
Fixed inline function calls pass to inline simple functions
2 parents 44b247e + 418bf20 commit af91f1b

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

src/bin/lpython.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <libasr/pass/arr_slice.h>
2525
#include <libasr/pass/print_arr.h>
2626
#include <libasr/pass/unused_functions.h>
27+
#include <libasr/pass/inline_function_calls.h>
2728
#include <libasr/asr_utils.h>
2829
#include <libasr/asr_verify.h>
2930
#include <libasr/modfile.h>
@@ -49,6 +50,7 @@ enum Backend {
4950
enum ASRPass {
5051
do_loops, global_stmts, implied_do_loops, array_op,
5152
arr_slice, print_arr, class_constructor, unused_functions,
53+
inline_function_calls
5254
};
5355

5456
std::string remove_extension(const std::string& filename) {
@@ -722,6 +724,8 @@ int main(int argc, char *argv[])
722724
passes.push_back(ASRPass::implied_do_loops);
723725
} else if (arg_pass == "array_op") {
724726
passes.push_back(ASRPass::array_op);
727+
} else if (arg_pass == "inline_function_calls") {
728+
passes.push_back(ASRPass::inline_function_calls);
725729
} else if (arg_pass == "class_constructor") {
726730
passes.push_back(ASRPass::class_constructor);
727731
} else if (arg_pass == "print_arr") {

src/libasr/asr_utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,10 @@ static inline bool is_intrinsic_function2(const ASR::Function_t *fn) {
385385
ASR::symbol_t *sym = (ASR::symbol_t*)fn;
386386
ASR::Module_t *m = get_sym_module0(sym);
387387
if (m != nullptr) {
388-
if (m->m_intrinsic) return true;
388+
if (m->m_intrinsic ||
389+
fn->m_abi == ASR::abiType::Intrinsic) {
390+
return true;
391+
}
389392
}
390393
return false;
391394
}

src/libasr/pass/inline_function_calls.cpp

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
4444
ASR::expr_t* function_result_var;
4545

4646
bool from_inline_function_call, inlining_function;
47+
bool fixed_duplicated_expr_stmt;
4748

4849
// Stores the local variables corresponding to the ones
4950
// present in function symbol table.
@@ -56,16 +57,18 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
5657

5758
ASR::ExprStmtDuplicator node_duplicator;
5859

60+
SymbolTable* current_routine_scope;
61+
5962
public:
6063

6164
bool function_inlined;
6265

6366
InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_, bool inline_external_symbol_calls_)
6467
: PassVisitor(al_, nullptr),
6568
rl_path(rl_path_), function_result_var(nullptr),
66-
from_inline_function_call(false), inlining_function(false),
69+
from_inline_function_call(false), inlining_function(false), fixed_duplicated_expr_stmt(false),
6770
current_routine(""), inline_external_symbol_calls(inline_external_symbol_calls_),
68-
node_duplicator(al_), function_inlined(false)
71+
node_duplicator(al_), current_routine_scope(nullptr), function_inlined(false)
6972
{
7073
pass_result.reserve(al, 1);
7174
}
@@ -85,14 +88,28 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
8588

8689
void visit_Var(const ASR::Var_t& x) {
8790
ASR::Var_t& xx = const_cast<ASR::Var_t&>(x);
88-
ASR::Variable_t* x_var = ASR::down_cast<ASR::Variable_t>(x.m_v);
89-
std::string x_var_name = std::string(x_var->m_name);
90-
if( arg2value.find(x_var_name) != arg2value.end() ) {
91-
x_var = ASR::down_cast<ASR::Variable_t>(arg2value[x_var_name]);
92-
if( current_scope->scope.find(std::string(x_var->m_name)) != current_scope->scope.end() ) {
93-
xx.m_v = arg2value[x_var_name];
91+
std::string x_var_name = std::string(ASRUtils::symbol_name(x.m_v));
92+
93+
// If anything is not local to a function being inlined
94+
// then do not inline the function by setting
95+
// fixed_duplicated_expr_stmt to false.
96+
// To be supported later.
97+
if( current_routine_scope &&
98+
current_routine_scope->scope.find(x_var_name) == current_routine_scope->scope.end() ) {
99+
fixed_duplicated_expr_stmt = false;
100+
return ;
101+
}
102+
if( x.m_v->type == ASR::symbolType::Variable ) {
103+
ASR::Variable_t* x_var = ASR::down_cast<ASR::Variable_t>(x.m_v);
104+
if( arg2value.find(x_var_name) != arg2value.end() ) {
105+
x_var = ASR::down_cast<ASR::Variable_t>(arg2value[x_var_name]);
106+
if( current_scope->scope.find(std::string(x_var->m_name)) != current_scope->scope.end() ) {
107+
xx.m_v = arg2value[x_var_name];
108+
}
109+
x_var = ASR::down_cast<ASR::Variable_t>(x.m_v);
94110
}
95-
x_var = ASR::down_cast<ASR::Variable_t>(x.m_v);
111+
} else {
112+
fixed_duplicated_expr_stmt = false;
96113
}
97114
}
98115

@@ -166,10 +183,13 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
166183

167184
// Avoid inlining current function call if its a recursion.
168185
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(routine);
169-
if( std::string(func->m_name) == current_routine ) {
186+
if( ASRUtils::is_intrinsic_function2(func) ||
187+
std::string(func->m_name) == current_routine ) {
170188
return ;
171189
}
172190

191+
current_routine_scope = func->m_symtab;
192+
173193
ASR::expr_t* return_var = nullptr;
174194
// The following prepares arg2value map for inlining the
175195
// current function call. Variables are created in the current
@@ -259,7 +279,12 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
259279
// in the current scope. See, `visit_Var` to know how replacement occurs.
260280
for( size_t i = 0; i < exprs_to_be_visited.size() && success; i++ ) {
261281
ASR::expr_t* value = exprs_to_be_visited[i].first;
282+
fixed_duplicated_expr_stmt = true;
262283
visit_expr(*value);
284+
if( !fixed_duplicated_expr_stmt ) {
285+
success = false;
286+
break;
287+
}
263288
ASR::symbol_t* variable = exprs_to_be_visited[i].second;
264289
ASR::expr_t* var = LFortran::ASRUtils::EXPR(ASR::make_Var_t(al, variable->base.loc, variable));
265290
ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, var->base.loc, var, value, nullptr));
@@ -280,32 +305,45 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
280305
}
281306

282307
if( success ) {
283-
// If duplication is successfull then fill the
284-
// pass result with assignment statements
285-
// (for local variables in the loop just below)
286-
// and the function body (the next loop).
287-
for( size_t i = 0; i < pass_result_local.size(); i++ ) {
288-
pass_result.push_back(al, pass_result_local[i]);
289-
}
290308
// Set inlining_function to true so that we inline
291309
// only one function at a time.
292310
inlining_function = true;
293-
for( size_t i = 0; i < func->n_body; i++ ) {
311+
for( size_t i = 0; i < func->n_body && success; i++ ) {
312+
fixed_duplicated_expr_stmt = true;
294313
visit_stmt(*func_copy[i]);
295-
pass_result.push_back(al, func_copy[i]);
314+
success = success && fixed_duplicated_expr_stmt;
315+
}
316+
317+
if( success ) {
318+
// If duplication is successfull then fill the
319+
// pass result with assignment statements
320+
// (for local variables in the loop just below)
321+
// and the function body (the next loop).
322+
for( size_t i = 0; i < pass_result_local.size(); i++ ) {
323+
pass_result.push_back(al, pass_result_local[i]);
324+
}
325+
326+
for( size_t i = 0; i < func->n_body; i++ ) {
327+
pass_result.push_back(al, func_copy[i]);
328+
}
296329
}
297330
inlining_function = false;
331+
current_routine_scope = nullptr;
298332
function_result_var = return_var;
299-
} else {
333+
}
334+
335+
if (!success) {
300336
// If not successfull then delete all the local variables
301337
// created for the purpose of inlining the current function call.
302338
for( auto& itr : arg2value ) {
303339
ASR::Variable_t* auxiliary_var = ASR::down_cast<ASR::Variable_t>(itr.second);
304340
current_scope->scope.erase(std::string(auxiliary_var->m_name));
305341
}
342+
function_result_var = nullptr;
306343
}
307344
// At least one function is inlined
308345
function_inlined = success;
346+
success = false;
309347
// Clear up the arg2value to avoid corruption
310348
// of any kind.
311349
arg2value.clear();

0 commit comments

Comments
 (0)