diff --git a/integration_tests/run_tests.py b/integration_tests/run_tests.py index 1bb1b1535d..02cf7e1dbc 100755 --- a/integration_tests/run_tests.py +++ b/integration_tests/run_tests.py @@ -18,6 +18,12 @@ "test_math1.py" ] +# At present we run these tests on cpython, later we should also move to lpython +test_cpython = [ + "test_generics_01.py" +] + + def main(): print("Compiling...") for pyfile in tests: @@ -30,7 +36,9 @@ def main(): if r != 0: print("Command '%s' failed." % cmd) sys.exit(1) + print("Running...") + python_path="src/runtime/ltypes" for pyfile in tests: basename = os.path.splitext(pyfile)[0] cmd = "integration_tests/%s" % (basename) @@ -39,7 +47,16 @@ def main(): if r != 0: print("Command '%s' failed." % cmd) sys.exit(1) - python_path="src/runtime/ltypes" + cmd = "PYTHONPATH=%s python integration_tests/%s" % (python_path, + pyfile) + print("+ " + cmd) + r = os.system(cmd) + if r != 0: + print("Command '%s' failed." % cmd) + sys.exit(1) + + print("Running cpython tests...") + for pyfile in test_cpython: cmd = "PYTHONPATH=%s python integration_tests/%s" % (python_path, pyfile) print("+ " + cmd) diff --git a/integration_tests/test_generics_01.py b/integration_tests/test_generics_01.py new file mode 100644 index 0000000000..47f18245b8 --- /dev/null +++ b/integration_tests/test_generics_01.py @@ -0,0 +1,30 @@ +from ltypes import overload + +@overload +def foo(a: int, b: int) -> int: + return a*b + +@overload +def foo(a: int) -> int: + return a**2 + +@overload +def foo(a: str) -> str: + return "lpython-" + a + +@overload +def test(a: int) -> int: + return a + 10 + +@overload +def test(a: bool) -> int: + if a: + return 10 + return -10 + + +assert foo(2) == 4 +assert foo(2, 10) == 20 +assert foo("hello") == "lpython-hello" +assert test(10) == 20 +assert test(False) == -test(True) and test(True) == 10 diff --git a/src/runtime/ltypes/ltypes.py b/src/runtime/ltypes/ltypes.py index 215b58b5ac..98b014dcbc 100644 --- a/src/runtime/ltypes/ltypes.py +++ b/src/runtime/ltypes/ltypes.py @@ -1,6 +1,53 @@ +from inspect import getfullargspec, getcallargs + +# data-types + i32 = [] i64 = [] f32 = [] f64 = [] c32 = [] c64 = [] + +# overloading support + +class OverloadedFunction: + """ + A wrapper class for allowing overloading. + """ + global_map = {} + + def __init__(self, func): + self.func_name = func.__name__ + f_list = self.global_map.get(func.__name__, []) + f_list.append((func, getfullargspec(func))) + self.global_map[func.__name__] = f_list + + def __call__(self, *args, **kwargs): + func_map_list = self.global_map.get(self.func_name, False) + if not func_map_list: + raise Exception("Function not defined") + for item in func_map_list: + func, key = item + try: + # This might fail for the cases when arguments don't match + ann_dict = getcallargs(func, *args, **kwargs) + except TypeError: + continue + flag = True + for k, v in ann_dict.items(): + if not key.annotations.get(k, False): + flag = False + break + else: + if type(v) != key.annotations.get(k): + flag = False + break + if flag: + return func(*args, **kwargs) + raise Exception("Function not found with matching signature") + + +def overload(f): + overloaded_f = OverloadedFunction(f) + return overloaded_f