Skip to content

Commit a2edfde

Browse files
committed
added support for broadcasting random.randint bounds
(recent numpy versions now support that)
1 parent da9fea0 commit a2edfde

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

larray/random.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,31 @@ def randint(low, high=None, axes=None, dtype='l', meta=None):
9595
a\b b0 b1 b2
9696
a0 0 3 1
9797
a1 4 0 1
98+
99+
With varying low and high (each depending on a different axis)
100+
101+
>>> low = la.sequence('a=a0,a1')
102+
>>> low
103+
a a0 a1
104+
0 1
105+
>>> high = la.sequence('b=b0..b2', initial=3)
106+
>>> high
107+
b b0 b1 b2
108+
3 4 5
109+
110+
In other words, we want to generate values between low and high (high included) for each cell. Let's
111+
note that low..high:
112+
113+
a\b b0 b1 b2
114+
a0 0..2 0..3 0..4
115+
a1 1..2 1..3 1..4
116+
117+
>>> la.random.randint(low, high) # doctest: +SKIP
118+
a\b b0 b1 b2
119+
a0 0 2 2
120+
a1 2 3 4
98121
"""
99-
# TODO: support broadcasting arguments when np.randint supports it (https://github.com/numpy/numpy/issues/6745)
100-
# to do that, uncommenting the following code should be enough:
101-
# return generic_random(np.random.randint, (low, high), axes, meta)
102-
axes = AxisCollection(axes)
103-
return Array(np.random.randint(low, high, axes.shape, dtype), axes, meta=meta)
122+
return generic_random(np.random.randint, (low, high), axes, meta)
104123

105124

106125
def normal(loc=0.0, scale=1.0, axes=None, meta=None):

0 commit comments

Comments
 (0)