Skip to content

Commit e4fac35

Browse files
committed
Create temporary variable for storing structs returned by value
1 parent 450a4a4 commit e4fac35

File tree

2 files changed

+59
-22
lines changed

2 files changed

+59
-22
lines changed

integration_tests/test_list_01.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from ltypes import f64, i32
22

3+
def fill_list_i32(size: i32) -> list[i32]:
4+
aarg: list[i32] = [0, 1, 2, 3, 4]
5+
i: i32
6+
for i in range(10):
7+
aarg.append(i + 5)
8+
return aarg
9+
10+
311
def test_list_01():
4-
a: list[i32] = [0, 1, 2, 3, 4]
12+
a: list[i32] = []
513
f: list[f64] = [1.0, 2.0, 3.0, 4.0, 5.0]
614
i: i32
715

16+
a = fill_list_i32(10)
17+
818
for i in range(10):
9-
a.append(i + 5)
1019
f.append(float(i + 6))
1120

1221
for i in range(15):

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2983,15 +2983,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
29832983
this->visit_expr(*x.m_value);
29842984
llvm::Value* value_list = tmp;
29852985
ptr_loads = ptr_loads_copy;
2986-
if (ASR::is_a<ASR::FunctionCall_t>(*x.m_value)) {
2987-
builder->CreateStore(value_list, target_list);
2988-
} else {
2989-
ASR::List_t* value_asr_list = ASR::down_cast<ASR::List_t>(
2990-
ASRUtils::expr_type(x.m_value));
2991-
std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type);
2992-
list_api->list_deepcopy(value_list, target_list,
2993-
value_asr_list, *module);
2994-
}
2986+
ASR::List_t* value_asr_list = ASR::down_cast<ASR::List_t>(
2987+
ASRUtils::expr_type(x.m_value));
2988+
std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type);
2989+
list_api->list_deepcopy(value_list, target_list,
2990+
value_asr_list, *module);
29952991
return ;
29962992
} else if( is_target_tuple && is_value_tuple ) {
29972993
uint64_t ptr_loads_copy = ptr_loads;
@@ -3023,15 +3019,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
30233019
this->visit_expr(*x.m_target);
30243020
llvm::Value* target_tuple = tmp;
30253021
ptr_loads = ptr_loads_copy;
3026-
if( ASR::is_a<ASR::FunctionCall_t>(*x.m_value) ) {
3027-
builder->CreateStore(value_tuple, target_tuple);
3028-
} else {
3029-
ASR::Tuple_t* value_tuple_type = ASR::down_cast<ASR::Tuple_t>(asr_value_type);
3030-
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
3031-
value_tuple_type->n_type);
3032-
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
3033-
value_tuple_type, *module);
3034-
}
3022+
ASR::Tuple_t* value_tuple_type = ASR::down_cast<ASR::Tuple_t>(asr_value_type);
3023+
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
3024+
value_tuple_type->n_type);
3025+
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
3026+
value_tuple_type, *module);
30353027
}
30363028
return ;
30373029
}
@@ -5222,6 +5214,41 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
52225214
tmp = builder->CreateOr(arg1, arg2);
52235215
}
52245216

5217+
llvm::Value* CreatePointerToStructReturnValue(llvm::FunctionType* fnty,
5218+
llvm::Value* return_value,
5219+
ASR::ttype_t* asr_return_type) {
5220+
if( !LLVM::is_llvm_struct(asr_return_type) ) {
5221+
return return_value;
5222+
}
5223+
5224+
// Call to LLVM APIs not needed to fetch the return type of the function.
5225+
// We can use asr_return_type as well but anyways for compactness I did it here.
5226+
llvm::Value* pointer_to_struct = builder->CreateAlloca(fnty->getReturnType(), nullptr);
5227+
LLVM::CreateStore(*builder, return_value, pointer_to_struct);
5228+
return pointer_to_struct;
5229+
}
5230+
5231+
llvm::Value* CreateCallUtil(llvm::FunctionType* fnty, llvm::Value* fn,
5232+
std::vector<llvm::Value*>& args,
5233+
ASR::ttype_t* asr_return_type) {
5234+
llvm::Value* return_value = builder->CreateCall(fnty, fn, args);
5235+
return CreatePointerToStructReturnValue(fnty, return_value,
5236+
asr_return_type);
5237+
}
5238+
5239+
llvm::Value* CreateCallUtil(llvm::FunctionType* fnty, llvm::Function* fn,
5240+
std::vector<llvm::Value*>& args,
5241+
ASR::ttype_t* asr_return_type) {
5242+
llvm::Value* return_value = builder->CreateCall(fn, args);
5243+
return CreatePointerToStructReturnValue(fnty, return_value,
5244+
asr_return_type);
5245+
}
5246+
5247+
llvm::Value* CreateCallUtil(llvm::Function* fn, std::vector<llvm::Value*>& args,
5248+
ASR::ttype_t* asr_return_type) {
5249+
return CreateCallUtil(fn->getFunctionType(), fn, args, asr_return_type);
5250+
}
5251+
52255252
void visit_FunctionCall(const ASR::FunctionCall_t &x) {
52265253
if( ASRUtils::is_intrinsic_optimization(x.m_name) ) {
52275254
ASR::Function_t* routine = ASR::down_cast<ASR::Function_t>(
@@ -5306,7 +5333,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
53065333
llvm::FunctionType* fntype = llvm_symtab_fn[h]->getFunctionType();
53075334
std::string m_name = std::string(((ASR::Function_t*)(&(x.m_name->base)))->m_name);
53085335
args = convert_call_args(x, m_name);
5309-
tmp = builder->CreateCall(fntype, fn, args);
5336+
tmp = CreateCallUtil(fntype, fn, args, x.m_type);
53105337
} else if (llvm_symtab_fn.find(h) == llvm_symtab_fn.end()) {
53115338
throw CodeGenError("Function code not generated for '"
53125339
+ std::string(s->m_name) + "'");
@@ -5315,8 +5342,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
53155342
std::string m_name = std::string(((ASR::Function_t*)(&(x.m_name->base)))->m_name);
53165343
std::vector<llvm::Value *> args2 = convert_call_args(x, m_name);
53175344
args.insert(args.end(), args2.begin(), args2.end());
5345+
ASR::ttype_t *return_var_type0 = EXPR2VAR(s->m_return_var)->m_type;
53185346
if (s->m_abi == ASR::abiType::BindC) {
5319-
ASR::ttype_t *return_var_type0 = EXPR2VAR(s->m_return_var)->m_type;
53205347
if (is_a<ASR::Complex_t>(*return_var_type0)) {
53215348
int a_kind = down_cast<ASR::Complex_t>(return_var_type0)->m_kind;
53225349
if (a_kind == 8) {
@@ -5337,6 +5364,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
53375364
}
53385365
} else {
53395366
tmp = builder->CreateCall(fn, args);
5367+
tmp = CreateCallUtil(fn, args, return_var_type0);
53405368
}
53415369
}
53425370
if (s->m_abi == ASR::abiType::BindC) {

0 commit comments

Comments
 (0)