Skip to content

Implementation of matmul, tensordot, and vecdot per array API #1490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
863bd50
Adds ThreeOffsets_CombinedIndexer
ndgrigorian Dec 22, 2023
7cb2a53
Remove unused `elementwise_functions.cpp`
ndgrigorian Jan 2, 2024
af41424
Implements `matmul`, `vecdot`, and `tensordot`
ndgrigorian Jan 2, 2024
39cf672
Tweaks to `matmul` and gemm kernels
ndgrigorian Jan 8, 2024
5b32a53
Remove double-counting of batch offset in gemm batch tree reduction
ndgrigorian Jan 8, 2024
e25b9a7
Fixes missing dependency in vecdot
ndgrigorian Jan 9, 2024
58bb4ab
Run test_matmul_simple2 in Windows before full test suite
ndgrigorian Jan 9, 2024
60f1d21
Test removing test_matmul_simple leaving only test_matmul_simple2
ndgrigorian Jan 9, 2024
ad53472
Fix incorrect comments throughtout gemm kernels
ndgrigorian Jan 9, 2024
144ac0f
Drastically reduced parameters used for gemm kernels which thread over k
ndgrigorian Jan 9, 2024
00cbc35
Test removal of k-threading gemm kernel which writes to multiple outp…
ndgrigorian Jan 10, 2024
2b84743
Refactors `gemm_tree_impl`
ndgrigorian Jan 10, 2024
01ff619
Reverse order of numeric types passed to test_matmul_simple2
ndgrigorian Jan 10, 2024
1b86161
Refactors `gemm_contig_tree_impl`
ndgrigorian Jan 11, 2024
459c0ef
Refactoring `gemm_batch_tree` functions
ndgrigorian Jan 11, 2024
eaa048a
Test reversing data types for `test_matmul_strided`
ndgrigorian Jan 11, 2024
6c57d2b
pre-commit fixes in `gemm.hpp`
ndgrigorian Jan 11, 2024
7875f38
Check if malloc_device return nullptr (#1493)
oleksandr-pavlyk Jan 11, 2024
f0079cf
Add step to Linux conda package workflow to run `test_matmul_strided`…
ndgrigorian Jan 11, 2024
e05b805
Remove unnecessary comments
ndgrigorian Jan 11, 2024
cb06ded
Adds a fast-path for empty (k = 0) gemm kernels
ndgrigorian Jan 11, 2024
a06bb2d
Adds logic that avoids certain kernels on CPU that are known to be pr…
ndgrigorian Jan 12, 2024
00ec8e6
Also access memory if indices are in range
oleksandr-pavlyk Jan 15, 2024
d97a9c2
Simplified computation of m_id/gr_id in kernels
oleksandr-pavlyk Jan 15, 2024
1dc2541
Change generic kernels to work for any value of m_groups, not just m_…
oleksandr-pavlyk Jan 15, 2024
d930b2e
Remove work-arounds/special-casing for CPUs
oleksandr-pavlyk Jan 15, 2024
7a277cb
Extended test_matmul_strided, reverted work-arounds
oleksandr-pavlyk Jan 15, 2024
686276a
Revert remaining gemm work-arounds
ndgrigorian Jan 16, 2024
303e7db
Revert tuning down of `gemm` kernel parameters
ndgrigorian Jan 16, 2024
9be7dca
Merge branch 'master' into feature/matmul-vecdot-tensordot
ndgrigorian Jan 16, 2024
0e14ba8
Removed logically dead code from _linear_algebra_functions.py
oleksandr-pavlyk Jan 16, 2024
05a71ee
Added more tests to improve coverage of _linear_algebra_functions
oleksandr-pavlyk Jan 16, 2024
7e428e0
Fixed "UnboundLocalError: local variable 'buf1_dt' referenced before …
oleksandr-pavlyk Jan 16, 2024
1e689d7
More tests to improve coverage
oleksandr-pavlyk Jan 16, 2024
35cc458
Removed more dead branches in _linear_algebra_functions.py
ndgrigorian Jan 16, 2024
d8659d4
`tensordot` now properly handles negative `axes`
ndgrigorian Jan 16, 2024
11b710c
Adds `test_tensordot_type_matrix` to `test_usm_ndarray_linalg.py`
ndgrigorian Jan 16, 2024
2e448dc
Addresses flaws in gemm tree kernel logic
ndgrigorian Jan 17, 2024
71ef294
Implements `__matmul__`, `__imatmul__`, and `__rmatmul__` operators f…
ndgrigorian Jan 17, 2024
4ccb6fd
Makes usm_ndarray operator argument names consistent
ndgrigorian Jan 17, 2024
877c762
Test changes for `tensordot`
ndgrigorian Jan 17, 2024
8ad7ca2
Reverts running certain `matmul` tests under gdb
ndgrigorian Jan 17, 2024
dc34e1d
Fix to typo in `test_tensordot_promotion`
ndgrigorian Jan 17, 2024
d03d16e
Removes unnecessary input type checks in `matmul`
ndgrigorian Jan 17, 2024
15fa952
More tests added to `test_usm_linalg.py`
ndgrigorian Jan 17, 2024
3ce9b59
Use result_type with tensors to take device capability into account
oleksandr-pavlyk Jan 17, 2024
03c36eb
Use order keyword in test of type promotion for matmul
oleksandr-pavlyk Jan 17, 2024
1eaadb6
Make generic k-threaded kernels handle arbitrary m_groups
ndgrigorian Jan 18, 2024
879b8bb
Adjusted dispatch logic for gemm kernels
ndgrigorian Jan 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ set(_tensor_sorting_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
${_sorting_sources}
)
set(_linalg_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
)
set(_tensor_linalg_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
${_linalg_sources}
)

set(_py_trgts)

Expand All @@ -179,6 +188,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources}
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources})
list(APPEND _py_trgts ${python_module_name})

set(python_module_name _tensor_linalg_impl)
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
list(APPEND _py_trgts ${python_module_name})

set(_clang_prefix "")
if (WIN32)
set(_clang_prefix "/clang:")
Expand All @@ -193,6 +207,7 @@ list(APPEND _no_fast_math_sources
${_elementwise_sources}
${_reduction_sources}
${_sorting_sources}
${_linalg_sources}
)

foreach(_src_fn ${_no_fast_math_sources})
Expand Down
10 changes: 9 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@
from dpctl.tensor._device import Device
from dpctl.tensor._dlpack import from_dlpack
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
from dpctl.tensor._linear_algebra_functions import matrix_transpose
from dpctl.tensor._linear_algebra_functions import (
matmul,
matrix_transpose,
tensordot,
vecdot,
)
from dpctl.tensor._manipulation_functions import (
broadcast_arrays,
broadcast_to,
Expand Down Expand Up @@ -356,4 +361,7 @@
"unique_counts",
"unique_inverse",
"unique_values",
"matmul",
"tensordot",
"vecdot",
]
Loading