Skip to content

Use PackedTensorAccessor32 in cpp extension tutorial #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 1, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 45 additions & 44 deletions advanced_source/cpp_extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,8 @@ without having to convert to a single pointer:
Accessor objects have a relatively high level interface, with ``.size()`` and
``.stride()`` methods and multi-dimensional indexing. The ``.accessor<>``
interface is designed to access data efficiently on cpu tensor. The equivalent
for cuda tensors is the ``packed_accessor<>``, which produces a Packed Accessor.
for cuda tensors are ``packed_accessor64<>`` and ``packed_accessor32<>``, which
produce Packed Accessors with either 64-bit or 32-bit integer indexing.

The fundamental difference with Accessor is that a Packed Accessor copies size
and stride data inside of its structure instead of pointing to it. It allows us
Expand All @@ -957,34 +958,34 @@ We can design a function that takes Packed Accessors instead of pointers.
.. code-block:: cpp

__global__ void lltm_cuda_forward_kernel(
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell)
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell)

Let's decompose the template used here. the first two arguments ``scalar_t`` and
``2`` are the same as regular Accessor. The argument
``torch::RestrictPtrTraits`` indicates that the ``__restrict__`` keyword must be
used. Finally, the argument ``size_t`` indicates that sizes and strides must be
stored in a ``size_t`` integer. This is important as by default ``int64_t`` is
used and can make the kernel slower.
used. Note also that we've used the ``PackedAccessor32`` variant which store the
sizes and strides in an ``int32_t``. This is important as using the 64-bit
variant (``PackedAccessor64``) can make the kernel slower.

The function declaration becomes

.. code-block:: cpp

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
//batch index
const int n = blockIdx.y;
// column index
Expand All @@ -1000,7 +1001,7 @@ The function declaration becomes
}

The implementation is much more readable! This function is then called by
creating Packed Accessors with the ``.packed_accessor<>`` method within the
creating Packed Accessors with the ``.packed_accessor32<>`` method within the
host function.

.. code-block:: cpp
Expand Down Expand Up @@ -1029,13 +1030,13 @@ host function.

AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
new_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
}));

return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
Expand All @@ -1048,15 +1049,15 @@ on it:

template <typename scalar_t>
__global__ void lltm_cuda_backward_kernel(
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> d_gates,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_h,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_cell,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
//batch index
const int n = blockIdx.y;
// column index
Expand Down Expand Up @@ -1102,15 +1103,15 @@ on it:

AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
d_old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
d_gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
grad_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
grad_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
}));

auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});
Expand Down