Skip to content

Support passing dict as return value of functions #1659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ RUN(NAME test_dict_02 LABELS cpython llvm c)
RUN(NAME test_dict_03 LABELS cpython llvm)
RUN(NAME test_dict_04 LABELS cpython llvm)
RUN(NAME test_dict_05 LABELS cpython llvm)
RUN(NAME test_dict_06 LABELS cpython llvm c)
RUN(NAME test_dict_07 LABELS cpython llvm)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
24 changes: 24 additions & 0 deletions integration_tests/test_dict_06.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from lpython import f64, i32

def fill_rollnumber2cpi(size: i32) -> dict[i32, f64]:
i : i32
rollnumber2cpi: dict[i32, f64] = {}

rollnumber2cpi[0] = 1.1
for i in range(1000, 1000 + size):
rollnumber2cpi[i] = float(i/100.0 + 5.0)

return rollnumber2cpi

def test_dict():
i: i32
size: i32 = 200
rollnumber2cpi: dict[i32, f64] = fill_rollnumber2cpi(size)

for i in range(1000 + size - 1, 1001, -1):
assert abs(rollnumber2cpi[i] - i/100.0 - 5.0) <= 1e-12

assert abs(rollnumber2cpi[0] - 1.1) <= 1e-12
assert len(rollnumber2cpi) == 201

test_dict()
18 changes: 18 additions & 0 deletions integration_tests/test_dict_07.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def fill_smalltocapital() -> dict[str, str]:
return {'a': 'A', 'b': 'B', 'c': 'C', 'd': 'D','e': 'E',
'f': 'F', 'g': 'G', 'h': 'H', 'i': 'I','j': 'J',
'k': 'K', 'l': 'L', 'm': 'M', 'n': 'N','o': 'O',
'p': 'P', 'q': 'Q', 'r': 'R', 's': 'S','t': 'T',
'u': 'U', 'v': 'V', 'w': 'W', 'x': 'X','y': 'Y',
'z': 'Z'}

def test_dict():
i : i32
smalltocaps: dict[str, str]
smalltocaps = fill_smalltocapital()

assert len(smalltocaps) == 26
for i in range(97, 97 + 26):
assert smalltocaps[chr(i)] == chr(i - 32)

test_dict()
16 changes: 16 additions & 0 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ R"(#include <stdio.h>
} else if (ASR::is_a<ASR::TypeParameter_t>(*return_var->m_type)) {
has_typevar = true;
return "";
} else if (ASR::is_a<ASR::Dict_t>(*return_var->m_type)) {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(return_var->m_type);
sub = c_ds_api->get_dict_type(dict_type) + " ";
} else {
throw CodeGenError("Return type not supported in function '" +
std::string(x.m_name) +
Expand Down Expand Up @@ -661,6 +664,16 @@ R"(#include <stdio.h>
const_name + " = " + src + ";\n");
src = const_name;
return;
} else if( ASR::is_a<ASR::Dict_t>(*x.m_type) ) {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(x.m_type);
const_name += std::to_string(const_vars_count);
const_vars_count += 1;
const_name = current_scope->get_unique_name(const_name);
std::string indent(indentation_level*indentation_spaces, ' ');
tmp_buffer_src.push_back(check_tmp_buffer() + indent + c_ds_api->get_dict_type(dict_type) +
" " + const_name + " = " + src + ";\n");
src = const_name;
return;
}
src = check_tmp_buffer() + src;
}
Expand Down Expand Up @@ -797,6 +810,9 @@ R"(#include <stdio.h>
}
src = check_tmp_buffer() + src_tmp;
return;
} else if (ASR::is_a<ASR::DictItem_t>(*x.m_target)) {
self().visit_DictItem(*ASR::down_cast<ASR::DictItem_t>(x.m_target));
target = src;
} else {
LCOMPILERS_ASSERT(false)
}
Expand Down
21 changes: 21 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4084,6 +4084,27 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
return_type = list_api->get_list_type(el_llvm_type, el_type_code, type_size);
break;
}
case (ASR::ttypeType::Dict) : {
ASR::Dict_t* asr_dict = ASR::down_cast<ASR::Dict_t>(return_var_type0);
std::string key_type_code = ASRUtils::get_type_code(asr_dict->m_key_type);
std::string value_type_code = ASRUtils::get_type_code(asr_dict->m_value_type);

bool is_local_array_type = false, is_local_malloc_array_type = false;
bool is_local_list = false;
ASR::dimension_t* local_m_dims = nullptr;
ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default;
int local_n_dims = 0, local_a_kind = -1;

llvm::Type* key_llvm_type = get_type_from_ttype_t(asr_dict->m_key_type, local_m_storage,is_local_array_type, is_local_malloc_array_type,is_local_list, local_m_dims, local_n_dims,local_a_kind);
llvm::Type* value_llvm_type = get_type_from_ttype_t(asr_dict->m_value_type, local_m_storage,is_local_array_type, is_local_malloc_array_type,is_local_list, local_m_dims, local_n_dims,local_a_kind);
int32_t key_type_size = get_type_size(asr_dict->m_key_type, key_llvm_type, local_a_kind);
int32_t value_type_size = get_type_size(asr_dict->m_value_type, value_llvm_type, local_a_kind);

set_dict_api(asr_dict);

return_type = llvm_utils->dict_api->get_dict_type(key_type_code, value_type_code, key_type_size,value_type_size, key_llvm_type, value_llvm_type);
break;
}
default :
throw CodeGenError("Type not implemented " + std::to_string(return_var_type));
}
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ namespace LCompilers {
return ASR::is_a<ASR::Tuple_t>(*asr_type) ||
ASR::is_a<ASR::List_t>(*asr_type) ||
ASR::is_a<ASR::Struct_t>(*asr_type) ||
ASR::is_a<ASR::Class_t>(*asr_type);
ASR::is_a<ASR::Class_t>(*asr_type)||
ASR::is_a<ASR::Dict_t>(*asr_type);
}
}

Expand Down