Skip to content

Commit d05a6fb

Browse files
One of the algorithms for parallel matrix multiplication (#241)
1 parent 797ec26 commit d05a6fb

File tree

3 files changed

+98
-4
lines changed

3 files changed

+98
-4
lines changed

pydatastructs/linear_data_structures/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .algorithms import (
2424
merge_sort_parallel,
2525
brick_sort,
26-
brick_sort_parallel
26+
brick_sort_parallel,
27+
matrix_multiply_parallel
2728
)
2829
__all__.extend(algorithms.__all__)

pydatastructs/linear_data_structures/algorithms.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
__all__ = [
88
'merge_sort_parallel',
99
'brick_sort',
10-
'brick_sort_parallel'
10+
'brick_sort_parallel',
11+
'matrix_multiply_parallel'
1112
]
1213

1314
def _merge(array, sl, el, sr, er, end, comp):
@@ -233,3 +234,70 @@ def brick_sort_parallel(array, num_threads, **kwargs):
233234

234235
if _check_type(array, DynamicArray):
235236
array._modify(force=True)
237+
238+
def _matrix_multiply_helper(m1, m2, row, col):
239+
s = 0
240+
for i in range(len(m1)):
241+
s += m1[row][i] * m2[i][col]
242+
return s
243+
244+
def matrix_multiply_parallel(matrix_1, matrix_2, num_threads):
245+
"""
246+
Implements concurrent Matrix multiplication
247+
248+
Parameters
249+
==========
250+
251+
matrix_1: Any matrix representation
252+
Left matrix
253+
254+
matrix_2: Any matrix representation
255+
Right matrix
256+
257+
num_threads: int
258+
The maximum number of threads
259+
to be used for multiplication.
260+
261+
Raises
262+
======
263+
264+
ValueError
265+
When the columns in matrix_1 are not equal to the rows in matrix_2
266+
267+
Returns
268+
=======
269+
270+
C: list
271+
The result of matrix multiplication.
272+
273+
Examples
274+
========
275+
276+
>>> from pydatastructs import matrix_multiply_parallel
277+
>>> I = [[1, 1, 0], [0, 1, 0], [0, 0, 1]]
278+
>>> J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
279+
>>> matrix_multiply_parallel(I, J, num_threads=5)
280+
[[3, 3, 3], [1, 2, 1], [2, 2, 2]]
281+
282+
References
283+
==========
284+
.. [1] https://www3.nd.edu/~zxu2/acms60212-40212/Lec-07-3.pdf
285+
"""
286+
row_matrix_1, col_matrix_1 = len(matrix_1), len(matrix_1[0])
287+
row_matrix_2, col_matrix_2 = len(matrix_2), len(matrix_2[0])
288+
289+
if col_matrix_1 != row_matrix_2:
290+
raise ValueError("Matrix size mismatch: %s * %s"%(
291+
(row_matrix_1, col_matrix_1), (row_matrix_2, col_matrix_2)))
292+
293+
C = [[None for i in range(col_matrix_1)] for j in range(row_matrix_2)]
294+
295+
with ThreadPoolExecutor(max_workers=num_threads) as Executor:
296+
for i in range(row_matrix_1):
297+
for j in range(col_matrix_2):
298+
C[i][j] = Executor.submit(_matrix_multiply_helper,
299+
matrix_1,
300+
matrix_2,
301+
i, j).result()
302+
303+
return C

pydatastructs/linear_data_structures/tests/test_algorithms.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from pydatastructs import (
22
merge_sort_parallel, DynamicOneDimensionalArray,
3-
OneDimensionalArray, brick_sort, brick_sort_parallel)
4-
3+
OneDimensionalArray, brick_sort, brick_sort_parallel,
4+
matrix_multiply_parallel)
5+
from pydatastructs.utils.raises_util import raises
56
import random
67

78
def _test_common_sort(sort, *args, **kwargs):
@@ -48,3 +49,27 @@ def test_brick_sort():
4849

4950
def test_brick_sort_parallel():
5051
_test_common_sort(brick_sort_parallel, num_threads=3)
52+
53+
def test_matrix_multiply_parallel():
54+
ODA = OneDimensionalArray
55+
56+
expected_result = [[3, 3, 3], [1, 2, 1], [2, 2, 2]]
57+
58+
I = ODA(ODA, [ODA(int, [1, 1, 0]), ODA(int, [0, 1, 0]), ODA(int, [0, 0, 1])])
59+
J = ODA(ODA, [ODA(int, [2, 1, 2]), ODA(int, [1, 2, 1]), ODA(int, [2, 2, 2])])
60+
output = matrix_multiply_parallel(I, J, num_threads=5)
61+
assert expected_result == output
62+
63+
I = [[1, 1, 0], [0, 1, 0], [0, 0, 1]]
64+
J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
65+
output = matrix_multiply_parallel(I, J, num_threads=5)
66+
assert expected_result == output
67+
68+
I = [[1, 1, 0, 1], [0, 1, 0, 1], [0, 0, 1, 1]]
69+
J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
70+
assert raises(ValueError, lambda: matrix_multiply_parallel(I, J, num_threads=5))
71+
72+
I = [[1, 1, 0], [0, 1, 0], [0, 0, 1]]
73+
J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
74+
output = matrix_multiply_parallel(I, J, num_threads=1)
75+
assert expected_result == output

0 commit comments

Comments
 (0)