Skip to content

Commit 909415c

Browse files
authored
Merge pull request #955 from Smit-create/list_ret
Support returning list from functions in LLVM backend
2 parents 61b307e + 71e062e commit 909415c

File tree

6 files changed

+149
-39
lines changed

6 files changed

+149
-39
lines changed

integration_tests/test_list_01.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from ltypes import f64, i32
22

3+
def fill_list_i32(size: i32) -> list[i32]:
4+
aarg: list[i32] = [0, 1, 2, 3, 4]
5+
i: i32
6+
for i in range(10):
7+
aarg.append(i + 5)
8+
return aarg
9+
10+
311
def test_list_01():
4-
a: list[i32] = [0, 1, 2, 3, 4]
12+
a: list[i32] = []
513
f: list[f64] = [1.0, 2.0, 3.0, 4.0, 5.0]
614
i: i32
715

16+
a = fill_list_i32(10)
17+
818
for i in range(10):
9-
a.append(i + 5)
1019
f.append(float(i + 6))
1120

1221
for i in range(15):

integration_tests/test_list_02.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from ltypes import i32
22

3-
def test_list_01():
3+
def fill_list_str(size: i32) -> list[str]:
44
a: list[str] = ["0_str", "1_str"]
5+
i: i32
6+
for i in range(size):
7+
a.append(str(i + 2) + "_str")
8+
return a
9+
10+
def test_list_01():
11+
a: list[str] = []
512
b: list[str]
613
string: str = "string_"
714
b = [string, string]
815
i: i32
916

17+
a = fill_list_str(10)
18+
1019
for i in range(10):
11-
a.append(str(i + 2) + "_str")
1220
b.append(string + str(i + 2))
1321

1422

integration_tests/test_list_03.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ def test_list_01(n: i32) -> i32:
1010
sum += a[i]
1111
return sum
1212

13-
def test_list_02(n: i32) -> i32:
14-
x: list[i32] = [50, 1]
15-
13+
def test_list_insert_02(x: list[i32], n: i32) -> list[i32]:
1614
i: i32
1715
imod: i32
1816
for i in range(n):
@@ -24,7 +22,15 @@ def test_list_02(n: i32) -> i32:
2422
elif imod == 2:
2523
x.insert(len(x)//2, i + n + 2)
2624

25+
return x
26+
27+
def test_list_02(n: i32) -> i32:
28+
x: list[i32] = [50, 1]
2729
acc: i32 = 0
30+
i: i32
31+
32+
x = test_list_insert_02(x, n)
33+
2834
for i in range(n):
2935
acc += x[i]
3036
return acc
@@ -49,6 +55,7 @@ def test_list_02_string():
4955
for i in range(50):
5056
assert x[i] == y[i]
5157

58+
5259
def verify():
5360
assert test_list_01(11) == 55
5461
assert test_list_02(50) == 3628

integration_tests/test_list_05.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,38 @@ def check_list_of_tuples(l: list[tuple[i32, f64, str]], sign: i32):
1818
assert t[2] == string
1919
assert l[i][2] == string
2020

21-
def test_list_of_tuples():
21+
def fill_list_of_tuples(size: i32) -> list[tuple[i32, f64, str]]:
2222
l1: list[tuple[i32, f64, str]] = []
2323
t: tuple[i32, f64, str]
24-
size: i32 = 20
2524
i: i32
26-
string: str
2725

2826
for i in range(size):
2927
t = (i, float(i), str(i) + "_str")
3028
l1.append(t)
3129

30+
return l1
31+
32+
def insert_tuples_into_list(l: list[tuple[i32, f64, str]], size: i32) -> list[tuple[i32, f64, str]]:
33+
i: i32
34+
string: str
35+
t: tuple[i32, f64, str]
36+
37+
for i in range(size//2, size):
38+
string = str(i) + "_str"
39+
t = (i, float(i), string)
40+
l.insert(i, t)
41+
42+
return l
43+
44+
def test_list_of_tuples():
45+
l1: list[tuple[i32, f64, str]] = []
46+
t: tuple[i32, f64, str]
47+
size: i32 = 20
48+
i: i32
49+
string: str
50+
51+
l1 = fill_list_of_tuples(size)
52+
3253
check_list_of_tuples(l1, 1)
3354

3455
for i in range(size//2):
@@ -39,10 +60,7 @@ def test_list_of_tuples():
3960
assert t[1] == size//2 - 1
4061
assert t[2] == str(size//2 - 1) + "_str"
4162

42-
for i in range(size//2, size):
43-
string = str(i) + "_str"
44-
t = (i, float(i), string)
45-
l1.insert(i, t)
63+
l1 = insert_tuples_into_list(l1, size)
4664

4765
check_list_of_tuples(l1, 1)
4866

integration_tests/test_list_07.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
11
from ltypes import c64, i32
22
from copy import deepcopy
33

4+
def generate_complex_tensors(mat: list[list[c64]], vec: list[c64]) -> list[tuple[list[list[c64]], list[c64]]]:
5+
tensor: tuple[list[list[c64]], list[c64]]
6+
tensors: list[tuple[list[list[c64]], list[c64]]] = []
7+
rows: i32 = len(mat)
8+
cols: i32 = len(vec)
9+
i: i32; j: i32; k: i32
10+
11+
tensor = (deepcopy(mat), deepcopy(vec))
12+
13+
for k in range(2 * rows):
14+
tensors.append(deepcopy(tensor))
15+
for i in range(rows):
16+
for j in range(cols):
17+
mat[i][j] += complex(1.0, 2.0)
18+
19+
for i in range(cols):
20+
vec[i] += complex(1.0, 2.0)
21+
22+
tensor = (deepcopy(mat), deepcopy(vec))
23+
24+
return tensors
25+
426
def test_tuple_with_lists():
527
mat: list[list[c64]] = []
628
vec: list[c64] = []
@@ -42,18 +64,7 @@ def test_tuple_with_lists():
4264
for i in range(cols):
4365
assert tensor[1][i] - vec[i] == -complex(0, 2.0)
4466

45-
tensor = (deepcopy(mat), deepcopy(vec))
46-
47-
for k in range(2 * rows):
48-
tensors.append(deepcopy(tensor))
49-
for i in range(rows):
50-
for j in range(cols):
51-
mat[i][j] += complex(1.0, 2.0)
52-
53-
for i in range(cols):
54-
vec[i] += complex(1.0, 2.0)
55-
56-
tensor = (deepcopy(mat), deepcopy(vec))
67+
tensors = generate_complex_tensors(mat, vec)
5768

5869
for k in range(2 * rows):
5970
for i in range(rows):

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,10 +1134,19 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
11341134

11351135
void visit_ListConstant(const ASR::ListConstant_t& x) {
11361136
ASR::List_t* list_type = ASR::down_cast<ASR::List_t>(x.m_type);
1137-
llvm::Type* llvm_el_type = get_el_type(list_type->m_type);
1137+
bool is_array_type_local = false, is_malloc_array_type_local = false;
1138+
bool is_list_local = false;
1139+
ASR::dimension_t* m_dims_local = nullptr;
1140+
int n_dims_local = -1, a_kind_local = -1;
1141+
llvm::Type* llvm_el_type = get_type_from_ttype_t(list_type->m_type,
1142+
ASR::storage_typeType::Default, is_array_type_local,
1143+
is_malloc_array_type_local, is_list_local, m_dims_local,
1144+
n_dims_local, a_kind_local);
11381145
std::string type_code = ASRUtils::get_type_code(list_type->m_type);
11391146
int32_t type_size = -1;
1140-
if( ASR::is_a<ASR::Character_t>(*list_type->m_type) ) {
1147+
if( ASR::is_a<ASR::Character_t>(*list_type->m_type) ||
1148+
LLVM::is_llvm_struct(list_type->m_type) ||
1149+
ASR::is_a<ASR::Complex_t>(*list_type->m_type) ) {
11411150
llvm::DataLayout data_layout(module.get());
11421151
type_size = data_layout.getTypeAllocSize(llvm_el_type);
11431152
} else {
@@ -2720,6 +2729,31 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
27202729
return_type = tuple_api->get_tuple_type(type_code, llvm_el_types);
27212730
break;
27222731
}
2732+
case (ASR::ttypeType::List) : {
2733+
bool is_array_type = false, is_malloc_array_type = false;
2734+
bool is_list = true;
2735+
ASR::dimension_t *m_dims = nullptr;
2736+
ASR::storage_typeType m_storage = ASR::storage_typeType::Default;
2737+
int n_dims = 0, a_kind = -1;
2738+
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(return_var_type0);
2739+
llvm::Type* el_llvm_type = get_type_from_ttype_t(asr_list->m_type, m_storage,
2740+
is_array_type,
2741+
is_malloc_array_type,
2742+
is_list, m_dims, n_dims,
2743+
a_kind);
2744+
int32_t type_size = -1;
2745+
if( LLVM::is_llvm_struct(asr_list->m_type) ||
2746+
ASR::is_a<ASR::Character_t>(*asr_list->m_type) ||
2747+
ASR::is_a<ASR::Complex_t>(*asr_list->m_type) ) {
2748+
llvm::DataLayout data_layout(module.get());
2749+
type_size = data_layout.getTypeAllocSize(el_llvm_type);
2750+
} else {
2751+
type_size = a_kind;
2752+
}
2753+
std::string el_type_code = ASRUtils::get_type_code(asr_list->m_type);
2754+
return_type = list_api->get_list_type(el_llvm_type, el_type_code, type_size);
2755+
break;
2756+
}
27232757
default :
27242758
LFORTRAN_ASSERT(false);
27252759
throw CodeGenError("Type not implemented");
@@ -3122,15 +3156,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
31223156
this->visit_expr(*x.m_target);
31233157
llvm::Value* target_tuple = tmp;
31243158
ptr_loads = ptr_loads_copy;
3125-
if( ASR::is_a<ASR::FunctionCall_t>(*x.m_value) ) {
3126-
builder->CreateStore(value_tuple, target_tuple);
3127-
} else {
3128-
ASR::Tuple_t* value_tuple_type = ASR::down_cast<ASR::Tuple_t>(asr_value_type);
3129-
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
3130-
value_tuple_type->n_type);
3131-
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
3132-
value_tuple_type, *module);
3133-
}
3159+
ASR::Tuple_t* value_tuple_type = ASR::down_cast<ASR::Tuple_t>(asr_value_type);
3160+
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
3161+
value_tuple_type->n_type);
3162+
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
3163+
value_tuple_type, *module);
31343164
}
31353165
return ;
31363166
} else if( is_target_dict && is_value_dict ) {
@@ -5333,6 +5363,33 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
53335363
tmp = builder->CreateOr(arg1, arg2);
53345364
}
53355365

5366+
llvm::Value* CreatePointerToStructReturnValue(llvm::FunctionType* fnty,
5367+
llvm::Value* return_value,
5368+
ASR::ttype_t* asr_return_type) {
5369+
if( !LLVM::is_llvm_struct(asr_return_type) ) {
5370+
return return_value;
5371+
}
5372+
5373+
// Call to LLVM APIs not needed to fetch the return type of the function.
5374+
// We can use asr_return_type as well but anyways for compactness I did it here.
5375+
llvm::Value* pointer_to_struct = builder->CreateAlloca(fnty->getReturnType(), nullptr);
5376+
LLVM::CreateStore(*builder, return_value, pointer_to_struct);
5377+
return pointer_to_struct;
5378+
}
5379+
5380+
llvm::Value* CreateCallUtil(llvm::FunctionType* fnty, llvm::Function* fn,
5381+
std::vector<llvm::Value*>& args,
5382+
ASR::ttype_t* asr_return_type) {
5383+
llvm::Value* return_value = builder->CreateCall(fn, args);
5384+
return CreatePointerToStructReturnValue(fnty, return_value,
5385+
asr_return_type);
5386+
}
5387+
5388+
llvm::Value* CreateCallUtil(llvm::Function* fn, std::vector<llvm::Value*>& args,
5389+
ASR::ttype_t* asr_return_type) {
5390+
return CreateCallUtil(fn->getFunctionType(), fn, args, asr_return_type);
5391+
}
5392+
53365393
void visit_FunctionCall(const ASR::FunctionCall_t &x) {
53375394
if( ASRUtils::is_intrinsic_optimization(x.m_name) ) {
53385395
ASR::Function_t* routine = ASR::down_cast<ASR::Function_t>(
@@ -5426,8 +5483,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
54265483
std::string m_name = std::string(((ASR::Function_t*)(&(x.m_name->base)))->m_name);
54275484
std::vector<llvm::Value *> args2 = convert_call_args(x, m_name);
54285485
args.insert(args.end(), args2.begin(), args2.end());
5486+
ASR::ttype_t *return_var_type0 = EXPR2VAR(s->m_return_var)->m_type;
54295487
if (s->m_abi == ASR::abiType::BindC) {
5430-
ASR::ttype_t *return_var_type0 = EXPR2VAR(s->m_return_var)->m_type;
54315488
if (is_a<ASR::Complex_t>(*return_var_type0)) {
54325489
int a_kind = down_cast<ASR::Complex_t>(return_var_type0)->m_kind;
54335490
if (a_kind == 8) {
@@ -5447,7 +5504,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
54475504
tmp = builder->CreateCall(fn, args);
54485505
}
54495506
} else {
5450-
tmp = builder->CreateCall(fn, args);
5507+
tmp = CreateCallUtil(fn, args, return_var_type0);
54515508
}
54525509
}
54535510
if (s->m_abi == ASR::abiType::BindC) {

0 commit comments

Comments
 (0)