Skip to content

Commit c7d4d45

Browse files
committed
fixed compilation error when when not using CUB
1 parent b2f8eea commit c7d4d45

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ __global__ void __launch_bounds__(splitD, 2)
5555
const int stride_s0 = src0_nb1 / sizeof(float);
5656
const int stride_A = src3_nb1 / sizeof(float);
5757
#pragma unroll
58-
for (int j = 0; j < N; ++j)
58+
for (size_t n = 0; n < N; ++n)
5959
{
60-
regA[j] = A_block[threadIdx.x * stride_A + j];
61-
regs0[j] = s0_block[threadIdx.x * stride_s0 + j];
60+
regA[n] = A_block[threadIdx.x * stride_A + n];
61+
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
6262
}
6363
#endif
6464

@@ -80,11 +80,11 @@ __global__ void __launch_bounds__(splitD, 2)
8080

8181
float sumf = 0.0f;
8282
#pragma unroll
83-
for (int j = 0; j < N; j++)
83+
for (size_t n = 0; n < N; n++)
8484
{
85-
float state = regs0[j] * expf(dt_soft_plus * regA[j]) + smemB[j] * x_dt;
86-
sumf += state * smemC[j];
87-
regs0[j] = state;
85+
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
86+
sumf += state * smemC[n];
87+
regs0[n] = state;
8888
}
8989
y_block[i * stride_y + threadIdx.x] = sumf;
9090
}
@@ -94,9 +94,9 @@ __global__ void __launch_bounds__(splitD, 2)
9494
#else
9595
const int stride_s = stride_s0;
9696
#pragma unroll
97-
for (int j = 0; j < N; ++j)
97+
for (size_t n = 0; n < N; ++n)
9898
{
99-
s_block[threadIdx.x * stride_s + j] = regs0[j];
99+
s_block[threadIdx.x * stride_s + n] = regs0[n];
100100
}
101101
#endif
102102
}
@@ -140,10 +140,10 @@ __global__ void __launch_bounds__(splitD, 2)
140140
const int stride_s0 = src0_nb1 / sizeof(float);
141141
const int stride_A = src3_nb1 / sizeof(float);
142142
#pragma unroll
143-
for (int j = 0; j < N; ++j)
143+
for (size_t n = 0; n < N; ++n)
144144
{
145-
regA[j] = A_block[threadIdx.x * stride_A + j];
146-
regs0[j] = s0_block[threadIdx.x * stride_s0 + j];
145+
regA[n] = A_block[threadIdx.x * stride_A + n];
146+
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
147147
}
148148
#endif
149149

@@ -163,23 +163,23 @@ __global__ void __launch_bounds__(splitD, 2)
163163
float x_dt = x_block[threadIdx.x] * dt_soft_plus;
164164
float sumf = 0.0f;
165165
#pragma unroll
166-
for (int j = 0; j < N; j++)
166+
for (size_t n = 0; n < N; n++)
167167
{
168-
float state = regs0[j] * expf(dt_soft_plus * regA[j]) + smemB[j] * x_dt;
169-
sumf += state * smemC[j];
170-
regs0[j] = state;
168+
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
169+
sumf += state * smemC[n];
170+
regs0[n] = state;
171171
}
172172
y_block[threadIdx.x] = sumf;
173173
}
174174

175175
#ifdef USE_CUB
176176
BlockStoreS(block_store_tempS).Store(s_block, regs0);
177177
#else
178-
const int stride_s = s0;
178+
const int stride_s = stride_s0;
179179
#pragma unroll
180-
for (int j = 0; j < N; ++j)
180+
for (size_t n = 0; n < N; ++n)
181181
{
182-
s_block[threadIdx.x * stride_s + j] = regs0[j];
182+
s_block[threadIdx.x * stride_s + n] = regs0[n];
183183
}
184184
#endif
185185
}

0 commit comments

Comments
 (0)