From bc1063ae3e29134946bd28fe7ca1545bc8538bf1 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Mon, 28 Feb 2022 20:37:30 +0530 Subject: [PATCH 01/14] add overload support in lpython --- src/lpython/semantics/python_ast_to_asr.cpp | 65 ++++++++++++++++++--- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 3dd332ad1e..e5bab33f0d 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -445,14 +445,14 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable * throw SemanticError("Only Subroutines, Functions and Variables are currently supported in 'import'", loc); } - // should not reach here + LFORTRAN_ASSERT(false); return nullptr; } class SymbolTableVisitor : public CommonVisitor { public: SymbolTable *global_scope; - std::map> generic_procedures; + std::map> generic_procedures, overload_defs; std::map>> generic_class_procedures; std::map> defined_op_procs; std::map> class_procedures; @@ -534,12 +534,23 @@ class SymbolTableVisitor : public CommonVisitor { Vec args; args.reserve(al, x.m_args.n_args); current_procedure_abi_type = ASR::abiType::Source; - if (x.n_decorator_list == 1) { - AST::expr_t *dec = x.m_decorator_list[0]; - if (AST::is_a(*dec)) { - std::string name = AST::down_cast(dec)->m_id; - if (name == "ccall") { - current_procedure_abi_type = ASR::abiType::BindC; + bool overload = false; + if (x.n_decorator_list > 0) { + for(size_t i=0; i(*dec)) { + std::string name = AST::down_cast(dec)->m_id; + if (name == "ccall") { + current_procedure_abi_type = ASR::abiType::BindC; + } else if (name == "overload") { + overload = true; + } else { + throw SemanticError("Decorator: " + name + " is not supported", + x.base.base.loc); + } + } else { + throw SemanticError("Unsupported Decorator type", + x.base.base.loc); } } } @@ -578,6 +589,15 @@ class SymbolTableVisitor : public CommonVisitor { var))); } std::string sym_name = x.m_name; + if (overload) { + std::string overload_number; + if (overload_defs.find(sym_name) == overload_defs.end()){ + overload_number = "0"; + } else { + overload_number = std::to_string(overload_defs[sym_name].size()); + } + sym_name = "__lpython_overloaded_" + overload_number + "__" + sym_name; + } if (parent_scope->scope.find(sym_name) != parent_scope->scope.end()) { throw SemanticError("Subroutine already defined", tmp->loc); } @@ -633,6 +653,9 @@ class SymbolTableVisitor : public CommonVisitor { } parent_scope->scope[sym_name] = ASR::down_cast(tmp); current_scope = parent_scope; + if (overload) { + overload_defs[x.m_name].push_back(sym_name); + } } void visit_ImportFrom(const AST::ImportFrom_t &x) { @@ -809,9 +832,33 @@ class BodyVisitor : public CommonVisitor { v.n_body = body.size(); } + ASR::symbol_t* overloaddef_find_helper(std::string func_name, Vec args, + const Location &loc) { + for(auto &t: overload_defs[func_name]) { + bool ok = ASRUtils::select_func_subrout(t, args, loc, + [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); + if (ok) { + return t; + } + } + return nullptr; + } + void visit_FunctionDef(const AST::FunctionDef_t &x) { SymbolTable *old_scope = current_scope; - ASR::symbol_t *t = current_scope->scope[x.m_name]; + ASR::symbol_t *t = nullptr; + if (overload_defs.find(x.m_name) != overload_defs.end()) { + Vec args; + args.reserve(al, x.m_args.n_args); + for (size_t i=0; iscope[x.m_name]; + } if (ASR::is_a(*t)) { handle_fn(x, *ASR::down_cast(t)); } else if (ASR::is_a(*t)) { From 8e1771c8b936ee0ae0e9dd4b7cf84b809f76fb93 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Tue, 1 Mar 2022 23:56:30 +0530 Subject: [PATCH 02/14] overload support now compiles --- src/lpython/semantics/python_ast_to_asr.cpp | 127 +++++++++++++++++++- 1 file changed, 123 insertions(+), 4 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index e5bab33f0d..b34f30a809 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -139,6 +139,121 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, return mod2; } +bool types_equal(const ASR::ttype_t &a, const ASR::ttype_t &b) { + // TODO: If anyone of the input or argument is derived type then + // add support for checking member wise types and do not compare + // directly. From stdlib_string len(pattern) error. + if (b.type == ASR::ttypeType::Derived || b.type == ASR::ttypeType::Class) { + return true; + } + if (a.type == b.type) { + // TODO: check dims + // TODO: check all types + switch (a.type) { + case (ASR::ttypeType::Integer) : { + ASR::Integer_t *a2 = ASR::down_cast(&a); + ASR::Integer_t *b2 = ASR::down_cast(&b); + if (a2->m_kind == b2->m_kind) { + return true; + } else { + return false; + } + break; + } + case (ASR::ttypeType::Real) : { + ASR::Real_t *a2 = ASR::down_cast(&a); + ASR::Real_t *b2 = ASR::down_cast(&b); + if (a2->m_kind == b2->m_kind) { + return true; + } else { + return false; + } + break; + } + case (ASR::ttypeType::Complex) : { + ASR::Complex_t *a2 = ASR::down_cast(&a); + ASR::Complex_t *b2 = ASR::down_cast(&b); + if (a2->m_kind == b2->m_kind) { + return true; + } else { + return false; + } + break; + } + case (ASR::ttypeType::Logical) : { + ASR::Logical_t *a2 = ASR::down_cast(&a); + ASR::Logical_t *b2 = ASR::down_cast(&b); + if (a2->m_kind == b2->m_kind) { + return true; + } else { + return false; + } + break; + } + case (ASR::ttypeType::Character) : { + ASR::Character_t *a2 = ASR::down_cast(&a); + ASR::Character_t *b2 = ASR::down_cast(&b); + if (a2->m_kind == b2->m_kind) { + return true; + } else { + return false; + } + break; + } + default : return false; + } + } + return false; +} + +template +bool argument_types_match(const Vec &args, + const T &sub) { + if (args.size() <= sub.n_args) { + size_t i; + for (i = 0; i < args.size(); i++) { + ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]); + ASR::ttype_t *arg1 = args[i]; + ASR::ttype_t *arg2 = v->m_type; + if (!types_equal(*arg1, *arg2)) { + return false; + } + } + for( ; i < sub.n_args; i++ ) { + ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]); + if( v->m_presence != ASR::presenceType::Optional ) { + return false; + } + } + return true; + } else { + return false; + } +} + +bool select_func_subrout(const ASR::symbol_t* proc, const Vec &args, + const Location& loc, const std::function err) { + bool result = false; + if (ASR::is_a(*proc)) { + ASR::Subroutine_t *sub + = ASR::down_cast(proc); + if (argument_types_match(args, *sub)) { + result = true; + } + } else if (ASR::is_a(*proc)) { + ASR::Function_t *fn + = ASR::down_cast(proc); + if (argument_types_match(args, *fn)) { + result = true; + } + } else { + err("Only Subroutine and Function supported in generic procedure", loc); + } + return result; +} + +std::map> overload_definitons; + template class CommonVisitor : public AST::BaseVisitor { public: @@ -156,6 +271,7 @@ class CommonVisitor : public AST::BaseVisitor { // The main module is stored directly in TranslationUnit, other modules are Modules bool main_module; PythonIntrinsicProcedures intrinsic_procedures; + std::map> overload_defs; CommonVisitor(Allocator &al, SymbolTable *symbol_table, diag::Diagnostics &diagnostics, bool main_module) @@ -452,7 +568,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable * class SymbolTableVisitor : public CommonVisitor { public: SymbolTable *global_scope; - std::map> generic_procedures, overload_defs; + std::map> generic_procedures; std::map>> generic_class_procedures; std::map> defined_op_procs; std::map> class_procedures; @@ -752,6 +868,7 @@ Result symbol_table_visitor(Allocator &al, const AST::Module_t &ast SymbolTableVisitor v(al, nullptr, diagnostics, main_module); try { v.visit_Module(ast); + overload_definitons = v.overload_defs; } catch (const SemanticError &e) { Error error; diagnostics.diagnostics.push_back(e.d); @@ -835,10 +952,11 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t* overloaddef_find_helper(std::string func_name, Vec args, const Location &loc) { for(auto &t: overload_defs[func_name]) { - bool ok = ASRUtils::select_func_subrout(t, args, loc, - [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); + ASR::symbol_t *st = current_scope->scope[t]; + bool ok = select_func_subrout(st, args, loc, + [&](const std::string &msg, const Location &l) { throw SemanticError(msg, l); }); if (ok) { - return t; + return st; } } return nullptr; @@ -2316,6 +2434,7 @@ Result body_visitor(Allocator &al, { BodyVisitor b(al, unit, diagnostics, main_module); try { + b.overload_defs = overload_definitons; b.visit_Module(ast); } catch (const SemanticError &e) { Error error; From b4dbf618dee0da79395f9f15137585bf2b30cee7 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 2 Mar 2022 16:18:00 +0530 Subject: [PATCH 03/14] add overload support in Call --- src/lpython/semantics/python_ast_to_asr.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index b34f30a809..4b00282b2c 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -952,7 +952,12 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t* overloaddef_find_helper(std::string func_name, Vec args, const Location &loc) { for(auto &t: overload_defs[func_name]) { - ASR::symbol_t *st = current_scope->scope[t]; + SymbolTable *symtab = current_scope; + while (symtab!= nullptr && symtab->scope.find(t) == symtab->scope.end()) { + symtab = symtab->parent; + } + LFORTRAN_ASSERT(symtab != nullptr); + ASR::symbol_t *st = symtab->scope[t]; bool ok = select_func_subrout(st, args, loc, [&](const std::string &msg, const Location &l) { throw SemanticError(msg, l); }); if (ok) { @@ -2275,6 +2280,14 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t *s = current_scope->resolve_symbol(call_name); + if (!s && overload_defs.find(call_name)!=overload_defs.end()) { + Vec args_type; + args_type.reserve(al, x.n_args); + for(size_t i=0; i Date: Wed, 2 Mar 2022 16:18:17 +0530 Subject: [PATCH 04/14] add overload test in lpython --- integration_tests/run_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/run_tests.py b/integration_tests/run_tests.py index d144ea8d5a..0fe78a7cd6 100755 --- a/integration_tests/run_tests.py +++ b/integration_tests/run_tests.py @@ -29,11 +29,11 @@ "test_math1.py", "test_math_02.py", "test_c_interop_01.py", + "test_generics_01.py", ] # CPython tests only test_cpython = [ - "test_generics_01.py", "test_builtin_bin.py", "test_builtin_hex.py", "test_builtin_oct.py" From ca3ce1bba91a9f05bcd2c9092265380e3dab84c3 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 2 Mar 2022 16:19:39 +0530 Subject: [PATCH 05/14] update test file --- integration_tests/test_generics_01.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/integration_tests/test_generics_01.py b/integration_tests/test_generics_01.py index 5419ad8a5b..1f39689f3a 100644 --- a/integration_tests/test_generics_01.py +++ b/integration_tests/test_generics_01.py @@ -23,8 +23,12 @@ def test(a: bool) -> i32: return -10 -assert foo(2) == 4 -assert foo(2, 10) == 20 -assert foo("hello") == "lpython-hello" -assert test(10) == 20 -assert test(False) == -test(True) and test(True) == 10 +def check(): + assert foo(2) == 4 + assert foo(2, 10) == 20 + # Following assert has LLVM string len issue: gh-175 + # assert foo("hello") == "lpython-hello" + assert test(10) == 20 + assert test(False) == -test(True) and test(True) == 10 + +check() From 9a832d8622fdab0ccc85724924a7a6174ccbedc7 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 2 Mar 2022 16:22:09 +0530 Subject: [PATCH 06/14] Use check_equal_type --- src/lpython/semantics/python_ast_to_asr.cpp | 69 +-------------------- 1 file changed, 1 insertion(+), 68 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 4b00282b2c..cffc87651d 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -139,73 +139,6 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, return mod2; } -bool types_equal(const ASR::ttype_t &a, const ASR::ttype_t &b) { - // TODO: If anyone of the input or argument is derived type then - // add support for checking member wise types and do not compare - // directly. From stdlib_string len(pattern) error. - if (b.type == ASR::ttypeType::Derived || b.type == ASR::ttypeType::Class) { - return true; - } - if (a.type == b.type) { - // TODO: check dims - // TODO: check all types - switch (a.type) { - case (ASR::ttypeType::Integer) : { - ASR::Integer_t *a2 = ASR::down_cast(&a); - ASR::Integer_t *b2 = ASR::down_cast(&b); - if (a2->m_kind == b2->m_kind) { - return true; - } else { - return false; - } - break; - } - case (ASR::ttypeType::Real) : { - ASR::Real_t *a2 = ASR::down_cast(&a); - ASR::Real_t *b2 = ASR::down_cast(&b); - if (a2->m_kind == b2->m_kind) { - return true; - } else { - return false; - } - break; - } - case (ASR::ttypeType::Complex) : { - ASR::Complex_t *a2 = ASR::down_cast(&a); - ASR::Complex_t *b2 = ASR::down_cast(&b); - if (a2->m_kind == b2->m_kind) { - return true; - } else { - return false; - } - break; - } - case (ASR::ttypeType::Logical) : { - ASR::Logical_t *a2 = ASR::down_cast(&a); - ASR::Logical_t *b2 = ASR::down_cast(&b); - if (a2->m_kind == b2->m_kind) { - return true; - } else { - return false; - } - break; - } - case (ASR::ttypeType::Character) : { - ASR::Character_t *a2 = ASR::down_cast(&a); - ASR::Character_t *b2 = ASR::down_cast(&b); - if (a2->m_kind == b2->m_kind) { - return true; - } else { - return false; - } - break; - } - default : return false; - } - } - return false; -} - template bool argument_types_match(const Vec &args, const T &sub) { @@ -215,7 +148,7 @@ bool argument_types_match(const Vec &args, ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]); ASR::ttype_t *arg1 = args[i]; ASR::ttype_t *arg2 = v->m_type; - if (!types_equal(*arg1, *arg2)) { + if (!ASRUtils::check_equal_type(arg1, arg2)) { return false; } } From 32bede6832bb6d849f913e320d94ec6ee608e377 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 2 Mar 2022 22:39:07 +0530 Subject: [PATCH 07/14] use generic procedure --- src/lpython/semantics/python_ast_to_asr.cpp | 115 +++++++++++--------- 1 file changed, 64 insertions(+), 51 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index cffc87651d..9d9dac1425 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -140,13 +140,13 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, } template -bool argument_types_match(const Vec &args, +bool argument_types_match(const Vec &args, const T &sub) { if (args.size() <= sub.n_args) { size_t i; for (i = 0; i < args.size(); i++) { ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]); - ASR::ttype_t *arg1 = args[i]; + ASR::ttype_t *arg1 = ASRUtils::expr_type(args[i]); ASR::ttype_t *arg2 = v->m_type; if (!ASRUtils::check_equal_type(arg1, arg2)) { return false; @@ -164,7 +164,7 @@ bool argument_types_match(const Vec &args, } } -bool select_func_subrout(const ASR::symbol_t* proc, const Vec &args, +bool select_func_subrout(const ASR::symbol_t* proc, const Vec &args, const Location& loc, const std::function err) { bool result = false; if (ASR::is_a(*proc)) { @@ -185,8 +185,7 @@ bool select_func_subrout(const ASR::symbol_t* proc, const Vec &ar return result; } -std::map> overload_definitons; - +std::map ast_overload; template class CommonVisitor : public AST::BaseVisitor { public: @@ -204,7 +203,7 @@ class CommonVisitor : public AST::BaseVisitor { // The main module is stored directly in TranslationUnit, other modules are Modules bool main_module; PythonIntrinsicProcedures intrinsic_procedures; - std::map> overload_defs; + std::map> overload_defs; CommonVisitor(Allocator &al, SymbolTable *symbol_table, diag::Diagnostics &diagnostics, bool main_module) @@ -571,7 +570,9 @@ class SymbolTableVisitor : public CommonVisitor { for (size_t i=0; i { std::string overload_number; if (overload_defs.find(sym_name) == overload_defs.end()){ overload_number = "0"; + Vec v; + v.reserve(al, 1); + overload_defs[sym_name] = v; } else { overload_number = std::to_string(overload_defs[sym_name].size()); } @@ -700,10 +704,22 @@ class SymbolTableVisitor : public CommonVisitor { s_access, deftype, bindc_name, is_pure, is_module); } - parent_scope->scope[sym_name] = ASR::down_cast(tmp); + ASR::symbol_t * t = ASR::down_cast(tmp); + parent_scope->scope[sym_name] = t; current_scope = parent_scope; if (overload) { - overload_defs[x.m_name].push_back(sym_name); + overload_defs[x.m_name].push_back(al, t); + ast_overload[(int64_t)&x] = t; + } + } + + void create_GenericProcedure(const Location &loc) { + for(auto &p: overload_defs) { + std::string def_name = p.first; + tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name), + p.second.p, p.second.size(), ASR::accessType::Public); + ASR::symbol_t *t = ASR::down_cast(tmp); + current_scope->scope[def_name] = t; } } @@ -801,7 +817,6 @@ Result symbol_table_visitor(Allocator &al, const AST::Module_t &ast SymbolTableVisitor v(al, nullptr, diagnostics, main_module); try { v.visit_Module(ast); - overload_definitons = v.overload_defs; } catch (const SemanticError &e) { Error error; diagnostics.diagnostics.push_back(e.d); @@ -882,44 +897,24 @@ class BodyVisitor : public CommonVisitor { v.n_body = body.size(); } - ASR::symbol_t* overloaddef_find_helper(std::string func_name, Vec args, - const Location &loc) { - for(auto &t: overload_defs[func_name]) { - SymbolTable *symtab = current_scope; - while (symtab!= nullptr && symtab->scope.find(t) == symtab->scope.end()) { - symtab = symtab->parent; - } - LFORTRAN_ASSERT(symtab != nullptr); - ASR::symbol_t *st = symtab->scope[t]; - bool ok = select_func_subrout(st, args, loc, - [&](const std::string &msg, const Location &l) { throw SemanticError(msg, l); }); - if (ok) { - return st; - } - } - return nullptr; - } - void visit_FunctionDef(const AST::FunctionDef_t &x) { SymbolTable *old_scope = current_scope; - ASR::symbol_t *t = nullptr; - if (overload_defs.find(x.m_name) != overload_defs.end()) { - Vec args; - args.reserve(al, x.m_args.n_args); - for (size_t i=0; iscope[x.m_name]; - } + ASR::symbol_t *t = t = current_scope->scope[x.m_name]; if (ASR::is_a(*t)) { handle_fn(x, *ASR::down_cast(t)); } else if (ASR::is_a(*t)) { ASR::Function_t *f = ASR::down_cast(t); handle_fn(x, *f); + } else if (ASR::is_a(*t)) { + ASR::symbol_t *s = ast_overload[(int64_t)&x]; + if (ASR::is_a(*s)) { + handle_fn(x, *ASR::down_cast(s)); + } else if (ASR::is_a(*s)) { + ASR::Function_t *f = ASR::down_cast(s); + handle_fn(x, *f); + } else { + LFORTRAN_ASSERT(false); + } } else { LFORTRAN_ASSERT(false); } @@ -2211,15 +2206,13 @@ class BodyVisitor : public CommonVisitor { x.base.base.loc); } - ASR::symbol_t *s = current_scope->resolve_symbol(call_name); - - if (!s && overload_defs.find(call_name)!=overload_defs.end()) { - Vec args_type; - args_type.reserve(al, x.n_args); - for(size_t i=0; iresolve_symbol(call_name), *s_generic = nullptr; + if (s->type == ASR::symbolType::GenericProcedure){ + ASR::GenericProcedure_t *p = ASR::down_cast(s); + int idx = select_generic_procedure(args, *p, x.base.base.loc); + // Create ExternalSymbol for procedures in different modules. + s_generic = s; + s = p->m_procs[idx]; } if (!s) { @@ -2366,6 +2359,27 @@ class BodyVisitor : public CommonVisitor { x.base.base.loc); } } + int select_generic_procedure(const Vec &args, + const ASR::GenericProcedure_t &p, Location loc) { + for (size_t i=0; i < p.n_procs; i++) { + + if( ASR::is_a(*p.m_procs[i]) ) { + ASR::ClassProcedure_t *clss_fn + = ASR::down_cast(p.m_procs[i]); + const ASR::symbol_t *proc = ASRUtils::symbol_get_past_external(clss_fn->m_proc); + if( select_func_subrout(proc, args, loc, + [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }) + ){ + return i; + } + } else { + if( select_func_subrout(p.m_procs[i], args, loc, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }) ) { + return i; + } + } + } + throw SemanticError("Arguments do not match for any generic procedure", loc); + } void visit_ImportFrom(const AST::ImportFrom_t &/*x*/) { // Handled by SymbolTableVisitor already @@ -2380,7 +2394,6 @@ Result body_visitor(Allocator &al, { BodyVisitor b(al, unit, diagnostics, main_module); try { - b.overload_defs = overload_definitons; b.visit_Module(ast); } catch (const SemanticError &e) { Error error; From bf74e61176dcbbcaefbde360359cd7dcd655dad1 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 2 Mar 2022 23:41:42 +0530 Subject: [PATCH 08/14] fix failing tests --- src/lpython/semantics/python_ast_to_asr.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 9d9dac1425..496a58c4a4 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -2207,7 +2207,7 @@ class BodyVisitor : public CommonVisitor { } ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr; - if (s->type == ASR::symbolType::GenericProcedure){ + if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) { ASR::GenericProcedure_t *p = ASR::down_cast(s); int idx = select_generic_procedure(args, *p, x.base.base.loc); // Create ExternalSymbol for procedures in different modules. @@ -2350,10 +2350,10 @@ class BodyVisitor : public CommonVisitor { value = intrinsic_procedures.comptime_eval(call_name, al, x.base.base.loc, args); } tmp = ASR::make_FunctionCall_t(al, x.base.base.loc, stemp, - nullptr, args.p, args.size(), nullptr, 0, a_type, value, nullptr); + s_generic, args.p, args.size(), nullptr, 0, a_type, value, nullptr); } else if(ASR::is_a(*s)) { tmp = ASR::make_SubroutineCall_t(al, x.base.base.loc, stemp, - nullptr, args.p, args.size(), nullptr); + s_generic, args.p, args.size(), nullptr); } else { throw SemanticError("Unsupported call type for " + call_name, x.base.base.loc); From fadfaa25aba4f84eb11eb144e0af00f6548e1902 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Thu, 3 Mar 2022 18:21:24 +0530 Subject: [PATCH 09/14] move overload_defs to symbolvisitor --- src/lpython/semantics/python_ast_to_asr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 496a58c4a4..a81a29cc37 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -203,7 +203,6 @@ class CommonVisitor : public AST::BaseVisitor { // The main module is stored directly in TranslationUnit, other modules are Modules bool main_module; PythonIntrinsicProcedures intrinsic_procedures; - std::map> overload_defs; CommonVisitor(Allocator &al, SymbolTable *symbol_table, diag::Diagnostics &diagnostics, bool main_module) @@ -517,6 +516,7 @@ class SymbolTableVisitor : public CommonVisitor { std::map assgn; ASR::symbol_t *current_module_sym; std::vector excluded_from_symtab; + std::map> overload_defs; SymbolTableVisitor(Allocator &al, SymbolTable *symbol_table, From b6f5d2274dd83cf4a9d1c2b6e1ca65b2b355dd2a Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 4 Mar 2022 00:26:13 +0530 Subject: [PATCH 10/14] cleaups: use function from ASRUtils namespace --- src/lpython/semantics/python_ast_to_asr.cpp | 71 +-------------------- 1 file changed, 3 insertions(+), 68 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index a81a29cc37..3666dbf677 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -139,53 +139,8 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, return mod2; } -template -bool argument_types_match(const Vec &args, - const T &sub) { - if (args.size() <= sub.n_args) { - size_t i; - for (i = 0; i < args.size(); i++) { - ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]); - ASR::ttype_t *arg1 = ASRUtils::expr_type(args[i]); - ASR::ttype_t *arg2 = v->m_type; - if (!ASRUtils::check_equal_type(arg1, arg2)) { - return false; - } - } - for( ; i < sub.n_args; i++ ) { - ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]); - if( v->m_presence != ASR::presenceType::Optional ) { - return false; - } - } - return true; - } else { - return false; - } -} - -bool select_func_subrout(const ASR::symbol_t* proc, const Vec &args, - const Location& loc, const std::function err) { - bool result = false; - if (ASR::is_a(*proc)) { - ASR::Subroutine_t *sub - = ASR::down_cast(proc); - if (argument_types_match(args, *sub)) { - result = true; - } - } else if (ASR::is_a(*proc)) { - ASR::Function_t *fn - = ASR::down_cast(proc); - if (argument_types_match(args, *fn)) { - result = true; - } - } else { - err("Only Subroutine and Function supported in generic procedure", loc); - } - return result; -} - std::map ast_overload; + template class CommonVisitor : public AST::BaseVisitor { public: @@ -2209,7 +2164,8 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr; if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) { ASR::GenericProcedure_t *p = ASR::down_cast(s); - int idx = select_generic_procedure(args, *p, x.base.base.loc); + int idx = ASRUtils::select_generic_procedure(args, *p, x.base.base.loc, + [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); // Create ExternalSymbol for procedures in different modules. s_generic = s; s = p->m_procs[idx]; @@ -2359,27 +2315,6 @@ class BodyVisitor : public CommonVisitor { x.base.base.loc); } } - int select_generic_procedure(const Vec &args, - const ASR::GenericProcedure_t &p, Location loc) { - for (size_t i=0; i < p.n_procs; i++) { - - if( ASR::is_a(*p.m_procs[i]) ) { - ASR::ClassProcedure_t *clss_fn - = ASR::down_cast(p.m_procs[i]); - const ASR::symbol_t *proc = ASRUtils::symbol_get_past_external(clss_fn->m_proc); - if( select_func_subrout(proc, args, loc, - [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }) - ){ - return i; - } - } else { - if( select_func_subrout(p.m_procs[i], args, loc, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }) ) { - return i; - } - } - } - throw SemanticError("Arguments do not match for any generic procedure", loc); - } void visit_ImportFrom(const AST::ImportFrom_t &/*x*/) { // Handled by SymbolTableVisitor already From 823cada7d88535e3ee37bb4e40be8f96d2c487ec Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 4 Mar 2022 00:29:08 +0530 Subject: [PATCH 11/14] fix typo --- integration_tests/run_tests.py | 2 +- src/lpython/semantics/python_ast_to_asr.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/run_tests.py b/integration_tests/run_tests.py index 0fe78a7cd6..1bea39b346 100755 --- a/integration_tests/run_tests.py +++ b/integration_tests/run_tests.py @@ -39,7 +39,7 @@ "test_builtin_oct.py" ] -CUR_DIR = ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__))) +CUR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__))) def main(): if not os.path.exists(os.path.join(CUR_DIR, 'tmp')): diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 3666dbf677..a4678cd72e 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -854,7 +854,7 @@ class BodyVisitor : public CommonVisitor { void visit_FunctionDef(const AST::FunctionDef_t &x) { SymbolTable *old_scope = current_scope; - ASR::symbol_t *t = t = current_scope->scope[x.m_name]; + ASR::symbol_t *t = current_scope->scope[x.m_name]; if (ASR::is_a(*t)) { handle_fn(x, *ASR::down_cast(t)); } else if (ASR::is_a(*t)) { From 223a64d54613f191f142ebfbd36c757e3a681c44 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Sat, 5 Mar 2022 01:11:41 +0530 Subject: [PATCH 12/14] clear ast_overload before visiting symbol table --- src/lpython/semantics/python_ast_to_asr.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index a4678cd72e..1e754b3d3b 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -769,6 +769,7 @@ class SymbolTableVisitor : public CommonVisitor { Result symbol_table_visitor(Allocator &al, const AST::Module_t &ast, diag::Diagnostics &diagnostics, bool main_module) { + ast_overload.clear(); SymbolTableVisitor v(al, nullptr, diagnostics, main_module); try { v.visit_Module(ast); From 2781c6ccafd02bfa83469adc8435f94c6022d343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Wed, 9 Mar 2022 10:36:24 -0700 Subject: [PATCH 13/14] Enable a string test for overload --- integration_tests/test_generics_01.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integration_tests/test_generics_01.py b/integration_tests/test_generics_01.py index 1f39689f3a..442c63e24a 100644 --- a/integration_tests/test_generics_01.py +++ b/integration_tests/test_generics_01.py @@ -26,8 +26,7 @@ def test(a: bool) -> i32: def check(): assert foo(2) == 4 assert foo(2, 10) == 20 - # Following assert has LLVM string len issue: gh-175 - # assert foo("hello") == "lpython-hello" + assert foo("hello") == "lpython-hello" assert test(10) == 20 assert test(False) == -test(True) and test(True) == 10 From ec200f50edc833882923ec5598116a3a81bb48db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Wed, 9 Mar 2022 10:42:36 -0700 Subject: [PATCH 14/14] Pass the ast_overload as arguments to visitors --- src/lpython/semantics/python_ast_to_asr.cpp | 37 +++++++++++++-------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 1e754b3d3b..a3a3ff15b8 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -139,7 +139,6 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, return mod2; } -std::map ast_overload; template class CommonVisitor : public AST::BaseVisitor { @@ -158,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor { // The main module is stored directly in TranslationUnit, other modules are Modules bool main_module; PythonIntrinsicProcedures intrinsic_procedures; + std::map &ast_overload; CommonVisitor(Allocator &al, SymbolTable *symbol_table, - diag::Diagnostics &diagnostics, bool main_module) - : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} { + diag::Diagnostics &diagnostics, bool main_module, + std::map &ast_overload) + : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module}, + ast_overload{ast_overload} { current_module_dependencies.reserve(al, 4); } @@ -475,8 +477,9 @@ class SymbolTableVisitor : public CommonVisitor { SymbolTableVisitor(Allocator &al, SymbolTable *symbol_table, - diag::Diagnostics &diagnostics, bool main_module) - : CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false} {} + diag::Diagnostics &diagnostics, bool main_module, + std::map &ast_overload) + : CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false} {} ASR::symbol_t* resolve_symbol(const Location &loc, const std::string &sub_name) { @@ -767,10 +770,10 @@ class SymbolTableVisitor : public CommonVisitor { }; Result symbol_table_visitor(Allocator &al, const AST::Module_t &ast, - diag::Diagnostics &diagnostics, bool main_module) + diag::Diagnostics &diagnostics, bool main_module, + std::map &ast_overload) { - ast_overload.clear(); - SymbolTableVisitor v(al, nullptr, diagnostics, main_module); + SymbolTableVisitor v(al, nullptr, diagnostics, main_module, ast_overload); try { v.visit_Module(ast); } catch (const SemanticError &e) { @@ -792,8 +795,9 @@ class BodyVisitor : public CommonVisitor { ASR::asr_t *asr; Vec *current_body; - BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module) - : CommonVisitor(al, nullptr, diagnostics, main_module), asr{unit} {} + BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, + bool main_module, std::map &ast_overload) + : CommonVisitor(al, nullptr, diagnostics, main_module, ast_overload), asr{unit} {} // Transforms statements to a list of ASR statements // In addition, it also inserts the following nodes if needed: @@ -2326,9 +2330,10 @@ class BodyVisitor : public CommonVisitor { Result body_visitor(Allocator &al, const AST::Module_t &ast, diag::Diagnostics &diagnostics, - ASR::asr_t *unit, bool main_module) + ASR::asr_t *unit, bool main_module, + std::map &ast_overload) { - BodyVisitor b(al, unit, diagnostics, main_module); + BodyVisitor b(al, unit, diagnostics, main_module, ast_overload); try { b.visit_Module(ast); } catch (const SemanticError &e) { @@ -2362,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) { Result python_ast_to_asr(Allocator &al, AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module) { + std::map ast_overload; + AST::Module_t *ast_m = AST::down_cast2(&ast); ASR::asr_t *unit; - auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module); + auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module, + ast_overload); if (res.ok) { unit = res.result; } else { @@ -2374,7 +2382,8 @@ Result python_ast_to_asr(Allocator &al, ASR::TranslationUnit_t *tu = ASR::down_cast2(unit); LFORTRAN_ASSERT(asr_verify(*tu)); - auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module); + auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module, + ast_overload); if (res2.ok) { tu = res2.result; } else {