Skip to content

Commit c5ab722

Browse files
committed
Merge branch 'pickrandom-allow-any-array)' of https://github.com/KonradLinkowski/mathjs into develop
2 parents 7575156 + a5cbb6a commit c5ab722

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

src/function/probability/pickRandom.js

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { factory } from '../../utils/factory'
22
import { isNumber } from '../../utils/is'
3-
import { arraySize } from '../../utils/array'
43
import { createRng } from './util/seededRNG'
4+
import { flatten } from '../../utils/array'
55

66
const name = 'pickRandom'
77
const dependencies = ['typed', 'config', '?on']
@@ -76,15 +76,11 @@ export const createPickRandom = /* #__PURE__ */ factory(name, dependencies, ({ t
7676
number = 1
7777
}
7878

79-
possibles = possibles.valueOf() // get Array
79+
possibles = flatten(possibles.valueOf()).valueOf() // get Array
8080
if (weights) {
8181
weights = weights.valueOf() // get Array
8282
}
8383

84-
if (arraySize(possibles).length > 1) {
85-
throw new Error('Only one dimensional vectors supported')
86-
}
87-
8884
let totalWeights = 0
8985

9086
if (typeof weights !== 'undefined') {

test/unit-tests/function/probability/pickRandom.test.js

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import assert from 'assert'
22
import { filter, times } from 'lodash'
33
import math from '../../../../src/bundleAny'
4+
import { flatten } from '../../../../src/utils/array'
45

56
const math2 = math.create({ randomSeed: 'test2' })
67
const pickRandom = math2.pickRandom
@@ -10,12 +11,6 @@ describe('pickRandom', function () {
1011
assert.strictEqual(typeof math.pickRandom, 'function')
1112
})
1213

13-
it('should throw an error when providing a multi dimensional matrix', function () {
14-
assert.throws(function () {
15-
pickRandom(math.matrix([[1, 2], [3, 4]]))
16-
}, /Only one dimensional vectors supported/)
17-
})
18-
1914
it('should throw an error if the length of the weights does not match the length of the possibles', function () {
2015
const possibles = [11, 22, 33, 44, 55]
2116
const weights = [1, 5, 2, 4]
@@ -68,19 +63,19 @@ describe('pickRandom', function () {
6863
const weights = [1, 5, 2, 4, 6]
6964
const number = 5
7065

71-
assert.strictEqual(pickRandom(possibles, number), possibles)
72-
assert.strictEqual(pickRandom(possibles, number, weights), possibles)
73-
assert.strictEqual(pickRandom(possibles, weights, number), possibles)
66+
pickRandom(possibles, number).forEach((element, index) => assert.strictEqual(element, possibles[index]))
67+
pickRandom(possibles, number, weights).forEach((element, index) => assert.strictEqual(element, possibles[index]))
68+
pickRandom(possibles, weights, number).forEach((element, index) => assert.strictEqual(element, possibles[index]))
7469
})
7570

7671
it('should return the given array if the given number is greater than its length', function () {
7772
const possibles = [11, 22, 33, 44, 55]
7873
const weights = [1, 5, 2, 4, 6]
7974
const number = 6
8075

81-
assert.strictEqual(pickRandom(possibles, number), possibles)
82-
assert.strictEqual(pickRandom(possibles, number, weights), possibles)
83-
assert.strictEqual(pickRandom(possibles, weights, number), possibles)
76+
pickRandom(possibles, number).forEach((element, index) => assert.strictEqual(element, possibles[index]))
77+
pickRandom(possibles, number, weights).forEach((element, index) => assert.strictEqual(element, possibles[index]))
78+
pickRandom(possibles, weights, number).forEach((element, index) => assert.strictEqual(element, possibles[index]))
8479
})
8580

8681
it('should return an empty array if the given number is 0', function () {
@@ -117,6 +112,30 @@ describe('pickRandom', function () {
117112
assert.strictEqual(pickRandom(possibles, weights, number).length, number)
118113
})
119114

115+
it('should pick a number from the given multi dimensional array following an uniform distribution', function () {
116+
const possibles = [[11, 12], [22, 23], [33, 34], [44, 45], [55, 56]]
117+
const picked = []
118+
119+
times(1000, () => picked.push(pickRandom(possibles)))
120+
121+
flatten(possibles).forEach(possible => {
122+
const count = filter(flatten(picked), val => val === possible).length
123+
assert.strictEqual(math.round(count / picked.length, 1), 0.1)
124+
})
125+
})
126+
127+
it('should pick a value from the given multi dimensional array following an uniform distribution', function () {
128+
// just to be sure that works for any kind of array
129+
const possibles = [[[11], [12]], ['test', 45], 'another test', 10, false, [1.3, 4.5, true]]
130+
const picked = []
131+
132+
times(1000, () => picked.push(pickRandom(possibles)))
133+
flatten(possibles).forEach(possible => {
134+
const count = filter(picked, val => val === possible).length
135+
assert.strictEqual(math.round(count / picked.length, 1), 0.1)
136+
})
137+
})
138+
120139
it('should pick a value from the given array following an uniform distribution if only possibles are passed', function () {
121140
const possibles = [11, 22, 33, 44, 55]
122141
const picked = []

0 commit comments

Comments
 (0)