Skip to content

Commit 5852947

Browse files
committed
Add in-place and reflected operators to the array object stubs
1 parent 9957408 commit 5852947

File tree

3 files changed

+183
-4
lines changed

3 files changed

+183
-4
lines changed

array_api_tests/function_stubs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
__all__ = []
1111

12-
from .array_object import __abs__, __add__, __and__, __bool__, __dlpack__, __dlpack_device__, __eq__, __float__, __floordiv__, __ge__, __getitem__, __gt__, __int__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, dtype, device, ndim, shape, size, T
12+
from .array_object import __abs__, __add__, __and__, __bool__, __dlpack__, __dlpack_device__, __eq__, __float__, __floordiv__, __ge__, __getitem__, __gt__, __int__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, __iadd__, __radd__, __iand__, __rand__, __ifloordiv__, __rfloordiv__, __ilshift__, __rlshift__, __imatmul__, __rmatmul__, __imod__, __rmod__, __imul__, __rmul__, __ior__, __ror__, __ipow__, __rpow__, __irshift__, __rrshift__, __isub__, __rsub__, __itruediv__, __rtruediv__, __ixor__, __rxor__, dtype, device, ndim, shape, size, T
1313

14-
__all__ += ['__abs__', '__add__', '__and__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
14+
__all__ += ['__abs__', '__add__', '__and__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
1515

1616
from .constants import e, inf, nan, pi
1717

array_api_tests/function_stubs/array_object.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,162 @@ def __xor__(x1: array, x2: array, /) -> array:
199199
"""
200200
pass
201201

202+
def __iadd__(x1: array, x2: array, /) -> array:
203+
"""
204+
Note: __iadd__ is a method of the array object.
205+
"""
206+
pass
207+
208+
def __radd__(x1: array, x2: array, /) -> array:
209+
"""
210+
Note: __radd__ is a method of the array object.
211+
"""
212+
pass
213+
214+
def __iand__(x1: array, x2: array, /) -> array:
215+
"""
216+
Note: __iand__ is a method of the array object.
217+
"""
218+
pass
219+
220+
def __rand__(x1: array, x2: array, /) -> array:
221+
"""
222+
Note: __rand__ is a method of the array object.
223+
"""
224+
pass
225+
226+
def __ifloordiv__(x1: array, x2: array, /) -> array:
227+
"""
228+
Note: __ifloordiv__ is a method of the array object.
229+
"""
230+
pass
231+
232+
def __rfloordiv__(x1: array, x2: array, /) -> array:
233+
"""
234+
Note: __rfloordiv__ is a method of the array object.
235+
"""
236+
pass
237+
238+
def __ilshift__(x1: array, x2: array, /) -> array:
239+
"""
240+
Note: __ilshift__ is a method of the array object.
241+
"""
242+
pass
243+
244+
def __rlshift__(x1: array, x2: array, /) -> array:
245+
"""
246+
Note: __rlshift__ is a method of the array object.
247+
"""
248+
pass
249+
250+
def __imatmul__(x1: array, x2: array, /) -> array:
251+
"""
252+
Note: __imatmul__ is a method of the array object.
253+
"""
254+
pass
255+
256+
def __rmatmul__(x1: array, x2: array, /) -> array:
257+
"""
258+
Note: __rmatmul__ is a method of the array object.
259+
"""
260+
pass
261+
262+
def __imod__(x1: array, x2: array, /) -> array:
263+
"""
264+
Note: __imod__ is a method of the array object.
265+
"""
266+
pass
267+
268+
def __rmod__(x1: array, x2: array, /) -> array:
269+
"""
270+
Note: __rmod__ is a method of the array object.
271+
"""
272+
pass
273+
274+
def __imul__(x1: array, x2: array, /) -> array:
275+
"""
276+
Note: __imul__ is a method of the array object.
277+
"""
278+
pass
279+
280+
def __rmul__(x1: array, x2: array, /) -> array:
281+
"""
282+
Note: __rmul__ is a method of the array object.
283+
"""
284+
pass
285+
286+
def __ior__(x1: array, x2: array, /) -> array:
287+
"""
288+
Note: __ior__ is a method of the array object.
289+
"""
290+
pass
291+
292+
def __ror__(x1: array, x2: array, /) -> array:
293+
"""
294+
Note: __ror__ is a method of the array object.
295+
"""
296+
pass
297+
298+
def __ipow__(x1: array, x2: array, /) -> array:
299+
"""
300+
Note: __ipow__ is a method of the array object.
301+
"""
302+
pass
303+
304+
def __rpow__(x1: array, x2: array, /) -> array:
305+
"""
306+
Note: __rpow__ is a method of the array object.
307+
"""
308+
pass
309+
310+
def __irshift__(x1: array, x2: array, /) -> array:
311+
"""
312+
Note: __irshift__ is a method of the array object.
313+
"""
314+
pass
315+
316+
def __rrshift__(x1: array, x2: array, /) -> array:
317+
"""
318+
Note: __rrshift__ is a method of the array object.
319+
"""
320+
pass
321+
322+
def __isub__(x1: array, x2: array, /) -> array:
323+
"""
324+
Note: __isub__ is a method of the array object.
325+
"""
326+
pass
327+
328+
def __rsub__(x1: array, x2: array, /) -> array:
329+
"""
330+
Note: __rsub__ is a method of the array object.
331+
"""
332+
pass
333+
334+
def __itruediv__(x1: array, x2: array, /) -> array:
335+
"""
336+
Note: __itruediv__ is a method of the array object.
337+
"""
338+
pass
339+
340+
def __rtruediv__(x1: array, x2: array, /) -> array:
341+
"""
342+
Note: __rtruediv__ is a method of the array object.
343+
"""
344+
pass
345+
346+
def __ixor__(x1: array, x2: array, /) -> array:
347+
"""
348+
Note: __ixor__ is a method of the array object.
349+
"""
350+
pass
351+
352+
def __rxor__(x1: array, x2: array, /) -> array:
353+
"""
354+
Note: __rxor__ is a method of the array object.
355+
"""
356+
pass
357+
202358
# Note: dtype is an attribute of the array object.
203359
dtype = None
204360

@@ -217,4 +373,4 @@ def __xor__(x1: array, x2: array, /) -> array:
217373
# Note: T is an attribute of the array object.
218374
T = None
219375

220-
__all__ = ['__abs__', '__add__', '__and__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
376+
__all__ = ['__abs__', '__add__', '__and__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']

generate_stubs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import sys
1515
import ast
16+
import itertools
1617
from collections import defaultdict
1718

1819
import regex
@@ -24,6 +25,9 @@
2425
METHOD_RE = regex.compile(r'\(method-.*\)=\n#+ ?(.*\(.*\))')
2526
CONSTANT_RE = regex.compile(r'\(constant-.*\)=\n#+ ?(.*)')
2627
ATTRIBUTE_RE = regex.compile(r'\(attribute-.*\)=\n#+ ?(.*)')
28+
IN_PLACE_OPERATOR_RE = regex.compile(r'- `.*`. May be implemented via `__i(.*)__`.')
29+
REFLECTED_OPERATOR_RE = regex.compile(r'- `__r(.*)__`')
30+
2731
NAME_RE = regex.compile(r'(.*)\(.*\)')
2832

2933
STUB_FILE_HEADER = '''\
@@ -136,10 +140,16 @@ def main():
136140

137141
annotations = parse_annotations(text, verbose=not args.quiet)
138142

143+
if filename == 'array_object.md':
144+
in_place_operators = IN_PLACE_OPERATOR_RE.findall(text)
145+
reflected_operators = REFLECTED_OPERATOR_RE.findall(text)
146+
if sorted(in_place_operators) != sorted(reflected_operators):
147+
raise RuntimeError(f"Unexpected in-place or reflected operator(s): {set(in_place_operators).symmetric_difference(set(reflected_operators))}")
148+
139149
sigs = {}
140150
code = ""
141151
code += STUB_FILE_HEADER.format(filename=filename, title=title)
142-
for sig in functions + methods:
152+
for sig in itertools.chain(functions, methods):
143153
ismethod = sig in methods
144154
sig = sig.replace(r'\_', '_')
145155
func_name = NAME_RE.match(sig).group(1)
@@ -162,6 +172,19 @@ def {annotated_sig}:{doc}
162172
"""
163173
modules[module_name].append(func_name)
164174
sigs[func_name] = sig
175+
176+
if (filename == 'array_object.md' and func_name.startswith('__')
177+
and (op := func_name[2:-2]) in in_place_operators):
178+
normal_op = func_name
179+
iop = f'__i{op}__'
180+
rop = f'__r{op}__'
181+
for func_name in [iop, rop]:
182+
methods.append(sigs[normal_op].replace(normal_op, func_name))
183+
annotation = annotations[normal_op].copy()
184+
for k, v in annotation.items():
185+
annotation[k] = v.replace(normal_op, func_name)
186+
annotations[func_name] = annotation
187+
165188
for const in constants + attributes:
166189
if not args.quiet:
167190
print(f"Writing stub for {const}")

0 commit comments

Comments
 (0)