@@ -44,6 +44,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
44
44
ASR::expr_t * function_result_var;
45
45
46
46
bool from_inline_function_call, inlining_function;
47
+ bool fixed_duplicated_expr_stmt;
47
48
48
49
// Stores the local variables corresponding to the ones
49
50
// present in function symbol table.
@@ -56,16 +57,18 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
56
57
57
58
ASR::ExprStmtDuplicator node_duplicator;
58
59
60
+ SymbolTable* current_routine_scope;
61
+
59
62
public:
60
63
61
64
bool function_inlined;
62
65
63
66
InlineFunctionCallVisitor (Allocator &al_, const std::string& rl_path_, bool inline_external_symbol_calls_)
64
67
: PassVisitor(al_, nullptr ),
65
68
rl_path (rl_path_), function_result_var(nullptr ),
66
- from_inline_function_call(false ), inlining_function(false ),
69
+ from_inline_function_call(false ), inlining_function(false ), fixed_duplicated_expr_stmt( false ),
67
70
current_routine(" " ), inline_external_symbol_calls(inline_external_symbol_calls_),
68
- node_duplicator(al_), function_inlined(false )
71
+ node_duplicator(al_), current_routine_scope( nullptr ), function_inlined(false )
69
72
{
70
73
pass_result.reserve (al, 1 );
71
74
}
@@ -85,14 +88,28 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
85
88
86
89
void visit_Var (const ASR::Var_t& x) {
87
90
ASR::Var_t& xx = const_cast <ASR::Var_t&>(x);
88
- ASR::Variable_t* x_var = ASR::down_cast<ASR::Variable_t>(x.m_v );
89
- std::string x_var_name = std::string (x_var->m_name );
90
- if ( arg2value.find (x_var_name) != arg2value.end () ) {
91
- x_var = ASR::down_cast<ASR::Variable_t>(arg2value[x_var_name]);
92
- if ( current_scope->scope .find (std::string (x_var->m_name )) != current_scope->scope .end () ) {
93
- xx.m_v = arg2value[x_var_name];
91
+ std::string x_var_name = std::string (ASRUtils::symbol_name (x.m_v ));
92
+
93
+ // If anything is not local to a function being inlined
94
+ // then do not inline the function by setting
95
+ // fixed_duplicated_expr_stmt to false.
96
+ // To be supported later.
97
+ if ( current_routine_scope &&
98
+ current_routine_scope->scope .find (x_var_name) == current_routine_scope->scope .end () ) {
99
+ fixed_duplicated_expr_stmt = false ;
100
+ return ;
101
+ }
102
+ if ( x.m_v ->type == ASR::symbolType::Variable ) {
103
+ ASR::Variable_t* x_var = ASR::down_cast<ASR::Variable_t>(x.m_v );
104
+ if ( arg2value.find (x_var_name) != arg2value.end () ) {
105
+ x_var = ASR::down_cast<ASR::Variable_t>(arg2value[x_var_name]);
106
+ if ( current_scope->scope .find (std::string (x_var->m_name )) != current_scope->scope .end () ) {
107
+ xx.m_v = arg2value[x_var_name];
108
+ }
109
+ x_var = ASR::down_cast<ASR::Variable_t>(x.m_v );
94
110
}
95
- x_var = ASR::down_cast<ASR::Variable_t>(x.m_v );
111
+ } else {
112
+ fixed_duplicated_expr_stmt = false ;
96
113
}
97
114
}
98
115
@@ -166,10 +183,13 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
166
183
167
184
// Avoid inlining current function call if its a recursion.
168
185
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(routine);
169
- if ( std::string (func->m_name ) == current_routine ) {
186
+ if ( ASRUtils::is_intrinsic_function2 (func) ||
187
+ std::string (func->m_name ) == current_routine ) {
170
188
return ;
171
189
}
172
190
191
+ current_routine_scope = func->m_symtab ;
192
+
173
193
ASR::expr_t * return_var = nullptr ;
174
194
// The following prepares arg2value map for inlining the
175
195
// current function call. Variables are created in the current
@@ -259,7 +279,12 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
259
279
// in the current scope. See, `visit_Var` to know how replacement occurs.
260
280
for ( size_t i = 0 ; i < exprs_to_be_visited.size () && success; i++ ) {
261
281
ASR::expr_t * value = exprs_to_be_visited[i].first ;
282
+ fixed_duplicated_expr_stmt = true ;
262
283
visit_expr (*value);
284
+ if ( !fixed_duplicated_expr_stmt ) {
285
+ success = false ;
286
+ break ;
287
+ }
263
288
ASR::symbol_t * variable = exprs_to_be_visited[i].second ;
264
289
ASR::expr_t * var = LFortran::ASRUtils::EXPR (ASR::make_Var_t (al, variable->base .loc , variable));
265
290
ASR::stmt_t * assign_stmt = ASRUtils::STMT (ASR::make_Assignment_t (al, var->base .loc , var, value, nullptr ));
@@ -280,32 +305,45 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
280
305
}
281
306
282
307
if ( success ) {
283
- // If duplication is successfull then fill the
284
- // pass result with assignment statements
285
- // (for local variables in the loop just below)
286
- // and the function body (the next loop).
287
- for ( size_t i = 0 ; i < pass_result_local.size (); i++ ) {
288
- pass_result.push_back (al, pass_result_local[i]);
289
- }
290
308
// Set inlining_function to true so that we inline
291
309
// only one function at a time.
292
310
inlining_function = true ;
293
- for ( size_t i = 0 ; i < func->n_body ; i++ ) {
311
+ for ( size_t i = 0 ; i < func->n_body && success; i++ ) {
312
+ fixed_duplicated_expr_stmt = true ;
294
313
visit_stmt (*func_copy[i]);
295
- pass_result.push_back (al, func_copy[i]);
314
+ success = success && fixed_duplicated_expr_stmt;
315
+ }
316
+
317
+ if ( success ) {
318
+ // If duplication is successfull then fill the
319
+ // pass result with assignment statements
320
+ // (for local variables in the loop just below)
321
+ // and the function body (the next loop).
322
+ for ( size_t i = 0 ; i < pass_result_local.size (); i++ ) {
323
+ pass_result.push_back (al, pass_result_local[i]);
324
+ }
325
+
326
+ for ( size_t i = 0 ; i < func->n_body ; i++ ) {
327
+ pass_result.push_back (al, func_copy[i]);
328
+ }
296
329
}
297
330
inlining_function = false ;
331
+ current_routine_scope = nullptr ;
298
332
function_result_var = return_var;
299
- } else {
333
+ }
334
+
335
+ if (!success) {
300
336
// If not successfull then delete all the local variables
301
337
// created for the purpose of inlining the current function call.
302
338
for ( auto & itr : arg2value ) {
303
339
ASR::Variable_t* auxiliary_var = ASR::down_cast<ASR::Variable_t>(itr.second );
304
340
current_scope->scope .erase (std::string (auxiliary_var->m_name ));
305
341
}
342
+ function_result_var = nullptr ;
306
343
}
307
344
// At least one function is inlined
308
345
function_inlined = success;
346
+ success = false ;
309
347
// Clear up the arg2value to avoid corruption
310
348
// of any kind.
311
349
arg2value.clear ();
0 commit comments