Skip to content

Commit e0791cf

Browse files
authored
Merge pull request #644 from peterbell10/packed_accessor
Use PackedTensorAccessor32 in cpp extension tutorial
2 parents ea34cb0 + 07839ae commit e0791cf

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

advanced_source/cpp_extension.rst

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ without having to convert to a single pointer:
946946
Accessor objects have a relatively high level interface, with ``.size()`` and
947947
``.stride()`` methods and multi-dimensional indexing. The ``.accessor<>``
948948
interface is designed to access data efficiently on cpu tensor. The equivalent
949-
for cuda tensors is the ``packed_accessor<>``, which produces a Packed Accessor.
949+
for cuda tensors are ``packed_accessor64<>`` and ``packed_accessor32<>``, which
950+
produce Packed Accessors with either 64-bit or 32-bit integer indexing.
950951

951952
The fundamental difference with Accessor is that a Packed Accessor copies size
952953
and stride data inside of its structure instead of pointing to it. It allows us
@@ -957,34 +958,34 @@ We can design a function that takes Packed Accessors instead of pointers.
957958
.. code-block:: cpp
958959
959960
__global__ void lltm_cuda_forward_kernel(
960-
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
961-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
962-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
963-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
964-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
965-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
966-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell)
961+
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
962+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
963+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
964+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
965+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
966+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
967+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell)
967968
968969
Let's decompose the template used here. the first two arguments ``scalar_t`` and
969970
``2`` are the same as regular Accessor. The argument
970971
``torch::RestrictPtrTraits`` indicates that the ``__restrict__`` keyword must be
971-
used. Finally, the argument ``size_t`` indicates that sizes and strides must be
972-
stored in a ``size_t`` integer. This is important as by default ``int64_t`` is
973-
used and can make the kernel slower.
972+
used. Note also that we've used the ``PackedAccessor32`` variant which store the
973+
sizes and strides in an ``int32_t``. This is important as using the 64-bit
974+
variant (``PackedAccessor64``) can make the kernel slower.
974975

975976
The function declaration becomes
976977

977978
.. code-block:: cpp
978979
979980
template <typename scalar_t>
980981
__global__ void lltm_cuda_forward_kernel(
981-
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
982-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
983-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
984-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
985-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
986-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
987-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
982+
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
983+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
984+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
985+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
986+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
987+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
988+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
988989
//batch index
989990
const int n = blockIdx.y;
990991
// column index
@@ -1000,7 +1001,7 @@ The function declaration becomes
10001001
}
10011002
10021003
The implementation is much more readable! This function is then called by
1003-
creating Packed Accessors with the ``.packed_accessor<>`` method within the
1004+
creating Packed Accessors with the ``.packed_accessor32<>`` method within the
10041005
host function.
10051006

10061007
.. code-block:: cpp
@@ -1029,13 +1030,13 @@ host function.
10291030
10301031
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
10311032
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
1032-
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
1033-
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1034-
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1035-
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1036-
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1037-
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1038-
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
1033+
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
1034+
old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1035+
new_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1036+
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1037+
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1038+
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1039+
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
10391040
}));
10401041
10411042
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -1048,15 +1049,15 @@ on it:
10481049
10491050
template <typename scalar_t>
10501051
__global__ void lltm_cuda_backward_kernel(
1051-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
1052-
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
1053-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
1054-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
1055-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
1056-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
1057-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
1058-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
1059-
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
1052+
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
1053+
torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> d_gates,
1054+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_h,
1055+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_cell,
1056+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
1057+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
1058+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
1059+
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
1060+
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
10601061
//batch index
10611062
const int n = blockIdx.y;
10621063
// column index
@@ -1102,15 +1103,15 @@ on it:
11021103
11031104
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
11041105
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
1105-
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1106-
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
1107-
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1108-
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1109-
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1110-
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1111-
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1112-
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
1113-
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
1106+
d_old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1107+
d_gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
1108+
grad_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1109+
grad_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1110+
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1111+
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1112+
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1113+
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
1114+
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
11141115
}));
11151116
11161117
auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});

0 commit comments

Comments
 (0)