Skip to content

Commit fae10c3

Browse files
authored
Merge pull request #1126 from czgdp1807/unions01
Supporting Union in C and LLVM backend
2 parents 98f5c67 + 00c1643 commit fae10c3

File tree

12 files changed

+487
-68
lines changed

12 files changed

+487
-68
lines changed

integration_tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ RUN(NAME structs_05 LABELS llvm c)
252252
RUN(NAME enum_01 LABELS cpython llvm c)
253253
RUN(NAME enum_02 LABELS cpython llvm)
254254
RUN(NAME enum_03 LABELS cpython llvm c)
255+
RUN(NAME union_01 LABELS cpython llvm c)
256+
RUN(NAME union_02 LABELS llvm c)
257+
RUN(NAME union_03 LABELS cpython llvm c)
255258
RUN(NAME test_str_to_int LABELS cpython llvm)
256259
RUN(NAME test_platform LABELS cpython llvm)
257260
RUN(NAME test_vars_01 LABELS cpython llvm)

integration_tests/union_01.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from ltypes import Union, i32, i64, f64, f32, ccall, union
2+
3+
@ccall
4+
@union
5+
class u_type(Union):
6+
integer32: i32
7+
real32: f32
8+
real64: f64
9+
integer64: i64
10+
11+
def test_union():
12+
unionobj: u_type = u_type()
13+
unionobj.integer32 = 1
14+
print(unionobj.integer32)
15+
assert unionobj.integer32 == 1
16+
17+
unionobj.real32 = 2.0
18+
print(unionobj.real32)
19+
assert abs(unionobj.real32 - 2.0) <= 1e-6
20+
21+
unionobj.real64 = 3.5
22+
print(unionobj.real64)
23+
assert abs(unionobj.real64 - 3.5) <= 1e-12
24+
25+
unionobj.integer64 = 4
26+
print(unionobj.integer64)
27+
assert unionobj.integer64 == 4
28+
29+
test_union()

integration_tests/union_02.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from ltypes import i32, f64, i64, dataclass, ccall, union, Union
2+
3+
@dataclass
4+
class A:
5+
ax: i32
6+
ay: f64
7+
8+
@dataclass
9+
class B:
10+
bx: i64
11+
by: f64
12+
13+
@dataclass
14+
class C:
15+
cx: i64
16+
cy: f64
17+
cz: f64
18+
19+
@ccall
20+
@union
21+
class D(Union):
22+
a: A
23+
b: B
24+
c: C
25+
26+
def test_struct_union():
27+
d: D = D()
28+
29+
aobj: A = A(0, 1.0)
30+
bobj: B = B(int(2), 7.0)
31+
cobj: C = C(int(5), 13.0, 8.0)
32+
33+
d.a = aobj
34+
print(d.a.ax, d.a.ay)
35+
assert d.a.ax == 0
36+
assert abs(d.a.ay - 1.0) <= 1e-12
37+
38+
d.b = bobj
39+
print(d.b.bx, d.b.by)
40+
assert d.b.bx == int(2)
41+
assert abs(d.b.by - 7.0) <= 1e-12
42+
43+
d.c = cobj
44+
print(d.c.cx, d.c.cy, d.c.cz)
45+
assert d.c.cx == 5
46+
assert abs(d.c.cy - 13.0) <= 1e-12
47+
assert abs(d.c.cz - 8.0) <= 1e-12
48+
49+
test_struct_union()

integration_tests/union_03.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from ltypes import Union, i32, i64, f64, f32, union
2+
3+
# without ccall but same as union_01.py
4+
@union
5+
class u_type(Union):
6+
integer32: i32
7+
real32: f32
8+
real64: f64
9+
integer64: i64
10+
11+
def test_union():
12+
unionobj: u_type = u_type()
13+
unionobj.integer32 = 1
14+
print(unionobj.integer32)
15+
assert unionobj.integer32 == 1
16+
17+
unionobj.real32 = 2.0
18+
print(unionobj.real32)
19+
assert abs(unionobj.real32 - 2.0) <= 1e-6
20+
21+
unionobj.real64 = 3.5
22+
print(unionobj.real64)
23+
assert abs(unionobj.real64 - 3.5) <= 1e-12
24+
25+
unionobj.integer64 = 4
26+
print(unionobj.integer64)
27+
assert unionobj.integer64 == 4
28+
29+
test_union()

src/libasr/ASR.asdl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ symbol
9999
abi abi, access access, symbol? parent)
100100
| EnumType(symbol_table symtab, identifier name, identifier* members,
101101
abi abi, access access, ttype type, symbol? parent)
102+
| UnionType(symbol_table symtab, identifier name, identifier* members,
103+
abi abi, access access, symbol? parent)
102104
| Variable(symbol_table parent_symtab, identifier name, intent intent,
103105
expr? symbolic_value, expr? value, storage_type storage, ttype type,
104106
abi abi, access access, presence presence, bool value_attr)
@@ -221,6 +223,7 @@ expr
221223
ttype type, expr? value, expr? dt)
222224
| DerivedTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
223225
| EnumTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
226+
| UnionTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
224227
| ImpliedDoLoop(expr* values, expr var, expr start, expr end,
225228
expr? increment, ttype type, expr? value)
226229
| IntegerConstant(int n, ttype type)
@@ -281,6 +284,7 @@ expr
281284

282285
| BitCast(expr source, expr mold, expr? size, ttype type, expr? value)
283286
| DerivedRef(expr v, symbol m, ttype type, expr? value)
287+
| UnionRef(expr v, symbol m, ttype type, expr? value)
284288
| EnumName(symbol v, ttype type, expr? value)
285289
| EnumValue(symbol v, ttype type, expr? value)
286290
| OverloadedCompare(expr left, cmpop op, expr right, ttype type, expr? value, expr overloaded)
@@ -331,6 +335,7 @@ ttype
331335
| Tuple(ttype* type)
332336
| Derived(symbol derived_type, dimension* dims)
333337
| Enum(symbol enum_type, dimension *dims)
338+
| Union(symbol union_type, dimension *dims)
334339
| Class(symbol class_type, dimension* dims)
335340
| Dict(ttype key_type, ttype value_type)
336341
| Pointer(ttype type)

src/libasr/asr_utils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ static inline std::string type_to_str(const ASR::ttype_t *t)
138138
case ASR::ttypeType::Derived: {
139139
return "derived type";
140140
}
141+
case ASR::ttypeType::Union: {
142+
return "union";
143+
}
141144
case ASR::ttypeType::CPtr: {
142145
return "type(c_ptr)";
143146
}
@@ -211,6 +214,9 @@ static inline char *symbol_name(const ASR::symbol_t *f)
211214
case ASR::symbolType::EnumType: {
212215
return ASR::down_cast<ASR::EnumType_t>(f)->m_name;
213216
}
217+
case ASR::symbolType::UnionType: {
218+
return ASR::down_cast<ASR::UnionType_t>(f)->m_name;
219+
}
214220
case ASR::symbolType::Variable: {
215221
return ASR::down_cast<ASR::Variable_t>(f)->m_name;
216222
}
@@ -254,6 +260,9 @@ static inline SymbolTable *symbol_parent_symtab(const ASR::symbol_t *f)
254260
case ASR::symbolType::EnumType: {
255261
return ASR::down_cast<ASR::EnumType_t>(f)->m_symtab->parent;
256262
}
263+
case ASR::symbolType::UnionType: {
264+
return ASR::down_cast<ASR::UnionType_t>(f)->m_symtab->parent;
265+
}
257266
case ASR::symbolType::Variable: {
258267
return ASR::down_cast<ASR::Variable_t>(f)->m_parent_symtab;
259268
}
@@ -299,6 +308,9 @@ static inline SymbolTable *symbol_symtab(const ASR::symbol_t *f)
299308
case ASR::symbolType::EnumType: {
300309
return ASR::down_cast<ASR::EnumType_t>(f)->m_symtab;
301310
}
311+
case ASR::symbolType::UnionType: {
312+
return ASR::down_cast<ASR::UnionType_t>(f)->m_symtab;
313+
}
302314
case ASR::symbolType::Variable: {
303315
return nullptr;
304316
//throw LCompilersException("Variable does not have a symtab");
@@ -1079,6 +1091,12 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
10791091
m_dims = Enum_type->m_dims;
10801092
break;
10811093
}
1094+
case ASR::ttypeType::Union: {
1095+
ASR::Union_t* Union_type = ASR::down_cast<ASR::Union_t>(x);
1096+
n_dims = Union_type->n_dims;
1097+
m_dims = Union_type->m_dims;
1098+
break;
1099+
}
10821100
case ASR::ttypeType::Class: {
10831101
ASR::Class_t* Class_type = ASR::down_cast<ASR::Class_t>(x);
10841102
n_dims = Class_type->n_dims;
@@ -1522,6 +1540,10 @@ static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) {
15221540
ASR::EnumType_t* enum_type = ASR::down_cast<ASR::EnumType_t>(enum_asr->m_enum_type);
15231541
return enum_type->m_type;
15241542
}
1543+
case ASR::ttypeType::Pointer: {
1544+
ASR::Pointer_t* pointer_asr = ASR::down_cast<ASR::Pointer_t>(asr_type);
1545+
return pointer_asr->m_type;
1546+
}
15251547
default: {
15261548
return asr_type;
15271549
}

src/libasr/asr_verify.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
260260
}
261261

262262
template <typename T>
263-
void visit_DerivedTypeEnumType(const T &x) {
263+
void visit_DerivedTypeEnumTypeUnionType(const T &x) {
264264
SymbolTable *parent_symtab = current_symtab;
265265
current_symtab = x.m_symtab;
266266
require(x.m_symtab != nullptr,
@@ -281,11 +281,11 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
281281
}
282282

283283
void visit_DerivedType(const DerivedType_t& x) {
284-
visit_DerivedTypeEnumType(x);
284+
visit_DerivedTypeEnumTypeUnionType(x);
285285
}
286286

287287
void visit_EnumType(const EnumType_t& x) {
288-
visit_DerivedTypeEnumType(x);
288+
visit_DerivedTypeEnumTypeUnionType(x);
289289
require(x.m_type != nullptr,
290290
"The common type of Enum cannot be nullptr. " +
291291
std::string(x.m_name) + " doesn't seem to follow this rule.");
@@ -303,6 +303,10 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
303303
}
304304
}
305305

306+
void visit_UnionType(const UnionType_t& x) {
307+
visit_DerivedTypeEnumTypeUnionType(x);
308+
}
309+
306310
void visit_Variable(const Variable_t &x) {
307311
SymbolTable *symtab = x.m_parent_symtab;
308312
require(symtab != nullptr,

src/libasr/codegen/asr_to_c.cpp

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
145145

146146
std::string convert_variable_decl(const ASR::Variable_t &v,
147147
bool pre_initialise_derived_type=true,
148+
bool use_ptr_for_derived_type=true,
148149
bool use_static=true)
149150
{
150151
std::string sub;
@@ -176,7 +177,11 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
176177
ASR::Derived_t *t = ASR::down_cast<ASR::Derived_t>(t2);
177178
std::string der_type_name = ASRUtils::symbol_name(t->m_derived_type);
178179
std::string dims = convert_dims_c(t->n_dims, t->m_dims);
179-
sub = format_type_c(dims, "struct " + der_type_name + "*",
180+
std::string ptr_char = "*";
181+
if( !use_ptr_for_derived_type ) {
182+
ptr_char.clear();
183+
}
184+
sub = format_type_c(dims, "struct " + der_type_name + ptr_char,
180185
v.m_name, use_ref, dummy);
181186
} else {
182187
diag.codegen_error_label("Type number '"
@@ -268,7 +273,11 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
268273
sub += "=" + init;
269274
}
270275
sub += ";\n";
271-
sub += indent + format_type_c("", "struct " + der_type_name + "*", v.m_name, use_ref, dummy);
276+
std::string ptr_char = "*";
277+
if( !use_ptr_for_derived_type ) {
278+
ptr_char.clear();
279+
}
280+
sub += indent + format_type_c("", "struct " + der_type_name + ptr_char, v.m_name, use_ref, dummy);
272281
if( t->n_dims != 0 ) {
273282
sub += " = " + value_var_name;
274283
} else {
@@ -282,7 +291,34 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
282291
use_ref = false;
283292
dims = "";
284293
}
285-
sub = format_type_c(dims, "struct " + der_type_name + "*",
294+
std::string ptr_char = "*";
295+
if( !use_ptr_for_derived_type ) {
296+
ptr_char.clear();
297+
}
298+
sub = format_type_c(dims, "struct " + der_type_name + ptr_char,
299+
v.m_name, use_ref, dummy);
300+
}
301+
} else if (ASR::is_a<ASR::Union_t>(*v.m_type)) {
302+
std::string indent(indentation_level*indentation_spaces, ' ');
303+
ASR::Union_t *t = ASR::down_cast<ASR::Union_t>(v.m_type);
304+
std::string der_type_name = ASRUtils::symbol_name(t->m_union_type);
305+
if( is_array ) {
306+
dims = convert_dims_c(t->n_dims, t->m_dims, true);
307+
std::string encoded_type_name = "x" + der_type_name;
308+
std::string type_name = std::string("union ") + der_type_name;
309+
generate_array_decl(sub, std::string(v.m_name), type_name, dims,
310+
encoded_type_name, t->m_dims, t->n_dims,
311+
use_ref, dummy,
312+
v.m_intent != ASRUtils::intent_in &&
313+
v.m_intent != ASRUtils::intent_inout);
314+
} else {
315+
dims = convert_dims_c(t->n_dims, t->m_dims);
316+
if( v.m_intent == ASRUtils::intent_in ||
317+
v.m_intent == ASRUtils::intent_inout ) {
318+
use_ref = false;
319+
dims = "";
320+
}
321+
sub = format_type_c(dims, "union " + der_type_name,
286322
v.m_name, use_ref, dummy);
287323
}
288324
} else if (ASR::is_a<ASR::CPtr_t>(*v.m_type)) {
@@ -357,6 +393,8 @@ R"(
357393
array_types_decls += "struct " + item.first + ";\n\n";
358394
} else if (ASR::is_a<ASR::EnumType_t>(*item.second)) {
359395
array_types_decls += "enum " + item.first + ";\n\n";
396+
} else if (ASR::is_a<ASR::UnionType_t>(*item.second)) {
397+
array_types_decls += "union " + item.first + ";\n\n";
360398
}
361399
}
362400

@@ -369,7 +407,8 @@ R"(
369407

370408
for (auto &item : x.m_global_scope->get_scope()) {
371409
if (ASR::is_a<ASR::DerivedType_t>(*item.second) ||
372-
ASR::is_a<ASR::EnumType_t>(*item.second)) {
410+
ASR::is_a<ASR::EnumType_t>(*item.second) ||
411+
ASR::is_a<ASR::UnionType_t>(*item.second)) {
373412
visit_symbol(*item.second);
374413
array_types_decls += src;
375414
}
@@ -481,22 +520,34 @@ R"(
481520
indentation_level -= 2;
482521
}
483522

484-
void visit_DerivedType(const ASR::DerivedType_t& x) {
523+
template <typename T>
524+
void visit_AggregateTypeUtil(const T& x, std::string c_type_name) {
485525
std::string indent(indentation_level*indentation_spaces, ' ');
486526
indentation_level += 1;
487-
std::string open_struct = indent + "struct " + std::string(x.m_name) + " {\n";
527+
std::string open_struct = indent + c_type_name + " " + std::string(x.m_name) + " {\n";
488528
std::string body = "";
489529
indent.push_back(' ');
490530
for( size_t i = 0; i < x.n_members; i++ ) {
491531
ASR::symbol_t* member = x.m_symtab->get_symbol(x.m_members[i]);
492532
LFORTRAN_ASSERT(ASR::is_a<ASR::Variable_t>(*member));
493-
body += indent + convert_variable_decl(*ASR::down_cast<ASR::Variable_t>(member), false) + ";\n";
533+
body += indent + convert_variable_decl(
534+
*ASR::down_cast<ASR::Variable_t>(member),
535+
false,
536+
(c_type_name != "union")) + ";\n";
494537
}
495538
indentation_level -= 1;
496539
std::string end_struct = "};\n\n";
497540
array_types_decls += open_struct + body + end_struct;
498541
}
499542

543+
void visit_DerivedType(const ASR::DerivedType_t& x) {
544+
visit_AggregateTypeUtil(x, "struct");
545+
}
546+
547+
void visit_UnionType(const ASR::UnionType_t& x) {
548+
visit_AggregateTypeUtil(x, "union");
549+
}
550+
500551
void visit_EnumType(const ASR::EnumType_t& x) {
501552
if( !ASR::is_a<ASR::Integer_t>(*x.m_type) ) {
502553
throw CodeGenError("C backend only supports integer valued Enum. " +
@@ -568,6 +619,10 @@ R"(
568619
src = "(enum " + std::string(enum_type->m_name) + ") (" + src + ")";
569620
}
570621

622+
void visit_UnionTypeConstructor(const ASR::UnionTypeConstructor_t& /*x*/) {
623+
624+
}
625+
571626
void visit_EnumValue(const ASR::EnumValue_t& x) {
572627
ASR::Variable_t* enum_var = ASR::down_cast<ASR::Variable_t>(x.m_v);
573628
src = std::string(enum_var->m_name);

0 commit comments

Comments
 (0)