8
8
from collections .abc import Sequence
9
9
from typing import TYPE_CHECKING , Any , NamedTuple , cast
10
10
11
- from ._helpers import _check_device , array_namespace
11
+ from ._helpers import _device_ctx , array_namespace
12
12
from ._helpers import device as _get_device
13
13
from ._helpers import is_cupy_namespace
14
14
from ._typing import Array , Device , DType , Namespace
@@ -33,8 +33,8 @@ def arange(
33
33
device : Device | None = None ,
34
34
** kwargs : object ,
35
35
) -> Array :
36
- _check_device (xp , device )
37
- return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
36
+ with _device_ctx (xp , device ):
37
+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
38
38
39
39
40
40
def empty (
@@ -45,8 +45,8 @@ def empty(
45
45
device : Device | None = None ,
46
46
** kwargs : object ,
47
47
) -> Array :
48
- _check_device (xp , device )
49
- return xp .empty (shape , dtype = dtype , ** kwargs )
48
+ with _device_ctx (xp , device ):
49
+ return xp .empty (shape , dtype = dtype , ** kwargs )
50
50
51
51
52
52
def empty_like (
@@ -58,8 +58,8 @@ def empty_like(
58
58
device : Device | None = None ,
59
59
** kwargs : object ,
60
60
) -> Array :
61
- _check_device (xp , device )
62
- return xp .empty_like (x , dtype = dtype , ** kwargs )
61
+ with _device_ctx (xp , device , like = x ):
62
+ return xp .empty_like (x , dtype = dtype , ** kwargs )
63
63
64
64
65
65
def eye (
@@ -73,8 +73,8 @@ def eye(
73
73
device : Device | None = None ,
74
74
** kwargs : object ,
75
75
) -> Array :
76
- _check_device (xp , device )
77
- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
76
+ with _device_ctx (xp , device ):
77
+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
78
78
79
79
80
80
def full (
@@ -86,8 +86,8 @@ def full(
86
86
device : Device | None = None ,
87
87
** kwargs : object ,
88
88
) -> Array :
89
- _check_device (xp , device )
90
- return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
89
+ with _device_ctx (xp , device ):
90
+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
91
91
92
92
93
93
def full_like (
@@ -100,8 +100,8 @@ def full_like(
100
100
device : Device | None = None ,
101
101
** kwargs : object ,
102
102
) -> Array :
103
- _check_device (xp , device )
104
- return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
103
+ with _device_ctx (xp , device , like = x ):
104
+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
105
105
106
106
107
107
def linspace (
@@ -116,8 +116,8 @@ def linspace(
116
116
endpoint : bool = True ,
117
117
** kwargs : object ,
118
118
) -> Array :
119
- _check_device (xp , device )
120
- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
119
+ with _device_ctx (xp , device ):
120
+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
121
121
122
122
123
123
def ones (
@@ -128,8 +128,8 @@ def ones(
128
128
device : Device | None = None ,
129
129
** kwargs : object ,
130
130
) -> Array :
131
- _check_device (xp , device )
132
- return xp .ones (shape , dtype = dtype , ** kwargs )
131
+ with _device_ctx (xp , device ):
132
+ return xp .ones (shape , dtype = dtype , ** kwargs )
133
133
134
134
135
135
def ones_like (
@@ -141,8 +141,8 @@ def ones_like(
141
141
device : Device | None = None ,
142
142
** kwargs : object ,
143
143
) -> Array :
144
- _check_device (xp , device )
145
- return xp .ones_like (x , dtype = dtype , ** kwargs )
144
+ with _device_ctx (xp , device , like = x ):
145
+ return xp .ones_like (x , dtype = dtype , ** kwargs )
146
146
147
147
148
148
def zeros (
@@ -153,8 +153,8 @@ def zeros(
153
153
device : Device | None = None ,
154
154
** kwargs : object ,
155
155
) -> Array :
156
- _check_device (xp , device )
157
- return xp .zeros (shape , dtype = dtype , ** kwargs )
156
+ with _device_ctx (xp , device ):
157
+ return xp .zeros (shape , dtype = dtype , ** kwargs )
158
158
159
159
160
160
def zeros_like (
@@ -166,8 +166,8 @@ def zeros_like(
166
166
device : Device | None = None ,
167
167
** kwargs : object ,
168
168
) -> Array :
169
- _check_device (xp , device )
170
- return xp .zeros_like (x , dtype = dtype , ** kwargs )
169
+ with _device_ctx (xp , device , like = x ):
170
+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
171
171
172
172
173
173
# np.unique() is split into four functions in the array API:
0 commit comments