diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 99e7bb8c03..2f739b3c0c 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -107,8 +107,8 @@ symbol intent intent, expr? symbolic_value, expr? value, storage_type storage, ttype type, abi abi, access access, presence presence, bool value_attr) | ClassType(symbol_table symtab, identifier name, abi abi, access access) - | ClassProcedure(symbol_table parent_symtab, identifier name, identifier proc_name, - symbol proc, abi abi) + | ClassProcedure(symbol_table parent_symtab, identifier name, identifier? self_argument, + identifier proc_name, symbol proc, abi abi) | AssociateBlock(symbol_table symtab, identifier name, stmt* body) | Block(symbol_table symtab, identifier name, stmt* body) @@ -185,6 +185,7 @@ stmt | FileOpen(int label, expr? newunit, expr? filename, expr? status) | FileClose(int label, expr? unit, expr? iostat, expr? iomsg, expr? err, expr? status) | FileRead(int label, expr? unit, expr? fmt, expr? iomsg, expr? iostat, expr? id, expr* values) + | FileBackspace(int label, expr? unit, expr? iostat, expr? err) | FileRewind(int label, expr? unit, expr? iostat, expr? err) | FileInquire(int label, expr? unit, expr? file, expr? iostat, expr? err, expr? exist, expr? opened, expr? number, expr? named, @@ -206,6 +207,7 @@ stmt | Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat) | ListAppend(expr a, expr ele) | AssociateBlockCall(symbol m) + | SelectType(type_stmt* body, stmt* default) | CPtrToPointer(expr cptr, expr ptr, expr? shape) | BlockCall(int label, symbol m) | SetInsert(expr a, expr ele) @@ -285,6 +287,7 @@ expr | ArrayMatMul(expr matrix_a, expr matrix_b, ttype type, expr? value) | ArrayPack(expr array, expr mask, expr? vector, ttype type, expr? value) | ArrayReshape(expr array, expr shape, ttype type, expr? value) + | ArrayMaxloc(expr array, expr? dim, expr? mask, expr? kind, expr? back, ttype type, expr? value) | BitCast(expr source, expr mold, expr? size, ttype type, expr? value) | StructInstanceMember(expr v, symbol m, ttype type, expr? value) @@ -310,9 +313,13 @@ expr | SetPop(expr a, ttype type, expr? value) | IntegerBitLen(expr a, ttype type, expr? value) | Ichar(expr arg, ttype type, expr? value) + | Iachar(expr arg, ttype type, expr? value) | SizeOfType(ttype arg, ttype type, expr? value) + | PointerNullConstant(ttype type) + | PointerAssociated(expr ptr, expr? tgt, ttype type, expr? value) + -- `len` in Character: -- >=0 ... the length of the string, known at compile time @@ -406,6 +413,8 @@ do_loop_head = (expr? v, expr? start, expr? end, expr? increment) case_stmt = CaseStmt(expr* test, stmt* body) | CaseStmt_Range(expr? start, expr? end, stmt* body) +type_stmt = TypeStmt(symbol sym, stmt* body) + enumtype = IntegerConsecutiveFromZero | IntegerUnique | IntegerNotUnique | NonInteger } diff --git a/src/libasr/asdl_cpp.py b/src/libasr/asdl_cpp.py index a88639b7cf..a6e13e333c 100644 --- a/src/libasr/asdl_cpp.py +++ b/src/libasr/asdl_cpp.py @@ -1196,7 +1196,7 @@ def visitModule(self, mod): self.emit("private:") self.emit( "Struct& self() { return static_cast(*this); }", 1) self.emit("public:") - self.emit( "std::string s, indtd = \"\";", 1) + self.emit( "std::string s, indented = \"\";", 1) self.emit( "bool use_colors;", 1) self.emit( "bool indent;", 1) self.emit( "int indent_level = 0, indent_spaces = 4;", 1) @@ -1204,12 +1204,12 @@ def visitModule(self, mod): self.emit( "PickleBaseVisitor() : use_colors(false), indent(false) { s.reserve(100000); }", 1) self.emit( "void inc_indent() {", 1) self.emit( "indent_level++;", 2) - self.emit( "indtd = std::string(indent_level*indent_spaces, ' ');",2) + self.emit( "indented = std::string(indent_level*indent_spaces, ' ');",2) self.emit( "}",1) self.emit( "void dec_indent() {", 1) self.emit( "indent_level--;", 2) self.emit( "LFORTRAN_ASSERT(indent_level >= 0);", 2) - self.emit( "indtd = std::string(indent_level*indent_spaces, ' ');",2) + self.emit( "indented = std::string(indent_level*indent_spaces, ' ');",2) self.emit( "}",1) self.mod = mod super(PickleVisitorVisitor, self).visitModule(mod) @@ -1264,23 +1264,28 @@ def make_visitor(self, name, fields, cons): self.emit( 's.append(color(style::reset));', 3) self.emit( '}', 2) if len(fields) > 0: - self.emit( 's.append(" ");', 2) if name not in symbol: self.emit( 'if(indent) {', 2) self.emit( 'inc_indent();', 3) - self.emit( 's.append("\\n" + indtd);', 3) + self.emit( 's.append("\\n" + indented);', 3) + self.emit( '} else {', 2) + self.emit( 's.append(" ");', 3) self.emit( '}', 2) + else: + self.emit( 's.append(" ");', 2) self.used = False for n, field in enumerate(fields): self.visitField(field, cons) if n < len(fields) - 1: - self.emit( 's.append(" ");', 2) if name not in symbol: - self.emit( 'if(indent) s.append("\\n" + indtd);', 2) + self.emit( 'if(indent) s.append("\\n" + indented);', 2) + self.emit( 'else s.append(" ");', 2) + else: + self.emit( 's.append(" ");', 2) if name not in symbol and cons and len(fields) > 0: self.emit( 'if(indent) {', 2) self.emit( 'dec_indent();', 3) - self.emit( 's.append("\\n" + indtd);', 3) + self.emit( 's.append("\\n" + indented);', 3) self.emit( '}', 2) self.emit( 's.append(")");', 2) if not self.used: @@ -1326,10 +1331,10 @@ def visitField(self, field, cons): self.emit("self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level+1) else: self.emit("self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level+1) - self.emit(' if (i < x.n_%s-1) {' % (field.name), level+1) - self.emit(' if (indent) s.append("\\n" + indtd);', level+2) - self.emit(' else s.append(" ");', level+2) - self.emit(' };', level+1) + self.emit(' if (i < x.n_%s-1) {' % (field.name), level) + self.emit(' if (indent) s.append("\\n" + indented);', level) + self.emit(' else s.append(" ");', level) + self.emit(' };', level) self.emit("}", level) self.emit('s.append("]");', level) elif field.opt: @@ -1347,11 +1352,11 @@ def visitField(self, field, cons): level = 2 self.emit('s.append("[");', level) self.emit("for (size_t i=0; iget_counter());' % field.name, level) - self.emit('if(indent) s.append("\\n" + indtd);', level) + self.emit('if(indent) s.append("\\n" + indented);', level) self.emit('else s.append(" ");', level) self.emit( 's.append("{");', level) self.emit('if(indent) {', level) - self.emit( 'inc_indent();', level+1) - self.emit( 's.append("\\n" + indtd);', level+1) + self.emit(' inc_indent();', level) + self.emit(' s.append("\\n" + indented);', level) self.emit('}', level) self.emit('{', level) self.emit(' size_t i = 0;', level) self.emit(' for (auto &a : x.m_%s->get_scope()) {' % field.name, level) self.emit(' s.append(a.first + ":");', level) self.emit(' if(indent) {', level) - self.emit(' inc_indent();', level+1) - self.emit(' s.append("\\n" + indtd);', level+1) + self.emit(' inc_indent();', level) + self.emit(' s.append("\\n" + indented);', level) + self.emit(' } else {', level) + self.emit(' s.append(" ");', level) self.emit(' }', level) - self.emit(' else s.append(" ");', level) self.emit(' this->visit_symbol(*a.second);', level) - self.emit(' if (i < x.m_%s->get_scope().size()-1) { ' % field.name, level) - self.emit(' s.append(", ");', level) - self.emit(' }', level) - self.emit(' if(indent) {', level) - self.emit(' dec_indent();', level+1) - self.emit(' s.append("\\n" + indtd);', level+1) + self.emit(' if(indent) dec_indent();', level) + self.emit(' if (i < x.m_%s->get_scope().size()-1) {' % field.name, level) + self.emit(' s.append(",");', level) + self.emit(' if(indent) s.append("\\n" + indented);', level) + self.emit(' else s.append(" ");', level) self.emit(' }', level) self.emit(' i++;', level) self.emit(' }', level) self.emit('}', level) self.emit('if(indent) {', level) self.emit( 'dec_indent();', level+1) - self.emit( 's.append("\\n" + indtd);', level+1) + self.emit( 's.append("\\n" + indented);', level+1) self.emit('}', level) self.emit('s.append("})");', level) self.emit('if(indent) dec_indent();', level) @@ -1467,6 +1473,274 @@ def visitField(self, field, cons): else: self.emit('s.append("Unimplemented' + field.type + '");', 2) +class JsonVisitorVisitor(ASDLVisitor): + + def visitModule(self, mod): + self.emit("/" + "*"*78 + "/") + self.emit("// Json Visitor base class") + self.emit("") + self.emit("template ") + self.emit("class JsonBaseVisitor : public BaseVisitor") + self.emit("{") + self.emit("private:") + self.emit( "Struct& self() { return static_cast(*this); }", 1) + self.emit("public:") + self.emit( "std::string s, indtd = \"\";", 1) + self.emit( "int indent_level = 0, indent_spaces = 4;", 1) + # Storing a reference to LocationManager like this isn't ideal. + # One must make sure JsonBaseVisitor isn't reused in a case where AST/ASR has changed + # but lm wasn't updated correspondingly. + # If LocationManager becomes needed in any of the other visitors, it should be + # passed by reference into all the visit functions instead of storing the reference here. + self.emit( "LocationManager &lm;", 1) + self.emit("public:") + self.emit( "JsonBaseVisitor(LocationManager &lmref) : lm(lmref) {", 1); + self.emit( "s.reserve(100000);", 2) + self.emit( "}", 1) + self.emit( "void inc_indent() {", 1) + self.emit( "indent_level++;", 2) + self.emit( "indtd = std::string(indent_level*indent_spaces, ' ');",2) + self.emit( "}",1) + self.emit( "void dec_indent() {", 1) + self.emit( "indent_level--;", 2) + self.emit( "LFORTRAN_ASSERT(indent_level >= 0);", 2) + self.emit( "indtd = std::string(indent_level*indent_spaces, ' ');",2) + self.emit( "}",1) + self.emit( "void append_location(std::string &s, uint32_t first, uint32_t last) {", 1) + self.emit( 's.append("\\"loc\\": {");', 2); + self.emit( 'inc_indent();', 2) + self.emit( 's.append("\\n" + indtd);', 2) + self.emit( 's.append("\\"first\\": " + std::to_string(first));', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"last\\": " + std::to_string(last));', 2) + self.emit( '') + self.emit( 'uint32_t first_line = 0, first_col = 0;', 2) + self.emit( 'std::string first_filename;', 2) + self.emit( 'uint32_t last_line = 0, last_col = 0;', 2) + self.emit( 'std::string last_filename;', 2) + self.emit( '') + self.emit( 'lm.pos_to_linecol(first, first_line, first_col, first_filename);', 2) + self.emit( 'lm.pos_to_linecol(last, last_line, last_col, last_filename);', 2) + self.emit( '') + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"first_filename\\": \\"" + first_filename + "\\"");', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"first_line\\": " + std::to_string(first_line));', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"first_column\\": " + std::to_string(first_col));', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"last_filename\\": \\"" + last_filename + "\\"");', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"last_line\\": " + std::to_string(last_line));', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"last_column\\": " + std::to_string(last_col));', 2) + self.emit( '') + self.emit( 'dec_indent();', 2) + self.emit( 's.append("\\n" + indtd);', 2) + self.emit( 's.append("}");', 2) + self.emit( '}', 1) + + self.mod = mod + super(JsonVisitorVisitor, self).visitModule(mod) + self.emit("};") + + def visitType(self, tp): + super(JsonVisitorVisitor, self).visitType(tp, tp.name) + + def visitSum(self, sum, *args): + assert isinstance(sum, asdl.Sum) + if is_simple_sum(sum): + name = args[0] + "Type" + self.make_simple_sum_visitor(name, sum.types) + else: + for tp in sum.types: + self.visit(tp, *args) + + def visitProduct(self, prod, name): + self.make_visitor(name, prod.fields, False) + + def visitConstructor(self, cons, _): + self.make_visitor(cons.name, cons.fields, True) + + def make_visitor(self, name, fields, cons): + self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1) + self.emit( 's.append("{");', 2) + self.emit( 'inc_indent(); s.append("\\n" + indtd);', 2) + self.emit( 's.append("\\"node\\": \\"%s\\"");' % name, 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"fields\\": {");', 2) + if len(fields) > 0: + self.emit('inc_indent(); s.append("\\n" + indtd);', 2) + for n, field in enumerate(fields): + self.visitField(field, cons) + if n < len(fields) - 1: + self.emit('s.append(",\\n" + indtd);', 2) + self.emit('dec_indent(); s.append("\\n" + indtd);', 2) + self.emit( 's.append("}");', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + if name in products: + self.emit( 'append_location(s, x.loc.first, x.loc.last);', 2) + else: + self.emit( 'append_location(s, x.base.base.loc.first, x.base.base.loc.last);', 2) + + self.emit( 'dec_indent(); s.append("\\n" + indtd);', 2) + self.emit( 's.append("}");', 2) + self.emit( 'if ((bool&)x) { } // Suppress unused warning', 2) + self.emit("}", 1) + + def make_simple_sum_visitor(self, name, types): + self.emit("void visit_%s(const %s &x) {" % (name, name), 1) + self.emit( 'switch (x) {', 2) + for tp in types: + self.emit( 'case (%s::%s) : {' % (name, tp.name), 3) + self.emit( 's.append("\\"%s\\"");' % (tp.name), 4) + self.emit( ' break; }',3) + self.emit( '}', 2) + self.emit("}", 1) + + def visitField(self, field, cons): + self.emit('s.append("\\"%s\\": ");' % field.name, 2) + if (field.type not in asdl.builtin_types and + field.type not in self.data.simple_types): + self.used = True + level = 2 + if field.type in products: + if field.opt: + template = "self().visit_%s(*x.m_%s);" % (field.type, field.name) + else: + template = "self().visit_%s(x.m_%s);" % (field.type, field.name) + else: + template = "self().visit_%s(*x.m_%s);" % (field.type, field.name) + if field.seq: + self.emit('s.append("[");', level) + self.emit('if (x.n_%s > 0) {' % field.name, level) + self.emit( 'inc_indent(); s.append("\\n" + indtd);', level+1) + self.emit( "for (size_t i=0; i 0) {' % field.name, level) + self.emit( 'inc_indent(); s.append("\\n" + indtd);', level+1) + self.emit( "for (size_t i=0; i 0) {' % field.name, level) + self.emit( 'inc_indent(); s.append("\\n" + indtd);', level+1) + self.emit( "for (size_t i=0; iget_counter());' % field.name, level) + else: + level = 2 + self.emit('s.append("{");', level) + self.emit('inc_indent(); s.append("\\n" + indtd);', level) + self.emit('s.append("\\"node\\": \\"SymbolTable" + x.m_%s->get_counter() +"\\"");' % field.name, level) + self.emit('s.append(",\\n" + indtd);', level) + self.emit('s.append("\\"fields\\": {");', level) + self.emit('if (x.m_%s->get_scope().size() > 0) {' % field.name, level) + self.emit( 'inc_indent(); s.append("\\n" + indtd);', level+1) + self.emit( 'size_t i = 0;', level+1) + self.emit( 'for (auto &a : x.m_%s->get_scope()) {' % field.name, level+1) + self.emit( 's.append("\\"" + a.first + "\\": ");', level+2) + self.emit( 'this->visit_symbol(*a.second);', level+2) + self.emit( 'if (i < x.m_%s->get_scope().size()-1) { ' % field.name, level+2) + self.emit( ' s.append(",\\n" + indtd);', level+3) + self.emit( '}', level+2) + self.emit( 'i++;', level+2) + self.emit( '}', level+1) + self.emit( 'dec_indent(); s.append("\\n" + indtd);', level+1) + self.emit('}', level) + self.emit('s.append("}");', level) + self.emit('dec_indent(); s.append("\\n" + indtd);', level) + self.emit('s.append("}");', level) + elif field.type == "string" and not field.seq: + if field.opt: + self.emit("if (x.m_%s) {" % field.name, 2) + self.emit( 's.append("\\"" + std::string(x.m_%s) + "\\"");' % field.name, 3) + self.emit("} else {", 2) + self.emit( 's.append("[]");', 3) + self.emit("}", 2) + else: + self.emit('s.append("\\"" + std::string(x.m_%s) + "\\"");' % field.name, 2) + elif field.type == "int" and not field.seq: + if field.opt: + self.emit("if (x.m_%s) {" % field.name, 2) + self.emit( 's.append(std::to_string(x.m_%s));' % field.name, 3) + self.emit("} else {", 2) + self.emit( 's.append("[]");', 3) + self.emit("}", 2) + else: + self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2) + elif field.type == "float" and not field.seq and not field.opt: + self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2) + elif field.type == "bool" and not field.seq and not field.opt: + self.emit("if (x.m_%s) {" % field.name, 2) + self.emit( 's.append("true");', 3) + self.emit("} else {", 2) + self.emit( 's.append("false");', 3) + self.emit("}", 2) + elif field.type in self.data.simple_types: + if field.opt: + self.emit('s.append("\\"Unimplementedopt\\"");', 2) + else: + self.emit('visit_%sType(x.m_%s);' \ + % (field.type, field.name), 2) + else: + self.emit('s.append("\\"Unimplemented%s\\"");' % field.type, 2) class SerializationVisitorVisitor(ASDLVisitor): @@ -2202,7 +2476,7 @@ def add_masks(fields, node): visitors = [ASTNodeVisitor0, ASTNodeVisitor1, ASTNodeVisitor, ASTVisitorVisitor1, ASTVisitorVisitor1b, ASTVisitorVisitor2, ASTWalkVisitorVisitor, TreeVisitorVisitor, PickleVisitorVisitor, - SerializationVisitorVisitor, DeserializationVisitorVisitor] + JsonVisitorVisitor, SerializationVisitorVisitor, DeserializationVisitorVisitor] def main(argv): diff --git a/src/libasr/asr_utils.cpp b/src/libasr/asr_utils.cpp index 46f44f2e5c..7628e5ad71 100644 --- a/src/libasr/asr_utils.cpp +++ b/src/libasr/asr_utils.cpp @@ -386,6 +386,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, ASR::binopType op, std::string& intrinsic_op_name, SymbolTable* curr_scope, ASR::asr_t*& asr, Allocator &al, const Location& loc, + std::set& current_function_dependencies, + Vec& current_module_dependencies, const std::function err) { ASR::ttype_t *left_type = LFortran::ASRUtils::expr_type(left); ASR::ttype_t *right_type = LFortran::ASRUtils::expr_type(right); @@ -430,6 +432,10 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, } else { return_type = ASRUtils::expr_type(func->m_return_var); } + current_function_dependencies.insert(matched_func_name); + if( ASR::is_a(*a_name) ) { + current_module_dependencies.push_back(al, ASR::down_cast(a_name)->m_module_name); + } asr = ASR::make_FunctionCall_t(al, loc, a_name, sym, a_args.p, 2, return_type, @@ -492,167 +498,6 @@ bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name, return result; } -bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value, - SymbolTable* curr_scope, ASR::asr_t*& asr, - Allocator &al, const Location& loc, - const std::function err) { - ASR::ttype_t *target_type = LFortran::ASRUtils::expr_type(target); - ASR::ttype_t *value_type = LFortran::ASRUtils::expr_type(value); - bool found = false; - ASR::symbol_t* sym = curr_scope->resolve_symbol("~assign"); - if (sym) { - ASR::symbol_t* orig_sym = ASRUtils::symbol_get_past_external(sym); - ASR::CustomOperator_t* gen_proc = ASR::down_cast(orig_sym); - for( size_t i = 0; i < gen_proc->n_procs && !found; i++ ) { - ASR::symbol_t* proc = gen_proc->m_procs[i]; - ASR::Function_t* subrout = ASR::down_cast(proc); - std::string matched_subrout_name = ""; - if( subrout->n_args == 2 ) { - ASR::ttype_t* target_arg_type = ASRUtils::expr_type(subrout->m_args[0]); - ASR::ttype_t* value_arg_type = ASRUtils::expr_type(subrout->m_args[1]); - if( target_arg_type->type == target_type->type && - value_arg_type->type == value_type->type ) { - found = true; - Vec a_args; - a_args.reserve(al, 2); - ASR::call_arg_t target_arg, value_arg; - target_arg.loc = target->base.loc, target_arg.m_value = target; - a_args.push_back(al, target_arg); - value_arg.loc = value->base.loc, value_arg.m_value = value; - a_args.push_back(al, value_arg); - std::string subrout_name = to_lower(subrout->m_name); - if( curr_scope->resolve_symbol(subrout_name) ) { - matched_subrout_name = subrout_name; - } else { - std::string mangled_name = subrout_name + "@~assign"; - matched_subrout_name = mangled_name; - } - ASR::symbol_t *a_name = curr_scope->resolve_symbol(matched_subrout_name); - if( a_name == nullptr ) { - err("Unable to resolve matched subroutine for assignment overloading, " + matched_subrout_name, loc); - } - asr = ASR::make_SubroutineCall_t(al, loc, a_name, sym, - a_args.p, 2, nullptr); - } - } - } - } - return found; -} - -bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, - ASR::cmpopType op, std::string& intrinsic_op_name, - SymbolTable* curr_scope, ASR::asr_t*& asr, - Allocator &al, const Location& loc, - const std::function err) { - ASR::ttype_t *left_type = LFortran::ASRUtils::expr_type(left); - ASR::ttype_t *right_type = LFortran::ASRUtils::expr_type(right); - bool found = false; - if( is_op_overloaded(op, intrinsic_op_name, curr_scope) ) { - ASR::symbol_t* sym = curr_scope->resolve_symbol(intrinsic_op_name); - ASR::symbol_t* orig_sym = ASRUtils::symbol_get_past_external(sym); - ASR::CustomOperator_t* gen_proc = ASR::down_cast(orig_sym); - for( size_t i = 0; i < gen_proc->n_procs && !found; i++ ) { - ASR::symbol_t* proc = gen_proc->m_procs[i]; - switch(proc->type) { - case ASR::symbolType::Function: { - ASR::Function_t* func = ASR::down_cast(proc); - std::string matched_func_name = ""; - if( func->n_args == 2 ) { - ASR::ttype_t* left_arg_type = ASRUtils::expr_type(func->m_args[0]); - ASR::ttype_t* right_arg_type = ASRUtils::expr_type(func->m_args[1]); - if( left_arg_type->type == left_type->type && - right_arg_type->type == right_type->type ) { - found = true; - Vec a_args; - a_args.reserve(al, 2); - ASR::call_arg_t left_call_arg, right_call_arg; - left_call_arg.loc = left->base.loc, left_call_arg.m_value = left; - a_args.push_back(al, left_call_arg); - right_call_arg.loc = right->base.loc, right_call_arg.m_value = right; - a_args.push_back(al, right_call_arg); - std::string func_name = to_lower(func->m_name); - if( curr_scope->resolve_symbol(func_name) ) { - matched_func_name = func_name; - } else { - std::string mangled_name = func_name + "@" + intrinsic_op_name; - matched_func_name = mangled_name; - } - ASR::symbol_t* a_name = curr_scope->resolve_symbol(matched_func_name); - if( a_name == nullptr ) { - err("Unable to resolve matched function for operator overloading, " + matched_func_name, loc); - } - ASR::ttype_t *return_type = nullptr; - if( func->m_elemental && func->n_args == 1 && ASRUtils::is_array(ASRUtils::expr_type(a_args[0].m_value)) ) { - return_type = ASRUtils::duplicate_type(al, ASRUtils::expr_type(a_args[0].m_value)); - } else { - return_type = ASRUtils::expr_type(func->m_return_var); - } - asr = ASR::make_FunctionCall_t(al, loc, a_name, sym, - a_args.p, 2, - return_type, - nullptr, nullptr); - } - } - break; - } - default: { - err("While overloading binary operators only functions can be used", - proc->base.loc); - } - } - } - } - return found; -} - -bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name, - SymbolTable* curr_scope) { - bool result = true; - switch(op) { - case ASR::cmpopType::Eq: { - if(intrinsic_op_name != "~eq") { - result = false; - } - break; - } - case ASR::cmpopType::NotEq: { - if(intrinsic_op_name != "~noteq") { - result = false; - } - break; - } - case ASR::cmpopType::Lt: { - if(intrinsic_op_name != "~lt") { - result = false; - } - break; - } - case ASR::cmpopType::LtE: { - if(intrinsic_op_name != "~lte") { - result = false; - } - break; - } - case ASR::cmpopType::Gt: { - if(intrinsic_op_name != "~gt") { - result = false; - } - break; - } - case ASR::cmpopType::GtE: { - if(intrinsic_op_name != "~gte") { - result = false; - } - break; - } - } - if( result && curr_scope->resolve_symbol(intrinsic_op_name) == nullptr ) { - result = false; - } - return result; -} - bool is_parent(ASR::StructType_t* a, ASR::StructType_t* b) { ASR::symbol_t* current_parent = b->m_parent; while( current_parent ) { @@ -804,6 +649,240 @@ bool types_equal(const ASR::ttype_t &a, const ASR::ttype_t &b) { return false; } +void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* target, ASR::expr_t* value, + ASR::ttype_t* target_type, ASR::ttype_t* value_type, bool& found, Allocator& al, const Location& target_loc, + const Location& value_loc, SymbolTable* curr_scope, std::set& current_function_dependencies, + Vec& current_module_dependencies, ASR::asr_t*& asr, ASR::symbol_t* sym, const Location& loc, ASR::expr_t* expr_dt, + const std::function err, char* pass_arg=nullptr) { + ASR::Function_t* subrout = ASR::down_cast(proc); + std::string matched_subrout_name = ""; + if( subrout->n_args == 2 ) { + ASR::ttype_t* target_arg_type = ASRUtils::expr_type(subrout->m_args[0]); + ASR::ttype_t* value_arg_type = ASRUtils::expr_type(subrout->m_args[1]); + if( ASRUtils::types_equal(*target_arg_type, *target_type) && + ASRUtils::types_equal(*value_arg_type, *value_type) ) { + std::string arg0_name = ASRUtils::symbol_name(ASR::down_cast(subrout->m_args[0])->m_v); + std::string arg1_name = ASRUtils::symbol_name(ASR::down_cast(subrout->m_args[1])->m_v); + if( pass_arg != nullptr ) { + std::string pass_arg_str = std::string(pass_arg); + if( arg0_name != pass_arg_str && arg1_name != pass_arg_str ) { + err(pass_arg_str + " argument is not present in " + std::string(subrout->m_name), + proc->base.loc); + } + if( (arg0_name == pass_arg_str && target != expr_dt) ) { + err(std::string(subrout->m_name) + " is not a procedure of " + + ASRUtils::type_to_str(target_type), + loc); + } + if( (arg1_name == pass_arg_str && value != expr_dt) ) { + err(std::string(subrout->m_name) + " is not a procedure of " + + ASRUtils::type_to_str(value_type), + loc); + } + } + found = true; + Vec a_args; + a_args.reserve(al, 2); + ASR::call_arg_t target_arg, value_arg; + target_arg.loc = target_loc, target_arg.m_value = target; + a_args.push_back(al, target_arg); + value_arg.loc = value_loc, value_arg.m_value = value; + a_args.push_back(al, value_arg); + std::string subrout_name = to_lower(subrout->m_name); + if( curr_scope->resolve_symbol(subrout_name) ) { + matched_subrout_name = subrout_name; + } else { + std::string mangled_name = subrout_name + "@~assign"; + matched_subrout_name = mangled_name; + } + ASR::symbol_t *a_name = curr_scope->resolve_symbol(matched_subrout_name); + if( a_name == nullptr ) { + err("Unable to resolve matched subroutine for assignment overloading, " + matched_subrout_name, loc); + } + current_function_dependencies.insert(matched_subrout_name); + if( ASR::is_a(*a_name) ) { + current_module_dependencies.push_back(al, ASR::down_cast(a_name)->m_module_name); + } + asr = ASR::make_SubroutineCall_t(al, loc, a_name, sym, + a_args.p, 2, nullptr); + } + } +} + +bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value, + SymbolTable* curr_scope, ASR::asr_t*& asr, + Allocator &al, const Location& loc, + std::set& current_function_dependencies, + Vec& current_module_dependencies, + const std::function err) { + ASR::ttype_t *target_type = LFortran::ASRUtils::expr_type(target); + ASR::ttype_t *value_type = LFortran::ASRUtils::expr_type(value); + bool found = false; + ASR::symbol_t* sym = curr_scope->resolve_symbol("~assign"); + ASR::expr_t* expr_dt = nullptr; + if( !sym ) { + if( ASR::is_a(*target_type) ) { + ASR::StructType_t* target_struct = ASR::down_cast( + ASRUtils::symbol_get_past_external(ASR::down_cast(target_type)->m_derived_type)); + sym = target_struct->m_symtab->resolve_symbol("~assign"); + expr_dt = target; + } else if( ASR::is_a(*value_type) ) { + ASR::StructType_t* value_struct = ASR::down_cast( + ASRUtils::symbol_get_past_external(ASR::down_cast(value_type)->m_derived_type)); + sym = value_struct->m_symtab->resolve_symbol("~assign"); + expr_dt = value; + } + } + if (sym) { + ASR::symbol_t* orig_sym = ASRUtils::symbol_get_past_external(sym); + ASR::CustomOperator_t* gen_proc = ASR::down_cast(orig_sym); + for( size_t i = 0; i < gen_proc->n_procs && !found; i++ ) { + ASR::symbol_t* proc = gen_proc->m_procs[i]; + switch( proc->type ) { + case ASR::symbolType::Function: { + process_overloaded_assignment_function(proc, target, value, target_type, + value_type, found, al, target->base.loc, value->base.loc, curr_scope, + current_function_dependencies, current_module_dependencies, asr, sym, + loc, expr_dt, err); + break; + } + case ASR::symbolType::ClassProcedure: { + ASR::ClassProcedure_t* class_proc = ASR::down_cast(proc); + ASR::symbol_t* proc_func = ASR::down_cast(proc)->m_proc; + process_overloaded_assignment_function(proc_func, target, value, target_type, + value_type, found, al, target->base.loc, value->base.loc, curr_scope, + current_function_dependencies, current_module_dependencies, asr, proc_func, loc, + expr_dt, err, class_proc->m_self_argument); + break; + } + default: { + err("Only functions and class procedures can be used for generic assignment statement", loc); + } + } + } + } + return found; +} + +bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, + ASR::cmpopType op, std::string& intrinsic_op_name, + SymbolTable* curr_scope, ASR::asr_t*& asr, + Allocator &al, const Location& loc, + std::set& current_function_dependencies, + Vec& current_module_dependencies, + const std::function err) { + ASR::ttype_t *left_type = LFortran::ASRUtils::expr_type(left); + ASR::ttype_t *right_type = LFortran::ASRUtils::expr_type(right); + bool found = false; + if( is_op_overloaded(op, intrinsic_op_name, curr_scope) ) { + ASR::symbol_t* sym = curr_scope->resolve_symbol(intrinsic_op_name); + ASR::symbol_t* orig_sym = ASRUtils::symbol_get_past_external(sym); + ASR::CustomOperator_t* gen_proc = ASR::down_cast(orig_sym); + for( size_t i = 0; i < gen_proc->n_procs && !found; i++ ) { + ASR::symbol_t* proc = gen_proc->m_procs[i]; + switch(proc->type) { + case ASR::symbolType::Function: { + ASR::Function_t* func = ASR::down_cast(proc); + std::string matched_func_name = ""; + if( func->n_args == 2 ) { + ASR::ttype_t* left_arg_type = ASRUtils::expr_type(func->m_args[0]); + ASR::ttype_t* right_arg_type = ASRUtils::expr_type(func->m_args[1]); + if( left_arg_type->type == left_type->type && + right_arg_type->type == right_type->type ) { + found = true; + Vec a_args; + a_args.reserve(al, 2); + ASR::call_arg_t left_call_arg, right_call_arg; + left_call_arg.loc = left->base.loc, left_call_arg.m_value = left; + a_args.push_back(al, left_call_arg); + right_call_arg.loc = right->base.loc, right_call_arg.m_value = right; + a_args.push_back(al, right_call_arg); + std::string func_name = to_lower(func->m_name); + if( curr_scope->resolve_symbol(func_name) ) { + matched_func_name = func_name; + } else { + std::string mangled_name = func_name + "@" + intrinsic_op_name; + matched_func_name = mangled_name; + } + ASR::symbol_t* a_name = curr_scope->resolve_symbol(matched_func_name); + if( a_name == nullptr ) { + err("Unable to resolve matched function for operator overloading, " + matched_func_name, loc); + } + ASR::ttype_t *return_type = nullptr; + if( func->m_elemental && func->n_args == 1 && ASRUtils::is_array(ASRUtils::expr_type(a_args[0].m_value)) ) { + return_type = ASRUtils::duplicate_type(al, ASRUtils::expr_type(a_args[0].m_value)); + } else { + return_type = ASRUtils::expr_type(func->m_return_var); + } + current_function_dependencies.insert(matched_func_name); + if( ASR::is_a(*a_name) ) { + current_module_dependencies.push_back(al, ASR::down_cast(a_name)->m_module_name); + } + asr = ASR::make_FunctionCall_t(al, loc, a_name, sym, + a_args.p, 2, + return_type, + nullptr, nullptr); + } + } + break; + } + default: { + err("While overloading binary operators only functions can be used", + proc->base.loc); + } + } + } + } + return found; +} + +bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name, + SymbolTable* curr_scope) { + bool result = true; + switch(op) { + case ASR::cmpopType::Eq: { + if(intrinsic_op_name != "~eq") { + result = false; + } + break; + } + case ASR::cmpopType::NotEq: { + if(intrinsic_op_name != "~noteq") { + result = false; + } + break; + } + case ASR::cmpopType::Lt: { + if(intrinsic_op_name != "~lt") { + result = false; + } + break; + } + case ASR::cmpopType::LtE: { + if(intrinsic_op_name != "~lte") { + result = false; + } + break; + } + case ASR::cmpopType::Gt: { + if(intrinsic_op_name != "~gt") { + result = false; + } + break; + } + case ASR::cmpopType::GtE: { + if(intrinsic_op_name != "~gte") { + result = false; + } + break; + } + } + if( result && curr_scope->resolve_symbol(intrinsic_op_name) == nullptr ) { + result = false; + } + return result; +} + template bool argument_types_match(const Vec& args, const T &sub) { diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index e6d276cd01..d3a5ee0d87 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -103,6 +103,13 @@ static inline ASR::ttype_t* symbol_type(const ASR::symbol_t *f) case ASR::symbolType::EnumType: { return ASR::down_cast(f)->m_type; } + case ASR::symbolType::ExternalSymbol: { + return symbol_type(ASRUtils::symbol_get_past_external(f)); + } + case ASR::symbolType::Function: { + return ASRUtils::expr_type( + ASR::down_cast(f)->m_return_var); + } default: { throw LCompilersException("Cannot return type of, " + std::to_string(f->type) + " symbol."); @@ -175,6 +182,52 @@ static inline ASR::abiType expr_abi(ASR::expr_t* e) { } } +static inline char *symbol_name(const ASR::symbol_t *f) +{ + switch (f->type) { + case ASR::symbolType::Program: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::Module: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::Function: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::GenericProcedure: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::StructType: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::EnumType: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::UnionType: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::Variable: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::ExternalSymbol: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::ClassProcedure: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::CustomOperator: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::AssociateBlock: { + return ASR::down_cast(f)->m_name; + } + case ASR::symbolType::Block: { + return ASR::down_cast(f)->m_name; + } + default : throw LCompilersException("Not implemented"); + } +} + static inline std::string type_to_str(const ASR::ttype_t *t) { switch (t->type) { @@ -206,7 +259,7 @@ static inline std::string type_to_str(const ASR::ttype_t *t) return "list"; } case ASR::ttypeType::Struct: { - return "derived type"; + return ASRUtils::symbol_name(ASR::down_cast(t)->m_derived_type); } case ASR::ttypeType::Union: { return "union"; @@ -267,52 +320,6 @@ static inline ASR::expr_t* expr_value(ASR::expr_t *f) return ASR::expr_value0(f); } -static inline char *symbol_name(const ASR::symbol_t *f) -{ - switch (f->type) { - case ASR::symbolType::Program: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::Module: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::Function: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::GenericProcedure: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::StructType: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::EnumType: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::UnionType: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::Variable: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::ExternalSymbol: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::ClassProcedure: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::CustomOperator: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::AssociateBlock: { - return ASR::down_cast(f)->m_name; - } - case ASR::symbolType::Block: { - return ASR::down_cast(f)->m_name; - } - default : throw LCompilersException("Not implemented"); - } -} - static inline std::pair symbol_dependencies(const ASR::symbol_t *f) { switch (f->type) { @@ -1231,6 +1238,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, ASR::binopType op, std::string& intrinsic_op_name, SymbolTable* curr_scope, ASR::asr_t*& asr, Allocator &al, const Location& loc, + std::set& current_function_dependencies, + Vec& current_module_dependencies, const std::function err); bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name, @@ -1240,6 +1249,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, ASR::cmpopType op, std::string& intrinsic_op_name, SymbolTable* curr_scope, ASR::asr_t*& asr, Allocator &al, const Location& loc, + std::set& current_function_dependencies, + Vec& current_module_dependencies, const std::function err); bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name, @@ -1248,6 +1259,8 @@ bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name, bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value, SymbolTable* curr_scope, ASR::asr_t*& asr, Allocator &al, const Location& loc, + std::set& current_function_dependencies, + Vec& /*current_module_dependencies*/, const std::function err); void set_intrinsic(ASR::symbol_t* sym); @@ -1318,6 +1331,27 @@ static inline bool is_generic(ASR::ttype_t &x) { } } +static inline bool is_generic_function(ASR::symbol_t *x) { + ASR::symbol_t* x2 = symbol_get_past_external(x); + switch (x2->type) { + case ASR::symbolType::Function: { + ASR::Function_t *func_sym = ASR::down_cast(x2); + return func_sym->n_type_params > 0 && !func_sym->m_is_restriction; + } + default: return false; + } +} + +static inline bool is_restriction_function(ASR::symbol_t *x) { + ASR::symbol_t* x2 = symbol_get_past_external(x); + switch (x2->type) { + case ASR::symbolType::Function: { + ASR::Function_t *func_sym = ASR::down_cast(x2); + return func_sym->m_is_restriction; + } + default: return false; + } +} static inline int get_body_size(ASR::symbol_t* s) { int n_body = 0; @@ -1678,7 +1712,12 @@ inline bool is_same_type_pointer(ASR::ttype_t* source, ASR::ttype_t* dest) { source = dest; dest = temp; } - bool res = source->type == ASR::down_cast(dest)->m_type->type; + dest = ASR::down_cast(dest)->m_type; + if( (ASR::is_a(*source) || ASR::is_a(*source)) && + (ASR::is_a(*dest) || ASR::is_a(*dest)) ) { + return true; + } + bool res = source->type == dest->type; return res; } @@ -1923,12 +1962,15 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer { Vec& orig_args; + std::set& current_function_dependencies; + public: ReplaceArgVisitor(Allocator& al_, SymbolTable* current_scope_, - ASR::Function_t* orig_func_, Vec& orig_args_) : + ASR::Function_t* orig_func_, Vec& orig_args_, + std::set& current_function_dependencies_) : al(al_), current_scope(current_scope_), orig_func(orig_func_), - orig_args(orig_args_) + orig_args(orig_args_), current_function_dependencies(current_function_dependencies_) {} void replace_FunctionCall(ASR::FunctionCall_t* x) { @@ -1984,6 +2026,7 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer { replace_expr(x->m_args[i].m_value); current_expr = current_expr_copy_; } + current_function_dependencies.insert(std::string(ASRUtils::symbol_name(new_es))); x->m_name = new_es; } diff --git a/src/libasr/asr_verify.cpp b/src/libasr/asr_verify.cpp index 3fe01476ee..d9e3bb57be 100644 --- a/src/libasr/asr_verify.cpp +++ b/src/libasr/asr_verify.cpp @@ -205,6 +205,13 @@ class VerifyVisitor : public BaseWalkVisitor current_symtab = parent_symtab; } + void visit_BlockCall(const BlockCall_t& x) { + require(x.m_m != nullptr, "Block call made to inexisting block"); + require(symtab_in_scope(current_symtab, x.m_m), + "Block " + std::string(ASRUtils::symbol_name(x.m_m)) + + " should resolve in current scope."); + } + void visit_Module(const Module_t &x) { module_dependencies.clear(); module_dependencies.reserve(x.n_dependencies); @@ -235,11 +242,13 @@ class VerifyVisitor : public BaseWalkVisitor "A module dependency must be a valid string"); } for( auto& dep: module_dependencies ) { - require(present(x.m_dependencies, x.n_dependencies, dep), - "Module " + std::string(x.m_name) + - " dependencies must contain " + dep + - " because a function present in it is getting called in " - + std::string(x.m_name) + "."); + if( dep != x.m_name ) { + require(present(x.m_dependencies, x.n_dependencies, dep), + "Module " + std::string(x.m_name) + + " dependencies must contain " + dep + + " because a function present in it is getting called in " + + std::string(x.m_name) + "."); + } } current_symtab = parent_symtab; } @@ -265,6 +274,44 @@ class VerifyVisitor : public BaseWalkVisitor BaseWalkVisitor::visit_Assignment(x); } + void visit_ClassProcedure(const ClassProcedure_t &x) { + require(x.m_name != nullptr, + "The ClassProcedure::m_name cannot be nullptr"); + require(x.m_proc != nullptr, + "The ClassProcedure::m_proc cannot be nullptr"); + require(x.m_proc_name != nullptr, + "The ClassProcedure::m_proc_name cannot be nullptr"); + + SymbolTable *symtab = x.m_parent_symtab; + require(symtab != nullptr, + "ClassProcedure::m_parent_symtab cannot be nullptr"); + require(symtab->get_symbol(std::string(x.m_name)) != nullptr, + "ClassProcedure '" + std::string(x.m_name) + "' not found in parent_symtab symbol table"); + symbol_t *symtab_sym = symtab->get_symbol(std::string(x.m_name)); + const symbol_t *current_sym = &x.base; + require(symtab_sym == current_sym, + "ClassProcedure's parent symbol table does not point to it"); + require(id_symtab_map.find(symtab->counter) != id_symtab_map.end(), + "ClassProcedure::m_parent_symtab must be present in the ASR (" + + std::string(x.m_name) + ")"); + + ASR::Function_t* x_m_proc = ASR::down_cast(x.m_proc); + if( x.m_self_argument ) { + bool arg_found = false; + std::string self_arg_name = std::string(x.m_self_argument); + for( size_t i = 0; i < x_m_proc->n_args; i++ ) { + std::string arg_name = std::string(ASRUtils::symbol_name( + ASR::down_cast(x_m_proc->m_args[i])->m_v)); + if( self_arg_name == arg_name ) { + arg_found = true; + break ; + } + } + require(arg_found, self_arg_name + " must be present in " + + std::string(x.m_name) + " procedures."); + } + } + void visit_Function(const Function_t &x) { function_dependencies.clear(); function_dependencies.reserve(x.n_dependencies); @@ -333,23 +380,24 @@ class VerifyVisitor : public BaseWalkVisitor if( ASR::is_a(*a.second) || ASR::is_a(*a.second) || ASR::is_a(*a.second) || - ASR::is_a(*a.second) ) { + ASR::is_a(*a.second) || + ASR::is_a(*a.second) ) { continue ; } ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(a.second)); char* aggregate_type_name = nullptr; + ASR::symbol_t* sym = nullptr; if( ASR::is_a(*var_type) ) { - ASR::symbol_t* sym = ASR::down_cast(var_type)->m_derived_type; + sym = ASR::down_cast(var_type)->m_derived_type; aggregate_type_name = ASRUtils::symbol_name(sym); } else if( ASR::is_a(*var_type) ) { - ASR::symbol_t* sym = ASR::down_cast(var_type)->m_enum_type; + sym = ASR::down_cast(var_type)->m_enum_type; aggregate_type_name = ASRUtils::symbol_name(sym); } else if( ASR::is_a(*var_type) ) { - ASR::symbol_t* sym = ASR::down_cast(var_type)->m_union_type; + sym = ASR::down_cast(var_type)->m_union_type; aggregate_type_name = ASRUtils::symbol_name(sym); } - if( aggregate_type_name && - !current_symtab->get_symbol(std::string(aggregate_type_name)) ) { + if( aggregate_type_name && ASRUtils::symbol_parent_symtab(sym) != current_symtab ) { struct_dependencies.push_back(std::string(aggregate_type_name)); require(present(x.m_dependencies, x.n_dependencies, std::string(aggregate_type_name)), std::string(x.m_name) + " depends on " + std::string(aggregate_type_name) @@ -619,11 +667,7 @@ class VerifyVisitor : public BaseWalkVisitor } SymbolTable *get_dt_symtab(ASR::expr_t *dt) { - require_impl(ASR::is_a(*dt), - "m_dt must point to a Var", dt->base.loc); - ASR::Var_t *var = ASR::down_cast(dt); - ASR::Variable_t *v = ASR::down_cast(var->m_v); - ASR::ttype_t *t2 = ASRUtils::type_get_past_pointer(v->m_type); + ASR::ttype_t *t2 = ASRUtils::type_get_past_pointer(ASRUtils::expr_type(dt)); ASR::symbol_t *type_sym=nullptr; switch (t2->type) { case (ASR::ttypeType::Struct): { @@ -660,11 +704,7 @@ class VerifyVisitor : public BaseWalkVisitor } ASR::symbol_t *get_parent_type_dt(ASR::expr_t *dt) { - require_impl(ASR::is_a(*dt), - "m_dt must point to a Var", dt->base.loc); - ASR::Var_t *var = ASR::down_cast(dt); - ASR::Variable_t *v = ASR::down_cast(var->m_v); - ASR::ttype_t *t2 = ASRUtils::type_get_past_pointer(v->m_type); + ASR::ttype_t *t2 = ASRUtils::type_get_past_pointer(ASRUtils::expr_type(dt)); ASR::symbol_t *type_sym=nullptr; ASR::symbol_t *parent = nullptr; switch (t2->type) { @@ -692,6 +732,10 @@ class VerifyVisitor : public BaseWalkVisitor return parent; } + void visit_PointerNullConstant(const PointerNullConstant_t& x) { + require(x.m_type != nullptr, "null() must have a type"); + } + void visit_FunctionCall(const FunctionCall_t &x) { require(x.m_name, "FunctionCall::m_name must be present"); diff --git a/src/libasr/bigint.h b/src/libasr/bigint.h index 40825158a6..d58628df2e 100644 --- a/src/libasr/bigint.h +++ b/src/libasr/bigint.h @@ -30,7 +30,13 @@ namespace BigInt { * in int64_t. * * To check if the integer has a pointer tag, we check that the first two bits - * (1-2) are equal to 01: + * (1-2) are equal to 01. + * + * If the first bit is 0, then it can either be a positive integer or a + * pointer. We check the second bit, if it is 1, then it is a pointer (shifted + * by 2), if it is 0, then is is a positive integer, represented by the rest of + * the 62 bits. If the first bit is 1, then it is a negative integer, + * represented by the full 64 bits in 2's complement representation. */ // Returns true if "i" is a pointer and false if "i" is an integer @@ -104,7 +110,7 @@ inline static bool is_int64(std::string str_repr) { if( str_repr.size() > str_int64.size() ) { return false; } - + if( str_repr.size() < str_int64.size() ) { return true; } diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 95e3b8a9be..138154e93b 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -1139,11 +1139,13 @@ R"(#include case 8: src = "(double)(" + src + ")"; break; default: throw CodeGenError("Cast IntegerToReal: Unsupported Kind " + std::to_string(dest_kind)); } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::RealToInteger) : { int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); src = "(int" + std::to_string(dest_kind * 8) + "_t)(" + src + ")"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::RealToReal) : { @@ -1166,6 +1168,7 @@ R"(#include } else { src = "std::complex(" + src + ")"; } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::ComplexToReal) : { @@ -1175,6 +1178,7 @@ R"(#include } else { src = "std::real(" + src + ")"; } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::RealToComplex) : { @@ -1184,18 +1188,22 @@ R"(#include } else { src = "std::complex(" + src + ")"; } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::LogicalToInteger) : { src = "(int)(" + src + ")"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::LogicalToCharacter) : { - src = src + " ? \"True\" : \"False\""; + src = "(" + src + " ? \"True\" : \"False\")"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::IntegerToLogical) : { src = "(bool)(" + src + ")"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::LogicalToReal) : { @@ -1205,18 +1213,22 @@ R"(#include case 8: src = "(double)(" + src + ")"; break; default: throw CodeGenError("Cast LogicalToReal: Unsupported Kind " + std::to_string(dest_kind)); } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::RealToLogical) : { src = "(bool)(" + src + ")"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::CharacterToLogical) : { src = "(bool)(strlen(" + src + ") > 0)"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::ComplexToLogical) : { src = "(bool)(" + src + ")"; + last_expr_precedence = 2; break; } case (ASR::cast_kindType::IntegerToCharacter) : { @@ -1235,6 +1247,7 @@ R"(#include } else { src = "std::to_string(" + src + ")"; } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::CharacterToInteger) : { @@ -1243,6 +1256,7 @@ R"(#include } else { src = "std::stoi(" + src + ")"; } + last_expr_precedence = 2; break; } case (ASR::cast_kindType::RealToCharacter) : { @@ -1258,12 +1272,12 @@ R"(#include } else { src = "std::to_string(" + src + ")"; } + last_expr_precedence = 2; break; } default : throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented", x.base.base.loc); } - last_expr_precedence = 2; } void visit_IntegerBitLen(const ASR::IntegerBitLen_t& x) { @@ -1475,7 +1489,7 @@ R"(#include src += ASRUtils::binop_to_str_python(x.m_op); if (right_precedence == 3) { src += "(" + right + ")"; - } else if (x.m_op == ASR::binopType::Sub) { + } else if (x.m_op == ASR::binopType::Sub || x.m_op == ASR::binopType::Div) { if (right_precedence < last_expr_precedence) { src += right; } else { @@ -1580,9 +1594,77 @@ R"(#include src = out; } - void visit_Select(const ASR::Select_t &/*x*/) { - std::string indent(indentation_level*indentation_spaces, ' '); - std::string out = indent + "// FIXME: select case()\n"; + void visit_Select(const ASR::Select_t& x) + { + std::string indent(indentation_level * indentation_spaces, ' '); + this->visit_expr(*x.m_test); + std::string var = std::move(src); + std::string out = indent + "if ("; + + for (size_t i = 0; i < x.n_body; i++) { + if (i > 0) + out += indent + "else if ("; + ASR::case_stmt_t* stmt = x.m_body[i]; + if (stmt->type == ASR::case_stmtType::CaseStmt) { + ASR::CaseStmt_t* case_stmt = ASR::down_cast(stmt); + for (size_t j = 0; j < case_stmt->n_test; j++) { + if (j > 0) + out += " || "; + this->visit_expr(*case_stmt->m_test[j]); + out += var + " == " + src; + } + out += ") {\n"; + indentation_level += 1; + for (size_t j = 0; j < case_stmt->n_body; j++) { + this->visit_stmt(*case_stmt->m_body[j]); + out += src; + } + out += indent + "}\n"; + indentation_level -= 1; + } else { + ASR::CaseStmt_Range_t* case_stmt_range + = ASR::down_cast(stmt); + std::string left, right; + if (case_stmt_range->m_start) { + this->visit_expr(*case_stmt_range->m_start); + left = std::move(src); + } + if (case_stmt_range->m_end) { + this->visit_expr(*case_stmt_range->m_end); + right = std::move(src); + } + if (left.empty() && right.empty()) { + diag.codegen_error_label( + "Empty range in select statement", { x.base.base.loc }, ""); + throw Abort(); + } + if (left.empty()) { + out += var + " <= " + right; + } else if (right.empty()) { + out += var + " >= " + left; + } else { + out += left + " <= " + var + " <= " + right; + } + out += ") {\n"; + indentation_level += 1; + for (size_t j = 0; j < case_stmt_range->n_body; j++) { + this->visit_stmt(*case_stmt_range->m_body[j]); + out += src; + } + out += indent + "}\n"; + indentation_level -= 1; + } + } + if (x.n_default) { + out += indent + "else {\n"; + indentation_level += 1; + for (size_t i = 0; i < x.n_default; i++) { + this->visit_stmt(*x.m_default[i]); + out += src; + } + out += indent + "}\n"; + indentation_level -= 1; + } src = out; } diff --git a/src/libasr/codegen/asr_to_cpp.cpp b/src/libasr/codegen/asr_to_cpp.cpp index ea7a005d20..da61503a49 100644 --- a/src/libasr/codegen/asr_to_cpp.cpp +++ b/src/libasr/codegen/asr_to_cpp.cpp @@ -214,8 +214,9 @@ class ASRToCPPVisitor : public BaseCCPPVisitor { std::string sub; bool use_ref = (v.m_intent == LFortran::ASRUtils::intent_out || - - v.m_intent == LFortran::ASRUtils::intent_inout); + v.m_intent == LFortran::ASRUtils::intent_inout || + v.m_intent == LFortran::ASRUtils::intent_unspecified + ); bool is_array = ASRUtils::is_array(v.m_type); bool dummy = LFortran::ASRUtils::is_arg_dummy(v.m_intent); if (ASRUtils::is_pointer(v.m_type)) { diff --git a/src/libasr/codegen/asr_to_julia.cpp b/src/libasr/codegen/asr_to_julia.cpp index aab8b156a8..477947d7b9 100644 --- a/src/libasr/codegen/asr_to_julia.cpp +++ b/src/libasr/codegen/asr_to_julia.cpp @@ -1490,22 +1490,26 @@ class ASRToJuliaVisitor : public ASR::BaseVisitor throw CodeGenError("Cast IntegerToReal: Unsupported Kind " + std::to_string(dest_kind)); } + last_expr_precedence = julia_prec::Base; break; } case (ASR::cast_kindType::RealToInteger): { int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); src = "trunc" + broadcast + "(Int" + std::to_string(dest_kind * 8) + ", " + src + ")"; + last_expr_precedence = julia_prec::Base; break; } case (ASR::cast_kindType::RealToReal): { // In Julia, we do not need to cast float to float explicitly: // src = src; + // last_expr_precedence = last_expr_precedence; break; } case (ASR::cast_kindType::IntegerToInteger): { // In Julia, we do not need to cast int <-> long long explicitly: // src = src; + // last_expr_precedence = last_expr_precedence; break; } case (ASR::cast_kindType::ComplexToComplex): { @@ -1513,29 +1517,33 @@ class ASRToJuliaVisitor : public ASR::BaseVisitor } case (ASR::cast_kindType::IntegerToComplex): { src = "complex" + broadcast + "(" + src + ")"; + last_expr_precedence = julia_prec::Base; break; } case (ASR::cast_kindType::ComplexToReal): { src = "real" + broadcast + "(" + src + ")"; + last_expr_precedence = julia_prec::Base; break; } case (ASR::cast_kindType::RealToComplex): { src = "complex" + broadcast + "(" + src + ")"; + last_expr_precedence = julia_prec::Base; break; } case (ASR::cast_kindType::LogicalToInteger): { src = "Int32" + broadcast + "(" + src + ")"; + last_expr_precedence = julia_prec::Base; break; } case (ASR::cast_kindType::IntegerToLogical): { src = "Bool" + broadcast + "(" + src + ")"; + last_expr_precedence = julia_prec::Base; break; } default: throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented", x.base.base.loc); } - last_expr_precedence = julia_prec::Base; } void visit_IntegerCompare(const ASR::IntegerCompare_t& x) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 1703904504..c413523457 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -820,7 +820,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor int member_idx = 0; for( auto itr = scope.begin(); itr != scope.end(); itr++ ) { if (!ASR::is_a(*itr->second) && - !ASR::is_a(*itr->second)) { + !ASR::is_a(*itr->second) && + !ASR::is_a(*itr->second)) { ASR::Variable_t* member = ASR::down_cast(itr->second); llvm::Type* mem_type = nullptr; switch( member->m_type->type ) { @@ -1487,6 +1488,46 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = builder->CreateCall(fn, {int_val}); } + void visit_Ichar(const ASR::Ichar_t &x) { + if (x.m_value) { + this->visit_expr_wrapper(x.m_value, true); + return; + } + this->visit_expr(*x.m_arg); + llvm::Value *c = tmp; + std::string runtime_func_name = "_lfortran_ichar"; + llvm::Function *fn = module->getFunction(runtime_func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getInt32Ty(context), { + llvm::Type::getInt8PtrTy(context) + }, false); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, runtime_func_name, *module); + } + tmp = builder->CreateCall(fn, {c}); + } + + void visit_Iachar(const ASR::Iachar_t &x) { + if (x.m_value) { + this->visit_expr_wrapper(x.m_value, true); + return; + } + this->visit_expr(*x.m_arg); + llvm::Value *c = tmp; + std::string runtime_func_name = "_lfortran_iachar"; + llvm::Function *fn = module->getFunction(runtime_func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getInt32Ty(context), { + llvm::Type::getInt8PtrTy(context) + }, false); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, runtime_func_name, *module); + } + tmp = builder->CreateCall(fn, {c}); + } + void visit_ListAppend(const ASR::ListAppend_t& x) { ASR::List_t* asr_list = ASR::down_cast(ASRUtils::expr_type(x.m_a)); int64_t ptr_loads_copy = ptr_loads; @@ -1729,6 +1770,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::ttype_t* x_mv_type = ASRUtils::expr_type(x.m_v); bool is_argument = false; llvm::Value* array = nullptr; + bool is_data_only = false; if( ASR::is_a(*x.m_v) ) { ASR::Variable_t *v = ASRUtils::EXPR2VAR(x.m_v); if( ASR::is_a(*ASRUtils::get_contained_type(v->m_type)) ) { @@ -1737,8 +1779,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor der_type_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(der_type->m_derived_type)); } uint32_t v_h = get_hash((ASR::asr_t*)v); - LFORTRAN_ASSERT(llvm_symtab.find(v_h) != llvm_symtab.end()); - array = llvm_symtab[v_h]; + if (llvm_symtab.find(v_h) == llvm_symtab.end()) { + LFORTRAN_ASSERT(std::find(nested_globals.begin(), + nested_globals.end(), v_h) != nested_globals.end()); + auto finder = std::find(nested_globals.begin(), + nested_globals.end(), v_h); + llvm::Constant *ptr = module->getOrInsertGlobal(nested_desc_name, + nested_global_struct); + int idx = std::distance(nested_globals.begin(), finder); + std::vector idx_vec = { + llvm::ConstantInt::get(context, llvm::APInt(32, 0)), + llvm::ConstantInt::get(context, llvm::APInt(32, idx))}; + array = CreateLoad(CreateGEP(ptr, idx_vec)); + is_data_only = true; + } else { + array = llvm_symtab[v_h]; + } is_argument = (v->m_intent == ASRUtils::intent_in) || (v->m_intent == ASRUtils::intent_out) || (v->m_intent == ASRUtils::intent_inout) @@ -1797,7 +1853,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::is_a(*x.m_v))) { array = CreateLoad(array); } - bool is_data_only = is_argument && !ASRUtils::is_dimension_empty(m_dims, n_dims); + is_data_only = is_data_only || (is_argument && !ASRUtils::is_dimension_empty(m_dims, n_dims)); is_data_only = is_data_only || is_bindc_array; Vec llvm_diminfo; llvm_diminfo.reserve(al, 2 * x.n_args + 1); @@ -2713,12 +2769,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor if( ASR::is_a(*item.second) || ASR::is_a(*item.second) || ASR::is_a(*item.second) || - ASR::is_a(*item.second) ) { + ASR::is_a(*item.second) || + ASR::is_a(*item.second) ) { continue ; } ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second); int idx = name2memidx[struct_type_name][item.first]; llvm::Value* ptr_member = llvm_utils->create_gep(ptr, idx); + if( ASR::is_a(*item.second) ) { + ASR::Variable_t* v = ASR::down_cast(item.second); + if( v->m_symbolic_value ) { + visit_expr(*v->m_symbolic_value); + LLVM::CreateStore(*builder, tmp, ptr_member); + } + } if( ASRUtils::is_array(symbol_type) ) { // Assume that struct member array is not allocatable ASR::dimension_t* m_dims = nullptr; @@ -2846,6 +2910,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor nested_global_struct); int idx = std::distance(nested_globals.begin(), finder); + if( is_array_type || is_malloc_array_type ) { + target_var = CreateLoad(target_var); + } builder->CreateStore(target_var, llvm_utils->create_gep(ptr, idx)); } @@ -4195,6 +4262,50 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor int idx = std::distance(nested_globals.begin(), finder); builder->CreateStore(target, llvm_utils->create_gep(ptr, idx)); } + if (is_a(*x.m_target)) { + ASR::ArrayItem_t *asr_target0 = ASR::down_cast(x.m_target); + if (is_a(*asr_target0->m_v)) { + ASR::Variable_t *asr_target = ASRUtils::EXPR2VAR(asr_target0->m_v); + h = get_hash((ASR::asr_t*)asr_target); + auto finder = std::find(nested_globals.begin(), + nested_globals.end(), h); + if (finder != nested_globals.end()) { + // This is used since array pass use array item visit + llvm::Constant *ptr = module->getOrInsertGlobal(nested_desc_name, + nested_global_struct); + int idx = std::distance(nested_globals.begin(), finder); + std::vector idx_vec = { + llvm::ConstantInt::get(context, llvm::APInt(32, 0)), + llvm::ConstantInt::get(context, llvm::APInt(32, idx))}; + llvm::Value* array = CreateGEP(ptr, idx_vec); + std::vector indices; + for( size_t r = 0; r < asr_target0->n_args; r++ ) { + ASR::array_index_t curr_idx = asr_target0->m_args[r]; + uint64_t ptr_loads_copy = ptr_loads; + ptr_loads = 2; + this->visit_expr_wrapper(curr_idx.m_right, true); + ptr_loads = ptr_loads_copy; + indices.push_back(tmp); + } + ASR::dimension_t* m_dims; + ASRUtils::extract_dimensions_from_ttype( + ASRUtils::expr_type(asr_target0->m_v), m_dims); + Vec llvm_diminfo; + llvm_diminfo.reserve(al, 2 * asr_target0->n_args + 1); + for( size_t idim = 0; idim < asr_target0->n_args; idim++ ) { + this->visit_expr_wrapper(m_dims[idim].m_start, true); + llvm::Value* dim_start = tmp; + this->visit_expr_wrapper(m_dims[idim].m_length, true); + llvm::Value* dim_size = tmp; + llvm_diminfo.push_back(al, dim_start); + llvm_diminfo.push_back(al, dim_size); + } + tmp = arr_descr->get_single_element(array, indices, asr_target0->n_args, + true, false, llvm_diminfo.p); + builder->CreateStore(target, tmp); + } + } + } } void visit_AssociateBlockCall(const ASR::AssociateBlockCall_t& x) { @@ -5941,7 +6052,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } else if ( x_abi == ASR::abiType::BindC ) { if( arr_descr->is_array(ASRUtils::get_contained_type(arg->m_type)) ) { - tmp = CreateLoad(arr_descr->get_pointer_to_data(tmp)); + // TODO: we need a dedicated and robust + // function that determines from ASR only + // if a given array is represented by + // a descriptor or with just a pointer. + // Until then we use the following heuristic: + bool arg_is_using_descriptor = true; + if (LLVMArrUtils::is_explicit_shape(arg)) { + if (arg->m_intent != intent_local) { + arg_is_using_descriptor = false; + } + } + if (arg_is_using_descriptor) { + tmp = CreateLoad(arr_descr->get_pointer_to_data(tmp)); + } } else { if (orig_arg->m_abi == ASR::abiType::BindC && orig_arg->m_value_attr) { @@ -6063,7 +6187,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ptr_loads = !LLVM::is_llvm_struct(arg_type); this->visit_expr_wrapper(x.m_args[i].m_value); if( x_abi == ASR::abiType::BindC ) { - if( ASR::is_a(*x.m_args[i].m_value) || + if( (ASR::is_a(*x.m_args[i].m_value) && + orig_arg_intent == ASR::intentType::In) || ASR::is_a(*x.m_args[i].m_value) || (ASR::is_a(*arg_type) && ASR::is_a(*x.m_args[i].m_value)) ) { @@ -6167,13 +6292,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor && orig_arg->m_value_attr) { use_value = true; } + if (ASR::is_a(*x.m_args[i].m_value)) { + use_value = true; + } if (!use_value) { // Create alloca to get a pointer, but do it // at the beginning of the function to avoid // using alloca inside a loop, which would // run out of stack - if( ASR::is_a(*x.m_args[i].m_value) || - ASR::is_a(*x.m_args[i].m_value) ) { + if( (ASR::is_a(*x.m_args[i].m_value) || + ASR::is_a(*x.m_args[i].m_value)) + && value->getType()->isPointerTy()) { value = CreateLoad(value); } if( !ASR::is_a(*arg_type) ) { diff --git a/src/libasr/codegen/asr_to_wasm.cpp b/src/libasr/codegen/asr_to_wasm.cpp index 3a2740c08f..aa365bf6bb 100644 --- a/src/libasr/codegen/asr_to_wasm.cpp +++ b/src/libasr/codegen/asr_to_wasm.cpp @@ -547,7 +547,8 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { m_var_name_idx_map[get_hash((ASR::asr_t *)arg)] = s->no_of_variables++; if (arg->m_intent == ASR::intentType::Out || - arg->m_intent == ASR::intentType::InOut) { + arg->m_intent == ASR::intentType::InOut || + arg->m_intent == ASR::intentType::Unspecified) { s->referenced_vars.push_back(m_al, arg); } } @@ -564,7 +565,8 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { for (size_t i = 0; i < x.n_args; i++) { ASR::Variable_t *arg = ASRUtils::EXPR2VAR(x.m_args[i]); if (arg->m_intent == ASR::intentType::Out || - arg->m_intent == ASR::intentType::InOut) { + arg->m_intent == ASR::intentType::InOut || + arg->m_intent == ASR::intentType::Unspecified) { emit_var_type(m_type_section, arg); } } @@ -1624,14 +1626,15 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { // sym_name = "_xx_lcompilers_changed_exit_xx"; // } - Vec intent_out_passed_vars; - intent_out_passed_vars.reserve(m_al, s->n_args); + Vec vars_passed_by_refs; + vars_passed_by_refs.reserve(m_al, s->n_args); if (x.n_args == s->n_args) { for (size_t i = 0; i < x.n_args; i++) { ASR::Variable_t *arg = ASRUtils::EXPR2VAR(s->m_args[i]); - if (arg->m_intent == ASRUtils::intent_out) { - intent_out_passed_vars.push_back( - m_al, ASRUtils::EXPR2VAR(x.m_args[i].m_value)); + if (arg->m_intent == ASRUtils::intent_out || + arg->m_intent == ASRUtils::intent_inout || + arg->m_intent == ASRUtils::intent_unspecified) { + vars_passed_by_refs.push_back(m_al, x.m_args[i].m_value); } visit_expr(*x.m_args[i].m_value); } @@ -1645,13 +1648,24 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { m_func_name_idx_map.end()) wasm::emit_call(m_code_section, m_al, m_func_name_idx_map[get_hash((ASR::asr_t *)s)]->index); - for (auto return_var : intent_out_passed_vars) { - LFORTRAN_ASSERT( + for (auto return_expr : vars_passed_by_refs) { + if (ASR::is_a(*return_expr)) { + auto return_var = ASRUtils::EXPR2VAR(return_expr); + LFORTRAN_ASSERT( m_var_name_idx_map.find(get_hash((ASR::asr_t *)return_var)) != m_var_name_idx_map.end()); - wasm::emit_set_local( - m_code_section, m_al, - m_var_name_idx_map[get_hash((ASR::asr_t *)return_var)]); + wasm::emit_set_local( + m_code_section, m_al, + m_var_name_idx_map[get_hash((ASR::asr_t *)return_var)]); + } else if (ASR::is_a(*return_expr)) { + // emit_memory_store(ASRUtils::EXPR(return_var)); + + throw CodeGenError( + "Passing array elements as arguments (with intent out, " + "inout, unspecified) to Subroutines is not yet supported"); + } else { + LFORTRAN_ASSERT(false); + } } } diff --git a/src/libasr/codegen/llvm_array_utils.cpp b/src/libasr/codegen/llvm_array_utils.cpp index bd34175ac6..4298b89417 100644 --- a/src/libasr/codegen/llvm_array_utils.cpp +++ b/src/libasr/codegen/llvm_array_utils.cpp @@ -33,10 +33,7 @@ namespace LFortran { is_ok = false; break; } - if( (m_dims[r].m_length != nullptr && - m_dims[r].m_length->type != ASR::exprType::IntegerConstant) || - (m_dims[r].m_start != nullptr && - m_dims[r].m_start->type != ASR::exprType::IntegerConstant) ) { + if( m_dims[r].m_length == nullptr ) { is_ok = false; break; } diff --git a/src/libasr/pass/class_constructor.cpp b/src/libasr/pass/class_constructor.cpp index 62921aba4f..de3ebcf313 100644 --- a/src/libasr/pass/class_constructor.cpp +++ b/src/libasr/pass/class_constructor.cpp @@ -21,24 +21,47 @@ class ClassConstructorVisitor : public PassUtils::PassVisitor(x); if( x.m_value->type == ASR::exprType::StructTypeConstructor ) { is_constructor_present = true; - result_var = x.m_target; - visit_expr(*x.m_value); + if( x.m_overloaded == nullptr ) { + result_var = x.m_target; + visit_expr(*x.m_value); + } else { + std::string result_var_name = current_scope->get_unique_name("temp_struct_var__"); + result_var = PassUtils::create_auxiliary_variable(x.m_value->base.loc, result_var_name, + al, current_scope, ASRUtils::expr_type(x.m_target)); + visit_expr(*x.m_value); + ASR::stmt_t* x_m_overloaded = x.m_overloaded; + if( ASR::is_a(*x.m_overloaded) ) { + ASR::SubroutineCall_t* assign_call = ASR::down_cast(xx.m_overloaded); + Vec assign_call_args; + assign_call_args.reserve(al, 2); + assign_call_args.push_back(al, assign_call->m_args[0]); + ASR::call_arg_t arg_1; + arg_1.loc = assign_call->m_args[1].loc; + arg_1.m_value = result_var; + assign_call_args.push_back(al, arg_1); + x_m_overloaded = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x_m_overloaded->base.loc, + assign_call->m_name, assign_call->m_original_name, assign_call_args.p, + assign_call_args.size(), assign_call->m_dt)); + } + pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, + result_var, x_m_overloaded))); + } } + } void visit_StructTypeConstructor(const ASR::StructTypeConstructor_t &x) { - is_init_constructor = true; if( x.n_args == 0 ) { remove_original_stmt = true; } diff --git a/src/libasr/pass/instantiate_template.cpp b/src/libasr/pass/instantiate_template.cpp index 88255cb0a8..25acedf90c 100644 --- a/src/libasr/pass/instantiate_template.cpp +++ b/src/libasr/pass/instantiate_template.cpp @@ -30,7 +30,7 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator(x->m_symtab->parent); + current_scope = al.make_new(func_scope); Vec args; args.reserve(al, x->n_args); @@ -150,7 +150,7 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicatorm_static, nullptr, 0, nullptr, 0, false); ASR::symbol_t *t = ASR::down_cast(result); - x->m_symtab->parent->add_symbol(new_func_name, t); + func_scope->add_symbol(new_func_name, t); return result; } @@ -199,15 +199,7 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicatorm_target); - ASR::ttype_t *target_type = substitute_type(ASRUtils::expr_type(x->m_target)); ASR::expr_t *value = duplicate_expr(x->m_value); - if (ASRUtils::is_real(*target_type) && ASR::is_a(*x->m_value)) { - ASR::IntegerConstant_t *int_value = ASR::down_cast(x->m_value); - if (int_value->m_n == 0) { - value = ASRUtils::EXPR(ASR::make_RealConstant_t(al, value->base.loc, 0, - ASRUtils::duplicate_type(al, target_type))); - } - } ASR::stmt_t *overloaded = duplicate_stmt(x->m_overloaded); return ASR::make_Assignment_t(al, x->base.base.loc, target, value, overloaded); } @@ -257,39 +249,45 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicatorm_value); ASR::expr_t* dt = duplicate_expr(x->m_dt); std::string call_name = ASRUtils::symbol_name(x->m_name); - //for (ASR::Function_t* rt: rts) { - // if (call_name.compare(rt->m_name) == 0) { - if (rt_subs.find(call_name) == rt_subs.end()) { - if (call_name.compare("add") == 0) { - ASR::expr_t* left_arg = duplicate_expr(x->m_args[0].m_value); - ASR::expr_t* right_arg = duplicate_expr(x->m_args[1].m_value); - ASR::ttype_t* left_type = substitute_type(ASRUtils::expr_type(left_arg)); - ASR::ttype_t* right_type = substitute_type(ASRUtils::expr_type(right_arg)); - if ((ASRUtils::is_integer(*left_type) && ASRUtils::is_integer(*right_type)) || - (ASRUtils::is_real(*left_type) && ASRUtils::is_real(*right_type))) { - return make_BinOp_helper(left_arg, right_arg, ASR::binopType::Add, x->base.base.loc); - } else { - throw SemanticError("Intrinsic plus not yet supported for this type", x->base.base.loc); - } - } else if (call_name.compare("zero") == 0) { - ASR::expr_t* arg = duplicate_expr(x->m_args[0].m_value); - ASR::ttype_t* arg_type = substitute_type(ASRUtils::expr_type(arg)); - if (ASRUtils::is_integer(*arg_type)) { - return ASR::make_IntegerConstant_t(al, x->base.base.loc, 0, arg_type); - } else if (ASRUtils::is_real(*arg_type)) { - return ASR::make_RealConstant_t(al, x->base.base.loc, 0, arg_type); - } - } else if (call_name.compare("div") == 0) { - ASR::expr_t* left_arg = duplicate_expr(x->m_args[0].m_value); - ASR::expr_t* right_arg = duplicate_expr(x->m_args[1].m_value); - return make_BinOp_helper(left_arg, right_arg, ASR::binopType::Div, x->base.base.loc); - } - LFORTRAN_ASSERT(false); // should never happen + if ((name && ASRUtils::is_restriction_function(name) && rt_subs.find(call_name) == rt_subs.end()) || + !name) { + if (call_name.compare("add") == 0) { + ASR::expr_t* left_arg = duplicate_expr(x->m_args[0].m_value); + ASR::expr_t* right_arg = duplicate_expr(x->m_args[1].m_value); + ASR::ttype_t* left_type = substitute_type(ASRUtils::expr_type(left_arg)); + ASR::ttype_t* right_type = substitute_type(ASRUtils::expr_type(right_arg)); + if ((ASRUtils::is_integer(*left_type) && ASRUtils::is_integer(*right_type)) || + (ASRUtils::is_real(*left_type) && ASRUtils::is_real(*right_type))) { + return make_BinOp_helper(left_arg, right_arg, ASR::binopType::Add, x->base.base.loc); + } else { + throw SemanticError("Intrinsic plus not yet supported for this type", x->base.base.loc); + } + } else if (call_name.compare("zero") == 0) { + ASR::expr_t* arg = duplicate_expr(x->m_args[0].m_value); + ASR::ttype_t* arg_type = substitute_type(ASRUtils::expr_type(arg)); + if (ASRUtils::is_integer(*arg_type)) { + return ASR::make_IntegerConstant_t(al, x->base.base.loc, 0, arg_type); + } else if (ASRUtils::is_real(*arg_type)) { + return ASR::make_RealConstant_t(al, x->base.base.loc, 0, arg_type); } - name = rt_subs[call_name]; - // } - //} - // TODO: Nested generic function call + } else if (call_name.compare("div") == 0) { + ASR::expr_t* left_arg = duplicate_expr(x->m_args[0].m_value); + ASR::expr_t* right_arg = duplicate_expr(x->m_args[1].m_value); + return make_BinOp_helper(left_arg, right_arg, ASR::binopType::Div, x->base.base.loc); + } + LFORTRAN_ASSERT(false); // should never happen + name = rt_subs[call_name]; + } + if (ASRUtils::is_restriction_function(name)) { + name = rt_subs[call_name]; + } else if (ASRUtils::is_generic_function(name)) { + std::string nested_func_name = "__lfortran_generic_" + sym_name; + ASR::symbol_t* name2 = ASRUtils::symbol_get_past_external(name); + ASR::Function_t* func = ASR::down_cast(name2); + FunctionInstantiator nested_tf(al, subs, rt_subs, func_scope, nested_func_name); + ASR::asr_t* nested_generic_func = nested_tf.instantiate_Function(func); + name = ASR::down_cast(nested_generic_func); + } dependencies.insert(std::string(ASRUtils::symbol_name(name))); return ASR::make_FunctionCall_t(al, x->base.base.loc, name, x->m_original_name, args.p, args.size(), type, value, dt); @@ -427,7 +425,17 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator subs, - std::map rt_subs, SymbolTable *current_scope, + std::map& rt_subs, SymbolTable *current_scope, + std::string new_func_name, ASR::symbol_t *sym) { + ASR::symbol_t* sym2 = ASRUtils::symbol_get_past_external(sym); + ASR::Function_t* func = ASR::down_cast(sym2); + FunctionInstantiator tf(al, subs, rt_subs, current_scope, new_func_name); + ASR::asr_t *new_function = tf.instantiate_Function(func); + return ASR::down_cast(new_function); +} + +ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map subs, + std::map& rt_subs, SymbolTable *current_scope, std::string new_func_name, ASR::Function_t *func) { FunctionInstantiator tf(al, subs, rt_subs, current_scope, new_func_name); ASR::asr_t *new_function = tf.instantiate_Function(func); diff --git a/src/libasr/pass/instantiate_template.h b/src/libasr/pass/instantiate_template.h index 41448f26cd..fd1b086427 100644 --- a/src/libasr/pass/instantiate_template.h +++ b/src/libasr/pass/instantiate_template.h @@ -10,9 +10,13 @@ namespace LFortran { * contain any type parameters and restrictions. No type checking * is executed here */ - ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, - std::map subs, std::map rt_subs, - SymbolTable *current_scope, std::string new_func_name, ASR::Function_t *func); + ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, + std::map subs, std::map& rt_subs, + SymbolTable *current_scope, std::string new_func_name, ASR::symbol_t *sym); + + ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, + std::map subs, std::map& rt_subs, + SymbolTable *current_scope, std::string new_func_name, ASR::Function_t *sym); } -#endif // LFORTRAN_PASS_TEMPLATE_VISITOR_H \ No newline at end of file +#endif // LFORTRAN_PASS_TEMPLATE_VISITOR_H diff --git a/src/libasr/pass/nested_vars.cpp b/src/libasr/pass/nested_vars.cpp index 9d066838c8..90d118c371 100644 --- a/src/libasr/pass/nested_vars.cpp +++ b/src/libasr/pass/nested_vars.cpp @@ -228,7 +228,7 @@ class NestedVarVisitor : public ASR::BaseWalkVisitor } void visit_FunctionCall(const ASR::FunctionCall_t &x) { - if (nesting_depth == 1){ + if (nesting_depth) { // Have to save all the calls out and make sure they are not solely // to the nested function ASR::Function_t *s = ASR::down_cast( @@ -238,12 +238,17 @@ class NestedVarVisitor : public ASR::BaseWalkVisitor calls_to.end()){ calls_to.push_back(call_hash); } + for (size_t i=0; i(*x.m_args[i].m_value)) { + visit_Var(*ASR::down_cast(x.m_args[i].m_value)); + } + } calls_out = true; } } void visit_SubroutineCall(const ASR::SubroutineCall_t &x) { - if (nesting_depth == 1){ + if (nesting_depth) { ASR::Function_t *s = ASR::down_cast( LFortran::ASRUtils::symbol_get_past_external(x.m_name)); uint32_t call_hash = get_hash((ASR::asr_t*)s); @@ -251,6 +256,11 @@ class NestedVarVisitor : public ASR::BaseWalkVisitor calls_to.end()){ calls_to.push_back(call_hash); } + for (size_t i=0; i(*x.m_args[i].m_value)) { + visit_Var(*ASR::down_cast(x.m_args[i].m_value)); + } + } calls_out = true; } } diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index e1ebf1e4ce..c342e8ca81 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -128,7 +128,7 @@ namespace LCompilers { "forall", "select_case", "inline_function_calls", - "unused_functions", + "unused_functions" }; _with_optimization_passes = { diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 7a52a0020b..0eaa93ec02 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -191,6 +191,39 @@ namespace LFortran { transform_stmts(xx.m_body, xx.n_body); } + void visit_If(const ASR::If_t& x) { + ASR::If_t &xx = const_cast(x); + self().visit_expr(*xx.m_test); + transform_stmts(xx.m_body, xx.n_body); + transform_stmts(xx.m_orelse, xx.n_orelse); + } + + void visit_CaseStmt(const ASR::CaseStmt_t& x) { + ASR::CaseStmt_t &xx = const_cast(x); + for (size_t i=0; i(x); + if (xx.m_start) + self().visit_expr(*xx.m_start); + if (xx.m_end) + self().visit_expr(*xx.m_end); + transform_stmts(xx.m_body, xx.n_body); + } + + void visit_Select(const ASR::Select_t& x) { + ASR::Select_t &xx = const_cast(x); + self().visit_expr(*xx.m_test); + for (size_t i=0; i diff --git a/src/libasr/pass/print_arr.cpp b/src/libasr/pass/print_arr.cpp index ec271c0c46..a9e4c3db07 100644 --- a/src/libasr/pass/print_arr.cpp +++ b/src/libasr/pass/print_arr.cpp @@ -25,6 +25,21 @@ The function `pass_replace_print_arr` transforms the ASR tree in-place. do i = 1, 3 print *, y(i) end do + + +Converts: + a: not_array + b: array + c: not_array + d: not_array + print *, a, b(1:10), c, d + +to: + print *, a + do i = 1, 10 + print *, b(i) + end do + print *, c, d */ class PrintArrVisitor : public PassUtils::PassVisitor @@ -38,41 +53,78 @@ class PrintArrVisitor : public PassUtils::PassVisitor } + ASR::stmt_t* print_array_using_doloop(ASR::expr_t *arr_expr, const Location &loc) { + int n_dims = PassUtils::get_rank(arr_expr); + Vec idx_vars; + PassUtils::create_idx_vars(idx_vars, n_dims, loc, al, current_scope); + ASR::stmt_t* doloop = nullptr; + ASR::stmt_t* empty_print_endl = LFortran::ASRUtils::STMT(ASR::make_Print_t(al, loc, + nullptr, nullptr, 0, nullptr, nullptr)); + for( int i = n_dims - 1; i >= 0; i-- ) { + ASR::do_loop_head_t head; + head.m_v = idx_vars[i]; + head.m_start = PassUtils::get_bound(arr_expr, i + 1, "lbound", al); + head.m_end = PassUtils::get_bound(arr_expr, i + 1, "ubound", al); + head.m_increment = nullptr; + head.loc = head.m_v->base.loc; + Vec doloop_body; + doloop_body.reserve(al, 1); + if( doloop == nullptr ) { + ASR::expr_t* ref = PassUtils::create_array_ref(arr_expr, idx_vars, al); + Vec print_args; + print_args.reserve(al, 1); + print_args.push_back(al, ref); + ASR::stmt_t* print_stmt = LFortran::ASRUtils::STMT(ASR::make_Print_t(al, loc, nullptr, + print_args.p, print_args.size(), nullptr, nullptr)); + doloop_body.push_back(al, print_stmt); + } else { + doloop_body.push_back(al, doloop); + doloop_body.push_back(al, empty_print_endl); + } + doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, loc, head, doloop_body.p, doloop_body.size())); + } + return doloop; + } + void visit_Print(const ASR::Print_t& x) { - if( x.n_values == 1 && PassUtils::is_array(x.m_values[0]) ) { - ASR::expr_t* arr_expr = x.m_values[0]; - - int n_dims = PassUtils::get_rank(arr_expr); - Vec idx_vars; - PassUtils::create_idx_vars(idx_vars, n_dims, x.base.base.loc, al, current_scope); - ASR::stmt_t* doloop = nullptr; - ASR::stmt_t* empty_print_endl = LFortran::ASRUtils::STMT(ASR::make_Print_t(al, x.base.base.loc, - nullptr, nullptr, 0, nullptr, nullptr)); - for( int i = n_dims - 1; i >= 0; i-- ) { - ASR::do_loop_head_t head; - head.m_v = idx_vars[i]; - head.m_start = PassUtils::get_bound(arr_expr, i + 1, "lbound", al); - head.m_end = PassUtils::get_bound(arr_expr, i + 1, "ubound", al); - head.m_increment = nullptr; - head.loc = head.m_v->base.loc; - Vec doloop_body; - doloop_body.reserve(al, 1); - if( doloop == nullptr ) { - ASR::expr_t* ref = PassUtils::create_array_ref(arr_expr, idx_vars, al); - Vec print_args; - print_args.reserve(al, 1); - print_args.push_back(al, ref); - ASR::stmt_t* print_stmt = LFortran::ASRUtils::STMT(ASR::make_Print_t(al, x.base.base.loc, nullptr, - print_args.p, print_args.size(), nullptr, nullptr)); - doloop_body.push_back(al, print_stmt); - } else { - doloop_body.push_back(al, doloop); - doloop_body.push_back(al, empty_print_endl); + std::vector print_body; + ASR::stmt_t* empty_print_endl = LFortran::ASRUtils::STMT(ASR::make_Print_t(al, x.base.base.loc, + nullptr, nullptr, 0, nullptr, nullptr)); + ASR::stmt_t* print_stmt; + for (size_t i=0; i(*ASRUtils::expr_type(x.m_values[i])) && + PassUtils::is_array(x.m_values[i])) { + if (print_body.size() > 0) { + Vec body; + body.reserve(al, print_body.size()); + for (size_t j=0; j 0) { + Vec body; + body.reserve(al, print_body.size()); + for (size_t j=0; j { ASRUtils::ExprStmtDuplicator expr_duplicator(al); expr_duplicator.allow_procedure_calls = true; ASRUtils::ReplaceArgVisitor arg_replacer(al, current_scope, orig_func, - orig_args); + orig_args, dependencies); for( size_t i = 0; i < exprs.size(); i++ ) { ASR::expr_t* expri = exprs[i]; if (expri) { @@ -1336,7 +1336,7 @@ class CommonVisitor : public AST::BaseVisitor { * arguments. If not, then instantiate a new function. */ ASR::symbol_t* get_generic_function(std::map subs, - std::map rt_subs, ASR::Function_t *func) { + std::map& rt_subs, ASR::Function_t *func) { int new_function_num; ASR::symbol_t *t; std::string func_name = func->m_name; @@ -1371,7 +1371,7 @@ class CommonVisitor : public AST::BaseVisitor { std::string new_func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num); generic_func_subs[new_func_name] = subs; - t = pass_instantiate_generic_function(al, subs, rt_subs, current_scope, + t = pass_instantiate_generic_function(al, subs, rt_subs, func->m_symtab->parent, new_func_name, func); dependencies.erase(func_name); dependencies.insert(new_func_name);