@@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat):
23
23
if library == "ndonnx" and api_version in ("2021.12" , "2022.12" ):
24
24
pytest .skip ("Unsupported API version" )
25
25
26
- namespace = array_namespace (array , api_version = api_version , use_compat = use_compat )
26
+ with warnings .catch_warnings ():
27
+ warnings .simplefilter ('ignore' , UserWarning )
28
+ namespace = array_namespace (array , api_version = api_version , use_compat = use_compat )
27
29
28
30
if use_compat is False or use_compat is None and library not in wrapped_libraries :
29
31
if library == "jax.numpy" and use_compat is None :
@@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat):
45
47
46
48
if library == "numpy" :
47
49
# check that the same namespace is returned for NumPy scalars
48
- scalar_namespace = array_namespace (
49
- xp .float64 (0.0 ), api_version = api_version , use_compat = use_compat
50
- )
51
- assert scalar_namespace == namespace
50
+ with warnings .catch_warnings ():
51
+ warnings .simplefilter ('ignore' , UserWarning )
52
+
53
+ scalar_namespace = array_namespace (
54
+ xp .float64 (0.0 ), api_version = api_version , use_compat = use_compat
55
+ )
56
+ assert scalar_namespace == namespace
52
57
53
58
# Check that array_namespace works even if jax.experimental.array_api
54
59
# hasn't been imported yet (it monkeypatches __array_namespace__
@@ -97,7 +102,9 @@ def test_api_version_torch():
97
102
torch = import_ ("torch" )
98
103
x = torch .asarray ([1 , 2 ])
99
104
torch_ = import_ ("torch" , wrapper = True )
100
- assert array_namespace (x , api_version = "2023.12" ) == torch_
105
+ with warnings .catch_warnings ():
106
+ warnings .simplefilter ('ignore' , UserWarning )
107
+ assert array_namespace (x , api_version = "2023.12" ) == torch_
101
108
assert array_namespace (x , api_version = None ) == torch_
102
109
assert array_namespace (x ) == torch_
103
110
# Should issue a warning
0 commit comments