From 8b6dac9a231762716396be4e99d2886bdfa43ce5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 2 Feb 2024 20:53:21 -0700 Subject: [PATCH] Allow broadcasting in cross() --- array_api_strict/linalg.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index b4b21c0..78e9ec4 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -73,9 +73,6 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in cross') - # Note: this is different from np.cross(), which broadcasts - if x1.shape != x2.shape: - raise ValueError('x1 and x2 must have the same shape') if x1.ndim == 0: raise ValueError('cross() requires arrays of dimension at least 1') # Note: this is different from np.cross(), which allows dimension 2