diff --git a/advanced_source/cpp_extension.rst b/advanced_source/cpp_extension.rst index 6db82bc7711..e0e7a034267 100644 --- a/advanced_source/cpp_extension.rst +++ b/advanced_source/cpp_extension.rst @@ -946,7 +946,8 @@ without having to convert to a single pointer: Accessor objects have a relatively high level interface, with ``.size()`` and ``.stride()`` methods and multi-dimensional indexing. The ``.accessor<>`` interface is designed to access data efficiently on cpu tensor. The equivalent -for cuda tensors is the ``packed_accessor<>``, which produces a Packed Accessor. +for cuda tensors are ``packed_accessor64<>`` and ``packed_accessor32<>``, which +produce Packed Accessors with either 64-bit or 32-bit integer indexing. The fundamental difference with Accessor is that a Packed Accessor copies size and stride data inside of its structure instead of pointing to it. It allows us @@ -957,20 +958,20 @@ We can design a function that takes Packed Accessors instead of pointers. .. code-block:: cpp __global__ void lltm_cuda_forward_kernel( - const torch::PackedTensorAccessor gates, - const torch::PackedTensorAccessor old_cell, - torch::PackedTensorAccessor new_h, - torch::PackedTensorAccessor new_cell, - torch::PackedTensorAccessor input_gate, - torch::PackedTensorAccessor output_gate, - torch::PackedTensorAccessor candidate_cell) + const torch::PackedTensorAccessor32 gates, + const torch::PackedTensorAccessor32 old_cell, + torch::PackedTensorAccessor32 new_h, + torch::PackedTensorAccessor32 new_cell, + torch::PackedTensorAccessor32 input_gate, + torch::PackedTensorAccessor32 output_gate, + torch::PackedTensorAccessor32 candidate_cell) Let's decompose the template used here. the first two arguments ``scalar_t`` and ``2`` are the same as regular Accessor. The argument ``torch::RestrictPtrTraits`` indicates that the ``__restrict__`` keyword must be -used. Finally, the argument ``size_t`` indicates that sizes and strides must be -stored in a ``size_t`` integer. This is important as by default ``int64_t`` is -used and can make the kernel slower. +used. Note also that we've used the ``PackedAccessor32`` variant which store the +sizes and strides in an ``int32_t``. This is important as using the 64-bit +variant (``PackedAccessor64``) can make the kernel slower. The function declaration becomes @@ -978,13 +979,13 @@ The function declaration becomes template __global__ void lltm_cuda_forward_kernel( - const torch::PackedTensorAccessor gates, - const torch::PackedTensorAccessor old_cell, - torch::PackedTensorAccessor new_h, - torch::PackedTensorAccessor new_cell, - torch::PackedTensorAccessor input_gate, - torch::PackedTensorAccessor output_gate, - torch::PackedTensorAccessor candidate_cell) { + const torch::PackedTensorAccessor32 gates, + const torch::PackedTensorAccessor32 old_cell, + torch::PackedTensorAccessor32 new_h, + torch::PackedTensorAccessor32 new_cell, + torch::PackedTensorAccessor32 input_gate, + torch::PackedTensorAccessor32 output_gate, + torch::PackedTensorAccessor32 candidate_cell) { //batch index const int n = blockIdx.y; // column index @@ -1000,7 +1001,7 @@ The function declaration becomes } The implementation is much more readable! This function is then called by -creating Packed Accessors with the ``.packed_accessor<>`` method within the +creating Packed Accessors with the ``.packed_accessor32<>`` method within the host function. .. code-block:: cpp @@ -1029,13 +1030,13 @@ host function. AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { lltm_cuda_forward_kernel<<>>( - gates.packed_accessor(), - old_cell.packed_accessor(), - new_h.packed_accessor(), - new_cell.packed_accessor(), - input_gate.packed_accessor(), - output_gate.packed_accessor(), - candidate_cell.packed_accessor()); + gates.packed_accessor32(), + old_cell.packed_accessor32(), + new_h.packed_accessor32(), + new_cell.packed_accessor32(), + input_gate.packed_accessor32(), + output_gate.packed_accessor32(), + candidate_cell.packed_accessor32()); })); return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; @@ -1048,15 +1049,15 @@ on it: template __global__ void lltm_cuda_backward_kernel( - torch::PackedTensorAccessor d_old_cell, - torch::PackedTensorAccessor d_gates, - const torch::PackedTensorAccessor grad_h, - const torch::PackedTensorAccessor grad_cell, - const torch::PackedTensorAccessor new_cell, - const torch::PackedTensorAccessor input_gate, - const torch::PackedTensorAccessor output_gate, - const torch::PackedTensorAccessor candidate_cell, - const torch::PackedTensorAccessor gate_weights) { + torch::PackedTensorAccessor32 d_old_cell, + torch::PackedTensorAccessor32 d_gates, + const torch::PackedTensorAccessor32 grad_h, + const torch::PackedTensorAccessor32 grad_cell, + const torch::PackedTensorAccessor32 new_cell, + const torch::PackedTensorAccessor32 input_gate, + const torch::PackedTensorAccessor32 output_gate, + const torch::PackedTensorAccessor32 candidate_cell, + const torch::PackedTensorAccessor32 gate_weights) { //batch index const int n = blockIdx.y; // column index @@ -1102,15 +1103,15 @@ on it: AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { lltm_cuda_backward_kernel<<>>( - d_old_cell.packed_accessor(), - d_gates.packed_accessor(), - grad_h.packed_accessor(), - grad_cell.packed_accessor(), - new_cell.packed_accessor(), - input_gate.packed_accessor(), - output_gate.packed_accessor(), - candidate_cell.packed_accessor(), - gates.packed_accessor()); + d_old_cell.packed_accessor32(), + d_gates.packed_accessor32(), + grad_h.packed_accessor32(), + grad_cell.packed_accessor32(), + new_cell.packed_accessor32(), + input_gate.packed_accessor32(), + output_gate.packed_accessor32(), + candidate_cell.packed_accessor32(), + gates.packed_accessor32()); })); auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});