Skip to content

Commit 2d25371

Browse files
Support passing dict as return value of functions (#1659)
1 parent fb0428c commit 2d25371

File tree

6 files changed

+83
-1
lines changed

6 files changed

+83
-1
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ RUN(NAME test_dict_02 LABELS cpython llvm c)
288288
RUN(NAME test_dict_03 LABELS cpython llvm)
289289
RUN(NAME test_dict_04 LABELS cpython llvm)
290290
RUN(NAME test_dict_05 LABELS cpython llvm)
291+
RUN(NAME test_dict_06 LABELS cpython llvm c)
292+
RUN(NAME test_dict_07 LABELS cpython llvm)
291293
RUN(NAME test_for_loop LABELS cpython llvm c)
292294
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
293295
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)

integration_tests/test_dict_06.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from lpython import f64, i32
2+
3+
def fill_rollnumber2cpi(size: i32) -> dict[i32, f64]:
4+
i : i32
5+
rollnumber2cpi: dict[i32, f64] = {}
6+
7+
rollnumber2cpi[0] = 1.1
8+
for i in range(1000, 1000 + size):
9+
rollnumber2cpi[i] = float(i/100.0 + 5.0)
10+
11+
return rollnumber2cpi
12+
13+
def test_dict():
14+
i: i32
15+
size: i32 = 200
16+
rollnumber2cpi: dict[i32, f64] = fill_rollnumber2cpi(size)
17+
18+
for i in range(1000 + size - 1, 1001, -1):
19+
assert abs(rollnumber2cpi[i] - i/100.0 - 5.0) <= 1e-12
20+
21+
assert abs(rollnumber2cpi[0] - 1.1) <= 1e-12
22+
assert len(rollnumber2cpi) == 201
23+
24+
test_dict()

integration_tests/test_dict_07.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
def fill_smalltocapital() -> dict[str, str]:
2+
return {'a': 'A', 'b': 'B', 'c': 'C', 'd': 'D','e': 'E',
3+
'f': 'F', 'g': 'G', 'h': 'H', 'i': 'I','j': 'J',
4+
'k': 'K', 'l': 'L', 'm': 'M', 'n': 'N','o': 'O',
5+
'p': 'P', 'q': 'Q', 'r': 'R', 's': 'S','t': 'T',
6+
'u': 'U', 'v': 'V', 'w': 'W', 'x': 'X','y': 'Y',
7+
'z': 'Z'}
8+
9+
def test_dict():
10+
i : i32
11+
smalltocaps: dict[str, str]
12+
smalltocaps = fill_smalltocapital()
13+
14+
assert len(smalltocaps) == 26
15+
for i in range(97, 97 + 26):
16+
assert smalltocaps[chr(i)] == chr(i - 32)
17+
18+
test_dict()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ R"(#include <stdio.h>
442442
} else if (ASR::is_a<ASR::TypeParameter_t>(*return_var->m_type)) {
443443
has_typevar = true;
444444
return "";
445+
} else if (ASR::is_a<ASR::Dict_t>(*return_var->m_type)) {
446+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(return_var->m_type);
447+
sub = c_ds_api->get_dict_type(dict_type) + " ";
445448
} else {
446449
throw CodeGenError("Return type not supported in function '" +
447450
std::string(x.m_name) +
@@ -661,6 +664,16 @@ R"(#include <stdio.h>
661664
const_name + " = " + src + ";\n");
662665
src = const_name;
663666
return;
667+
} else if( ASR::is_a<ASR::Dict_t>(*x.m_type) ) {
668+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(x.m_type);
669+
const_name += std::to_string(const_vars_count);
670+
const_vars_count += 1;
671+
const_name = current_scope->get_unique_name(const_name);
672+
std::string indent(indentation_level*indentation_spaces, ' ');
673+
tmp_buffer_src.push_back(check_tmp_buffer() + indent + c_ds_api->get_dict_type(dict_type) +
674+
" " + const_name + " = " + src + ";\n");
675+
src = const_name;
676+
return;
664677
}
665678
src = check_tmp_buffer() + src;
666679
}
@@ -797,6 +810,9 @@ R"(#include <stdio.h>
797810
}
798811
src = check_tmp_buffer() + src_tmp;
799812
return;
813+
} else if (ASR::is_a<ASR::DictItem_t>(*x.m_target)) {
814+
self().visit_DictItem(*ASR::down_cast<ASR::DictItem_t>(x.m_target));
815+
target = src;
800816
} else {
801817
LCOMPILERS_ASSERT(false)
802818
}

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4084,6 +4084,27 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
40844084
return_type = list_api->get_list_type(el_llvm_type, el_type_code, type_size);
40854085
break;
40864086
}
4087+
case (ASR::ttypeType::Dict) : {
4088+
ASR::Dict_t* asr_dict = ASR::down_cast<ASR::Dict_t>(return_var_type0);
4089+
std::string key_type_code = ASRUtils::get_type_code(asr_dict->m_key_type);
4090+
std::string value_type_code = ASRUtils::get_type_code(asr_dict->m_value_type);
4091+
4092+
bool is_local_array_type = false, is_local_malloc_array_type = false;
4093+
bool is_local_list = false;
4094+
ASR::dimension_t* local_m_dims = nullptr;
4095+
ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default;
4096+
int local_n_dims = 0, local_a_kind = -1;
4097+
4098+
llvm::Type* key_llvm_type = get_type_from_ttype_t(asr_dict->m_key_type, local_m_storage,is_local_array_type, is_local_malloc_array_type,is_local_list, local_m_dims, local_n_dims,local_a_kind);
4099+
llvm::Type* value_llvm_type = get_type_from_ttype_t(asr_dict->m_value_type, local_m_storage,is_local_array_type, is_local_malloc_array_type,is_local_list, local_m_dims, local_n_dims,local_a_kind);
4100+
int32_t key_type_size = get_type_size(asr_dict->m_key_type, key_llvm_type, local_a_kind);
4101+
int32_t value_type_size = get_type_size(asr_dict->m_value_type, value_llvm_type, local_a_kind);
4102+
4103+
set_dict_api(asr_dict);
4104+
4105+
return_type = llvm_utils->dict_api->get_dict_type(key_type_code, value_type_code, key_type_size,value_type_size, key_llvm_type, value_llvm_type);
4106+
break;
4107+
}
40874108
default :
40884109
throw CodeGenError("Type not implemented " + std::to_string(return_var_type));
40894110
}

src/libasr/codegen/llvm_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ namespace LCompilers {
9595
return ASR::is_a<ASR::Tuple_t>(*asr_type) ||
9696
ASR::is_a<ASR::List_t>(*asr_type) ||
9797
ASR::is_a<ASR::Struct_t>(*asr_type) ||
98-
ASR::is_a<ASR::Class_t>(*asr_type);
98+
ASR::is_a<ASR::Class_t>(*asr_type)||
99+
ASR::is_a<ASR::Dict_t>(*asr_type);
99100
}
100101
}
101102

0 commit comments

Comments
 (0)