1
+ #define TORCH_ASSERT_NO_OPERATORS
2
+ #include < ATen/native/sparse/SparseFactories.h>
3
+
1
4
#include < ATen/Dispatch.h>
2
- #include < ATen/SparseTensorImpl.h>
3
- #include < ATen/SparseTensorUtils.h>
4
- #include < ATen/TensorIndexing.h>
5
5
#include < ATen/TensorIterator.h>
6
- #include < ATen/core/ATen_fwd.h>
7
- #include < ATen/core/Tensor.h>
6
+ #include < ATen/core/TensorBase.h>
8
7
#include < ATen/native/cpu/Loops.h>
9
- #include < ATen/native/sparse/SparseFactories.h>
10
- #include < c10/core/Scalar.h>
11
- #include < c10/util/ArrayRef.h>
8
+ #include < c10/core/ScalarType.h>
12
9
#include < c10/util/Exception.h>
13
10
14
- #ifndef AT_PER_OPERATOR_HEADERS
15
- #include < ATen/Functions.h>
16
- #include < ATen/NativeFunctions.h>
17
- #else
18
- #include < ATen/ops/sparse_coo_tensor.h>
19
- #endif
20
-
21
11
namespace at {
22
12
namespace native {
23
- using namespace at ::sparse;
24
13
25
14
namespace {
26
15
void _spdiags_kernel_cpu (
27
16
TensorIterator& iter,
28
- const Tensor& diagonals,
29
- Tensor& values,
30
- Tensor& indices) {
31
- auto * row_index_write_ptr = indices[0 ].data_ptr <int64_t >();
32
- auto * col_index_write_ptr = indices[1 ].data_ptr <int64_t >();
17
+ const TensorBase& diagonals,
18
+ TensorBase& values,
19
+ TensorBase& indices) {
20
+ auto * row_index_write_ptr = indices.data_ptr <int64_t >();
21
+ auto * col_index_write_ptr = row_index_write_ptr + indices.stride (0 );
22
+ const int64_t diagonals_index_stride = diagonals.stride (0 );
33
23
const int64_t diagonals_read_stride = diagonals.stride (1 );
34
24
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4 (
35
25
at::ScalarType::BFloat16,
@@ -39,7 +29,9 @@ void _spdiags_kernel_cpu(
39
29
diagonals.scalar_type (),
40
30
" spdiags_cpu" ,
41
31
[&] {
42
- auto * values_write_ptr = values.data_ptr <scalar_t >();
32
+ auto * const values_write_ptr = values.data_ptr <scalar_t >();
33
+ const auto * const diagonals_ptr = diagonals.data_ptr <scalar_t >();
34
+
43
35
cpu_kernel (
44
36
iter,
45
37
[&](int64_t diag_index,
@@ -52,8 +44,9 @@ void _spdiags_kernel_cpu(
52
44
auto * vals_start = values_write_ptr + out_offset;
53
45
const int64_t first_col = std::max<int64_t >(diag_offset, 0 );
54
46
const int64_t first_row = first_col - diag_offset;
55
- auto * data_read = diagonals[diag_index].data_ptr <scalar_t >() +
56
- first_col * diagonals_read_stride;
47
+ auto * data_read = (diagonals_ptr +
48
+ diagonals_index_stride * diag_index +
49
+ first_col * diagonals_read_stride);
57
50
for (int64_t i = 0 ; i < n_out; ++i) {
58
51
rows_start[i] = first_row + i;
59
52
cols_start[i] = first_col + i;
0 commit comments