Skip to content

Commit 44bf2af

Browse files
authored
Merge pull request #140 from asmeurer/torch-linalg-fixes
Some PyTorch fixes
2 parents 8b9e0c0 + 3eb826d commit 44bf2af

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

.github/workflows/array-api-tests-torch.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ jobs:
1010
# Proper linalg testing will require
1111
# https://github.com/data-apis/array-api-tests/pull/101
1212
pytest-extra-args: "--disable-extension linalg"
13+
extra-env-vars: |
14+
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64

.github/workflows/array-api-tests.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ on:
2727
skips-file-extra:
2828
required: false
2929
type: string
30-
30+
extra-env-vars:
31+
required: false
32+
type: string
33+
description: "Multiline string of environment variables to set for the test run."
3134

3235
env:
3336
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
@@ -54,6 +57,11 @@ jobs:
5457
uses: actions/setup-python@v5
5558
with:
5659
python-version: ${{ matrix.python-version }}
60+
- name: Set Extra Environment Variables
61+
# Set additional environment variables if provided
62+
if: inputs.extra-env-vars
63+
run: |
64+
echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV
5765
- name: Install dependencies
5866
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
5967
# to put this in the numpy 1.21 config file.

array_api_compat/torch/linalg.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
6060

6161
def solve(x1: array, x2: array, /, **kwargs) -> array:
6262
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
63+
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
64+
# whenever
65+
# 1. x1.ndim - 1 == x2.ndim
66+
# 2. x1.shape[:-1] == x2.shape
67+
#
68+
# See linalg_solve_is_vector_rhs in
69+
# aten/src/ATen/native/LinearAlgebraUtils.h and
70+
# TORCH_META_FUNC(_linalg_solve_ex) in
71+
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
72+
#
73+
# The easiest way to work around this is to prepend a size 1 dimension to
74+
# x2, since x2 is already one dimension less than x1.
75+
#
76+
# See https://github.com/pytorch/pytorch/issues/52915
77+
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
78+
x2 = x2[None]
6379
return torch.linalg.solve(x1, x2, **kwargs)
6480

6581
# torch.trace doesn't support the offset argument and doesn't support stacking
@@ -78,7 +94,23 @@ def vector_norm(
7894
) -> array:
7995
# torch.vector_norm incorrectly treats axis=() the same as axis=None
8096
if axis == ():
81-
keepdims = True
97+
out = kwargs.get('out')
98+
if out is None:
99+
dtype = None
100+
if x.dtype == torch.complex64:
101+
dtype = torch.float32
102+
elif x.dtype == torch.complex128:
103+
dtype = torch.float64
104+
105+
out = torch.zeros_like(x, dtype=dtype)
106+
107+
# The norm of a single scalar works out to abs(x) in every case except
108+
# for ord=0, which is x != 0.
109+
if ord == 0:
110+
out[:] = (x != 0)
111+
else:
112+
out[:] = torch.abs(x)
113+
return out
82114
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
83115

84116
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',

0 commit comments

Comments
 (0)