Skip to content

Commit af41424

Browse files
committed
Implements matmul, vecdot, and tensordot
These three functions are implemented through a common `py_dot` binding, which is also part of a new tensor submodule `_tensor_linalg_impl`
1 parent 7cb2a53 commit af41424

File tree

11 files changed

+13681
-2
lines changed

11 files changed

+13681
-2
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ set(_tensor_reductions_impl_sources
148148
${_boolean_reduction_sources}
149149
${_reduction_sources}
150150
)
151+
set(_linalg_sources
152+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
153+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
154+
)
155+
set(_tensor_linalg_impl_sources
156+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp
157+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
158+
${_linalg_sources}
159+
)
151160

152161
set(_py_trgts)
153162

@@ -166,6 +175,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sourc
166175
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources})
167176
list(APPEND _py_trgts ${python_module_name})
168177

178+
set(python_module_name _tensor_linalg_impl)
179+
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
180+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
181+
list(APPEND _py_trgts ${python_module_name})
182+
169183
set(_clang_prefix "")
170184
if (WIN32)
171185
set(_clang_prefix "/clang:")
@@ -179,6 +193,7 @@ set(_no_fast_math_sources
179193
list(APPEND _no_fast_math_sources
180194
${_elementwise_sources}
181195
${_reduction_sources}
196+
${_linalg_sources}
182197
)
183198

184199
foreach(_src_fn ${_no_fast_math_sources})

dpctl/tensor/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@
6060
from dpctl.tensor._device import Device
6161
from dpctl.tensor._dlpack import from_dlpack
6262
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
63-
from dpctl.tensor._linear_algebra_functions import matrix_transpose
63+
from dpctl.tensor._linear_algebra_functions import (
64+
matmul,
65+
matrix_transpose,
66+
tensordot,
67+
vecdot,
68+
)
6469
from dpctl.tensor._manipulation_functions import (
6570
broadcast_arrays,
6671
broadcast_to,
@@ -343,4 +348,7 @@
343348
"__array_namespace_info__",
344349
"reciprocal",
345350
"angle",
351+
"matmul",
352+
"tensordot",
353+
"vecdot",
346354
]

0 commit comments

Comments
 (0)