Skip to content

Commit 11ce647

Browse files
author
Clément Pinard
committed
Use packed_accessor
provides a way to index tensors inside cuda kernels
1 parent 4a86842 commit 11ce647

File tree

1 file changed

+68
-73
lines changed

1 file changed

+68
-73
lines changed

cuda/lltm_cuda_kernel.cu

Lines changed: 68 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -37,64 +37,59 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
3737

3838
template <typename scalar_t>
3939
__global__ void lltm_cuda_forward_kernel(
40-
const scalar_t* __restrict__ gates,
41-
const scalar_t* __restrict__ old_cell,
42-
scalar_t* __restrict__ new_h,
43-
scalar_t* __restrict__ new_cell,
44-
scalar_t* __restrict__ input_gate,
45-
scalar_t* __restrict__ output_gate,
46-
scalar_t* __restrict__ candidate_cell,
47-
size_t state_size) {
48-
const int column = blockIdx.x * blockDim.x + threadIdx.x;
49-
const int index = blockIdx.y * state_size + column;
50-
const int gates_row = blockIdx.y * (state_size * 3);
51-
if (column < state_size) {
52-
input_gate[index] = sigmoid(gates[gates_row + column]);
53-
output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
54-
candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
55-
new_cell[index] =
56-
old_cell[index] + candidate_cell[index] * input_gate[index];
57-
new_h[index] = tanh(new_cell[index]) * output_gate[index];
40+
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
41+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
42+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
43+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
44+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
45+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
46+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
47+
//batch index
48+
const int n = blockIdx.y;
49+
// column index
50+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
51+
if (c < gates.size(2)){
52+
input_gate[n][c] = sigmoid(gates[n][0][c]);
53+
output_gate[n][c] = sigmoid(gates[n][1][c]);
54+
candidate_cell[n][c] = elu(gates[n][2][c]);
55+
new_cell[n][c] =
56+
old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
57+
new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
5858
}
5959
}
6060

6161
template <typename scalar_t>
6262
__global__ void lltm_cuda_backward_kernel(
63-
scalar_t* __restrict__ d_old_cell,
64-
scalar_t* __restrict__ d_gates,
65-
const scalar_t* __restrict__ grad_h,
66-
const scalar_t* __restrict__ grad_cell,
67-
const scalar_t* __restrict__ new_cell,
68-
const scalar_t* __restrict__ input_gate,
69-
const scalar_t* __restrict__ output_gate,
70-
const scalar_t* __restrict__ candidate_cell,
71-
const scalar_t* __restrict__ gate_weights,
72-
size_t state_size) {
73-
const int column = blockIdx.x * blockDim.x + threadIdx.x;
74-
const int index = blockIdx.y * state_size + column;
75-
const int gates_row = blockIdx.y * (state_size * 3);
76-
if (column < state_size) {
77-
const auto d_output_gate = tanh(new_cell[index]) * grad_h[index];
78-
const auto d_tanh_new_cell = output_gate[index] * grad_h[index];
63+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
64+
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
65+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
66+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
67+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
68+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
69+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
70+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
71+
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
72+
//batch index
73+
const int n = blockIdx.y;
74+
// column index
75+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
76+
if (c < d_gates.size(2)){
77+
const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
78+
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
7979
const auto d_new_cell =
80-
d_tanh(new_cell[index]) * d_tanh_new_cell + grad_cell[index];
80+
d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];
8181

8282

83-
d_old_cell[index] = d_new_cell;
84-
const auto d_candidate_cell = input_gate[index] * d_new_cell;
85-
const auto d_input_gate = candidate_cell[index] * d_new_cell;
83+
d_old_cell[n][c] = d_new_cell;
84+
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
85+
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;
8686

87-
88-
const auto input_gate_index = gates_row + column;
89-
const auto output_gate_index = gates_row + state_size + column;
90-
const auto candidate_cell_index = gates_row + 2 * state_size + column;
91-
92-
d_gates[input_gate_index] =
93-
d_input_gate * d_sigmoid(gate_weights[input_gate_index]);
94-
d_gates[output_gate_index] =
95-
d_output_gate * d_sigmoid(gate_weights[output_gate_index]);
96-
d_gates[candidate_cell_index] =
97-
d_candidate_cell * d_elu(gate_weights[candidate_cell_index]);
87+
d_gates[n][0][c] =
88+
d_input_gate * d_sigmoid(gate_weights[n][0][c]);
89+
d_gates[n][1][c] =
90+
d_output_gate * d_sigmoid(gate_weights[n][1][c]);
91+
d_gates[n][2][c] =
92+
d_candidate_cell * d_elu(gate_weights[n][2][c]);
9893
}
9994
}
10095
} // namespace
@@ -106,11 +101,12 @@ std::vector<torch::Tensor> lltm_cuda_forward(
106101
torch::Tensor old_h,
107102
torch::Tensor old_cell) {
108103
auto X = torch::cat({old_h, input}, /*dim=*/1);
109-
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));
104+
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
110105

111106
const auto batch_size = old_cell.size(0);
112107
const auto state_size = old_cell.size(1);
113108

109+
auto gates = gate_weights.reshape({batch_size, 3, state_size});
114110
auto new_h = torch::zeros_like(old_cell);
115111
auto new_cell = torch::zeros_like(old_cell);
116112
auto input_gate = torch::zeros_like(old_cell);
@@ -122,14 +118,13 @@ std::vector<torch::Tensor> lltm_cuda_forward(
122118

123119
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
124120
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
125-
gates.data<scalar_t>(),
126-
old_cell.data<scalar_t>(),
127-
new_h.data<scalar_t>(),
128-
new_cell.data<scalar_t>(),
129-
input_gate.data<scalar_t>(),
130-
output_gate.data<scalar_t>(),
131-
candidate_cell.data<scalar_t>(),
132-
state_size);
121+
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
122+
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
123+
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
124+
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
125+
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
126+
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
127+
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
133128
}));
134129

135130
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -143,10 +138,10 @@ std::vector<torch::Tensor> lltm_cuda_backward(
143138
torch::Tensor output_gate,
144139
torch::Tensor candidate_cell,
145140
torch::Tensor X,
146-
torch::Tensor gate_weights,
141+
torch::Tensor gates,
147142
torch::Tensor weights) {
148143
auto d_old_cell = torch::zeros_like(new_cell);
149-
auto d_gates = torch::zeros_like(gate_weights);
144+
auto d_gates = torch::zeros_like(gates);
150145

151146
const auto batch_size = new_cell.size(0);
152147
const auto state_size = new_cell.size(1);
@@ -156,22 +151,22 @@ std::vector<torch::Tensor> lltm_cuda_backward(
156151

157152
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
158153
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
159-
d_old_cell.data<scalar_t>(),
160-
d_gates.data<scalar_t>(),
161-
grad_h.data<scalar_t>(),
162-
grad_cell.data<scalar_t>(),
163-
new_cell.data<scalar_t>(),
164-
input_gate.data<scalar_t>(),
165-
output_gate.data<scalar_t>(),
166-
candidate_cell.data<scalar_t>(),
167-
gate_weights.data<scalar_t>(),
168-
state_size);
154+
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
155+
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
156+
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
157+
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
158+
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
159+
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
160+
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
161+
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
162+
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
169163
}));
170164

171-
auto d_weights = d_gates.t().mm(X);
172-
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
165+
auto d_gate_weights = d_gates.flatten(1, 2);
166+
auto d_weights = d_gate_weights.t().mm(X);
167+
auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);
173168

174-
auto d_X = d_gates.mm(weights);
169+
auto d_X = d_gate_weights.mm(weights);
175170
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
176171
auto d_input = d_X.slice(/*dim=*/1, state_size);
177172

0 commit comments

Comments
 (0)