Skip to content

Commit b775044

Browse files
committed
Add overload decorator in ltypes
1 parent 27ee768 commit b775044

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

src/runtime/ltypes/ltypes.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,53 @@
1+
from inspect import getfullargspec, getcallargs
2+
3+
# data-types
4+
15
i32 = []
26
i64 = []
37
f32 = []
48
f64 = []
59
c32 = []
610
c64 = []
11+
12+
# overloading support
13+
14+
class OverloadedFunction:
15+
"""
16+
A wrapper class for allowing overloading.
17+
"""
18+
global_map = {}
19+
20+
def __init__(self, func):
21+
self.func_name = func.__name__
22+
f_list = self.global_map.get(func.__name__, [])
23+
f_list.append((func, getfullargspec(func)))
24+
self.global_map[func.__name__] = f_list
25+
26+
def __call__(self, *args, **kwargs):
27+
func_map_list = self.global_map.get(self.func_name, False)
28+
if not func_map_list:
29+
raise Exception("Function not defined")
30+
for item in func_map_list:
31+
func, key = item
32+
try:
33+
# This might fail for the cases when arguments don't match
34+
ann_dict = getcallargs(func, *args, **kwargs)
35+
except TypeError:
36+
continue
37+
flag = True
38+
for k, v in ann_dict.items():
39+
if not key.annotations.get(k, False):
40+
flag = False
41+
break
42+
else:
43+
if type(v) != key.annotations.get(k):
44+
flag = False
45+
break
46+
if flag:
47+
return func(*args, **kwargs)
48+
raise Exception("Function not found with matching signature")
49+
50+
51+
def overload(f):
52+
overloaded_f = OverloadedFunction(f)
53+
return overloaded_f

0 commit comments

Comments
 (0)