Skip to content

Commit fb3bb9d

Browse files
committed
Don't wrap vecdot, isdtype, and vector_norm if they are already defined
1 parent 405c205 commit fb3bb9d

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,17 @@
6161
matmul = get_xp(cp)(_aliases.matmul)
6262
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6363
tensordot = get_xp(cp)(_aliases.tensordot)
64-
vecdot = get_xp(cp)(_aliases.vecdot)
65-
isdtype = get_xp(cp)(_aliases.isdtype)
64+
65+
# These functions are completely new here. If the library already has them
66+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67+
if hasattr(cp, 'vecdot'):
68+
vecdot = get_xp(cp)(_aliases.vecdot)
69+
else:
70+
vecdot = cp.vecdot
71+
if hasattr(cp, 'isdtype'):
72+
isdtype = cp.isdtype
73+
else:
74+
isdtype = get_xp(cp)(_aliases.isdtype)
6675

6776
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
6877
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/cupy/linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@
2929
pinv = get_xp(cp)(_linalg.pinv)
3030
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
3131
svdvals = get_xp(cp)(_linalg.svdvals)
32-
vector_norm = get_xp(cp)(_linalg.vector_norm)
3332
diagonal = get_xp(cp)(_linalg.diagonal)
3433
trace = get_xp(cp)(_linalg.trace)
3534

35+
# These functions are completely new here. If the library already has them
36+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
37+
if hasattr(cp.linalg, 'vector_norm'):
38+
vector_norm = cp.linalg.vector_norm
39+
else:
40+
vector_norm = get_xp(cp)(_linalg.vector_norm)
41+
3642
__all__ = linalg_all + _linalg.__all__
3743

3844
del get_xp

array_api_compat/numpy/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,17 @@
6161
matmul = get_xp(np)(_aliases.matmul)
6262
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6363
tensordot = get_xp(np)(_aliases.tensordot)
64-
vecdot = get_xp(np)(_aliases.vecdot)
65-
isdtype = get_xp(np)(_aliases.isdtype)
64+
65+
# These functions are completely new here. If the library already has them
66+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67+
if hasattr(np, 'vecdot'):
68+
vecdot = get_xp(np)(_aliases.vecdot)
69+
else:
70+
vecdot = np.vecdot
71+
if hasattr(np, 'isdtype'):
72+
isdtype = np.isdtype
73+
else:
74+
isdtype = get_xp(np)(_aliases.isdtype)
6675

6776
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
6877
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/numpy/linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@
2222
pinv = get_xp(np)(_linalg.pinv)
2323
matrix_norm = get_xp(np)(_linalg.matrix_norm)
2424
svdvals = get_xp(np)(_linalg.svdvals)
25-
vector_norm = get_xp(np)(_linalg.vector_norm)
2625
diagonal = get_xp(np)(_linalg.diagonal)
2726
trace = get_xp(np)(_linalg.trace)
2827

28+
# These functions are completely new here. If the library already has them
29+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
30+
if hasattr(np.linalg, 'vector_norm'):
31+
vector_norm = np.linalg.vector_norm
32+
else:
33+
vector_norm = get_xp(np)(_linalg.vector_norm)
34+
2935
__all__ = linalg_all + _linalg.__all__
3036

3137
del get_xp

0 commit comments

Comments
 (0)