@@ -37,64 +37,59 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
37
37
38
38
template <typename scalar_t >
39
39
__global__ void lltm_cuda_forward_kernel (
40
- const scalar_t * __restrict__ gates,
41
- const scalar_t * __restrict__ old_cell,
42
- scalar_t * __restrict__ new_h,
43
- scalar_t * __restrict__ new_cell,
44
- scalar_t * __restrict__ input_gate,
45
- scalar_t * __restrict__ output_gate,
46
- scalar_t * __restrict__ candidate_cell,
47
- size_t state_size) {
48
- const int column = blockIdx .x * blockDim . x + threadIdx . x ;
49
- const int index = blockIdx . y * state_size + column;
50
- const int gates_row = blockIdx .y * (state_size * 3 ) ;
51
- if (column < state_size) {
52
- input_gate[index] = sigmoid (gates[gates_row + column ]);
53
- output_gate[index] = sigmoid (gates[gates_row + state_size + column ]);
54
- candidate_cell[index] = elu (gates[gates_row + 2 * state_size + column ]);
55
- new_cell[index ] =
56
- old_cell[index] + candidate_cell[index] * input_gate[index ];
57
- new_h[index] = tanh (new_cell[index] ) * output_gate[index ];
40
+ const torch::PackedTensorAccessor< scalar_t , 3 ,torch::RestrictPtrTraits, size_t > gates,
41
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > old_cell,
42
+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > new_h,
43
+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > new_cell,
44
+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > input_gate,
45
+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > output_gate,
46
+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > candidate_cell) {
47
+ // batch index
48
+ const int n = blockIdx .y ;
49
+ // column index
50
+ const int c = blockIdx .x * blockDim . x + threadIdx . x ;
51
+ if (c < gates. size ( 2 )) {
52
+ input_gate[n][c] = sigmoid (gates[n][ 0 ][c ]);
53
+ output_gate[n][c] = sigmoid (gates[n][ 1 ][c ]);
54
+ candidate_cell[n][c] = elu (gates[n][ 2 ][c ]);
55
+ new_cell[n][c ] =
56
+ old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c ];
57
+ new_h[n][c] = tanh (new_cell[n][c] ) * output_gate[n][c ];
58
58
}
59
59
}
60
60
61
61
template <typename scalar_t >
62
62
__global__ void lltm_cuda_backward_kernel (
63
- scalar_t * __restrict__ d_old_cell,
64
- scalar_t * __restrict__ d_gates,
65
- const scalar_t * __restrict__ grad_h,
66
- const scalar_t * __restrict__ grad_cell,
67
- const scalar_t * __restrict__ new_cell,
68
- const scalar_t * __restrict__ input_gate,
69
- const scalar_t * __restrict__ output_gate,
70
- const scalar_t * __restrict__ candidate_cell,
71
- const scalar_t * __restrict__ gate_weights,
72
- size_t state_size) {
73
- const int column = blockIdx .x * blockDim . x + threadIdx . x ;
74
- const int index = blockIdx . y * state_size + column;
75
- const int gates_row = blockIdx .y * (state_size * 3 ) ;
76
- if (column < state_size) {
77
- const auto d_output_gate = tanh (new_cell[index] ) * grad_h[index ];
78
- const auto d_tanh_new_cell = output_gate[index] * grad_h[index ];
63
+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > d_old_cell,
64
+ torch::PackedTensorAccessor< scalar_t , 3 ,torch::RestrictPtrTraits, size_t > d_gates,
65
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > grad_h,
66
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > grad_cell,
67
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > new_cell,
68
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > input_gate,
69
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > output_gate,
70
+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > candidate_cell,
71
+ const torch::PackedTensorAccessor< scalar_t , 3 ,torch::RestrictPtrTraits, size_t > gate_weights) {
72
+ // batch index
73
+ const int n = blockIdx .y ;
74
+ // column index
75
+ const int c = blockIdx .x * blockDim . x + threadIdx . x ;
76
+ if (c < d_gates. size ( 2 )) {
77
+ const auto d_output_gate = tanh (new_cell[n][c] ) * grad_h[n][c ];
78
+ const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c ];
79
79
const auto d_new_cell =
80
- d_tanh (new_cell[index] ) * d_tanh_new_cell + grad_cell[index ];
80
+ d_tanh (new_cell[n][c] ) * d_tanh_new_cell + grad_cell[n][c ];
81
81
82
82
83
- d_old_cell[index ] = d_new_cell;
84
- const auto d_candidate_cell = input_gate[index ] * d_new_cell;
85
- const auto d_input_gate = candidate_cell[index ] * d_new_cell;
83
+ d_old_cell[n][c ] = d_new_cell;
84
+ const auto d_candidate_cell = input_gate[n][c ] * d_new_cell;
85
+ const auto d_input_gate = candidate_cell[n][c ] * d_new_cell;
86
86
87
-
88
- const auto input_gate_index = gates_row + column;
89
- const auto output_gate_index = gates_row + state_size + column;
90
- const auto candidate_cell_index = gates_row + 2 * state_size + column;
91
-
92
- d_gates[input_gate_index] =
93
- d_input_gate * d_sigmoid (gate_weights[input_gate_index]);
94
- d_gates[output_gate_index] =
95
- d_output_gate * d_sigmoid (gate_weights[output_gate_index]);
96
- d_gates[candidate_cell_index] =
97
- d_candidate_cell * d_elu (gate_weights[candidate_cell_index]);
87
+ d_gates[n][0 ][c] =
88
+ d_input_gate * d_sigmoid (gate_weights[n][0 ][c]);
89
+ d_gates[n][1 ][c] =
90
+ d_output_gate * d_sigmoid (gate_weights[n][1 ][c]);
91
+ d_gates[n][2 ][c] =
92
+ d_candidate_cell * d_elu (gate_weights[n][2 ][c]);
98
93
}
99
94
}
100
95
} // namespace
@@ -106,11 +101,12 @@ std::vector<torch::Tensor> lltm_cuda_forward(
106
101
torch::Tensor old_h,
107
102
torch::Tensor old_cell) {
108
103
auto X = torch::cat ({old_h, input}, /* dim=*/ 1 );
109
- auto gates = torch::addmm (bias, X, weights.transpose (0 , 1 ));
104
+ auto gate_weights = torch::addmm (bias, X, weights.transpose (0 , 1 ));
110
105
111
106
const auto batch_size = old_cell.size (0 );
112
107
const auto state_size = old_cell.size (1 );
113
108
109
+ auto gates = gate_weights.reshape ({batch_size, 3 , state_size});
114
110
auto new_h = torch::zeros_like (old_cell);
115
111
auto new_cell = torch::zeros_like (old_cell);
116
112
auto input_gate = torch::zeros_like (old_cell);
@@ -122,14 +118,13 @@ std::vector<torch::Tensor> lltm_cuda_forward(
122
118
123
119
AT_DISPATCH_FLOATING_TYPES (gates.type (), " lltm_forward_cuda" , ([&] {
124
120
lltm_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
125
- gates.data <scalar_t >(),
126
- old_cell.data <scalar_t >(),
127
- new_h.data <scalar_t >(),
128
- new_cell.data <scalar_t >(),
129
- input_gate.data <scalar_t >(),
130
- output_gate.data <scalar_t >(),
131
- candidate_cell.data <scalar_t >(),
132
- state_size);
121
+ gates.packed_accessor <scalar_t ,3 ,torch::RestrictPtrTraits,size_t >(),
122
+ old_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
123
+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
124
+ new_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
125
+ input_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
126
+ output_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
127
+ candidate_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
133
128
}));
134
129
135
130
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -143,10 +138,10 @@ std::vector<torch::Tensor> lltm_cuda_backward(
143
138
torch::Tensor output_gate,
144
139
torch::Tensor candidate_cell,
145
140
torch::Tensor X,
146
- torch::Tensor gate_weights ,
141
+ torch::Tensor gates ,
147
142
torch::Tensor weights) {
148
143
auto d_old_cell = torch::zeros_like (new_cell);
149
- auto d_gates = torch::zeros_like (gate_weights );
144
+ auto d_gates = torch::zeros_like (gates );
150
145
151
146
const auto batch_size = new_cell.size (0 );
152
147
const auto state_size = new_cell.size (1 );
@@ -156,22 +151,22 @@ std::vector<torch::Tensor> lltm_cuda_backward(
156
151
157
152
AT_DISPATCH_FLOATING_TYPES (X.type (), " lltm_forward_cuda" , ([&] {
158
153
lltm_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
159
- d_old_cell.data <scalar_t >(),
160
- d_gates.data <scalar_t >(),
161
- grad_h.data <scalar_t >(),
162
- grad_cell.data <scalar_t >(),
163
- new_cell.data <scalar_t >(),
164
- input_gate.data <scalar_t >(),
165
- output_gate.data <scalar_t >(),
166
- candidate_cell.data <scalar_t >(),
167
- gate_weights.data <scalar_t >(),
168
- state_size);
154
+ d_old_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
155
+ d_gates.packed_accessor <scalar_t ,3 ,torch::RestrictPtrTraits,size_t >(),
156
+ grad_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
157
+ grad_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
158
+ new_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
159
+ input_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
160
+ output_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
161
+ candidate_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
162
+ gates.packed_accessor <scalar_t ,3 ,torch::RestrictPtrTraits,size_t >());
169
163
}));
170
164
171
- auto d_weights = d_gates.t ().mm (X);
172
- auto d_bias = d_gates.sum (/* dim=*/ 0 , /* keepdim=*/ true );
165
+ auto d_gate_weights = d_gates.flatten (1 , 2 );
166
+ auto d_weights = d_gate_weights.t ().mm (X);
167
+ auto d_bias = d_gate_weights.sum (/* dim=*/ 0 , /* keepdim=*/ true );
173
168
174
- auto d_X = d_gates .mm (weights);
169
+ auto d_X = d_gate_weights .mm (weights);
175
170
auto d_old_h = d_X.slice (/* dim=*/ 1 , 0 , state_size);
176
171
auto d_input = d_X.slice (/* dim=*/ 1 , state_size);
177
172
0 commit comments