@@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
60
60
61
61
def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
62
62
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
63
+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
64
+ # whenever
65
+ # 1. x1.ndim - 1 == x2.ndim
66
+ # 2. x1.shape[:-1] == x2.shape
67
+ #
68
+ # See linalg_solve_is_vector_rhs in
69
+ # aten/src/ATen/native/LinearAlgebraUtils.h and
70
+ # TORCH_META_FUNC(_linalg_solve_ex) in
71
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
72
+ #
73
+ # The easiest way to work around this is to prepend a size 1 dimension to
74
+ # x2, since x2 is already one dimension less than x1.
75
+ #
76
+ # See https://github.com/pytorch/pytorch/issues/52915
77
+ if x2 .ndim != 1 and x1 .ndim - 1 == x2 .ndim and x1 .shape [:- 1 ] == x2 .shape :
78
+ x2 = x2 [None ]
63
79
return torch .linalg .solve (x1 , x2 , ** kwargs )
64
80
65
81
# torch.trace doesn't support the offset argument and doesn't support stacking
@@ -78,7 +94,23 @@ def vector_norm(
78
94
) -> array :
79
95
# torch.vector_norm incorrectly treats axis=() the same as axis=None
80
96
if axis == ():
81
- keepdims = True
97
+ out = kwargs .get ('out' )
98
+ if out is None :
99
+ dtype = None
100
+ if x .dtype == torch .complex64 :
101
+ dtype = torch .float32
102
+ elif x .dtype == torch .complex128 :
103
+ dtype = torch .float64
104
+
105
+ out = torch .zeros_like (x , dtype = dtype )
106
+
107
+ # The norm of a single scalar works out to abs(x) in every case except
108
+ # for ord=0, which is x != 0.
109
+ if ord == 0 :
110
+ out [:] = (x != 0 )
111
+ else :
112
+ out [:] = torch .abs (x )
113
+ return out
82
114
return torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
83
115
84
116
__all__ = linalg_all + ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
0 commit comments