@@ -946,7 +946,8 @@ without having to convert to a single pointer:
946
946
Accessor objects have a relatively high level interface, with ``.size() `` and
947
947
``.stride() `` methods and multi-dimensional indexing. The ``.accessor<> ``
948
948
interface is designed to access data efficiently on cpu tensor. The equivalent
949
- for cuda tensors is the ``packed_accessor<> ``, which produces a Packed Accessor.
949
+ for cuda tensors are ``packed_accessor64<> `` and ``packed_accessor32<> ``, which
950
+ produce Packed Accessors with either 64-bit or 32-bit integer indexing.
950
951
951
952
The fundamental difference with Accessor is that a Packed Accessor copies size
952
953
and stride data inside of its structure instead of pointing to it. It allows us
@@ -957,34 +958,34 @@ We can design a function that takes Packed Accessors instead of pointers.
957
958
.. code-block :: cpp
958
959
959
960
__global__ void lltm_cuda_forward_kernel(
960
- const torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > gates,
961
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > old_cell,
962
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_h,
963
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_cell,
964
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > input_gate,
965
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > output_gate,
966
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > candidate_cell)
961
+ const torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> gates,
962
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> old_cell,
963
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_h,
964
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_cell,
965
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> input_gate,
966
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> output_gate,
967
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> candidate_cell)
967
968
968
969
Let's decompose the template used here. the first two arguments ``scalar_t `` and
969
970
``2 `` are the same as regular Accessor. The argument
970
971
``torch::RestrictPtrTraits `` indicates that the ``__restrict__ `` keyword must be
971
- used. Finally, the argument `` size_t `` indicates that sizes and strides must be
972
- stored in a ``size_t `` integer . This is important as by default `` int64_t `` is
973
- used and can make the kernel slower.
972
+ used. Note also that we've used the `` PackedAccessor32 `` variant which store the
973
+ sizes and strides in an ``int32_t `` . This is important as using the 64-bit
974
+ variant (`` PackedAccessor64 ``) can make the kernel slower.
974
975
975
976
The function declaration becomes
976
977
977
978
.. code-block :: cpp
978
979
979
980
template <typename scalar_t>
980
981
__global__ void lltm_cuda_forward_kernel(
981
- const torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > gates,
982
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > old_cell,
983
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_h,
984
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_cell,
985
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > input_gate,
986
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > output_gate,
987
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > candidate_cell) {
982
+ const torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> gates,
983
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> old_cell,
984
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_h,
985
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_cell,
986
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> input_gate,
987
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> output_gate,
988
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
988
989
//batch index
989
990
const int n = blockIdx.y;
990
991
// column index
@@ -1000,7 +1001,7 @@ The function declaration becomes
1000
1001
}
1001
1002
1002
1003
The implementation is much more readable! This function is then called by
1003
- creating Packed Accessors with the ``.packed_accessor <> `` method within the
1004
+ creating Packed Accessors with the ``.packed_accessor32 <> `` method within the
1004
1005
host function.
1005
1006
1006
1007
.. code-block :: cpp
@@ -1029,13 +1030,13 @@ host function.
1029
1030
1030
1031
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
1031
1032
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
1032
- gates.packed_accessor <scalar_t,3,torch::RestrictPtrTraits,size_t >(),
1033
- old_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1034
- new_h.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1035
- new_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1036
- input_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1037
- output_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1038
- candidate_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >());
1033
+ gates.packed_accessor32 <scalar_t,3,torch::RestrictPtrTraits>(),
1034
+ old_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1035
+ new_h.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1036
+ new_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1037
+ input_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1038
+ output_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1039
+ candidate_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>());
1039
1040
}));
1040
1041
1041
1042
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -1048,15 +1049,15 @@ on it:
1048
1049
1049
1050
template <typename scalar_t>
1050
1051
__global__ void lltm_cuda_backward_kernel(
1051
- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > d_old_cell,
1052
- torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > d_gates,
1053
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > grad_h,
1054
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > grad_cell,
1055
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_cell,
1056
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > input_gate,
1057
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > output_gate,
1058
- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > candidate_cell,
1059
- const torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > gate_weights) {
1052
+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
1053
+ torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> d_gates,
1054
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> grad_h,
1055
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> grad_cell,
1056
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_cell,
1057
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> input_gate,
1058
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> output_gate,
1059
+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
1060
+ const torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
1060
1061
//batch index
1061
1062
const int n = blockIdx.y;
1062
1063
// column index
@@ -1102,15 +1103,15 @@ on it:
1102
1103
1103
1104
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
1104
1105
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
1105
- d_old_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1106
- d_gates.packed_accessor <scalar_t,3,torch::RestrictPtrTraits,size_t >(),
1107
- grad_h.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1108
- grad_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1109
- new_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1110
- input_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1111
- output_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1112
- candidate_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1113
- gates.packed_accessor <scalar_t,3,torch::RestrictPtrTraits,size_t >());
1106
+ d_old_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1107
+ d_gates.packed_accessor32 <scalar_t,3,torch::RestrictPtrTraits>(),
1108
+ grad_h.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1109
+ grad_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1110
+ new_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1111
+ input_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1112
+ output_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1113
+ candidate_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1114
+ gates.packed_accessor32 <scalar_t,3,torch::RestrictPtrTraits>());
1114
1115
}));
1115
1116
1116
1117
auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});
0 commit comments