|
5 | 5 | import torch
|
6 | 6 | array = torch.Tensor
|
7 | 7 | from torch import dtype as Dtype
|
8 |
| - from typing import Optional |
| 8 | + from typing import Optional, Union, Tuple, Literal |
| 9 | + inf = float('inf') |
9 | 10 |
|
10 | 11 | from ._aliases import _fix_promotion, sum
|
11 | 12 |
|
@@ -66,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
|
66 | 67 | # Use our wrapped sum to make sure it does upcasting correctly
|
67 | 68 | return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
|
68 | 69 |
|
69 |
| -__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot', |
70 |
| - 'vecdot', 'solve'] |
| 70 | +def vector_norm( |
| 71 | + x: array, |
| 72 | + /, |
| 73 | + *, |
| 74 | + axis: Optional[Union[int, Tuple[int, ...]]] = None, |
| 75 | + keepdims: bool = False, |
| 76 | + ord: Union[int, float, Literal[inf, -inf]] = 2, |
| 77 | + **kwargs, |
| 78 | +) -> array: |
| 79 | + # torch.vector_norm incorrectly treats axis=() the same as axis=None |
| 80 | + if axis == (): |
| 81 | + keepdims = True |
| 82 | + return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) |
| 83 | + |
| 84 | +__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', |
| 85 | + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] |
71 | 86 |
|
72 | 87 | _all_ignore = ['torch_linalg', 'sum']
|
73 | 88 |
|
|
0 commit comments