Skip to content

Commit 0b63d4f

Browse files
authored
Adds nin and nout properties to unary and binary element-wise classes (#1478)
* Implements `nin` and `nout` for element-wise funcs `nin` and `nout` properties return the number of arguments to the function treated as inputs or outputs, respectively * Adds tests for `nin` and `nout` properties
1 parent 7e82c21 commit 0b63d4f

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,20 @@ def get_type_promotion_path_acceptance_function(self):
134134
"""
135135
return self.acceptance_fn_
136136

137+
@property
138+
def nin(self):
139+
"""
140+
Returns the number of arguments treated as inputs.
141+
"""
142+
return 1
143+
144+
@property
145+
def nout(self):
146+
"""
147+
Returns the number of arguments treated as outputs.
148+
"""
149+
return 1
150+
137151
@property
138152
def types(self):
139153
"""Returns information about types supported by
@@ -575,6 +589,20 @@ def get_type_promotion_path_acceptance_function(self):
575589
"""
576590
return self.acceptance_fn_
577591

592+
@property
593+
def nin(self):
594+
"""
595+
Returns the number of arguments treated as inputs.
596+
"""
597+
return 2
598+
599+
@property
600+
def nout(self):
601+
"""
602+
Returns the number of arguments treated as outputs.
603+
"""
604+
return 1
605+
578606
@property
579607
def types(self):
580608
"""Returns information about types supported by

dpctl/tests/elementwise/test_elementwise_classes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,27 @@ def test_binary_class_str_repr():
7878
kl_n = binary_fn.__name__
7979
assert kl_n in s
8080
assert kl_n in r
81+
82+
83+
def test_unary_class_nin():
84+
nin = unary_fn.nin
85+
assert isinstance(nin, int)
86+
assert nin == 1
87+
88+
89+
def test_binary_class_nin():
90+
nin = binary_fn.nin
91+
assert isinstance(nin, int)
92+
assert nin == 2
93+
94+
95+
def test_unary_class_nout():
96+
nout = unary_fn.nout
97+
assert isinstance(nout, int)
98+
assert nout == 1
99+
100+
101+
def test_binary_class_nout():
102+
nout = binary_fn.nout
103+
assert isinstance(nout, int)
104+
assert nout == 1

0 commit comments

Comments
 (0)