@@ -407,7 +407,7 @@ def power(self, a, exponents):
407
407
"""
408
408
raise NotImplementedError ()
409
409
410
- def norm (self , a ):
410
+ def norm (self , a , axis = None ):
411
411
r"""
412
412
Computes the matrix frobenius norm.
413
413
@@ -627,7 +627,7 @@ def diag(self, a, k=0):
627
627
"""
628
628
raise NotImplementedError ()
629
629
630
- def unique (self , a ):
630
+ def unique (self , a , return_inverse = False ):
631
631
r"""
632
632
Finds unique elements of given tensor.
633
633
@@ -1087,8 +1087,8 @@ def sqrt(self, a):
1087
1087
def power (self , a , exponents ):
1088
1088
return np .power (a , exponents )
1089
1089
1090
- def norm (self , a ):
1091
- return np .sqrt ( np . sum ( np . square ( a )) )
1090
+ def norm (self , a , axis = None ):
1091
+ return np .linalg . norm ( a , axis = axis )
1092
1092
1093
1093
def any (self , a ):
1094
1094
return np .any (a )
@@ -1164,8 +1164,8 @@ def meshgrid(self, a, b):
1164
1164
def diag (self , a , k = 0 ):
1165
1165
return np .diag (a , k )
1166
1166
1167
- def unique (self , a ):
1168
- return np .unique (a )
1167
+ def unique (self , a , return_inverse = False ):
1168
+ return np .unique (a , return_inverse = return_inverse )
1169
1169
1170
1170
def logsumexp (self , a , axis = None ):
1171
1171
return special .logsumexp (a , axis = axis )
@@ -1461,8 +1461,8 @@ def sqrt(self, a):
1461
1461
def power (self , a , exponents ):
1462
1462
return jnp .power (a , exponents )
1463
1463
1464
- def norm (self , a ):
1465
- return jnp .sqrt ( jnp . sum ( jnp . square ( a )) )
1464
+ def norm (self , a , axis = None ):
1465
+ return jnp .linalg . norm ( a , axis = axis )
1466
1466
1467
1467
def any (self , a ):
1468
1468
return jnp .any (a )
@@ -1535,8 +1535,8 @@ def meshgrid(self, a, b):
1535
1535
def diag (self , a , k = 0 ):
1536
1536
return jnp .diag (a , k )
1537
1537
1538
- def unique (self , a ):
1539
- return jnp .unique (a )
1538
+ def unique (self , a , return_inverse = False ):
1539
+ return jnp .unique (a , return_inverse = return_inverse )
1540
1540
1541
1541
def logsumexp (self , a , axis = None ):
1542
1542
return jspecial .logsumexp (a , axis = axis )
@@ -1881,8 +1881,8 @@ def sqrt(self, a):
1881
1881
def power (self , a , exponents ):
1882
1882
return torch .pow (a , exponents )
1883
1883
1884
- def norm (self , a ):
1885
- return torch .sqrt ( torch . sum ( torch . square ( a )) )
1884
+ def norm (self , a , axis = None ):
1885
+ return torch .linalg . norm ( a , dim = axis )
1886
1886
1887
1887
def any (self , a ):
1888
1888
return torch .any (a )
@@ -1986,8 +1986,8 @@ def meshgrid(self, a, b):
1986
1986
def diag (self , a , k = 0 ):
1987
1987
return torch .diag (a , diagonal = k )
1988
1988
1989
- def unique (self , a ):
1990
- return torch .unique (a )
1989
+ def unique (self , a , return_inverse = False ):
1990
+ return torch .unique (a , return_inverse = return_inverse )
1991
1991
1992
1992
def logsumexp (self , a , axis = None ):
1993
1993
if axis is not None :
@@ -2306,8 +2306,8 @@ def power(self, a, exponents):
2306
2306
def dot (self , a , b ):
2307
2307
return cp .dot (a , b )
2308
2308
2309
- def norm (self , a ):
2310
- return cp .sqrt ( cp . sum ( cp . square ( a )) )
2309
+ def norm (self , a , axis = None ):
2310
+ return cp .linalg . norm ( a , axis = axis )
2311
2311
2312
2312
def any (self , a ):
2313
2313
return cp .any (a )
@@ -2383,8 +2383,8 @@ def meshgrid(self, a, b):
2383
2383
def diag (self , a , k = 0 ):
2384
2384
return cp .diag (a , k )
2385
2385
2386
- def unique (self , a ):
2387
- return cp .unique (a )
2386
+ def unique (self , a , return_inverse = False ):
2387
+ return cp .unique (a , return_inverse = return_inverse )
2388
2388
2389
2389
def logsumexp (self , a , axis = None ):
2390
2390
# Taken from
@@ -2717,8 +2717,8 @@ def sqrt(self, a):
2717
2717
def power (self , a , exponents ):
2718
2718
return tnp .power (a , exponents )
2719
2719
2720
- def norm (self , a ):
2721
- return tf .math .reduce_euclidean_norm (a )
2720
+ def norm (self , a , axis = None ):
2721
+ return tf .math .reduce_euclidean_norm (a , axis = axis )
2722
2722
2723
2723
def any (self , a ):
2724
2724
return tnp .any (a )
@@ -2790,8 +2790,10 @@ def meshgrid(self, a, b):
2790
2790
def diag (self , a , k = 0 ):
2791
2791
return tnp .diag (a , k )
2792
2792
2793
- def unique (self , a ):
2794
- return tf .sort (tf .unique (tf .reshape (a , [- 1 ]))[0 ])
2793
+ def unique (self , a , return_inverse = False ):
2794
+ y , idx = tf .unique (tf .reshape (a , [- 1 ]))
2795
+ sort_idx = tf .argsort (y )
2796
+ return y [sort_idx ] if not return_inverse else (y [sort_idx ], idx [sort_idx ])
2795
2797
2796
2798
def logsumexp (self , a , axis = None ):
2797
2799
return tf .math .reduce_logsumexp (a , axis = axis )
0 commit comments