Skip to content

Commit 6f29642

Browse files
peterbell10pytorchmergebot
authored andcommitted
Remove Tensor.h includes from spdiags cpu kernel (#84500)
This file uses `Tensor::operator[]` in the middle of a `cpu_kernel` which is not allowed because it relies on the thread-local dispatcher state. Instead, we should just do the stride calculations. Pull Request resolved: #84500 Approved by: https://github.com/ezyang
1 parent 1a16b25 commit 6f29642

File tree

3 files changed

+22
-28
lines changed

3 files changed

+22
-28
lines changed

aten/src/ATen/native/cpu/SparseFactories.cpp

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,25 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
2+
#include <ATen/native/sparse/SparseFactories.h>
3+
14
#include <ATen/Dispatch.h>
2-
#include <ATen/SparseTensorImpl.h>
3-
#include <ATen/SparseTensorUtils.h>
4-
#include <ATen/TensorIndexing.h>
55
#include <ATen/TensorIterator.h>
6-
#include <ATen/core/ATen_fwd.h>
7-
#include <ATen/core/Tensor.h>
6+
#include <ATen/core/TensorBase.h>
87
#include <ATen/native/cpu/Loops.h>
9-
#include <ATen/native/sparse/SparseFactories.h>
10-
#include <c10/core/Scalar.h>
11-
#include <c10/util/ArrayRef.h>
8+
#include <c10/core/ScalarType.h>
129
#include <c10/util/Exception.h>
1310

14-
#ifndef AT_PER_OPERATOR_HEADERS
15-
#include <ATen/Functions.h>
16-
#include <ATen/NativeFunctions.h>
17-
#else
18-
#include <ATen/ops/sparse_coo_tensor.h>
19-
#endif
20-
2111
namespace at {
2212
namespace native {
23-
using namespace at::sparse;
2413

2514
namespace {
2615
void _spdiags_kernel_cpu(
2716
TensorIterator& iter,
28-
const Tensor& diagonals,
29-
Tensor& values,
30-
Tensor& indices) {
31-
auto* row_index_write_ptr = indices[0].data_ptr<int64_t>();
32-
auto* col_index_write_ptr = indices[1].data_ptr<int64_t>();
17+
const TensorBase& diagonals,
18+
TensorBase& values,
19+
TensorBase& indices) {
20+
auto* row_index_write_ptr = indices.data_ptr<int64_t>();
21+
auto* col_index_write_ptr = row_index_write_ptr + indices.stride(0);
22+
const int64_t diagonals_index_stride = diagonals.stride(0);
3323
const int64_t diagonals_read_stride = diagonals.stride(1);
3424
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
3525
at::ScalarType::BFloat16,
@@ -39,7 +29,9 @@ void _spdiags_kernel_cpu(
3929
diagonals.scalar_type(),
4030
"spdiags_cpu",
4131
[&] {
42-
auto* values_write_ptr = values.data_ptr<scalar_t>();
32+
auto* const values_write_ptr = values.data_ptr<scalar_t>();
33+
const auto* const diagonals_ptr = diagonals.data_ptr<scalar_t>();
34+
4335
cpu_kernel(
4436
iter,
4537
[&](int64_t diag_index,
@@ -52,8 +44,9 @@ void _spdiags_kernel_cpu(
5244
auto* vals_start = values_write_ptr + out_offset;
5345
const int64_t first_col = std::max<int64_t>(diag_offset, 0);
5446
const int64_t first_row = first_col - diag_offset;
55-
auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() +
56-
first_col * diagonals_read_stride;
47+
auto* data_read = (diagonals_ptr +
48+
diagonals_index_stride * diag_index +
49+
first_col * diagonals_read_stride);
5750
for (int64_t i = 0; i < n_out; ++i) {
5851
rows_start[i] = first_row + i;
5952
cols_start[i] = first_col + i;

aten/src/ATen/native/sparse/SparseFactories.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/Dispatch.h>
2+
#include <ATen/TensorIterator.h>
23
#include <ATen/native/sparse/SparseFactories.h>
34

45
#ifndef AT_PER_OPERATOR_HEADERS

aten/src/ATen/native/sparse/SparseFactories.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#pragma once
2-
#include <ATen/TensorIterator.h>
3-
#include <ATen/core/ATen_fwd.h>
4-
#include <ATen/core/Tensor.h>
52
#include <ATen/native/DispatchStub.h>
63

74
namespace at {
5+
struct TensorIterator;
6+
class TensorBase;
7+
88
namespace native {
99

1010
using spdiags_kernel_fn_t =
11-
void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&);
11+
void (*)(TensorIterator&, const TensorBase&, TensorBase&, TensorBase&);
1212

1313
DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub);
1414
} // namespace native

0 commit comments

Comments
 (0)