|
| 1 | +#include <libasr/asr.h> |
| 2 | +#include <libasr/containers.h> |
| 3 | +#include <libasr/exception.h> |
| 4 | +#include <libasr/asr_utils.h> |
| 5 | +#include <libasr/asr_verify.h> |
| 6 | +#include <libasr/pass/pass_utils.h> |
| 7 | +#include <libasr/pass/print_list.h> |
| 8 | + |
| 9 | +namespace LFortran { |
| 10 | + |
| 11 | +/* |
| 12 | +This ASR pass replaces print list with print every value, |
| 13 | +comma_space, brackets and newline. The function |
| 14 | +`pass_replace_print_list` transforms the ASR tree in-place. |
| 15 | +
|
| 16 | +Converts: |
| 17 | +
|
| 18 | + print(l, sep="pqr", end="xyz") # l is a list |
| 19 | +
|
| 20 | +to: |
| 21 | +
|
| 22 | + print("[", end="") |
| 23 | + for i in range(len(l)): |
| 24 | + print(l[i], end="") |
| 25 | + if i < len(l) - 1: |
| 26 | + print(", ", end="") |
| 27 | + print("]", sep="pqr", end="xyz") |
| 28 | +
|
| 29 | +for nested lists it transforms to: |
| 30 | +
|
| 31 | + print("[", end="") |
| 32 | + for i in range(len(l)): |
| 33 | + # print(l[i], end="") is replaced by the following code |
| 34 | +
|
| 35 | + print("[", end="") |
| 36 | + for i1 in range(len(l[i])): |
| 37 | + print(l[i][i1], end="") |
| 38 | + if i1 < len(l[i]) - 1: |
| 39 | + print(", ", end="") |
| 40 | + print("]") |
| 41 | +
|
| 42 | + if i < len(l) - 1: |
| 43 | + print(", ", end="") |
| 44 | + print("]", sep="pqr", end="xyz") |
| 45 | +
|
| 46 | +Note: In code, the variable `i` is named as `__list_iterator` |
| 47 | +*/ |
| 48 | + |
| 49 | +class PrintListVisitor |
| 50 | + : public PassUtils::PassVisitor<PrintListVisitor> { |
| 51 | + private: |
| 52 | + std::string rl_path; |
| 53 | + |
| 54 | + public: |
| 55 | + PrintListVisitor(Allocator &al, const std::string &rl_path_) |
| 56 | + : PassVisitor(al, nullptr), rl_path(rl_path_) { |
| 57 | + pass_result.reserve(al, 1); |
| 58 | + } |
| 59 | + |
| 60 | + void visit_Print(const ASR::Print_t &x) { |
| 61 | + if (x.n_values == 1 && |
| 62 | + ASR::is_a<ASR::List_t>(*ASRUtils::expr_type(x.m_values[0]))) { |
| 63 | + ASR::List_t *listC = |
| 64 | + ASR::down_cast<ASR::List_t>(ASRUtils::expr_type(x.m_values[0])); |
| 65 | + |
| 66 | + ASR::ttype_t *int_type = ASRUtils::TYPE( |
| 67 | + ASR::make_Integer_t(al, x.base.base.loc, 4, nullptr, 0)); |
| 68 | + ASR::ttype_t *bool_type = ASRUtils::TYPE( |
| 69 | + ASR::make_Logical_t(al, x.base.base.loc, 4, nullptr, 0)); |
| 70 | + ASR::ttype_t *str_type_len_0 = ASRUtils::TYPE(ASR::make_Character_t( |
| 71 | + al, x.base.base.loc, 1, 0, nullptr, nullptr, 0)); |
| 72 | + ASR::ttype_t *str_type_len_1 = ASRUtils::TYPE(ASR::make_Character_t( |
| 73 | + al, x.base.base.loc, 1, 1, nullptr, nullptr, 0)); |
| 74 | + ASR::ttype_t *str_type_len_2 = ASRUtils::TYPE(ASR::make_Character_t( |
| 75 | + al, x.base.base.loc, 1, 2, nullptr, nullptr, 0)); |
| 76 | + ASR::expr_t *comma_space = |
| 77 | + ASRUtils::EXPR(ASR::make_StringConstant_t( |
| 78 | + al, x.base.base.loc, s2c(al, ", "), str_type_len_2)); |
| 79 | + ASR::expr_t *single_quote = |
| 80 | + ASRUtils::EXPR(ASR::make_StringConstant_t( |
| 81 | + al, x.base.base.loc, s2c(al, "'"), str_type_len_1)); |
| 82 | + ASR::expr_t *empty_str = ASRUtils::EXPR(ASR::make_StringConstant_t( |
| 83 | + al, x.base.base.loc, s2c(al, ""), str_type_len_0)); |
| 84 | + ASR::expr_t *open_bracket = |
| 85 | + ASRUtils::EXPR(ASR::make_StringConstant_t( |
| 86 | + al, x.base.base.loc, s2c(al, "["), str_type_len_1)); |
| 87 | + ASR::expr_t *close_bracket = |
| 88 | + ASRUtils::EXPR(ASR::make_StringConstant_t( |
| 89 | + al, x.base.base.loc, s2c(al, "]"), str_type_len_1)); |
| 90 | + |
| 91 | + std::string list_iter_var_name; |
| 92 | + ASR::symbol_t *list_iter_variable; |
| 93 | + ASR::expr_t *list_iter_var; |
| 94 | + { |
| 95 | + list_iter_var_name = |
| 96 | + current_scope->get_unique_name("__list_iterator"); |
| 97 | + list_iter_variable = |
| 98 | + ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t( |
| 99 | + al, x.base.base.loc, current_scope, |
| 100 | + s2c(al, list_iter_var_name), ASR::intentType::Local, |
| 101 | + nullptr, nullptr, ASR::storage_typeType::Default, |
| 102 | + int_type, ASR::abiType::Source, ASR::accessType::Public, |
| 103 | + ASR::presenceType::Required, false)); |
| 104 | + |
| 105 | + current_scope->add_symbol(list_iter_var_name, |
| 106 | + list_iter_variable); |
| 107 | + list_iter_var = ASRUtils::EXPR( |
| 108 | + ASR::make_Var_t(al, x.base.base.loc, list_iter_variable)); |
| 109 | + } |
| 110 | + |
| 111 | + ASR::expr_t *list_item = ASRUtils::EXPR( |
| 112 | + ASR::make_ListItem_t(al, x.base.base.loc, x.m_values[0], |
| 113 | + list_iter_var, listC->m_type, nullptr)); |
| 114 | + ASR::expr_t *list_len = ASRUtils::EXPR(ASR::make_ListLen_t( |
| 115 | + al, x.base.base.loc, x.m_values[0], int_type, nullptr)); |
| 116 | + ASR::expr_t *constant_one = ASRUtils::EXPR( |
| 117 | + ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int_type)); |
| 118 | + ASR::expr_t *list_len_minus_one = |
| 119 | + ASRUtils::EXPR(ASR::make_IntegerBinOp_t( |
| 120 | + al, x.base.base.loc, list_len, ASR::binopType::Sub, |
| 121 | + constant_one, int_type, nullptr)); |
| 122 | + ASR::expr_t *compare_cond = |
| 123 | + ASRUtils::EXPR(ASR::make_IntegerCompare_t( |
| 124 | + al, x.base.base.loc, list_iter_var, ASR::cmpopType::Lt, |
| 125 | + list_len_minus_one, bool_type, nullptr)); |
| 126 | + |
| 127 | + Vec<ASR::expr_t *> v1, v2, v3, v4; |
| 128 | + v1.reserve(al, 1); |
| 129 | + v3.reserve(al, 1); |
| 130 | + v4.reserve(al, 1); |
| 131 | + v1.push_back(al, open_bracket); |
| 132 | + v3.push_back(al, close_bracket); |
| 133 | + v4.push_back(al, comma_space); |
| 134 | + |
| 135 | + if (ASR::is_a<ASR::Character_t>(*listC->m_type)) { |
| 136 | + v2.reserve(al, 3); |
| 137 | + v2.push_back(al, single_quote); |
| 138 | + v2.push_back(al, list_item); |
| 139 | + v2.push_back(al, single_quote); |
| 140 | + } else { |
| 141 | + v2.reserve(al, 1); |
| 142 | + v2.push_back(al, list_item); |
| 143 | + } |
| 144 | + |
| 145 | + ASR::stmt_t *print_open_bracket = LFortran::ASRUtils::STMT( |
| 146 | + ASR::make_Print_t(al, x.base.base.loc, nullptr, v1.p, v1.size(), |
| 147 | + nullptr, empty_str)); |
| 148 | + ASR::stmt_t *print_comma_space = ASRUtils::STMT( |
| 149 | + ASR::make_Print_t(al, x.base.base.loc, nullptr, v4.p, v4.size(), |
| 150 | + empty_str, empty_str)); |
| 151 | + ASR::stmt_t *print_item = ASRUtils::STMT( |
| 152 | + ASR::make_Print_t(al, x.base.base.loc, nullptr, v2.p, v2.size(), |
| 153 | + empty_str, empty_str)); |
| 154 | + ASR::stmt_t *print_close_bracket = LFortran::ASRUtils::STMT( |
| 155 | + ASR::make_Print_t(al, x.base.base.loc, nullptr, v3.p, v3.size(), |
| 156 | + x.m_separator, x.m_end)); |
| 157 | + |
| 158 | + Vec<ASR::stmt_t *> if_body; |
| 159 | + if_body.reserve(al, 1); |
| 160 | + if_body.push_back(al, print_comma_space); |
| 161 | + |
| 162 | + ASR::stmt_t *if_cond = ASRUtils::STMT( |
| 163 | + ASR::make_If_t(al, x.base.base.loc, compare_cond, if_body.p, |
| 164 | + if_body.size(), nullptr, 0)); |
| 165 | + |
| 166 | + ASR::do_loop_head_t loop_head; |
| 167 | + Vec<ASR::stmt_t *> loop_body; |
| 168 | + { |
| 169 | + loop_head.loc = x.base.base.loc; |
| 170 | + loop_head.m_v = list_iter_var; |
| 171 | + loop_head.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t( |
| 172 | + al, x.base.base.loc, 0, int_type)); |
| 173 | + loop_head.m_end = list_len_minus_one; |
| 174 | + loop_head.m_increment = |
| 175 | + ASRUtils::EXPR(ASR::make_IntegerConstant_t( |
| 176 | + al, x.base.base.loc, 1, int_type)); |
| 177 | + |
| 178 | + if (!ASR::is_a<ASR::List_t>(*listC->m_type)) { |
| 179 | + loop_body.reserve(al, 2); |
| 180 | + loop_body.push_back(al, print_item); |
| 181 | + } else { |
| 182 | + visit_Print(*ASR::down_cast<ASR::Print_t>(print_item)); |
| 183 | + loop_body.from_pointer_n_copy(al, pass_result.p, pass_result.size()); |
| 184 | + pass_result.n = 0; |
| 185 | + } |
| 186 | + loop_body.push_back(al, if_cond); |
| 187 | + } |
| 188 | + |
| 189 | + ASR::stmt_t *loop = ASRUtils::STMT(ASR::make_DoLoop_t( |
| 190 | + al, x.base.base.loc, loop_head, loop_body.p, loop_body.size())); |
| 191 | + |
| 192 | + { |
| 193 | + pass_result.push_back(al, print_open_bracket); |
| 194 | + pass_result.push_back(al, loop); |
| 195 | + pass_result.push_back(al, print_close_bracket); |
| 196 | + } |
| 197 | + } |
| 198 | + } |
| 199 | +}; |
| 200 | + |
| 201 | +void pass_replace_print_list( |
| 202 | + Allocator &al, ASR::TranslationUnit_t &unit, |
| 203 | + const LCompilers::PassOptions &pass_options) { |
| 204 | + std::string rl_path = pass_options.runtime_library_dir; |
| 205 | + PrintListVisitor v(al, rl_path); |
| 206 | + v.visit_TranslationUnit(unit); |
| 207 | + LFORTRAN_ASSERT(asr_verify(unit)); |
| 208 | +} |
| 209 | + |
| 210 | +} // namespace LFortran |
0 commit comments