@@ -55,10 +55,10 @@ __global__ void __launch_bounds__(splitD, 2)
55
55
const int stride_s0 = src0_nb1 / sizeof (float );
56
56
const int stride_A = src3_nb1 / sizeof (float );
57
57
#pragma unroll
58
- for (int j = 0 ; j < N; ++j )
58
+ for (size_t n = 0 ; n < N; ++n )
59
59
{
60
- regA[j ] = A_block[threadIdx .x * stride_A + j ];
61
- regs0[j ] = s0_block[threadIdx .x * stride_s0 + j ];
60
+ regA[n ] = A_block[threadIdx .x * stride_A + n ];
61
+ regs0[n ] = s0_block[threadIdx .x * stride_s0 + n ];
62
62
}
63
63
#endif
64
64
@@ -80,11 +80,11 @@ __global__ void __launch_bounds__(splitD, 2)
80
80
81
81
float sumf = 0 .0f ;
82
82
#pragma unroll
83
- for (int j = 0 ; j < N; j ++)
83
+ for (size_t n = 0 ; n < N; n ++)
84
84
{
85
- float state = regs0[j ] * expf (dt_soft_plus * regA[j ]) + smemB[j ] * x_dt;
86
- sumf += state * smemC[j ];
87
- regs0[j ] = state;
85
+ float state = regs0[n ] * expf (dt_soft_plus * regA[n ]) + smemB[n ] * x_dt;
86
+ sumf += state * smemC[n ];
87
+ regs0[n ] = state;
88
88
}
89
89
y_block[i * stride_y + threadIdx .x ] = sumf;
90
90
}
@@ -94,9 +94,9 @@ __global__ void __launch_bounds__(splitD, 2)
94
94
#else
95
95
const int stride_s = stride_s0;
96
96
#pragma unroll
97
- for (int j = 0 ; j < N; ++j )
97
+ for (size_t n = 0 ; n < N; ++n )
98
98
{
99
- s_block[threadIdx .x * stride_s + j ] = regs0[j ];
99
+ s_block[threadIdx .x * stride_s + n ] = regs0[n ];
100
100
}
101
101
#endif
102
102
}
@@ -140,10 +140,10 @@ __global__ void __launch_bounds__(splitD, 2)
140
140
const int stride_s0 = src0_nb1 / sizeof (float );
141
141
const int stride_A = src3_nb1 / sizeof (float );
142
142
#pragma unroll
143
- for (int j = 0 ; j < N; ++j )
143
+ for (size_t n = 0 ; n < N; ++n )
144
144
{
145
- regA[j ] = A_block[threadIdx .x * stride_A + j ];
146
- regs0[j ] = s0_block[threadIdx .x * stride_s0 + j ];
145
+ regA[n ] = A_block[threadIdx .x * stride_A + n ];
146
+ regs0[n ] = s0_block[threadIdx .x * stride_s0 + n ];
147
147
}
148
148
#endif
149
149
@@ -163,23 +163,23 @@ __global__ void __launch_bounds__(splitD, 2)
163
163
float x_dt = x_block[threadIdx .x ] * dt_soft_plus;
164
164
float sumf = 0 .0f ;
165
165
#pragma unroll
166
- for (int j = 0 ; j < N; j ++)
166
+ for (size_t n = 0 ; n < N; n ++)
167
167
{
168
- float state = regs0[j ] * expf (dt_soft_plus * regA[j ]) + smemB[j ] * x_dt;
169
- sumf += state * smemC[j ];
170
- regs0[j ] = state;
168
+ float state = regs0[n ] * expf (dt_soft_plus * regA[n ]) + smemB[n ] * x_dt;
169
+ sumf += state * smemC[n ];
170
+ regs0[n ] = state;
171
171
}
172
172
y_block[threadIdx .x ] = sumf;
173
173
}
174
174
175
175
#ifdef USE_CUB
176
176
BlockStoreS (block_store_tempS).Store (s_block, regs0);
177
177
#else
178
- const int stride_s = s0 ;
178
+ const int stride_s = stride_s0 ;
179
179
#pragma unroll
180
- for (int j = 0 ; j < N; ++j )
180
+ for (size_t n = 0 ; n < N; ++n )
181
181
{
182
- s_block[threadIdx .x * stride_s + j ] = regs0[j ];
182
+ s_block[threadIdx .x * stride_s + n ] = regs0[n ];
183
183
}
184
184
#endif
185
185
}
0 commit comments